diff --git a/app/MindWork AI Studio/Pages/About.razor b/app/MindWork AI Studio/Pages/About.razor
index cfc70a93..046704fd 100644
--- a/app/MindWork AI Studio/Pages/About.razor
+++ b/app/MindWork AI Studio/Pages/About.razor
@@ -50,6 +50,7 @@
+
diff --git a/app/MindWork AI Studio/Program.cs b/app/MindWork AI Studio/Program.cs
index b55aa32d..bddcae3e 100644
--- a/app/MindWork AI Studio/Program.cs
+++ b/app/MindWork AI Studio/Program.cs
@@ -23,16 +23,7 @@ internal sealed class Program
{
if(args.Length == 0)
{
- Console.WriteLine("Please provide the port of the runtime API.");
- return;
- }
-
- var rustApiPort = args[0];
- using var rust = new RustService(rustApiPort);
- var appPort = await rust.GetAppPort();
- if(appPort == 0)
- {
- Console.WriteLine("Failed to get the app port from Rust.");
+ Console.WriteLine("Error: Please provide the port of the runtime API.");
return;
}
@@ -40,7 +31,7 @@ internal sealed class Program
var secretPasswordEncoded = Environment.GetEnvironmentVariable("AI_STUDIO_SECRET_PASSWORD");
if(string.IsNullOrWhiteSpace(secretPasswordEncoded))
{
- Console.WriteLine("The AI_STUDIO_SECRET_PASSWORD environment variable is not set.");
+ Console.WriteLine("Error: The AI_STUDIO_SECRET_PASSWORD environment variable is not set.");
return;
}
@@ -48,12 +39,28 @@ internal sealed class Program
var secretKeySaltEncoded = Environment.GetEnvironmentVariable("AI_STUDIO_SECRET_KEY_SALT");
if(string.IsNullOrWhiteSpace(secretKeySaltEncoded))
{
- Console.WriteLine("The AI_STUDIO_SECRET_KEY_SALT environment variable is not set.");
+ Console.WriteLine("Error: The AI_STUDIO_SECRET_KEY_SALT environment variable is not set.");
return;
}
var secretKeySalt = Convert.FromBase64String(secretKeySaltEncoded);
+ var certificateFingerprint = Environment.GetEnvironmentVariable("AI_STUDIO_CERTIFICATE_FINGERPRINT");
+ if(string.IsNullOrWhiteSpace(certificateFingerprint))
+ {
+ Console.WriteLine("Error: The AI_STUDIO_CERTIFICATE_FINGERPRINT environment variable is not set.");
+ return;
+ }
+
+ var rustApiPort = args[0];
+ using var rust = new RustService(rustApiPort, certificateFingerprint);
+ var appPort = await rust.GetAppPort();
+ if(appPort == 0)
+ {
+ Console.WriteLine("Error: Failed to get the app port from Rust.");
+ return;
+ }
+
var builder = WebApplication.CreateBuilder();
builder.WebHost.ConfigureKestrel(kestrelServerOptions =>
@@ -62,10 +69,6 @@ internal sealed class Program
{
listenOptions.Protocols = HttpProtocols.Http1AndHttp2AndHttp3;
});
-
- kestrelServerOptions.ConfigureHttpsDefaults(adapterOptions =>
- {
- });
});
builder.Logging.ClearProviders();
diff --git a/app/MindWork AI Studio/Tools/RustService.cs b/app/MindWork AI Studio/Tools/RustService.cs
index 1bf7745e..f382575e 100644
--- a/app/MindWork AI Studio/Tools/RustService.cs
+++ b/app/MindWork AI Studio/Tools/RustService.cs
@@ -1,3 +1,4 @@
+using System.Security.Cryptography;
using System.Text.Json;
using AIStudio.Provider;
@@ -10,12 +11,9 @@ namespace AIStudio.Tools;
///
/// Calling Rust functions.
///
-public sealed class RustService(string apiPort) : IDisposable
+public sealed class RustService : IDisposable
{
- private readonly HttpClient http = new()
- {
- BaseAddress = new Uri($"http://127.0.0.1:{apiPort}"),
- };
+ private readonly HttpClient http;
private readonly JsonSerializerOptions jsonRustSerializerOptions = new()
{
@@ -24,6 +22,33 @@ public sealed class RustService(string apiPort) : IDisposable
private ILogger? logger;
private Encryption? encryptor;
+
+ private readonly string apiPort;
+ private readonly string certificateFingerprint;
+
+ public RustService(string apiPort, string certificateFingerprint)
+ {
+ this.apiPort = apiPort;
+ this.certificateFingerprint = certificateFingerprint;
+ var certificateValidationHandler = new HttpClientHandler
+ {
+ ServerCertificateCustomValidationCallback = (_, certificate, _, _) =>
+ {
+ if(certificate is null)
+ return false;
+
+ var currentCertificateFingerprint = certificate.GetCertHashString(HashAlgorithmName.SHA256);
+ return currentCertificateFingerprint == certificateFingerprint;
+ },
+ };
+
+ this.http = new HttpClient(certificateValidationHandler)
+ {
+ BaseAddress = new Uri($"https://127.0.0.1:{apiPort}"),
+ DefaultRequestVersion = Version.Parse("2.0"),
+ DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher,
+ };
+ }
public void SetLogger(ILogger logService)
{
@@ -48,7 +73,7 @@ public sealed class RustService(string apiPort) : IDisposable
const int MAX_TRIES = 160;
var tris = 0;
var wait4Try = TimeSpan.FromMilliseconds(250);
- var url = new Uri($"http://127.0.0.1:{apiPort}/system/dotnet/port");
+ var url = new Uri($"https://127.0.0.1:{this.apiPort}/system/dotnet/port");
while (tris++ < MAX_TRIES)
{
//
@@ -57,19 +82,47 @@ public sealed class RustService(string apiPort) : IDisposable
// instance, we would always get the same result (403 forbidden),
// without even trying to connect to the Rust server.
//
- using var initialHttp = new HttpClient();
- var response = await initialHttp.GetAsync(url);
- if (!response.IsSuccessStatusCode)
+
+ using var initialHttp = new HttpClient(new HttpClientHandler
{
- Console.WriteLine($"Try {tris}/{MAX_TRIES} to get the app port from Rust runtime");
- await Task.Delay(wait4Try);
- continue;
- }
+ //
+ // Note III: We have to create also a new HttpClientHandler instance
+ // for each try to avoid .NET is caching the result. This is necessary
+ // because it gets disposed when the HttpClient instance gets disposed.
+ //
+ ServerCertificateCustomValidationCallback = (_, certificate, _, _) =>
+ {
+ if(certificate is null)
+ return false;
- var appPortContent = await response.Content.ReadAsStringAsync();
- var appPort = int.Parse(appPortContent);
- Console.WriteLine($"Received app port from Rust runtime: '{appPort}'");
- return appPort;
+ var currentCertificateFingerprint = certificate.GetCertHashString(HashAlgorithmName.SHA256);
+ return currentCertificateFingerprint == this.certificateFingerprint;
+ }
+ });
+ initialHttp.DefaultRequestVersion = Version.Parse("2.0");
+ initialHttp.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher;
+
+ try
+ {
+ var response = await initialHttp.GetAsync(url);
+ if (!response.IsSuccessStatusCode)
+ {
+ Console.WriteLine($"Try {tris}/{MAX_TRIES} to get the app port from Rust runtime");
+ await Task.Delay(wait4Try);
+ continue;
+ }
+
+ var appPortContent = await response.Content.ReadAsStringAsync();
+ var appPort = int.Parse(appPortContent);
+ Console.WriteLine($"Received app port from Rust runtime: '{appPort}'");
+ return appPort;
+ }
+ catch (Exception e)
+ {
+ Console.WriteLine($"Error: Was not able to get the app port from Rust runtime: '{e.Message}'");
+ Console.WriteLine(e.InnerException);
+ throw;
+ }
}
Console.WriteLine("Failed to receive the app port from Rust runtime.");
@@ -80,10 +133,18 @@ public sealed class RustService(string apiPort) : IDisposable
{
const string URL = "/system/dotnet/ready";
this.logger!.LogInformation("Notifying Rust runtime that the app is ready.");
- var response = await this.http.GetAsync(URL);
- if (!response.IsSuccessStatusCode)
+ try
{
- this.logger!.LogError($"Failed to notify Rust runtime that the app is ready: '{response.StatusCode}'");
+ var response = await this.http.GetAsync(URL);
+ if (!response.IsSuccessStatusCode)
+ {
+ this.logger!.LogError($"Failed to notify Rust runtime that the app is ready: '{response.StatusCode}'");
+ }
+ }
+ catch (Exception e)
+ {
+ this.logger!.LogError(e, "Failed to notify the Rust runtime that the app is ready.");
+ throw;
}
}
diff --git a/runtime/Cargo.lock b/runtime/Cargo.lock
index 98b91160..f84c23dd 100644
--- a/runtime/Cargo.lock
+++ b/runtime/Cargo.lock
@@ -964,9 +964,9 @@ dependencies = [
[[package]]
name = "flexi_logger"
-version = "0.28.5"
+version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cca927478b3747ba47f98af6ba0ac0daea4f12d12f55e9104071b3dc00276310"
+checksum = "a250587a211932896a131f214a4f64c047b826ce072d2018764e5ff5141df8fa"
dependencies = [
"chrono",
"glob",
@@ -2145,6 +2145,7 @@ dependencies = [
"pbkdf2",
"rand 0.8.5",
"rand_chacha 0.3.1",
+ "rcgen",
"reqwest 0.12.4",
"rocket",
"serde",
@@ -2709,6 +2710,16 @@ dependencies = [
"syn 2.0.72",
]
+[[package]]
+name = "pem"
+version = "3.0.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae"
+dependencies = [
+ "base64 0.22.1",
+ "serde",
+]
+
[[package]]
name = "percent-encoding"
version = "2.3.1"
@@ -3098,6 +3109,19 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9"
+[[package]]
+name = "rcgen"
+version = "0.13.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779"
+dependencies = [
+ "pem",
+ "ring",
+ "rustls-pki-types",
+ "time",
+ "yasna",
+]
+
[[package]]
name = "redox_syscall"
version = "0.4.1"
@@ -3299,6 +3323,21 @@ dependencies = [
"windows 0.37.0",
]
+[[package]]
+name = "ring"
+version = "0.17.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
+dependencies = [
+ "cc",
+ "cfg-if",
+ "getrandom 0.2.15",
+ "libc",
+ "spin",
+ "untrusted",
+ "windows-sys 0.52.0",
+]
+
[[package]]
name = "rocket"
version = "0.5.1"
@@ -3372,12 +3411,15 @@ dependencies = [
"percent-encoding",
"pin-project-lite",
"ref-cast",
+ "rustls",
+ "rustls-pemfile 1.0.4",
"serde",
"smallvec",
"stable-pattern",
"state 0.6.0",
"time",
"tokio",
+ "tokio-rustls",
"uncased",
]
@@ -3409,6 +3451,18 @@ dependencies = [
"windows-sys 0.52.0",
]
+[[package]]
+name = "rustls"
+version = "0.21.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e"
+dependencies = [
+ "log",
+ "ring",
+ "rustls-webpki",
+ "sct",
+]
+
[[package]]
name = "rustls-pemfile"
version = "1.0.4"
@@ -3434,6 +3488,16 @@ version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d"
+[[package]]
+name = "rustls-webpki"
+version = "0.101.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
+dependencies = [
+ "ring",
+ "untrusted",
+]
+
[[package]]
name = "rustversion"
version = "1.0.17"
@@ -3476,6 +3540,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
+[[package]]
+name = "sct"
+version = "0.7.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
+dependencies = [
+ "ring",
+ "untrusted",
+]
+
[[package]]
name = "security-framework"
version = "2.11.1"
@@ -4352,6 +4426,16 @@ dependencies = [
"tokio",
]
+[[package]]
+name = "tokio-rustls"
+version = "0.24.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081"
+dependencies = [
+ "rustls",
+ "tokio",
+]
+
[[package]]
name = "tokio-stream"
version = "0.1.15"
@@ -4596,6 +4680,12 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a"
+[[package]]
+name = "untrusted"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
+
[[package]]
name = "url"
version = "2.5.2"
@@ -5400,6 +5490,15 @@ dependencies = [
"is-terminal",
]
+[[package]]
+name = "yasna"
+version = "0.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd"
+dependencies = [
+ "time",
+]
+
[[package]]
name = "zip"
version = "0.6.6"
diff --git a/runtime/Cargo.toml b/runtime/Cargo.toml
index 4ec4e97c..692fff4f 100644
--- a/runtime/Cargo.toml
+++ b/runtime/Cargo.toml
@@ -16,10 +16,10 @@ serde_json = "1.0"
keyring = { version = "3.2", features = ["apple-native", "windows-native", "sync-secret-service"] }
arboard = "3.4.0"
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "macros"] }
-flexi_logger = "0.28"
+flexi_logger = "0.29"
log = { version = "0.4", features = ["kv"] }
once_cell = "1.19.0"
-rocket = { version = "0.5", features = ["json"] }
+rocket = { version = "0.5", features = ["json", "tls"] }
rand = "0.8"
rand_chacha = "0.3.1"
base64 = "0.22.1"
@@ -29,6 +29,7 @@ cbc = "0.1.2"
pbkdf2 = "0.12.2"
hmac = "0.12.1"
sha2 = "0.10.8"
+rcgen = { version = "0.13.1", features = ["pem"] }
[target.'cfg(target_os = "linux")'.dependencies]
# See issue https://github.com/tauri-apps/tauri/issues/4470
diff --git a/runtime/src/main.rs b/runtime/src/main.rs
index 19936178..cf280b22 100644
--- a/runtime/src/main.rs
+++ b/runtime/src/main.rs
@@ -28,13 +28,14 @@ use log::{debug, error, info, kv, warn};
use log::kv::{Key, Value, VisitSource};
use pbkdf2::pbkdf2;
use rand::{RngCore, SeedableRng};
+use rcgen::generate_simple_self_signed;
use rocket::figment::Figment;
use rocket::{data, get, post, routes, Data, Request};
-use rocket::config::Shutdown;
+use rocket::config::{Shutdown};
use rocket::data::{Outcome, ToByteUnit};
use rocket::http::Status;
use rocket::serde::json::Json;
-use sha2::Sha512;
+use sha2::{Sha256, Sha512, Digest};
use tauri::updater::UpdateResponse;
use tokio::io::AsyncReadExt;
@@ -153,6 +154,20 @@ async fn main() {
info!("Running in production mode.");
}
+ info!("Try to generate a TLS certificate for the runtime API server...");
+
+ let subject_alt_names = vec!["localhost".to_string()];
+ let certificate_data = generate_simple_self_signed(subject_alt_names).unwrap();
+ let certificate_binary_data = certificate_data.cert.der().to_vec();
+ let certificate_fingerprint = Sha256::digest(certificate_binary_data).to_vec();
+ let certificate_fingerprint = certificate_fingerprint.iter().fold(String::new(), |mut result, byte| {
+ result.push_str(&format!("{:02x}", byte));
+ result
+ });
+ let certificate_fingerprint = certificate_fingerprint.to_uppercase();
+ info!("Certificate fingerprint: '{certificate_fingerprint}'.");
+ info!("Done generating certificate for the runtime API server.");
+
let api_port = *API_SERVER_PORT;
info!("Try to start the API server on 'http://localhost:{api_port}'...");
@@ -178,6 +193,10 @@ async fn main() {
// No colors and emojis in the log output:
.merge(("cli_colors", false))
+ // Read the TLS certificate and key from the generated certificate data in-memory:
+ .merge(("tls.certs", certificate_data.cert.pem().as_bytes()))
+ .merge(("tls.key", certificate_data.key_pair.serialize_pem().as_bytes()))
+
// Set the shutdown configuration:
.merge(("shutdown", Shutdown {
@@ -234,6 +253,7 @@ async fn main() {
.envs(HashMap::from_iter([
(String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password),
(String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt),
+ (String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint),
]))
.spawn()
.expect("Failed to spawn .NET server process.")
@@ -251,6 +271,7 @@ async fn main() {
.envs(HashMap::from_iter([
(String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password),
(String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt),
+ (String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint),
]))
.spawn()
.expect("Failed to spawn .NET server process.")