From 1dcfd19f722c4558cae8c4420c9425a4dac553f8 Mon Sep 17 00:00:00 2001 From: PaulKoudelka Date: Tue, 13 Jan 2026 16:38:22 +0100 Subject: [PATCH] Added TLS and API token support for Qdrant communication. --- .github/workflows/build-and-release.yml | 9 ++- app/Build/Commands/Qdrant.cs | 5 +- .../MindWork AI Studio.csproj | 1 + app/MindWork AI Studio/Pages/About.razor | 6 +- app/MindWork AI Studio/Pages/About.razor.cs | 7 ++ app/MindWork AI Studio/Program.cs | 23 +++++- .../Tools/Databases/DatabaseClient.cs | 57 +++++---------- .../Tools/Databases/Qdrant/QdrantClient.cs | 15 ---- .../Qdrant/QdrantClientImplementation.cs | 61 ++++++++++++++++ .../Tools/Rust/QdrantInfo.cs | 2 + .../Tools/Services/RustService.Databases.cs | 4 +- runtime/Cargo.toml | 1 + .../resources/databases/qdrant/config.yaml | 4 +- runtime/src/api_token.rs | 58 +++------------ runtime/src/app_window.rs | 3 +- runtime/src/certificate.rs | 38 ---------- runtime/src/certificate_factory.rs | 22 ++++++ runtime/src/dotnet.rs | 5 +- runtime/src/lib.rs | 6 +- runtime/src/main.rs | 4 +- runtime/src/qdrant.rs | 71 +++++++++++++++++-- runtime/src/runtime_api.rs | 2 +- runtime/src/runtime_api_token.rs | 40 +++++++++++ runtime/src/runtime_certificate.rs | 26 +++++++ 24 files changed, 302 insertions(+), 168 deletions(-) delete mode 100644 app/MindWork AI Studio/Tools/Databases/Qdrant/QdrantClient.cs create mode 100644 app/MindWork AI Studio/Tools/Databases/Qdrant/QdrantClientImplementation.cs delete mode 100644 runtime/src/certificate.rs create mode 100644 runtime/src/certificate_factory.rs create mode 100644 runtime/src/runtime_api_token.rs create mode 100644 runtime/src/runtime_certificate.rs diff --git a/.github/workflows/build-and-release.yml b/.github/workflows/build-and-release.yml index 8cf531b8..cf8daf45 100644 --- a/.github/workflows/build-and-release.yml +++ b/.github/workflows/build-and-release.yml @@ -344,7 +344,7 @@ jobs: echo "Cleaning up ..." rm -fr "$TMP" - - name: Install PDFium (Windows) + - name: Deploy PDFium (Windows) if: matrix.platform == 'windows-latest' env: PDFIUM_VERSION: ${{ env.PDFIUM_VERSION }} @@ -464,7 +464,7 @@ jobs: echo "Cleaning up ..." rm -fr "$TMP" - - name: Install Qdrant (Windows) + - name: Deploy Qdrant (Windows) if: matrix.platform == 'windows-latest' env: QDRANT_VERSION: ${{ env.QDRANT_VERSION }} @@ -479,6 +479,11 @@ jobs: $DB_SOURCE = "qdrant.exe" $DB_TARGET = "qdrant.exe" } + "win-arm64" { + $QDRANT_FILE = "x86_64-pc-windows-msvc.zip" + $DB_SOURCE = "qdrant.exe" + $DB_TARGET = "qdrant.exe"" + } default { Write-Error "Unknown platform: $($env:DOTNET_RUNTIME)" exit 1 diff --git a/app/Build/Commands/Qdrant.cs b/app/Build/Commands/Qdrant.cs index 9a573823..4133332e 100644 --- a/app/Build/Commands/Qdrant.cs +++ b/app/Build/Commands/Qdrant.cs @@ -91,10 +91,11 @@ public static class Qdrant RID.OSX_ARM64 => new("qdrant", "qdrant-aarch64-apple-darwin"), RID.OSX_X64 => new("qdrant", "qdrant-x86_64-apple-darwin"), - RID.LINUX_ARM64 => new("qdrant", "qdrant-aarch64-unknown-linux-gnu"), + RID.LINUX_ARM64 => new("qdrant", "qdrant-aarch64-unknown-linux-musl"), RID.LINUX_X64 => new("qdrant", "qdrant-x86_64-unknown-linux-gnu"), RID.WIN_X64 => new("qdrant.exe", "qdrant-x86_64-pc-windows-msvc.exe"), + RID.WIN_ARM64 => new("qdrant.exe", "qdrant-aarch64-pc-windows-msvc.exe"), _ => new(string.Empty, string.Empty), }; @@ -111,7 +112,7 @@ public static class Qdrant RID.OSX_X64 => $"{baseUrl}x86_64-apple-darwin.tar.gz", RID.WIN_X64 => $"{baseUrl}x86_64-pc-windows-msvc.zip", - #warning We have to handle Qdrant for Windows ARM + RID.WIN_ARM64 => $"{baseUrl}x86_64-pc-windows-msvc.zip", _ => string.Empty, }; diff --git a/app/MindWork AI Studio/MindWork AI Studio.csproj b/app/MindWork AI Studio/MindWork AI Studio.csproj index b4b16cd2..572c3504 100644 --- a/app/MindWork AI Studio/MindWork AI Studio.csproj +++ b/app/MindWork AI Studio/MindWork AI Studio.csproj @@ -52,6 +52,7 @@ + diff --git a/app/MindWork AI Studio/Pages/About.razor b/app/MindWork AI Studio/Pages/About.razor index d9fd748b..b442a833 100644 --- a/app/MindWork AI Studio/Pages/About.razor +++ b/app/MindWork AI Studio/Pages/About.razor @@ -25,12 +25,12 @@ - @foreach (var (Label, Value) in DatabaseClient.GetDisplayInfo()) + @foreach (var (label, value) in DatabaseDisplayInfo) {
- @Label: @Value - + @label: @value +
}
diff --git a/app/MindWork AI Studio/Pages/About.razor.cs b/app/MindWork AI Studio/Pages/About.razor.cs index 9371c610..5f3c2fce 100644 --- a/app/MindWork AI Studio/Pages/About.razor.cs +++ b/app/MindWork AI Studio/Pages/About.razor.cs @@ -70,6 +70,8 @@ public partial class About : MSGComponentBase private bool showDatabaseDetails = false; private IPluginMetadata? configPlug = PluginFactory.AvailablePlugins.FirstOrDefault(x => x.Type is PluginType.CONFIGURATION); + + private List<(string Label, string Value)> DatabaseDisplayInfo = new(); /// /// Determines whether the enterprise configuration has details that can be shown/hidden. @@ -105,6 +107,11 @@ public partial class About : MSGComponentBase this.osLanguage = await this.RustService.ReadUserLanguage(); this.logPaths = await this.RustService.GetLogPaths(); + await foreach (var item in this.DatabaseClient.GetDisplayInfo()) + { + this.DatabaseDisplayInfo.Add(item); + } + // Determine the Pandoc version may take some time, so we start it here // without waiting for the result: _ = this.DeterminePandocVersion(); diff --git a/app/MindWork AI Studio/Program.cs b/app/MindWork AI Studio/Program.cs index 0b63f17a..b67cdcfe 100644 --- a/app/MindWork AI Studio/Program.cs +++ b/app/MindWork AI Studio/Program.cs @@ -27,6 +27,7 @@ internal sealed class Program public static string API_TOKEN = null!; public static IServiceProvider SERVICE_PROVIDER = null!; public static ILoggerFactory LOGGER_FACTORY = null!; + public static DatabaseClient DATABASE_CLIENT = null!; public static async Task Main() { @@ -102,6 +103,20 @@ internal sealed class Program Console.WriteLine("Error: Failed to get the Qdrant gRPC port from Rust."); return; } + + if (qdrantInfo.Fingerprint == string.Empty) + { + Console.WriteLine("Error: Failed to get the Qdrant fingerprint from Rust."); + return; + } + + if (qdrantInfo.ApiToken == string.Empty) + { + Console.WriteLine("Error: Failed to get the Qdrant API token from Rust."); + return; + } + + var databaseClient = new QdrantClientImplementation("Qdrant", qdrantInfo.Path, qdrantInfo.PortHttp, qdrantInfo.PortGrpc, qdrantInfo.Fingerprint, qdrantInfo.ApiToken); var builder = WebApplication.CreateBuilder(); @@ -155,7 +170,7 @@ internal sealed class Program builder.Services.AddHostedService(); builder.Services.AddHostedService(); builder.Services.AddHostedService(); - builder.Services.AddSingleton(new QdrantClient("Qdrant", qdrantInfo.Path, qdrantInfo.PortHttp, qdrantInfo.PortGrpc)); + builder.Services.AddSingleton(databaseClient); // ReSharper disable AccessToDisposedClosure builder.Services.AddHostedService(_ => rust); @@ -211,6 +226,10 @@ internal sealed class Program RUST_SERVICE = rust; ENCRYPTION = encryption; + + var databaseLogger = app.Services.GetRequiredService>(); + databaseClient.SetLogger(databaseLogger); + DATABASE_CLIENT = databaseClient; programLogger.LogInformation("Initialize internal file system."); app.Use(Redirect.HandlerContentAsync); @@ -238,7 +257,6 @@ internal sealed class Program await rust.AppIsReady(); programLogger.LogInformation("The AI Studio server is ready."); - TaskScheduler.UnobservedTaskException += (sender, taskArgs) => { programLogger.LogError(taskArgs.Exception, $"Unobserved task exception by sender '{sender ?? "n/a"}'."); @@ -248,6 +266,7 @@ internal sealed class Program await serverTask; RUST_SERVICE.Dispose(); + DATABASE_CLIENT.Dispose(); PluginFactory.Dispose(); programLogger.LogInformation("The AI Studio server was stopped."); } diff --git a/app/MindWork AI Studio/Tools/Databases/DatabaseClient.cs b/app/MindWork AI Studio/Tools/Databases/DatabaseClient.cs index 0ca84e01..3881a9fc 100644 --- a/app/MindWork AI Studio/Tools/Databases/DatabaseClient.cs +++ b/app/MindWork AI Studio/Tools/Databases/DatabaseClient.cs @@ -1,57 +1,29 @@ namespace AIStudio.Tools.Databases; -public abstract class DatabaseClient +public abstract class DatabaseClient(string name, string path) { - public string Name { get; } - private string Path { get; } - - public DatabaseClient(string name, string path) - { - this.Name = name; - this.Path = path; - } + public string Name => name; + private string Path => path; + protected ILogger? logger; - public abstract IEnumerable<(string Label, string Value)> GetDisplayInfo(); + public abstract IAsyncEnumerable<(string Label, string Value)> GetDisplayInfo(); public string GetStorageSize() { - if (string.IsNullOrEmpty(this.Path)) + if (string.IsNullOrWhiteSpace(this.Path)) { - Console.WriteLine($"Error: Database path '{this.Path}' cannot be null or empty."); + this.logger!.LogError($"Error: Database path '{this.Path}' cannot be null or empty."); return "0 B"; } if (!Directory.Exists(this.Path)) { - Console.WriteLine($"Error: Database path '{this.Path}' does not exist."); + this.logger!.LogError($"Error: Database path '{this.Path}' does not exist."); return "0 B"; } - long size = 0; - var stack = new Stack(); - stack.Push(this.Path); - while (stack.Count > 0) - { - string directory = stack.Pop(); - try - { - var files = Directory.GetFiles(directory); - size += files.Sum(file => new FileInfo(file).Length); - var subDirectories = Directory.GetDirectories(directory); - foreach (var subDirectory in subDirectories) - { - stack.Push(subDirectory); - } - } - catch (UnauthorizedAccessException) - { - Console.WriteLine($"No access to {directory}"); - } - catch (Exception ex) - { - Console.WriteLine($"An error encountered while processing {directory}: "); - Console.WriteLine($"{ ex.Message}"); - } - } + var files = Directory.EnumerateFiles(this.Path, "*", SearchOption.AllDirectories) + .Where(file => !System.IO.Path.GetDirectoryName(file)!.Contains("cert", StringComparison.OrdinalIgnoreCase)); + var size = files.Sum(file => new FileInfo(file).Length); return FormatBytes(size); } @@ -68,4 +40,11 @@ public abstract class DatabaseClient return $"{size:0##} {suffixes[suffixIndex]}"; } + + public void SetLogger(ILogger logService) + { + this.logger = logService; + } + + public abstract void Dispose(); } \ No newline at end of file diff --git a/app/MindWork AI Studio/Tools/Databases/Qdrant/QdrantClient.cs b/app/MindWork AI Studio/Tools/Databases/Qdrant/QdrantClient.cs deleted file mode 100644 index c3a4fabd..00000000 --- a/app/MindWork AI Studio/Tools/Databases/Qdrant/QdrantClient.cs +++ /dev/null @@ -1,15 +0,0 @@ -namespace AIStudio.Tools.Databases.Qdrant; - -public class QdrantClient(string name, string path, int httpPort, int grpcPort) : DatabaseClient(name, path) -{ - private int HttpPort { get; } = httpPort; - private int GrpcPort { get; } = grpcPort; - private string IpAddress { get; } = "127.0.0.1"; - - public override IEnumerable<(string Label, string Value)> GetDisplayInfo() - { - yield return ("HTTP Port", this.HttpPort.ToString()); - yield return ("gRPC Port", this.GrpcPort.ToString()); - yield return ("Storage Size", $"{base.GetStorageSize()}"); - } -} \ No newline at end of file diff --git a/app/MindWork AI Studio/Tools/Databases/Qdrant/QdrantClientImplementation.cs b/app/MindWork AI Studio/Tools/Databases/Qdrant/QdrantClientImplementation.cs new file mode 100644 index 00000000..4ef49dc1 --- /dev/null +++ b/app/MindWork AI Studio/Tools/Databases/Qdrant/QdrantClientImplementation.cs @@ -0,0 +1,61 @@ +using Qdrant.Client; +using Qdrant.Client.Grpc; + +namespace AIStudio.Tools.Databases.Qdrant; + +public class QdrantClientImplementation : DatabaseClient +{ + private int HttpPort { get; } + private int GrpcPort { get; } + private string IpAddress => "localhost"; + private QdrantClient GrpcClient { get; } + private string Fingerprint { get; } + private string ApiToken { get; } + + public QdrantClientImplementation(string name, string path, int httpPort, int grpcPort, string fingerprint, string apiToken): base(name, path) + { + this.HttpPort = httpPort; + this.GrpcPort = grpcPort; + this.Fingerprint = fingerprint; + this.ApiToken = apiToken; + this.GrpcClient = this.CreateQdrantClient(); + } + + public QdrantClient CreateQdrantClient() + { + var address = "https://" + this.IpAddress + ":" + this.GrpcPort; + var channel = QdrantChannel.ForAddress(address, new ClientConfiguration + { + ApiKey = this.ApiToken, + CertificateThumbprint = this.Fingerprint + }); + var grpcClient = new QdrantGrpcClient(channel); + return new QdrantClient(grpcClient); + } + + public async Task GetVersion() + { + var operation = await this.GrpcClient.HealthAsync(); + return "v"+operation.Version; + } + + public async Task GetCollectionsAmount() + { + var operation = await this.GrpcClient.ListCollectionsAsync(); + return operation.Count.ToString(); + } + + public override async IAsyncEnumerable<(string Label, string Value)> GetDisplayInfo() + { + yield return ("HTTP port", this.HttpPort.ToString()); + yield return ("gRPC port", this.GrpcPort.ToString()); + yield return ("Extracted version", await this.GetVersion()); + yield return ("Storage size", $"{base.GetStorageSize()}"); + yield return ("Amount of collections", await this.GetCollectionsAmount()); + } + + public override void Dispose() + { + this.GrpcClient.Dispose(); + } +} \ No newline at end of file diff --git a/app/MindWork AI Studio/Tools/Rust/QdrantInfo.cs b/app/MindWork AI Studio/Tools/Rust/QdrantInfo.cs index 8cbe5e9c..6f9b2e5c 100644 --- a/app/MindWork AI Studio/Tools/Rust/QdrantInfo.cs +++ b/app/MindWork AI Studio/Tools/Rust/QdrantInfo.cs @@ -10,4 +10,6 @@ public record struct QdrantInfo public string Path { get; init; } public int PortHttp { get; init; } public int PortGrpc { get; init; } + public string Fingerprint { get; init; } + public string ApiToken { get; init; } } \ No newline at end of file diff --git a/app/MindWork AI Studio/Tools/Services/RustService.Databases.cs b/app/MindWork AI Studio/Tools/Services/RustService.Databases.cs index ae42316d..a4e0eade 100644 --- a/app/MindWork AI Studio/Tools/Services/RustService.Databases.cs +++ b/app/MindWork AI Studio/Tools/Services/RustService.Databases.cs @@ -9,7 +9,7 @@ public sealed partial class RustService try { var cts = new CancellationTokenSource(TimeSpan.FromSeconds(45)); - var response = await this.http.GetFromJsonAsync("/system/qdrant/port", this.jsonRustSerializerOptions, cts.Token); + var response = await this.http.GetFromJsonAsync("/system/qdrant/info", this.jsonRustSerializerOptions, cts.Token); return response; } catch (Exception e) @@ -20,6 +20,8 @@ public sealed partial class RustService Path = string.Empty, PortHttp = 0, PortGrpc = 0, + Fingerprint = string.Empty, + ApiToken = string.Empty, }; } } diff --git a/runtime/Cargo.toml b/runtime/Cargo.toml index 3eb33086..a518b32d 100644 --- a/runtime/Cargo.toml +++ b/runtime/Cargo.toml @@ -39,6 +39,7 @@ pdfium-render = "0.8.34" sys-locale = "0.3.2" cfg-if = "1.0.1" pptx-to-md = "0.4.0" +tempfile = "3.8" # Fixes security vulnerability downstream, where the upstream is not fixed yet: url = "2.5" diff --git a/runtime/resources/databases/qdrant/config.yaml b/runtime/resources/databases/qdrant/config.yaml index 267f81c2..1149a0ec 100644 --- a/runtime/resources/databases/qdrant/config.yaml +++ b/runtime/resources/databases/qdrant/config.yaml @@ -332,10 +332,10 @@ telemetry_disabled: true # Required if either service.enable_tls or cluster.p2p.enable_tls is true. tls: # Server certificate chain file - cert: ./tls/cert.pem + # cert: ./tls/cert.pem # Server private key file - key: ./tls/key.pem + # key: ./tls/key.pem # Certificate authority certificate file. # This certificate will be used to validate the certificates diff --git a/runtime/src/api_token.rs b/runtime/src/api_token.rs index 31759185..e945095e 100644 --- a/runtime/src/api_token.rs +++ b/runtime/src/api_token.rs @@ -1,21 +1,5 @@ -use log::info; -use once_cell::sync::Lazy; use rand::{RngCore, SeedableRng}; -use rocket::http::Status; -use rocket::Request; -use rocket::request::FromRequest; - -/// The API token used to authenticate requests. -pub static API_TOKEN: Lazy = Lazy::new(|| { - let mut token = [0u8; 32]; - let mut rng = rand_chacha::ChaChaRng::from_os_rng(); - rng.fill_bytes(&mut token); - - let token = APIToken::from_bytes(token.to_vec()); - info!("API token was generated successfully."); - - token -}); +use rand_chacha::ChaChaRng; /// The API token data structure used to authenticate requests. pub struct APIToken { @@ -34,7 +18,7 @@ impl APIToken { } /// Creates a new API token from a hexadecimal text. - fn from_hex_text(hex_text: &str) -> Self { + pub fn from_hex_text(hex_text: &str) -> Self { APIToken { hex_text: hex_text.to_string(), } @@ -45,40 +29,14 @@ impl APIToken { } /// Validates the received token against the valid token. - fn validate(&self, received_token: &Self) -> bool { + pub fn validate(&self, received_token: &Self) -> bool { received_token.to_hex_text() == self.to_hex_text() } } -/// The request outcome type used to handle API token requests. -type RequestOutcome = rocket::request::Outcome; - -/// The request outcome implementation for the API token. -#[rocket::async_trait] -impl<'r> FromRequest<'r> for APIToken { - type Error = APITokenError; - - /// Handles the API token requests. - async fn from_request(request: &'r Request<'_>) -> RequestOutcome { - let token = request.headers().get_one("token"); - match token { - Some(token) => { - let received_token = APIToken::from_hex_text(token); - if API_TOKEN.validate(&received_token) { - RequestOutcome::Success(received_token) - } else { - RequestOutcome::Error((Status::Unauthorized, APITokenError::Invalid)) - } - } - - None => RequestOutcome::Error((Status::Unauthorized, APITokenError::Missing)), - } - } -} - -/// The API token error types. -#[derive(Debug)] -pub enum APITokenError { - Missing, - Invalid, +pub fn generate_api_token() -> APIToken { + let mut token = [0u8; 32]; + let mut rng = ChaChaRng::from_os_rng(); + rng.fill_bytes(&mut token); + APIToken::from_bytes(token.to_vec()) } \ No newline at end of file diff --git a/runtime/src/app_window.rs b/runtime/src/app_window.rs index 7cd97b8b..69661553 100644 --- a/runtime/src/app_window.rs +++ b/runtime/src/app_window.rs @@ -17,7 +17,7 @@ use crate::dotnet::stop_dotnet_server; use crate::environment::{is_prod, is_dev, CONFIG_DIRECTORY, DATA_DIRECTORY}; use crate::log::switch_to_file_logging; use crate::pdfium::PDFIUM_LIB_PATH; -use crate::qdrant::start_qdrant_server; +use crate::qdrant::{start_qdrant_server, stop_qdrant_server}; /// The Tauri main window. static MAIN_WINDOW: Lazy>> = Lazy::new(|| Mutex::new(None)); @@ -174,6 +174,7 @@ pub fn start_tauri() { RunEvent::ExitRequested { .. } => { warn!(Source = "Tauri"; "Run event: exit was requested."); + stop_qdrant_server(); } RunEvent::Ready => { diff --git a/runtime/src/certificate.rs b/runtime/src/certificate.rs deleted file mode 100644 index 8cf7fb38..00000000 --- a/runtime/src/certificate.rs +++ /dev/null @@ -1,38 +0,0 @@ -use std::sync::OnceLock; -use log::info; -use rcgen::generate_simple_self_signed; -use sha2::{Sha256, Digest}; - -/// The certificate used for the runtime API server. -pub static CERTIFICATE: OnceLock> = OnceLock::new(); - -/// The private key used for the certificate of the runtime API server. -pub static CERTIFICATE_PRIVATE_KEY: OnceLock> = OnceLock::new(); - -/// The fingerprint of the certificate used for the runtime API server. -pub static CERTIFICATE_FINGERPRINT: OnceLock = OnceLock::new(); - -/// Generates a TLS certificate for the runtime API server. -pub fn generate_certificate() { - - info!("Try to generate a TLS certificate for the runtime API server..."); - - let subject_alt_names = vec!["localhost".to_string()]; - let certificate_data = generate_simple_self_signed(subject_alt_names).unwrap(); - let certificate_binary_data = certificate_data.cert.der().to_vec(); - - let certificate_fingerprint = Sha256::digest(certificate_binary_data).to_vec(); - let certificate_fingerprint = certificate_fingerprint.iter().fold(String::new(), |mut result, byte| { - result.push_str(&format!("{:02x}", byte)); - result - }); - - let certificate_fingerprint = certificate_fingerprint.to_uppercase(); - - CERTIFICATE_FINGERPRINT.set(certificate_fingerprint.clone()).expect("Could not set the certificate fingerprint."); - CERTIFICATE.set(certificate_data.cert.pem().as_bytes().to_vec()).expect("Could not set the certificate."); - CERTIFICATE_PRIVATE_KEY.set(certificate_data.signing_key.serialize_pem().as_bytes().to_vec()).expect("Could not set the private key."); - - info!("Certificate fingerprint: '{certificate_fingerprint}'."); - info!("Done generating certificate for the runtime API server."); -} \ No newline at end of file diff --git a/runtime/src/certificate_factory.rs b/runtime/src/certificate_factory.rs new file mode 100644 index 00000000..3c30d34a --- /dev/null +++ b/runtime/src/certificate_factory.rs @@ -0,0 +1,22 @@ +use log::info; +use rcgen::generate_simple_self_signed; +use sha2::{Sha256, Digest}; + +pub fn generate_certificate() -> (Vec, Vec, String) { + + let subject_alt_names = vec!["localhost".to_string()]; + let certificate_data = generate_simple_self_signed(subject_alt_names).unwrap(); + let certificate_binary_data = certificate_data.cert.der().to_vec(); + + let certificate_fingerprint = Sha256::digest(certificate_binary_data).to_vec(); + let certificate_fingerprint = certificate_fingerprint.iter().fold(String::new(), |mut result, byte| { + result.push_str(&format!("{:02x}", byte)); + result + }); + + let certificate_fingerprint = certificate_fingerprint.to_uppercase(); + + info!("Certificate fingerprint: '{certificate_fingerprint}'."); + + (certificate_data.cert.pem().as_bytes().to_vec(), certificate_data.signing_key.serialize_pem().as_bytes().to_vec(), certificate_fingerprint.clone()) +} \ No newline at end of file diff --git a/runtime/src/dotnet.rs b/runtime/src/dotnet.rs index 26b793f5..fb792a15 100644 --- a/runtime/src/dotnet.rs +++ b/runtime/src/dotnet.rs @@ -7,9 +7,10 @@ use once_cell::sync::Lazy; use rocket::get; use tauri::api::process::{Command, CommandChild, CommandEvent}; use tauri::Url; -use crate::api_token::{APIToken, API_TOKEN}; +use crate::api_token::APIToken; +use crate::runtime_api_token::API_TOKEN; use crate::app_window::change_location_to; -use crate::certificate::CERTIFICATE_FINGERPRINT; +use crate::runtime_certificate::CERTIFICATE_FINGERPRINT; use crate::encryption::ENCRYPTION; use crate::environment::is_dev; use crate::network::get_available_port; diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index bd7da307..e99f528b 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -8,9 +8,11 @@ pub mod app_window; pub mod secret; pub mod clipboard; pub mod runtime_api; -pub mod certificate; +pub mod runtime_certificate; pub mod file_data; pub mod metadata; pub mod pdfium; pub mod pandoc; -pub mod qdrant; \ No newline at end of file +pub mod qdrant; +pub mod certificate_factory; +pub mod runtime_api_token; \ No newline at end of file diff --git a/runtime/src/main.rs b/runtime/src/main.rs index bfbe4750..91427472 100644 --- a/runtime/src/main.rs +++ b/runtime/src/main.rs @@ -6,7 +6,7 @@ extern crate core; use log::{info, warn}; use mindwork_ai_studio::app_window::start_tauri; -use mindwork_ai_studio::certificate::{generate_certificate}; +use mindwork_ai_studio::runtime_certificate::{generate_runtime_certificate}; use mindwork_ai_studio::dotnet::start_dotnet_server; use mindwork_ai_studio::environment::is_dev; use mindwork_ai_studio::log::init_logging; @@ -46,7 +46,7 @@ async fn main() { info!("Running in production mode."); } - generate_certificate(); + generate_runtime_certificate(); start_runtime_api(); if is_dev() { diff --git a/runtime/src/qdrant.rs b/runtime/src/qdrant.rs index 3b2b94ce..4a945d5d 100644 --- a/runtime/src/qdrant.rs +++ b/runtime/src/qdrant.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; +use std::fs::File; +use std::io::Write; use std::path::Path; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, OnceLock}; use log::{debug, error, info, warn}; use once_cell::sync::Lazy; use rocket::get; @@ -9,6 +11,9 @@ use rocket::serde::Serialize; use tauri::api::process::{Command, CommandChild, CommandEvent}; use crate::api_token::{APIToken}; use crate::environment::DATA_DIRECTORY; +use crate::certificate_factory::generate_certificate; +use std::path::PathBuf; +use tempfile::{TempDir, Builder}; // Qdrant server process started in a separate process and can communicate // via HTTP or gRPC with the .NET server and the runtime process @@ -23,26 +28,38 @@ static QDRANT_SERVER_PORT_GRPC: Lazy = Lazy::new(|| { crate::network::get_available_port().unwrap_or(6334) }); +pub static CERTIFICATE_FINGERPRINT: OnceLock = OnceLock::new(); +static API_TOKEN: Lazy = Lazy::new(|| { + crate::api_token::generate_api_token() +}); + +static TMPDIR: Lazy>> = Lazy::new(|| Mutex::new(None)); + #[derive(Serialize)] pub struct ProvideQdrantInfo { path: String, port_http: u16, port_grpc: u16, + fingerprint: String, + api_token: String, } -#[get("/system/qdrant/port")] +#[get("/system/qdrant/info")] pub fn qdrant_port(_token: APIToken) -> Json { return Json(ProvideQdrantInfo { path: Path::new(DATA_DIRECTORY.get().unwrap()).join("databases").join("qdrant").to_str().unwrap().to_string(), port_http: *QDRANT_SERVER_PORT_HTTP, port_grpc: *QDRANT_SERVER_PORT_GRPC, + fingerprint: CERTIFICATE_FINGERPRINT.get().expect("Certificate fingerprint not available").to_string(), + api_token: API_TOKEN.to_hex_text().to_string(), }); } /// Starts the Qdrant server in a separate process. pub fn start_qdrant_server() { - let base_path = DATA_DIRECTORY.get().unwrap(); + let base_path = DATA_DIRECTORY.get().unwrap(); + let (cert_path, key_path) =create_temp_tls_files(Path::new(base_path).join("databases").join("qdrant")).unwrap(); let storage_path = Path::new(base_path).join("databases").join("qdrant").join("storage").to_str().unwrap().to_string(); let snapshot_path = Path::new(base_path).join("databases").join("qdrant").join("snapshots").to_str().unwrap().to_string(); @@ -54,6 +71,10 @@ pub fn start_qdrant_server() { (String::from("QDRANT_INIT_FILE_PATH"), init_path), (String::from("QDRANT__STORAGE__STORAGE_PATH"), storage_path), (String::from("QDRANT__STORAGE__SNAPSHOTS_PATH"), snapshot_path), + (String::from("QDRANT__TLS__CERT"), cert_path.to_str().unwrap().to_string()), + (String::from("QDRANT__TLS__KEY"), key_path.to_str().unwrap().to_string()), + (String::from("QDRANT__SERVICE__ENABLE_TLS"), "true".to_string()), + (String::from("QDRANT__SERVICE__API_KEY"), API_TOKEN.to_hex_text().to_string()), ]); let server_spawn_clone = QDRANT_SERVER.clone(); @@ -97,13 +118,51 @@ pub fn start_qdrant_server() { /// Stops the Qdrant server process. pub fn stop_qdrant_server() { + drop_tmpdir(); if let Some(server_process) = QDRANT_SERVER.lock().unwrap().take() { let server_kill_result = server_process.kill(); match server_kill_result { - Ok(_) => info!("Qdrant server process was stopped."), - Err(e) => error!("Failed to stop Qdrant server process: {e}."), + Ok(_) => warn!(Source = "Qdrant"; "Qdrant server process was stopped."), + Err(e) => error!(Source = "Qdrant"; "Failed to stop Qdrant server process: {e}."), } } else { - warn!("Qdrant server process was not started or is already stopped."); + warn!(Source = "Qdrant"; "Qdrant server process was not started or is already stopped."); } +} + +pub fn create_temp_tls_files(path: PathBuf) -> Result<(PathBuf, PathBuf), Box> { + let (certificate, cert_private_key, cert_fingerprint) = generate_certificate(); + + let temp_dir = init_tmpdir_in(path); + + let cert_path = temp_dir.join("cert.pem"); + let key_path = temp_dir.join("key.pem"); + + let mut cert_file = File::create(&cert_path)?; + cert_file.write_all(&*certificate)?; + + let mut key_file = File::create(&key_path)?; + key_file.write_all(&*cert_private_key)?; + + CERTIFICATE_FINGERPRINT.set(cert_fingerprint).expect("Could not set the certificate fingerprint."); + + Ok((cert_path, key_path)) +} + +pub fn init_tmpdir_in>(path: P) -> PathBuf { + let mut guard = TMPDIR.lock().unwrap(); + let dir = guard.get_or_insert_with(|| { + Builder::new() + .prefix("cert-") + .tempdir_in(path) + .expect("failed to create tempdir") + }); + + dir.path().to_path_buf() +} + +pub fn drop_tmpdir() { + let mut guard = TMPDIR.lock().unwrap(); + *guard = None; + warn!(Source = "Qdrant"; "Temporary directory for TLS was dropped."); } \ No newline at end of file diff --git a/runtime/src/runtime_api.rs b/runtime/src/runtime_api.rs index 529d9636..d08c5abe 100644 --- a/runtime/src/runtime_api.rs +++ b/runtime/src/runtime_api.rs @@ -3,7 +3,7 @@ use once_cell::sync::Lazy; use rocket::config::Shutdown; use rocket::figment::Figment; use rocket::routes; -use crate::certificate::{CERTIFICATE, CERTIFICATE_PRIVATE_KEY}; +use crate::runtime_certificate::{CERTIFICATE, CERTIFICATE_PRIVATE_KEY}; use crate::environment::is_dev; use crate::network::get_available_port; diff --git a/runtime/src/runtime_api_token.rs b/runtime/src/runtime_api_token.rs new file mode 100644 index 00000000..f1e762c9 --- /dev/null +++ b/runtime/src/runtime_api_token.rs @@ -0,0 +1,40 @@ +use once_cell::sync::Lazy; +use rocket::http::Status; +use rocket::Request; +use rocket::request::FromRequest; +use crate::api_token::{generate_api_token, APIToken}; + +pub static API_TOKEN: Lazy = Lazy::new(|| generate_api_token()); + +/// The request outcome type used to handle API token requests. +type RequestOutcome = rocket::request::Outcome; + +/// The request outcome implementation for the API token. +#[rocket::async_trait] +impl<'r> FromRequest<'r> for APIToken { + type Error = APITokenError; + + /// Handles the API token requests. + async fn from_request(request: &'r Request<'_>) -> RequestOutcome { + let token = request.headers().get_one("token"); + match token { + Some(token) => { + let received_token = APIToken::from_hex_text(token); + if API_TOKEN.validate(&received_token) { + RequestOutcome::Success(received_token) + } else { + RequestOutcome::Error((Status::Unauthorized, APITokenError::Invalid)) + } + } + + None => RequestOutcome::Error((Status::Unauthorized, APITokenError::Missing)), + } + } +} + +/// The API token error types. +#[derive(Debug)] +pub enum APITokenError { + Missing, + Invalid, +} \ No newline at end of file diff --git a/runtime/src/runtime_certificate.rs b/runtime/src/runtime_certificate.rs new file mode 100644 index 00000000..abbde65c --- /dev/null +++ b/runtime/src/runtime_certificate.rs @@ -0,0 +1,26 @@ +use std::sync::OnceLock; +use log::info; +use crate::certificate_factory::generate_certificate; + +/// The certificate used for the runtime API server. +pub static CERTIFICATE: OnceLock> = OnceLock::new(); + +/// The private key used for the certificate of the runtime API server. +pub static CERTIFICATE_PRIVATE_KEY: OnceLock> = OnceLock::new(); + +/// The fingerprint of the certificate used for the runtime API server. +pub static CERTIFICATE_FINGERPRINT: OnceLock = OnceLock::new(); + +/// Generates a TLS certificate for the runtime API server. +pub fn generate_runtime_certificate() { + + info!("Try to generate a TLS certificate for the runtime API server..."); + + let (certificate, cer_private_key, cer_fingerprint) = generate_certificate(); + + CERTIFICATE_FINGERPRINT.set(cer_fingerprint).expect("Could not set the certificate fingerprint."); + CERTIFICATE.set(certificate).expect("Could not set the certificate."); + CERTIFICATE_PRIVATE_KEY.set(cer_private_key).expect("Could not set the private key."); + + info!("Done generating certificate for the runtime API server."); +} \ No newline at end of file