Refactored runtime API to use TLS encryption

This commit is contained in:
Thorsten Sommer 2024-08-30 22:46:20 +02:00
parent 67726a91d4
commit f929789535
Signed by: tsommer
GPG Key ID: 371BBA77A02C0108
6 changed files with 228 additions and 42 deletions

View File

@ -50,6 +50,7 @@
<ThirdPartyComponent Name="rand" Developer="Rust developers & Open Source Community" LicenseName="MIT" LicenseUrl="https://github.com/rust-random/rand/blob/master/LICENSE-MIT" RepositoryUrl="https://github.com/rust-random/rand" UseCase="We must generate random numbers, e.g., for securing the interprocess communication between the user interface and the runtime. The rand library is great for this purpose."/>
<ThirdPartyComponent Name="base64" Developer="Marshall Pierce, Alice Maz & Open Source Community" LicenseName="MIT" LicenseUrl="https://github.com/marshallpierce/rust-base64/blob/master/LICENSE-MIT" RepositoryUrl="https://github.com/marshallpierce/rust-base64" UseCase="For some data transfers, we need to encode the data in base64. This Rust library is great for this purpose."/>
<ThirdPartyComponent Name="Rust Crypto" Developer="Artyom Pavlov, Tony Arcieri, Brian Warner, Arthur Gautier, Vlad Filippov, Friedel Ziegelmayer, Nicolas Stalder & Open Source Community" LicenseName="MIT" LicenseUrl="https://github.com/RustCrypto/traits/blob/master/cipher/LICENSE-MIT" RepositoryUrl="https://github.com/RustCrypto" UseCase="When transferring sensitive data between Rust runtime and .NET app, we encrypt the data. We use some libraries from the Rust Crypto project for this purpose: cipher, aes, cbc, pbkdf2, hmac, and sha2. We are thankful for the great work of the Rust Crypto project."/>
<ThirdPartyComponent Name="rcgen" Developer="RustTLS developers, est31 & Open Source Community" LicenseName="MIT" LicenseUrl="https://github.com/rustls/rcgen/blob/main/LICENSE" RepositoryUrl="https://github.com/rustls/rcgen" UseCase="For the secure communication between the user interface and the runtime, we need to create certificates. This Rust library is great for this purpose."/>
<ThirdPartyComponent Name="HtmlAgilityPack" Developer="ZZZ Projects & Open Source Community" LicenseName="MIT" LicenseUrl="https://github.com/zzzprojects/html-agility-pack/blob/master/LICENSE" RepositoryUrl="https://github.com/zzzprojects/html-agility-pack" UseCase="We use the HtmlAgilityPack to extract content from the web. This is necessary, e.g., when you provide a URL as input for an assistant."/>
<ThirdPartyComponent Name="ReverseMarkdown" Developer="Babu Annamalai & Open Source Community" LicenseName="MIT" LicenseUrl="https://github.com/mysticmind/reversemarkdown-net/blob/master/LICENSE" RepositoryUrl="https://github.com/mysticmind/reversemarkdown-net" UseCase="This library is used to convert HTML to Markdown. This is necessary, e.g., when you provide a URL as input for an assistant."/>
<ThirdPartyComponent Name="wikEd diff" Developer="Cacycle & Open Source Community" LicenseName="None (public domain)" LicenseUrl="https://en.wikipedia.org/wiki/User:Cacycle/diff#License" RepositoryUrl="https://en.wikipedia.org/wiki/User:Cacycle/diff" UseCase="This library is used to display the differences between two texts. This is necessary, e.g., for the grammar and spelling assistant."/>

View File

@ -23,16 +23,7 @@ internal sealed class Program
{
if(args.Length == 0)
{
Console.WriteLine("Please provide the port of the runtime API.");
return;
}
var rustApiPort = args[0];
using var rust = new RustService(rustApiPort);
var appPort = await rust.GetAppPort();
if(appPort == 0)
{
Console.WriteLine("Failed to get the app port from Rust.");
Console.WriteLine("Error: Please provide the port of the runtime API.");
return;
}
@ -40,7 +31,7 @@ internal sealed class Program
var secretPasswordEncoded = Environment.GetEnvironmentVariable("AI_STUDIO_SECRET_PASSWORD");
if(string.IsNullOrWhiteSpace(secretPasswordEncoded))
{
Console.WriteLine("The AI_STUDIO_SECRET_PASSWORD environment variable is not set.");
Console.WriteLine("Error: The AI_STUDIO_SECRET_PASSWORD environment variable is not set.");
return;
}
@ -48,12 +39,28 @@ internal sealed class Program
var secretKeySaltEncoded = Environment.GetEnvironmentVariable("AI_STUDIO_SECRET_KEY_SALT");
if(string.IsNullOrWhiteSpace(secretKeySaltEncoded))
{
Console.WriteLine("The AI_STUDIO_SECRET_KEY_SALT environment variable is not set.");
Console.WriteLine("Error: The AI_STUDIO_SECRET_KEY_SALT environment variable is not set.");
return;
}
var secretKeySalt = Convert.FromBase64String(secretKeySaltEncoded);
var certificateFingerprint = Environment.GetEnvironmentVariable("AI_STUDIO_CERTIFICATE_FINGERPRINT");
if(string.IsNullOrWhiteSpace(certificateFingerprint))
{
Console.WriteLine("Error: The AI_STUDIO_CERTIFICATE_FINGERPRINT environment variable is not set.");
return;
}
var rustApiPort = args[0];
using var rust = new RustService(rustApiPort, certificateFingerprint);
var appPort = await rust.GetAppPort();
if(appPort == 0)
{
Console.WriteLine("Error: Failed to get the app port from Rust.");
return;
}
var builder = WebApplication.CreateBuilder();
builder.WebHost.ConfigureKestrel(kestrelServerOptions =>
@ -62,10 +69,6 @@ internal sealed class Program
{
listenOptions.Protocols = HttpProtocols.Http1AndHttp2AndHttp3;
});
kestrelServerOptions.ConfigureHttpsDefaults(adapterOptions =>
{
});
});
builder.Logging.ClearProviders();

View File

@ -1,3 +1,4 @@
using System.Security.Cryptography;
using System.Text.Json;
using AIStudio.Provider;
@ -10,12 +11,9 @@ namespace AIStudio.Tools;
/// <summary>
/// Calling Rust functions.
/// </summary>
public sealed class RustService(string apiPort) : IDisposable
public sealed class RustService : IDisposable
{
private readonly HttpClient http = new()
{
BaseAddress = new Uri($"http://127.0.0.1:{apiPort}"),
};
private readonly HttpClient http;
private readonly JsonSerializerOptions jsonRustSerializerOptions = new()
{
@ -25,6 +23,33 @@ public sealed class RustService(string apiPort) : IDisposable
private ILogger<RustService>? logger;
private Encryption? encryptor;
private readonly string apiPort;
private readonly string certificateFingerprint;
public RustService(string apiPort, string certificateFingerprint)
{
this.apiPort = apiPort;
this.certificateFingerprint = certificateFingerprint;
var certificateValidationHandler = new HttpClientHandler
{
ServerCertificateCustomValidationCallback = (_, certificate, _, _) =>
{
if(certificate is null)
return false;
var currentCertificateFingerprint = certificate.GetCertHashString(HashAlgorithmName.SHA256);
return currentCertificateFingerprint == certificateFingerprint;
},
};
this.http = new HttpClient(certificateValidationHandler)
{
BaseAddress = new Uri($"https://127.0.0.1:{apiPort}"),
DefaultRequestVersion = Version.Parse("2.0"),
DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher,
};
}
public void SetLogger(ILogger<RustService> logService)
{
this.logger = logService;
@ -48,7 +73,7 @@ public sealed class RustService(string apiPort) : IDisposable
const int MAX_TRIES = 160;
var tris = 0;
var wait4Try = TimeSpan.FromMilliseconds(250);
var url = new Uri($"http://127.0.0.1:{apiPort}/system/dotnet/port");
var url = new Uri($"https://127.0.0.1:{this.apiPort}/system/dotnet/port");
while (tris++ < MAX_TRIES)
{
//
@ -57,19 +82,47 @@ public sealed class RustService(string apiPort) : IDisposable
// instance, we would always get the same result (403 forbidden),
// without even trying to connect to the Rust server.
//
using var initialHttp = new HttpClient();
var response = await initialHttp.GetAsync(url);
if (!response.IsSuccessStatusCode)
{
Console.WriteLine($"Try {tris}/{MAX_TRIES} to get the app port from Rust runtime");
await Task.Delay(wait4Try);
continue;
}
var appPortContent = await response.Content.ReadAsStringAsync();
var appPort = int.Parse(appPortContent);
Console.WriteLine($"Received app port from Rust runtime: '{appPort}'");
return appPort;
using var initialHttp = new HttpClient(new HttpClientHandler
{
//
// Note III: We have to create also a new HttpClientHandler instance
// for each try to avoid .NET is caching the result. This is necessary
// because it gets disposed when the HttpClient instance gets disposed.
//
ServerCertificateCustomValidationCallback = (_, certificate, _, _) =>
{
if(certificate is null)
return false;
var currentCertificateFingerprint = certificate.GetCertHashString(HashAlgorithmName.SHA256);
return currentCertificateFingerprint == this.certificateFingerprint;
}
});
initialHttp.DefaultRequestVersion = Version.Parse("2.0");
initialHttp.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher;
try
{
var response = await initialHttp.GetAsync(url);
if (!response.IsSuccessStatusCode)
{
Console.WriteLine($"Try {tris}/{MAX_TRIES} to get the app port from Rust runtime");
await Task.Delay(wait4Try);
continue;
}
var appPortContent = await response.Content.ReadAsStringAsync();
var appPort = int.Parse(appPortContent);
Console.WriteLine($"Received app port from Rust runtime: '{appPort}'");
return appPort;
}
catch (Exception e)
{
Console.WriteLine($"Error: Was not able to get the app port from Rust runtime: '{e.Message}'");
Console.WriteLine(e.InnerException);
throw;
}
}
Console.WriteLine("Failed to receive the app port from Rust runtime.");
@ -80,10 +133,18 @@ public sealed class RustService(string apiPort) : IDisposable
{
const string URL = "/system/dotnet/ready";
this.logger!.LogInformation("Notifying Rust runtime that the app is ready.");
var response = await this.http.GetAsync(URL);
if (!response.IsSuccessStatusCode)
try
{
this.logger!.LogError($"Failed to notify Rust runtime that the app is ready: '{response.StatusCode}'");
var response = await this.http.GetAsync(URL);
if (!response.IsSuccessStatusCode)
{
this.logger!.LogError($"Failed to notify Rust runtime that the app is ready: '{response.StatusCode}'");
}
}
catch (Exception e)
{
this.logger!.LogError(e, "Failed to notify the Rust runtime that the app is ready.");
throw;
}
}

103
runtime/Cargo.lock generated
View File

@ -964,9 +964,9 @@ dependencies = [
[[package]]
name = "flexi_logger"
version = "0.28.5"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cca927478b3747ba47f98af6ba0ac0daea4f12d12f55e9104071b3dc00276310"
checksum = "a250587a211932896a131f214a4f64c047b826ce072d2018764e5ff5141df8fa"
dependencies = [
"chrono",
"glob",
@ -2145,6 +2145,7 @@ dependencies = [
"pbkdf2",
"rand 0.8.5",
"rand_chacha 0.3.1",
"rcgen",
"reqwest 0.12.4",
"rocket",
"serde",
@ -2709,6 +2710,16 @@ dependencies = [
"syn 2.0.72",
]
[[package]]
name = "pem"
version = "3.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae"
dependencies = [
"base64 0.22.1",
"serde",
]
[[package]]
name = "percent-encoding"
version = "2.3.1"
@ -3098,6 +3109,19 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9"
[[package]]
name = "rcgen"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779"
dependencies = [
"pem",
"ring",
"rustls-pki-types",
"time",
"yasna",
]
[[package]]
name = "redox_syscall"
version = "0.4.1"
@ -3299,6 +3323,21 @@ dependencies = [
"windows 0.37.0",
]
[[package]]
name = "ring"
version = "0.17.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
dependencies = [
"cc",
"cfg-if",
"getrandom 0.2.15",
"libc",
"spin",
"untrusted",
"windows-sys 0.52.0",
]
[[package]]
name = "rocket"
version = "0.5.1"
@ -3372,12 +3411,15 @@ dependencies = [
"percent-encoding",
"pin-project-lite",
"ref-cast",
"rustls",
"rustls-pemfile 1.0.4",
"serde",
"smallvec",
"stable-pattern",
"state 0.6.0",
"time",
"tokio",
"tokio-rustls",
"uncased",
]
@ -3409,6 +3451,18 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "rustls"
version = "0.21.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e"
dependencies = [
"log",
"ring",
"rustls-webpki",
"sct",
]
[[package]]
name = "rustls-pemfile"
version = "1.0.4"
@ -3434,6 +3488,16 @@ version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d"
[[package]]
name = "rustls-webpki"
version = "0.101.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "rustversion"
version = "1.0.17"
@ -3476,6 +3540,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "sct"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "security-framework"
version = "2.11.1"
@ -4352,6 +4426,16 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081"
dependencies = [
"rustls",
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.15"
@ -4596,6 +4680,12 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a"
[[package]]
name = "untrusted"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "url"
version = "2.5.2"
@ -5400,6 +5490,15 @@ dependencies = [
"is-terminal",
]
[[package]]
name = "yasna"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd"
dependencies = [
"time",
]
[[package]]
name = "zip"
version = "0.6.6"

View File

@ -16,10 +16,10 @@ serde_json = "1.0"
keyring = { version = "3.2", features = ["apple-native", "windows-native", "sync-secret-service"] }
arboard = "3.4.0"
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "macros"] }
flexi_logger = "0.28"
flexi_logger = "0.29"
log = { version = "0.4", features = ["kv"] }
once_cell = "1.19.0"
rocket = { version = "0.5", features = ["json"] }
rocket = { version = "0.5", features = ["json", "tls"] }
rand = "0.8"
rand_chacha = "0.3.1"
base64 = "0.22.1"
@ -29,6 +29,7 @@ cbc = "0.1.2"
pbkdf2 = "0.12.2"
hmac = "0.12.1"
sha2 = "0.10.8"
rcgen = { version = "0.13.1", features = ["pem"] }
[target.'cfg(target_os = "linux")'.dependencies]
# See issue https://github.com/tauri-apps/tauri/issues/4470

View File

@ -28,13 +28,14 @@ use log::{debug, error, info, kv, warn};
use log::kv::{Key, Value, VisitSource};
use pbkdf2::pbkdf2;
use rand::{RngCore, SeedableRng};
use rcgen::generate_simple_self_signed;
use rocket::figment::Figment;
use rocket::{data, get, post, routes, Data, Request};
use rocket::config::Shutdown;
use rocket::config::{Shutdown};
use rocket::data::{Outcome, ToByteUnit};
use rocket::http::Status;
use rocket::serde::json::Json;
use sha2::Sha512;
use sha2::{Sha256, Sha512, Digest};
use tauri::updater::UpdateResponse;
use tokio::io::AsyncReadExt;
@ -153,6 +154,20 @@ async fn main() {
info!("Running in production mode.");
}
info!("Try to generate a TLS certificate for the runtime API server...");
let subject_alt_names = vec!["localhost".to_string()];
let certificate_data = generate_simple_self_signed(subject_alt_names).unwrap();
let certificate_binary_data = certificate_data.cert.der().to_vec();
let certificate_fingerprint = Sha256::digest(certificate_binary_data).to_vec();
let certificate_fingerprint = certificate_fingerprint.iter().fold(String::new(), |mut result, byte| {
result.push_str(&format!("{:02x}", byte));
result
});
let certificate_fingerprint = certificate_fingerprint.to_uppercase();
info!("Certificate fingerprint: '{certificate_fingerprint}'.");
info!("Done generating certificate for the runtime API server.");
let api_port = *API_SERVER_PORT;
info!("Try to start the API server on 'http://localhost:{api_port}'...");
@ -178,6 +193,10 @@ async fn main() {
// No colors and emojis in the log output:
.merge(("cli_colors", false))
// Read the TLS certificate and key from the generated certificate data in-memory:
.merge(("tls.certs", certificate_data.cert.pem().as_bytes()))
.merge(("tls.key", certificate_data.key_pair.serialize_pem().as_bytes()))
// Set the shutdown configuration:
.merge(("shutdown", Shutdown {
@ -234,6 +253,7 @@ async fn main() {
.envs(HashMap::from_iter([
(String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password),
(String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt),
(String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint),
]))
.spawn()
.expect("Failed to spawn .NET server process.")
@ -251,6 +271,7 @@ async fn main() {
.envs(HashMap::from_iter([
(String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password),
(String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt),
(String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint),
]))
.spawn()
.expect("Failed to spawn .NET server process.")