Added TLS and API token support for Qdrant communication.

This commit is contained in:
PaulKoudelka 2026-01-13 16:38:22 +01:00
parent 8400422044
commit 1dcfd19f72
24 changed files with 302 additions and 168 deletions

View File

@ -344,7 +344,7 @@ jobs:
echo "Cleaning up ..."
rm -fr "$TMP"
- name: Install PDFium (Windows)
- name: Deploy PDFium (Windows)
if: matrix.platform == 'windows-latest'
env:
PDFIUM_VERSION: ${{ env.PDFIUM_VERSION }}
@ -464,7 +464,7 @@ jobs:
echo "Cleaning up ..."
rm -fr "$TMP"
- name: Install Qdrant (Windows)
- name: Deploy Qdrant (Windows)
if: matrix.platform == 'windows-latest'
env:
QDRANT_VERSION: ${{ env.QDRANT_VERSION }}
@ -479,6 +479,11 @@ jobs:
$DB_SOURCE = "qdrant.exe"
$DB_TARGET = "qdrant.exe"
}
"win-arm64" {
$QDRANT_FILE = "x86_64-pc-windows-msvc.zip"
$DB_SOURCE = "qdrant.exe"
$DB_TARGET = "qdrant.exe""
}
default {
Write-Error "Unknown platform: $($env:DOTNET_RUNTIME)"
exit 1

View File

@ -91,10 +91,11 @@ public static class Qdrant
RID.OSX_ARM64 => new("qdrant", "qdrant-aarch64-apple-darwin"),
RID.OSX_X64 => new("qdrant", "qdrant-x86_64-apple-darwin"),
RID.LINUX_ARM64 => new("qdrant", "qdrant-aarch64-unknown-linux-gnu"),
RID.LINUX_ARM64 => new("qdrant", "qdrant-aarch64-unknown-linux-musl"),
RID.LINUX_X64 => new("qdrant", "qdrant-x86_64-unknown-linux-gnu"),
RID.WIN_X64 => new("qdrant.exe", "qdrant-x86_64-pc-windows-msvc.exe"),
RID.WIN_ARM64 => new("qdrant.exe", "qdrant-aarch64-pc-windows-msvc.exe"),
_ => new(string.Empty, string.Empty),
};
@ -111,7 +112,7 @@ public static class Qdrant
RID.OSX_X64 => $"{baseUrl}x86_64-apple-darwin.tar.gz",
RID.WIN_X64 => $"{baseUrl}x86_64-pc-windows-msvc.zip",
#warning We have to handle Qdrant for Windows ARM
RID.WIN_ARM64 => $"{baseUrl}x86_64-pc-windows-msvc.zip",
_ => string.Empty,
};

View File

@ -52,6 +52,7 @@
<PackageReference Include="Microsoft.Extensions.FileProviders.Embedded" Version="9.0.11" />
<PackageReference Include="MudBlazor" Version="8.12.0" />
<PackageReference Include="MudBlazor.Markdown" Version="8.11.0" />
<PackageReference Include="Qdrant.Client" Version="1.16.1" />
<PackageReference Include="ReverseMarkdown" Version="4.7.1" />
<PackageReference Include="LuaCSharp" Version="0.4.2" />
</ItemGroup>

View File

@ -25,12 +25,12 @@
</MudText>
<MudCollapse Expanded="@showDatabaseDetails">
<MudText Typo="Typo.body1" Class="mt-2 mb-2">
@foreach (var (Label, Value) in DatabaseClient.GetDisplayInfo())
@foreach (var (label, value) in DatabaseDisplayInfo)
{
<div style="display: flex; align-items: center; gap: 8px;">
<MudIcon Icon="@Icons.Material.Filled.ArrowRightAlt"/>
<span>@Label: @Value</span>
<MudCopyClipboardButton TooltipMessage="@(T("Copies the following to the clipboard")+": "+Value)" StringContent=@Value/>
<span>@label: @value</span>
<MudCopyClipboardButton TooltipMessage="@(T("Copies the following to the clipboard")+": "+value)" StringContent=@value/>
</div>
}
</MudText>

View File

@ -70,6 +70,8 @@ public partial class About : MSGComponentBase
private bool showDatabaseDetails = false;
private IPluginMetadata? configPlug = PluginFactory.AvailablePlugins.FirstOrDefault(x => x.Type is PluginType.CONFIGURATION);
private List<(string Label, string Value)> DatabaseDisplayInfo = new();
/// <summary>
/// Determines whether the enterprise configuration has details that can be shown/hidden.
@ -105,6 +107,11 @@ public partial class About : MSGComponentBase
this.osLanguage = await this.RustService.ReadUserLanguage();
this.logPaths = await this.RustService.GetLogPaths();
await foreach (var item in this.DatabaseClient.GetDisplayInfo())
{
this.DatabaseDisplayInfo.Add(item);
}
// Determine the Pandoc version may take some time, so we start it here
// without waiting for the result:
_ = this.DeterminePandocVersion();

View File

@ -27,6 +27,7 @@ internal sealed class Program
public static string API_TOKEN = null!;
public static IServiceProvider SERVICE_PROVIDER = null!;
public static ILoggerFactory LOGGER_FACTORY = null!;
public static DatabaseClient DATABASE_CLIENT = null!;
public static async Task Main()
{
@ -102,6 +103,20 @@ internal sealed class Program
Console.WriteLine("Error: Failed to get the Qdrant gRPC port from Rust.");
return;
}
if (qdrantInfo.Fingerprint == string.Empty)
{
Console.WriteLine("Error: Failed to get the Qdrant fingerprint from Rust.");
return;
}
if (qdrantInfo.ApiToken == string.Empty)
{
Console.WriteLine("Error: Failed to get the Qdrant API token from Rust.");
return;
}
var databaseClient = new QdrantClientImplementation("Qdrant", qdrantInfo.Path, qdrantInfo.PortHttp, qdrantInfo.PortGrpc, qdrantInfo.Fingerprint, qdrantInfo.ApiToken);
var builder = WebApplication.CreateBuilder();
@ -155,7 +170,7 @@ internal sealed class Program
builder.Services.AddHostedService<UpdateService>();
builder.Services.AddHostedService<TemporaryChatService>();
builder.Services.AddHostedService<EnterpriseEnvironmentService>();
builder.Services.AddSingleton<DatabaseClient>(new QdrantClient("Qdrant", qdrantInfo.Path, qdrantInfo.PortHttp, qdrantInfo.PortGrpc));
builder.Services.AddSingleton<DatabaseClient>(databaseClient);
// ReSharper disable AccessToDisposedClosure
builder.Services.AddHostedService<RustService>(_ => rust);
@ -211,6 +226,10 @@ internal sealed class Program
RUST_SERVICE = rust;
ENCRYPTION = encryption;
var databaseLogger = app.Services.GetRequiredService<ILogger<DatabaseClient>>();
databaseClient.SetLogger(databaseLogger);
DATABASE_CLIENT = databaseClient;
programLogger.LogInformation("Initialize internal file system.");
app.Use(Redirect.HandlerContentAsync);
@ -238,7 +257,6 @@ internal sealed class Program
await rust.AppIsReady();
programLogger.LogInformation("The AI Studio server is ready.");
TaskScheduler.UnobservedTaskException += (sender, taskArgs) =>
{
programLogger.LogError(taskArgs.Exception, $"Unobserved task exception by sender '{sender ?? "n/a"}'.");
@ -248,6 +266,7 @@ internal sealed class Program
await serverTask;
RUST_SERVICE.Dispose();
DATABASE_CLIENT.Dispose();
PluginFactory.Dispose();
programLogger.LogInformation("The AI Studio server was stopped.");
}

View File

@ -1,57 +1,29 @@
namespace AIStudio.Tools.Databases;
public abstract class DatabaseClient
public abstract class DatabaseClient(string name, string path)
{
public string Name { get; }
private string Path { get; }
public DatabaseClient(string name, string path)
{
this.Name = name;
this.Path = path;
}
public string Name => name;
private string Path => path;
protected ILogger<DatabaseClient>? logger;
public abstract IEnumerable<(string Label, string Value)> GetDisplayInfo();
public abstract IAsyncEnumerable<(string Label, string Value)> GetDisplayInfo();
public string GetStorageSize()
{
if (string.IsNullOrEmpty(this.Path))
if (string.IsNullOrWhiteSpace(this.Path))
{
Console.WriteLine($"Error: Database path '{this.Path}' cannot be null or empty.");
this.logger!.LogError($"Error: Database path '{this.Path}' cannot be null or empty.");
return "0 B";
}
if (!Directory.Exists(this.Path))
{
Console.WriteLine($"Error: Database path '{this.Path}' does not exist.");
this.logger!.LogError($"Error: Database path '{this.Path}' does not exist.");
return "0 B";
}
long size = 0;
var stack = new Stack<string>();
stack.Push(this.Path);
while (stack.Count > 0)
{
string directory = stack.Pop();
try
{
var files = Directory.GetFiles(directory);
size += files.Sum(file => new FileInfo(file).Length);
var subDirectories = Directory.GetDirectories(directory);
foreach (var subDirectory in subDirectories)
{
stack.Push(subDirectory);
}
}
catch (UnauthorizedAccessException)
{
Console.WriteLine($"No access to {directory}");
}
catch (Exception ex)
{
Console.WriteLine($"An error encountered while processing {directory}: ");
Console.WriteLine($"{ ex.Message}");
}
}
var files = Directory.EnumerateFiles(this.Path, "*", SearchOption.AllDirectories)
.Where(file => !System.IO.Path.GetDirectoryName(file)!.Contains("cert", StringComparison.OrdinalIgnoreCase));
var size = files.Sum(file => new FileInfo(file).Length);
return FormatBytes(size);
}
@ -68,4 +40,11 @@ public abstract class DatabaseClient
return $"{size:0##} {suffixes[suffixIndex]}";
}
public void SetLogger(ILogger<DatabaseClient> logService)
{
this.logger = logService;
}
public abstract void Dispose();
}

View File

@ -1,15 +0,0 @@
namespace AIStudio.Tools.Databases.Qdrant;
public class QdrantClient(string name, string path, int httpPort, int grpcPort) : DatabaseClient(name, path)
{
private int HttpPort { get; } = httpPort;
private int GrpcPort { get; } = grpcPort;
private string IpAddress { get; } = "127.0.0.1";
public override IEnumerable<(string Label, string Value)> GetDisplayInfo()
{
yield return ("HTTP Port", this.HttpPort.ToString());
yield return ("gRPC Port", this.GrpcPort.ToString());
yield return ("Storage Size", $"{base.GetStorageSize()}");
}
}

View File

@ -0,0 +1,61 @@
using Qdrant.Client;
using Qdrant.Client.Grpc;
namespace AIStudio.Tools.Databases.Qdrant;
public class QdrantClientImplementation : DatabaseClient
{
private int HttpPort { get; }
private int GrpcPort { get; }
private string IpAddress => "localhost";
private QdrantClient GrpcClient { get; }
private string Fingerprint { get; }
private string ApiToken { get; }
public QdrantClientImplementation(string name, string path, int httpPort, int grpcPort, string fingerprint, string apiToken): base(name, path)
{
this.HttpPort = httpPort;
this.GrpcPort = grpcPort;
this.Fingerprint = fingerprint;
this.ApiToken = apiToken;
this.GrpcClient = this.CreateQdrantClient();
}
public QdrantClient CreateQdrantClient()
{
var address = "https://" + this.IpAddress + ":" + this.GrpcPort;
var channel = QdrantChannel.ForAddress(address, new ClientConfiguration
{
ApiKey = this.ApiToken,
CertificateThumbprint = this.Fingerprint
});
var grpcClient = new QdrantGrpcClient(channel);
return new QdrantClient(grpcClient);
}
public async Task<string> GetVersion()
{
var operation = await this.GrpcClient.HealthAsync();
return "v"+operation.Version;
}
public async Task<string> GetCollectionsAmount()
{
var operation = await this.GrpcClient.ListCollectionsAsync();
return operation.Count.ToString();
}
public override async IAsyncEnumerable<(string Label, string Value)> GetDisplayInfo()
{
yield return ("HTTP port", this.HttpPort.ToString());
yield return ("gRPC port", this.GrpcPort.ToString());
yield return ("Extracted version", await this.GetVersion());
yield return ("Storage size", $"{base.GetStorageSize()}");
yield return ("Amount of collections", await this.GetCollectionsAmount());
}
public override void Dispose()
{
this.GrpcClient.Dispose();
}
}

View File

@ -10,4 +10,6 @@ public record struct QdrantInfo
public string Path { get; init; }
public int PortHttp { get; init; }
public int PortGrpc { get; init; }
public string Fingerprint { get; init; }
public string ApiToken { get; init; }
}

View File

@ -9,7 +9,7 @@ public sealed partial class RustService
try
{
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(45));
var response = await this.http.GetFromJsonAsync<QdrantInfo>("/system/qdrant/port", this.jsonRustSerializerOptions, cts.Token);
var response = await this.http.GetFromJsonAsync<QdrantInfo>("/system/qdrant/info", this.jsonRustSerializerOptions, cts.Token);
return response;
}
catch (Exception e)
@ -20,6 +20,8 @@ public sealed partial class RustService
Path = string.Empty,
PortHttp = 0,
PortGrpc = 0,
Fingerprint = string.Empty,
ApiToken = string.Empty,
};
}
}

View File

@ -39,6 +39,7 @@ pdfium-render = "0.8.34"
sys-locale = "0.3.2"
cfg-if = "1.0.1"
pptx-to-md = "0.4.0"
tempfile = "3.8"
# Fixes security vulnerability downstream, where the upstream is not fixed yet:
url = "2.5"

View File

@ -332,10 +332,10 @@ telemetry_disabled: true
# Required if either service.enable_tls or cluster.p2p.enable_tls is true.
tls:
# Server certificate chain file
cert: ./tls/cert.pem
# cert: ./tls/cert.pem
# Server private key file
key: ./tls/key.pem
# key: ./tls/key.pem
# Certificate authority certificate file.
# This certificate will be used to validate the certificates

View File

@ -1,21 +1,5 @@
use log::info;
use once_cell::sync::Lazy;
use rand::{RngCore, SeedableRng};
use rocket::http::Status;
use rocket::Request;
use rocket::request::FromRequest;
/// The API token used to authenticate requests.
pub static API_TOKEN: Lazy<APIToken> = Lazy::new(|| {
let mut token = [0u8; 32];
let mut rng = rand_chacha::ChaChaRng::from_os_rng();
rng.fill_bytes(&mut token);
let token = APIToken::from_bytes(token.to_vec());
info!("API token was generated successfully.");
token
});
use rand_chacha::ChaChaRng;
/// The API token data structure used to authenticate requests.
pub struct APIToken {
@ -34,7 +18,7 @@ impl APIToken {
}
/// Creates a new API token from a hexadecimal text.
fn from_hex_text(hex_text: &str) -> Self {
pub fn from_hex_text(hex_text: &str) -> Self {
APIToken {
hex_text: hex_text.to_string(),
}
@ -45,40 +29,14 @@ impl APIToken {
}
/// Validates the received token against the valid token.
fn validate(&self, received_token: &Self) -> bool {
pub fn validate(&self, received_token: &Self) -> bool {
received_token.to_hex_text() == self.to_hex_text()
}
}
/// The request outcome type used to handle API token requests.
type RequestOutcome<R, T> = rocket::request::Outcome<R, T>;
/// The request outcome implementation for the API token.
#[rocket::async_trait]
impl<'r> FromRequest<'r> for APIToken {
type Error = APITokenError;
/// Handles the API token requests.
async fn from_request(request: &'r Request<'_>) -> RequestOutcome<Self, Self::Error> {
let token = request.headers().get_one("token");
match token {
Some(token) => {
let received_token = APIToken::from_hex_text(token);
if API_TOKEN.validate(&received_token) {
RequestOutcome::Success(received_token)
} else {
RequestOutcome::Error((Status::Unauthorized, APITokenError::Invalid))
}
}
None => RequestOutcome::Error((Status::Unauthorized, APITokenError::Missing)),
}
}
}
/// The API token error types.
#[derive(Debug)]
pub enum APITokenError {
Missing,
Invalid,
pub fn generate_api_token() -> APIToken {
let mut token = [0u8; 32];
let mut rng = ChaChaRng::from_os_rng();
rng.fill_bytes(&mut token);
APIToken::from_bytes(token.to_vec())
}

View File

@ -17,7 +17,7 @@ use crate::dotnet::stop_dotnet_server;
use crate::environment::{is_prod, is_dev, CONFIG_DIRECTORY, DATA_DIRECTORY};
use crate::log::switch_to_file_logging;
use crate::pdfium::PDFIUM_LIB_PATH;
use crate::qdrant::start_qdrant_server;
use crate::qdrant::{start_qdrant_server, stop_qdrant_server};
/// The Tauri main window.
static MAIN_WINDOW: Lazy<Mutex<Option<Window>>> = Lazy::new(|| Mutex::new(None));
@ -174,6 +174,7 @@ pub fn start_tauri() {
RunEvent::ExitRequested { .. } => {
warn!(Source = "Tauri"; "Run event: exit was requested.");
stop_qdrant_server();
}
RunEvent::Ready => {

View File

@ -1,38 +0,0 @@
use std::sync::OnceLock;
use log::info;
use rcgen::generate_simple_self_signed;
use sha2::{Sha256, Digest};
/// The certificate used for the runtime API server.
pub static CERTIFICATE: OnceLock<Vec<u8>> = OnceLock::new();
/// The private key used for the certificate of the runtime API server.
pub static CERTIFICATE_PRIVATE_KEY: OnceLock<Vec<u8>> = OnceLock::new();
/// The fingerprint of the certificate used for the runtime API server.
pub static CERTIFICATE_FINGERPRINT: OnceLock<String> = OnceLock::new();
/// Generates a TLS certificate for the runtime API server.
pub fn generate_certificate() {
info!("Try to generate a TLS certificate for the runtime API server...");
let subject_alt_names = vec!["localhost".to_string()];
let certificate_data = generate_simple_self_signed(subject_alt_names).unwrap();
let certificate_binary_data = certificate_data.cert.der().to_vec();
let certificate_fingerprint = Sha256::digest(certificate_binary_data).to_vec();
let certificate_fingerprint = certificate_fingerprint.iter().fold(String::new(), |mut result, byte| {
result.push_str(&format!("{:02x}", byte));
result
});
let certificate_fingerprint = certificate_fingerprint.to_uppercase();
CERTIFICATE_FINGERPRINT.set(certificate_fingerprint.clone()).expect("Could not set the certificate fingerprint.");
CERTIFICATE.set(certificate_data.cert.pem().as_bytes().to_vec()).expect("Could not set the certificate.");
CERTIFICATE_PRIVATE_KEY.set(certificate_data.signing_key.serialize_pem().as_bytes().to_vec()).expect("Could not set the private key.");
info!("Certificate fingerprint: '{certificate_fingerprint}'.");
info!("Done generating certificate for the runtime API server.");
}

View File

@ -0,0 +1,22 @@
use log::info;
use rcgen::generate_simple_self_signed;
use sha2::{Sha256, Digest};
pub fn generate_certificate() -> (Vec<u8>, Vec<u8>, String) {
let subject_alt_names = vec!["localhost".to_string()];
let certificate_data = generate_simple_self_signed(subject_alt_names).unwrap();
let certificate_binary_data = certificate_data.cert.der().to_vec();
let certificate_fingerprint = Sha256::digest(certificate_binary_data).to_vec();
let certificate_fingerprint = certificate_fingerprint.iter().fold(String::new(), |mut result, byte| {
result.push_str(&format!("{:02x}", byte));
result
});
let certificate_fingerprint = certificate_fingerprint.to_uppercase();
info!("Certificate fingerprint: '{certificate_fingerprint}'.");
(certificate_data.cert.pem().as_bytes().to_vec(), certificate_data.signing_key.serialize_pem().as_bytes().to_vec(), certificate_fingerprint.clone())
}

View File

@ -7,9 +7,10 @@ use once_cell::sync::Lazy;
use rocket::get;
use tauri::api::process::{Command, CommandChild, CommandEvent};
use tauri::Url;
use crate::api_token::{APIToken, API_TOKEN};
use crate::api_token::APIToken;
use crate::runtime_api_token::API_TOKEN;
use crate::app_window::change_location_to;
use crate::certificate::CERTIFICATE_FINGERPRINT;
use crate::runtime_certificate::CERTIFICATE_FINGERPRINT;
use crate::encryption::ENCRYPTION;
use crate::environment::is_dev;
use crate::network::get_available_port;

View File

@ -8,9 +8,11 @@ pub mod app_window;
pub mod secret;
pub mod clipboard;
pub mod runtime_api;
pub mod certificate;
pub mod runtime_certificate;
pub mod file_data;
pub mod metadata;
pub mod pdfium;
pub mod pandoc;
pub mod qdrant;
pub mod qdrant;
pub mod certificate_factory;
pub mod runtime_api_token;

View File

@ -6,7 +6,7 @@ extern crate core;
use log::{info, warn};
use mindwork_ai_studio::app_window::start_tauri;
use mindwork_ai_studio::certificate::{generate_certificate};
use mindwork_ai_studio::runtime_certificate::{generate_runtime_certificate};
use mindwork_ai_studio::dotnet::start_dotnet_server;
use mindwork_ai_studio::environment::is_dev;
use mindwork_ai_studio::log::init_logging;
@ -46,7 +46,7 @@ async fn main() {
info!("Running in production mode.");
}
generate_certificate();
generate_runtime_certificate();
start_runtime_api();
if is_dev() {

View File

@ -1,6 +1,8 @@
use std::collections::HashMap;
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, OnceLock};
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
use rocket::get;
@ -9,6 +11,9 @@ use rocket::serde::Serialize;
use tauri::api::process::{Command, CommandChild, CommandEvent};
use crate::api_token::{APIToken};
use crate::environment::DATA_DIRECTORY;
use crate::certificate_factory::generate_certificate;
use std::path::PathBuf;
use tempfile::{TempDir, Builder};
// Qdrant server process started in a separate process and can communicate
// via HTTP or gRPC with the .NET server and the runtime process
@ -23,26 +28,38 @@ static QDRANT_SERVER_PORT_GRPC: Lazy<u16> = Lazy::new(|| {
crate::network::get_available_port().unwrap_or(6334)
});
pub static CERTIFICATE_FINGERPRINT: OnceLock<String> = OnceLock::new();
static API_TOKEN: Lazy<APIToken> = Lazy::new(|| {
crate::api_token::generate_api_token()
});
static TMPDIR: Lazy<Mutex<Option<TempDir>>> = Lazy::new(|| Mutex::new(None));
#[derive(Serialize)]
pub struct ProvideQdrantInfo {
path: String,
port_http: u16,
port_grpc: u16,
fingerprint: String,
api_token: String,
}
#[get("/system/qdrant/port")]
#[get("/system/qdrant/info")]
pub fn qdrant_port(_token: APIToken) -> Json<ProvideQdrantInfo> {
return Json(ProvideQdrantInfo {
path: Path::new(DATA_DIRECTORY.get().unwrap()).join("databases").join("qdrant").to_str().unwrap().to_string(),
port_http: *QDRANT_SERVER_PORT_HTTP,
port_grpc: *QDRANT_SERVER_PORT_GRPC,
fingerprint: CERTIFICATE_FINGERPRINT.get().expect("Certificate fingerprint not available").to_string(),
api_token: API_TOKEN.to_hex_text().to_string(),
});
}
/// Starts the Qdrant server in a separate process.
pub fn start_qdrant_server() {
let base_path = DATA_DIRECTORY.get().unwrap();
let base_path = DATA_DIRECTORY.get().unwrap();
let (cert_path, key_path) =create_temp_tls_files(Path::new(base_path).join("databases").join("qdrant")).unwrap();
let storage_path = Path::new(base_path).join("databases").join("qdrant").join("storage").to_str().unwrap().to_string();
let snapshot_path = Path::new(base_path).join("databases").join("qdrant").join("snapshots").to_str().unwrap().to_string();
@ -54,6 +71,10 @@ pub fn start_qdrant_server() {
(String::from("QDRANT_INIT_FILE_PATH"), init_path),
(String::from("QDRANT__STORAGE__STORAGE_PATH"), storage_path),
(String::from("QDRANT__STORAGE__SNAPSHOTS_PATH"), snapshot_path),
(String::from("QDRANT__TLS__CERT"), cert_path.to_str().unwrap().to_string()),
(String::from("QDRANT__TLS__KEY"), key_path.to_str().unwrap().to_string()),
(String::from("QDRANT__SERVICE__ENABLE_TLS"), "true".to_string()),
(String::from("QDRANT__SERVICE__API_KEY"), API_TOKEN.to_hex_text().to_string()),
]);
let server_spawn_clone = QDRANT_SERVER.clone();
@ -97,13 +118,51 @@ pub fn start_qdrant_server() {
/// Stops the Qdrant server process.
pub fn stop_qdrant_server() {
drop_tmpdir();
if let Some(server_process) = QDRANT_SERVER.lock().unwrap().take() {
let server_kill_result = server_process.kill();
match server_kill_result {
Ok(_) => info!("Qdrant server process was stopped."),
Err(e) => error!("Failed to stop Qdrant server process: {e}."),
Ok(_) => warn!(Source = "Qdrant"; "Qdrant server process was stopped."),
Err(e) => error!(Source = "Qdrant"; "Failed to stop Qdrant server process: {e}."),
}
} else {
warn!("Qdrant server process was not started or is already stopped.");
warn!(Source = "Qdrant"; "Qdrant server process was not started or is already stopped.");
}
}
pub fn create_temp_tls_files(path: PathBuf) -> Result<(PathBuf, PathBuf), Box<dyn std::error::Error>> {
let (certificate, cert_private_key, cert_fingerprint) = generate_certificate();
let temp_dir = init_tmpdir_in(path);
let cert_path = temp_dir.join("cert.pem");
let key_path = temp_dir.join("key.pem");
let mut cert_file = File::create(&cert_path)?;
cert_file.write_all(&*certificate)?;
let mut key_file = File::create(&key_path)?;
key_file.write_all(&*cert_private_key)?;
CERTIFICATE_FINGERPRINT.set(cert_fingerprint).expect("Could not set the certificate fingerprint.");
Ok((cert_path, key_path))
}
pub fn init_tmpdir_in<P: AsRef<Path>>(path: P) -> PathBuf {
let mut guard = TMPDIR.lock().unwrap();
let dir = guard.get_or_insert_with(|| {
Builder::new()
.prefix("cert-")
.tempdir_in(path)
.expect("failed to create tempdir")
});
dir.path().to_path_buf()
}
pub fn drop_tmpdir() {
let mut guard = TMPDIR.lock().unwrap();
*guard = None;
warn!(Source = "Qdrant"; "Temporary directory for TLS was dropped.");
}

View File

@ -3,7 +3,7 @@ use once_cell::sync::Lazy;
use rocket::config::Shutdown;
use rocket::figment::Figment;
use rocket::routes;
use crate::certificate::{CERTIFICATE, CERTIFICATE_PRIVATE_KEY};
use crate::runtime_certificate::{CERTIFICATE, CERTIFICATE_PRIVATE_KEY};
use crate::environment::is_dev;
use crate::network::get_available_port;

View File

@ -0,0 +1,40 @@
use once_cell::sync::Lazy;
use rocket::http::Status;
use rocket::Request;
use rocket::request::FromRequest;
use crate::api_token::{generate_api_token, APIToken};
pub static API_TOKEN: Lazy<APIToken> = Lazy::new(|| generate_api_token());
/// The request outcome type used to handle API token requests.
type RequestOutcome<R, T> = rocket::request::Outcome<R, T>;
/// The request outcome implementation for the API token.
#[rocket::async_trait]
impl<'r> FromRequest<'r> for APIToken {
type Error = APITokenError;
/// Handles the API token requests.
async fn from_request(request: &'r Request<'_>) -> RequestOutcome<Self, Self::Error> {
let token = request.headers().get_one("token");
match token {
Some(token) => {
let received_token = APIToken::from_hex_text(token);
if API_TOKEN.validate(&received_token) {
RequestOutcome::Success(received_token)
} else {
RequestOutcome::Error((Status::Unauthorized, APITokenError::Invalid))
}
}
None => RequestOutcome::Error((Status::Unauthorized, APITokenError::Missing)),
}
}
}
/// The API token error types.
#[derive(Debug)]
pub enum APITokenError {
Missing,
Invalid,
}

View File

@ -0,0 +1,26 @@
use std::sync::OnceLock;
use log::info;
use crate::certificate_factory::generate_certificate;
/// The certificate used for the runtime API server.
pub static CERTIFICATE: OnceLock<Vec<u8>> = OnceLock::new();
/// The private key used for the certificate of the runtime API server.
pub static CERTIFICATE_PRIVATE_KEY: OnceLock<Vec<u8>> = OnceLock::new();
/// The fingerprint of the certificate used for the runtime API server.
pub static CERTIFICATE_FINGERPRINT: OnceLock<String> = OnceLock::new();
/// Generates a TLS certificate for the runtime API server.
pub fn generate_runtime_certificate() {
info!("Try to generate a TLS certificate for the runtime API server...");
let (certificate, cer_private_key, cer_fingerprint) = generate_certificate();
CERTIFICATE_FINGERPRINT.set(cer_fingerprint).expect("Could not set the certificate fingerprint.");
CERTIFICATE.set(certificate).expect("Could not set the certificate.");
CERTIFICATE_PRIVATE_KEY.set(cer_private_key).expect("Could not set the private key.");
info!("Done generating certificate for the runtime API server.");
}