diff --git a/app/MindWork AI Studio/Pages/About.razor b/app/MindWork AI Studio/Pages/About.razor index cfc70a93..046704fd 100644 --- a/app/MindWork AI Studio/Pages/About.razor +++ b/app/MindWork AI Studio/Pages/About.razor @@ -50,6 +50,7 @@ + diff --git a/app/MindWork AI Studio/Program.cs b/app/MindWork AI Studio/Program.cs index b55aa32d..bddcae3e 100644 --- a/app/MindWork AI Studio/Program.cs +++ b/app/MindWork AI Studio/Program.cs @@ -23,16 +23,7 @@ internal sealed class Program { if(args.Length == 0) { - Console.WriteLine("Please provide the port of the runtime API."); - return; - } - - var rustApiPort = args[0]; - using var rust = new RustService(rustApiPort); - var appPort = await rust.GetAppPort(); - if(appPort == 0) - { - Console.WriteLine("Failed to get the app port from Rust."); + Console.WriteLine("Error: Please provide the port of the runtime API."); return; } @@ -40,7 +31,7 @@ internal sealed class Program var secretPasswordEncoded = Environment.GetEnvironmentVariable("AI_STUDIO_SECRET_PASSWORD"); if(string.IsNullOrWhiteSpace(secretPasswordEncoded)) { - Console.WriteLine("The AI_STUDIO_SECRET_PASSWORD environment variable is not set."); + Console.WriteLine("Error: The AI_STUDIO_SECRET_PASSWORD environment variable is not set."); return; } @@ -48,12 +39,28 @@ internal sealed class Program var secretKeySaltEncoded = Environment.GetEnvironmentVariable("AI_STUDIO_SECRET_KEY_SALT"); if(string.IsNullOrWhiteSpace(secretKeySaltEncoded)) { - Console.WriteLine("The AI_STUDIO_SECRET_KEY_SALT environment variable is not set."); + Console.WriteLine("Error: The AI_STUDIO_SECRET_KEY_SALT environment variable is not set."); return; } var secretKeySalt = Convert.FromBase64String(secretKeySaltEncoded); + var certificateFingerprint = Environment.GetEnvironmentVariable("AI_STUDIO_CERTIFICATE_FINGERPRINT"); + if(string.IsNullOrWhiteSpace(certificateFingerprint)) + { + Console.WriteLine("Error: The AI_STUDIO_CERTIFICATE_FINGERPRINT environment variable is not set."); + return; + } + + var rustApiPort = args[0]; + using var rust = new RustService(rustApiPort, certificateFingerprint); + var appPort = await rust.GetAppPort(); + if(appPort == 0) + { + Console.WriteLine("Error: Failed to get the app port from Rust."); + return; + } + var builder = WebApplication.CreateBuilder(); builder.WebHost.ConfigureKestrel(kestrelServerOptions => @@ -62,10 +69,6 @@ internal sealed class Program { listenOptions.Protocols = HttpProtocols.Http1AndHttp2AndHttp3; }); - - kestrelServerOptions.ConfigureHttpsDefaults(adapterOptions => - { - }); }); builder.Logging.ClearProviders(); diff --git a/app/MindWork AI Studio/Tools/RustService.cs b/app/MindWork AI Studio/Tools/RustService.cs index 1bf7745e..f382575e 100644 --- a/app/MindWork AI Studio/Tools/RustService.cs +++ b/app/MindWork AI Studio/Tools/RustService.cs @@ -1,3 +1,4 @@ +using System.Security.Cryptography; using System.Text.Json; using AIStudio.Provider; @@ -10,12 +11,9 @@ namespace AIStudio.Tools; /// /// Calling Rust functions. /// -public sealed class RustService(string apiPort) : IDisposable +public sealed class RustService : IDisposable { - private readonly HttpClient http = new() - { - BaseAddress = new Uri($"http://127.0.0.1:{apiPort}"), - }; + private readonly HttpClient http; private readonly JsonSerializerOptions jsonRustSerializerOptions = new() { @@ -24,6 +22,33 @@ public sealed class RustService(string apiPort) : IDisposable private ILogger? logger; private Encryption? encryptor; + + private readonly string apiPort; + private readonly string certificateFingerprint; + + public RustService(string apiPort, string certificateFingerprint) + { + this.apiPort = apiPort; + this.certificateFingerprint = certificateFingerprint; + var certificateValidationHandler = new HttpClientHandler + { + ServerCertificateCustomValidationCallback = (_, certificate, _, _) => + { + if(certificate is null) + return false; + + var currentCertificateFingerprint = certificate.GetCertHashString(HashAlgorithmName.SHA256); + return currentCertificateFingerprint == certificateFingerprint; + }, + }; + + this.http = new HttpClient(certificateValidationHandler) + { + BaseAddress = new Uri($"https://127.0.0.1:{apiPort}"), + DefaultRequestVersion = Version.Parse("2.0"), + DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher, + }; + } public void SetLogger(ILogger logService) { @@ -48,7 +73,7 @@ public sealed class RustService(string apiPort) : IDisposable const int MAX_TRIES = 160; var tris = 0; var wait4Try = TimeSpan.FromMilliseconds(250); - var url = new Uri($"http://127.0.0.1:{apiPort}/system/dotnet/port"); + var url = new Uri($"https://127.0.0.1:{this.apiPort}/system/dotnet/port"); while (tris++ < MAX_TRIES) { // @@ -57,19 +82,47 @@ public sealed class RustService(string apiPort) : IDisposable // instance, we would always get the same result (403 forbidden), // without even trying to connect to the Rust server. // - using var initialHttp = new HttpClient(); - var response = await initialHttp.GetAsync(url); - if (!response.IsSuccessStatusCode) + + using var initialHttp = new HttpClient(new HttpClientHandler { - Console.WriteLine($"Try {tris}/{MAX_TRIES} to get the app port from Rust runtime"); - await Task.Delay(wait4Try); - continue; - } + // + // Note III: We have to create also a new HttpClientHandler instance + // for each try to avoid .NET is caching the result. This is necessary + // because it gets disposed when the HttpClient instance gets disposed. + // + ServerCertificateCustomValidationCallback = (_, certificate, _, _) => + { + if(certificate is null) + return false; - var appPortContent = await response.Content.ReadAsStringAsync(); - var appPort = int.Parse(appPortContent); - Console.WriteLine($"Received app port from Rust runtime: '{appPort}'"); - return appPort; + var currentCertificateFingerprint = certificate.GetCertHashString(HashAlgorithmName.SHA256); + return currentCertificateFingerprint == this.certificateFingerprint; + } + }); + initialHttp.DefaultRequestVersion = Version.Parse("2.0"); + initialHttp.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher; + + try + { + var response = await initialHttp.GetAsync(url); + if (!response.IsSuccessStatusCode) + { + Console.WriteLine($"Try {tris}/{MAX_TRIES} to get the app port from Rust runtime"); + await Task.Delay(wait4Try); + continue; + } + + var appPortContent = await response.Content.ReadAsStringAsync(); + var appPort = int.Parse(appPortContent); + Console.WriteLine($"Received app port from Rust runtime: '{appPort}'"); + return appPort; + } + catch (Exception e) + { + Console.WriteLine($"Error: Was not able to get the app port from Rust runtime: '{e.Message}'"); + Console.WriteLine(e.InnerException); + throw; + } } Console.WriteLine("Failed to receive the app port from Rust runtime."); @@ -80,10 +133,18 @@ public sealed class RustService(string apiPort) : IDisposable { const string URL = "/system/dotnet/ready"; this.logger!.LogInformation("Notifying Rust runtime that the app is ready."); - var response = await this.http.GetAsync(URL); - if (!response.IsSuccessStatusCode) + try { - this.logger!.LogError($"Failed to notify Rust runtime that the app is ready: '{response.StatusCode}'"); + var response = await this.http.GetAsync(URL); + if (!response.IsSuccessStatusCode) + { + this.logger!.LogError($"Failed to notify Rust runtime that the app is ready: '{response.StatusCode}'"); + } + } + catch (Exception e) + { + this.logger!.LogError(e, "Failed to notify the Rust runtime that the app is ready."); + throw; } } diff --git a/runtime/Cargo.lock b/runtime/Cargo.lock index 98b91160..f84c23dd 100644 --- a/runtime/Cargo.lock +++ b/runtime/Cargo.lock @@ -964,9 +964,9 @@ dependencies = [ [[package]] name = "flexi_logger" -version = "0.28.5" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cca927478b3747ba47f98af6ba0ac0daea4f12d12f55e9104071b3dc00276310" +checksum = "a250587a211932896a131f214a4f64c047b826ce072d2018764e5ff5141df8fa" dependencies = [ "chrono", "glob", @@ -2145,6 +2145,7 @@ dependencies = [ "pbkdf2", "rand 0.8.5", "rand_chacha 0.3.1", + "rcgen", "reqwest 0.12.4", "rocket", "serde", @@ -2709,6 +2710,16 @@ dependencies = [ "syn 2.0.72", ] +[[package]] +name = "pem" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae" +dependencies = [ + "base64 0.22.1", + "serde", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -3098,6 +3109,19 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" +[[package]] +name = "rcgen" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "yasna", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -3299,6 +3323,21 @@ dependencies = [ "windows 0.37.0", ] +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.15", + "libc", + "spin", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rocket" version = "0.5.1" @@ -3372,12 +3411,15 @@ dependencies = [ "percent-encoding", "pin-project-lite", "ref-cast", + "rustls", + "rustls-pemfile 1.0.4", "serde", "smallvec", "stable-pattern", "state 0.6.0", "time", "tokio", + "tokio-rustls", "uncased", ] @@ -3409,6 +3451,18 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustls" +version = "0.21.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" +dependencies = [ + "log", + "ring", + "rustls-webpki", + "sct", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -3434,6 +3488,16 @@ version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.17" @@ -3476,6 +3540,16 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -4352,6 +4426,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.15" @@ -4596,6 +4680,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.5.2" @@ -5400,6 +5490,15 @@ dependencies = [ "is-terminal", ] +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "zip" version = "0.6.6" diff --git a/runtime/Cargo.toml b/runtime/Cargo.toml index 4ec4e97c..692fff4f 100644 --- a/runtime/Cargo.toml +++ b/runtime/Cargo.toml @@ -16,10 +16,10 @@ serde_json = "1.0" keyring = { version = "3.2", features = ["apple-native", "windows-native", "sync-secret-service"] } arboard = "3.4.0" tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "macros"] } -flexi_logger = "0.28" +flexi_logger = "0.29" log = { version = "0.4", features = ["kv"] } once_cell = "1.19.0" -rocket = { version = "0.5", features = ["json"] } +rocket = { version = "0.5", features = ["json", "tls"] } rand = "0.8" rand_chacha = "0.3.1" base64 = "0.22.1" @@ -29,6 +29,7 @@ cbc = "0.1.2" pbkdf2 = "0.12.2" hmac = "0.12.1" sha2 = "0.10.8" +rcgen = { version = "0.13.1", features = ["pem"] } [target.'cfg(target_os = "linux")'.dependencies] # See issue https://github.com/tauri-apps/tauri/issues/4470 diff --git a/runtime/src/main.rs b/runtime/src/main.rs index 19936178..cf280b22 100644 --- a/runtime/src/main.rs +++ b/runtime/src/main.rs @@ -28,13 +28,14 @@ use log::{debug, error, info, kv, warn}; use log::kv::{Key, Value, VisitSource}; use pbkdf2::pbkdf2; use rand::{RngCore, SeedableRng}; +use rcgen::generate_simple_self_signed; use rocket::figment::Figment; use rocket::{data, get, post, routes, Data, Request}; -use rocket::config::Shutdown; +use rocket::config::{Shutdown}; use rocket::data::{Outcome, ToByteUnit}; use rocket::http::Status; use rocket::serde::json::Json; -use sha2::Sha512; +use sha2::{Sha256, Sha512, Digest}; use tauri::updater::UpdateResponse; use tokio::io::AsyncReadExt; @@ -153,6 +154,20 @@ async fn main() { info!("Running in production mode."); } + info!("Try to generate a TLS certificate for the runtime API server..."); + + let subject_alt_names = vec!["localhost".to_string()]; + let certificate_data = generate_simple_self_signed(subject_alt_names).unwrap(); + let certificate_binary_data = certificate_data.cert.der().to_vec(); + let certificate_fingerprint = Sha256::digest(certificate_binary_data).to_vec(); + let certificate_fingerprint = certificate_fingerprint.iter().fold(String::new(), |mut result, byte| { + result.push_str(&format!("{:02x}", byte)); + result + }); + let certificate_fingerprint = certificate_fingerprint.to_uppercase(); + info!("Certificate fingerprint: '{certificate_fingerprint}'."); + info!("Done generating certificate for the runtime API server."); + let api_port = *API_SERVER_PORT; info!("Try to start the API server on 'http://localhost:{api_port}'..."); @@ -178,6 +193,10 @@ async fn main() { // No colors and emojis in the log output: .merge(("cli_colors", false)) + // Read the TLS certificate and key from the generated certificate data in-memory: + .merge(("tls.certs", certificate_data.cert.pem().as_bytes())) + .merge(("tls.key", certificate_data.key_pair.serialize_pem().as_bytes())) + // Set the shutdown configuration: .merge(("shutdown", Shutdown { @@ -234,6 +253,7 @@ async fn main() { .envs(HashMap::from_iter([ (String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password), (String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt), + (String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint), ])) .spawn() .expect("Failed to spawn .NET server process.") @@ -251,6 +271,7 @@ async fn main() { .envs(HashMap::from_iter([ (String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password), (String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt), + (String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint), ])) .spawn() .expect("Failed to spawn .NET server process.")