mirror of
https://github.com/MindWorkAI/AI-Studio.git
synced 2026-02-12 20:01:37 +00:00
Added TLS and API token support for Qdrant communication.
This commit is contained in:
parent
8400422044
commit
1dcfd19f72
9
.github/workflows/build-and-release.yml
vendored
9
.github/workflows/build-and-release.yml
vendored
@ -344,7 +344,7 @@ jobs:
|
||||
echo "Cleaning up ..."
|
||||
rm -fr "$TMP"
|
||||
|
||||
- name: Install PDFium (Windows)
|
||||
- name: Deploy PDFium (Windows)
|
||||
if: matrix.platform == 'windows-latest'
|
||||
env:
|
||||
PDFIUM_VERSION: ${{ env.PDFIUM_VERSION }}
|
||||
@ -464,7 +464,7 @@ jobs:
|
||||
echo "Cleaning up ..."
|
||||
rm -fr "$TMP"
|
||||
|
||||
- name: Install Qdrant (Windows)
|
||||
- name: Deploy Qdrant (Windows)
|
||||
if: matrix.platform == 'windows-latest'
|
||||
env:
|
||||
QDRANT_VERSION: ${{ env.QDRANT_VERSION }}
|
||||
@ -479,6 +479,11 @@ jobs:
|
||||
$DB_SOURCE = "qdrant.exe"
|
||||
$DB_TARGET = "qdrant.exe"
|
||||
}
|
||||
"win-arm64" {
|
||||
$QDRANT_FILE = "x86_64-pc-windows-msvc.zip"
|
||||
$DB_SOURCE = "qdrant.exe"
|
||||
$DB_TARGET = "qdrant.exe""
|
||||
}
|
||||
default {
|
||||
Write-Error "Unknown platform: $($env:DOTNET_RUNTIME)"
|
||||
exit 1
|
||||
|
||||
@ -91,10 +91,11 @@ public static class Qdrant
|
||||
RID.OSX_ARM64 => new("qdrant", "qdrant-aarch64-apple-darwin"),
|
||||
RID.OSX_X64 => new("qdrant", "qdrant-x86_64-apple-darwin"),
|
||||
|
||||
RID.LINUX_ARM64 => new("qdrant", "qdrant-aarch64-unknown-linux-gnu"),
|
||||
RID.LINUX_ARM64 => new("qdrant", "qdrant-aarch64-unknown-linux-musl"),
|
||||
RID.LINUX_X64 => new("qdrant", "qdrant-x86_64-unknown-linux-gnu"),
|
||||
|
||||
RID.WIN_X64 => new("qdrant.exe", "qdrant-x86_64-pc-windows-msvc.exe"),
|
||||
RID.WIN_ARM64 => new("qdrant.exe", "qdrant-aarch64-pc-windows-msvc.exe"),
|
||||
|
||||
_ => new(string.Empty, string.Empty),
|
||||
};
|
||||
@ -111,7 +112,7 @@ public static class Qdrant
|
||||
RID.OSX_X64 => $"{baseUrl}x86_64-apple-darwin.tar.gz",
|
||||
|
||||
RID.WIN_X64 => $"{baseUrl}x86_64-pc-windows-msvc.zip",
|
||||
#warning We have to handle Qdrant for Windows ARM
|
||||
RID.WIN_ARM64 => $"{baseUrl}x86_64-pc-windows-msvc.zip",
|
||||
|
||||
_ => string.Empty,
|
||||
};
|
||||
|
||||
@ -52,6 +52,7 @@
|
||||
<PackageReference Include="Microsoft.Extensions.FileProviders.Embedded" Version="9.0.11" />
|
||||
<PackageReference Include="MudBlazor" Version="8.12.0" />
|
||||
<PackageReference Include="MudBlazor.Markdown" Version="8.11.0" />
|
||||
<PackageReference Include="Qdrant.Client" Version="1.16.1" />
|
||||
<PackageReference Include="ReverseMarkdown" Version="4.7.1" />
|
||||
<PackageReference Include="LuaCSharp" Version="0.4.2" />
|
||||
</ItemGroup>
|
||||
|
||||
@ -25,12 +25,12 @@
|
||||
</MudText>
|
||||
<MudCollapse Expanded="@showDatabaseDetails">
|
||||
<MudText Typo="Typo.body1" Class="mt-2 mb-2">
|
||||
@foreach (var (Label, Value) in DatabaseClient.GetDisplayInfo())
|
||||
@foreach (var (label, value) in DatabaseDisplayInfo)
|
||||
{
|
||||
<div style="display: flex; align-items: center; gap: 8px;">
|
||||
<MudIcon Icon="@Icons.Material.Filled.ArrowRightAlt"/>
|
||||
<span>@Label: @Value</span>
|
||||
<MudCopyClipboardButton TooltipMessage="@(T("Copies the following to the clipboard")+": "+Value)" StringContent=@Value/>
|
||||
<span>@label: @value</span>
|
||||
<MudCopyClipboardButton TooltipMessage="@(T("Copies the following to the clipboard")+": "+value)" StringContent=@value/>
|
||||
</div>
|
||||
}
|
||||
</MudText>
|
||||
|
||||
@ -70,6 +70,8 @@ public partial class About : MSGComponentBase
|
||||
private bool showDatabaseDetails = false;
|
||||
|
||||
private IPluginMetadata? configPlug = PluginFactory.AvailablePlugins.FirstOrDefault(x => x.Type is PluginType.CONFIGURATION);
|
||||
|
||||
private List<(string Label, string Value)> DatabaseDisplayInfo = new();
|
||||
|
||||
/// <summary>
|
||||
/// Determines whether the enterprise configuration has details that can be shown/hidden.
|
||||
@ -105,6 +107,11 @@ public partial class About : MSGComponentBase
|
||||
this.osLanguage = await this.RustService.ReadUserLanguage();
|
||||
this.logPaths = await this.RustService.GetLogPaths();
|
||||
|
||||
await foreach (var item in this.DatabaseClient.GetDisplayInfo())
|
||||
{
|
||||
this.DatabaseDisplayInfo.Add(item);
|
||||
}
|
||||
|
||||
// Determine the Pandoc version may take some time, so we start it here
|
||||
// without waiting for the result:
|
||||
_ = this.DeterminePandocVersion();
|
||||
|
||||
@ -27,6 +27,7 @@ internal sealed class Program
|
||||
public static string API_TOKEN = null!;
|
||||
public static IServiceProvider SERVICE_PROVIDER = null!;
|
||||
public static ILoggerFactory LOGGER_FACTORY = null!;
|
||||
public static DatabaseClient DATABASE_CLIENT = null!;
|
||||
|
||||
public static async Task Main()
|
||||
{
|
||||
@ -102,6 +103,20 @@ internal sealed class Program
|
||||
Console.WriteLine("Error: Failed to get the Qdrant gRPC port from Rust.");
|
||||
return;
|
||||
}
|
||||
|
||||
if (qdrantInfo.Fingerprint == string.Empty)
|
||||
{
|
||||
Console.WriteLine("Error: Failed to get the Qdrant fingerprint from Rust.");
|
||||
return;
|
||||
}
|
||||
|
||||
if (qdrantInfo.ApiToken == string.Empty)
|
||||
{
|
||||
Console.WriteLine("Error: Failed to get the Qdrant API token from Rust.");
|
||||
return;
|
||||
}
|
||||
|
||||
var databaseClient = new QdrantClientImplementation("Qdrant", qdrantInfo.Path, qdrantInfo.PortHttp, qdrantInfo.PortGrpc, qdrantInfo.Fingerprint, qdrantInfo.ApiToken);
|
||||
|
||||
var builder = WebApplication.CreateBuilder();
|
||||
|
||||
@ -155,7 +170,7 @@ internal sealed class Program
|
||||
builder.Services.AddHostedService<UpdateService>();
|
||||
builder.Services.AddHostedService<TemporaryChatService>();
|
||||
builder.Services.AddHostedService<EnterpriseEnvironmentService>();
|
||||
builder.Services.AddSingleton<DatabaseClient>(new QdrantClient("Qdrant", qdrantInfo.Path, qdrantInfo.PortHttp, qdrantInfo.PortGrpc));
|
||||
builder.Services.AddSingleton<DatabaseClient>(databaseClient);
|
||||
|
||||
// ReSharper disable AccessToDisposedClosure
|
||||
builder.Services.AddHostedService<RustService>(_ => rust);
|
||||
@ -211,6 +226,10 @@ internal sealed class Program
|
||||
|
||||
RUST_SERVICE = rust;
|
||||
ENCRYPTION = encryption;
|
||||
|
||||
var databaseLogger = app.Services.GetRequiredService<ILogger<DatabaseClient>>();
|
||||
databaseClient.SetLogger(databaseLogger);
|
||||
DATABASE_CLIENT = databaseClient;
|
||||
|
||||
programLogger.LogInformation("Initialize internal file system.");
|
||||
app.Use(Redirect.HandlerContentAsync);
|
||||
@ -238,7 +257,6 @@ internal sealed class Program
|
||||
await rust.AppIsReady();
|
||||
programLogger.LogInformation("The AI Studio server is ready.");
|
||||
|
||||
|
||||
TaskScheduler.UnobservedTaskException += (sender, taskArgs) =>
|
||||
{
|
||||
programLogger.LogError(taskArgs.Exception, $"Unobserved task exception by sender '{sender ?? "n/a"}'.");
|
||||
@ -248,6 +266,7 @@ internal sealed class Program
|
||||
await serverTask;
|
||||
|
||||
RUST_SERVICE.Dispose();
|
||||
DATABASE_CLIENT.Dispose();
|
||||
PluginFactory.Dispose();
|
||||
programLogger.LogInformation("The AI Studio server was stopped.");
|
||||
}
|
||||
|
||||
@ -1,57 +1,29 @@
|
||||
namespace AIStudio.Tools.Databases;
|
||||
|
||||
public abstract class DatabaseClient
|
||||
public abstract class DatabaseClient(string name, string path)
|
||||
{
|
||||
public string Name { get; }
|
||||
private string Path { get; }
|
||||
|
||||
public DatabaseClient(string name, string path)
|
||||
{
|
||||
this.Name = name;
|
||||
this.Path = path;
|
||||
}
|
||||
public string Name => name;
|
||||
private string Path => path;
|
||||
protected ILogger<DatabaseClient>? logger;
|
||||
|
||||
public abstract IEnumerable<(string Label, string Value)> GetDisplayInfo();
|
||||
public abstract IAsyncEnumerable<(string Label, string Value)> GetDisplayInfo();
|
||||
|
||||
public string GetStorageSize()
|
||||
{
|
||||
if (string.IsNullOrEmpty(this.Path))
|
||||
if (string.IsNullOrWhiteSpace(this.Path))
|
||||
{
|
||||
Console.WriteLine($"Error: Database path '{this.Path}' cannot be null or empty.");
|
||||
this.logger!.LogError($"Error: Database path '{this.Path}' cannot be null or empty.");
|
||||
return "0 B";
|
||||
}
|
||||
|
||||
if (!Directory.Exists(this.Path))
|
||||
{
|
||||
Console.WriteLine($"Error: Database path '{this.Path}' does not exist.");
|
||||
this.logger!.LogError($"Error: Database path '{this.Path}' does not exist.");
|
||||
return "0 B";
|
||||
}
|
||||
long size = 0;
|
||||
var stack = new Stack<string>();
|
||||
stack.Push(this.Path);
|
||||
while (stack.Count > 0)
|
||||
{
|
||||
string directory = stack.Pop();
|
||||
try
|
||||
{
|
||||
var files = Directory.GetFiles(directory);
|
||||
size += files.Sum(file => new FileInfo(file).Length);
|
||||
var subDirectories = Directory.GetDirectories(directory);
|
||||
foreach (var subDirectory in subDirectories)
|
||||
{
|
||||
stack.Push(subDirectory);
|
||||
}
|
||||
}
|
||||
catch (UnauthorizedAccessException)
|
||||
{
|
||||
Console.WriteLine($"No access to {directory}");
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Console.WriteLine($"An error encountered while processing {directory}: ");
|
||||
Console.WriteLine($"{ ex.Message}");
|
||||
}
|
||||
}
|
||||
var files = Directory.EnumerateFiles(this.Path, "*", SearchOption.AllDirectories)
|
||||
.Where(file => !System.IO.Path.GetDirectoryName(file)!.Contains("cert", StringComparison.OrdinalIgnoreCase));
|
||||
var size = files.Sum(file => new FileInfo(file).Length);
|
||||
return FormatBytes(size);
|
||||
}
|
||||
|
||||
@ -68,4 +40,11 @@ public abstract class DatabaseClient
|
||||
|
||||
return $"{size:0##} {suffixes[suffixIndex]}";
|
||||
}
|
||||
|
||||
public void SetLogger(ILogger<DatabaseClient> logService)
|
||||
{
|
||||
this.logger = logService;
|
||||
}
|
||||
|
||||
public abstract void Dispose();
|
||||
}
|
||||
@ -1,15 +0,0 @@
|
||||
namespace AIStudio.Tools.Databases.Qdrant;
|
||||
|
||||
public class QdrantClient(string name, string path, int httpPort, int grpcPort) : DatabaseClient(name, path)
|
||||
{
|
||||
private int HttpPort { get; } = httpPort;
|
||||
private int GrpcPort { get; } = grpcPort;
|
||||
private string IpAddress { get; } = "127.0.0.1";
|
||||
|
||||
public override IEnumerable<(string Label, string Value)> GetDisplayInfo()
|
||||
{
|
||||
yield return ("HTTP Port", this.HttpPort.ToString());
|
||||
yield return ("gRPC Port", this.GrpcPort.ToString());
|
||||
yield return ("Storage Size", $"{base.GetStorageSize()}");
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,61 @@
|
||||
using Qdrant.Client;
|
||||
using Qdrant.Client.Grpc;
|
||||
|
||||
namespace AIStudio.Tools.Databases.Qdrant;
|
||||
|
||||
public class QdrantClientImplementation : DatabaseClient
|
||||
{
|
||||
private int HttpPort { get; }
|
||||
private int GrpcPort { get; }
|
||||
private string IpAddress => "localhost";
|
||||
private QdrantClient GrpcClient { get; }
|
||||
private string Fingerprint { get; }
|
||||
private string ApiToken { get; }
|
||||
|
||||
public QdrantClientImplementation(string name, string path, int httpPort, int grpcPort, string fingerprint, string apiToken): base(name, path)
|
||||
{
|
||||
this.HttpPort = httpPort;
|
||||
this.GrpcPort = grpcPort;
|
||||
this.Fingerprint = fingerprint;
|
||||
this.ApiToken = apiToken;
|
||||
this.GrpcClient = this.CreateQdrantClient();
|
||||
}
|
||||
|
||||
public QdrantClient CreateQdrantClient()
|
||||
{
|
||||
var address = "https://" + this.IpAddress + ":" + this.GrpcPort;
|
||||
var channel = QdrantChannel.ForAddress(address, new ClientConfiguration
|
||||
{
|
||||
ApiKey = this.ApiToken,
|
||||
CertificateThumbprint = this.Fingerprint
|
||||
});
|
||||
var grpcClient = new QdrantGrpcClient(channel);
|
||||
return new QdrantClient(grpcClient);
|
||||
}
|
||||
|
||||
public async Task<string> GetVersion()
|
||||
{
|
||||
var operation = await this.GrpcClient.HealthAsync();
|
||||
return "v"+operation.Version;
|
||||
}
|
||||
|
||||
public async Task<string> GetCollectionsAmount()
|
||||
{
|
||||
var operation = await this.GrpcClient.ListCollectionsAsync();
|
||||
return operation.Count.ToString();
|
||||
}
|
||||
|
||||
public override async IAsyncEnumerable<(string Label, string Value)> GetDisplayInfo()
|
||||
{
|
||||
yield return ("HTTP port", this.HttpPort.ToString());
|
||||
yield return ("gRPC port", this.GrpcPort.ToString());
|
||||
yield return ("Extracted version", await this.GetVersion());
|
||||
yield return ("Storage size", $"{base.GetStorageSize()}");
|
||||
yield return ("Amount of collections", await this.GetCollectionsAmount());
|
||||
}
|
||||
|
||||
public override void Dispose()
|
||||
{
|
||||
this.GrpcClient.Dispose();
|
||||
}
|
||||
}
|
||||
@ -10,4 +10,6 @@ public record struct QdrantInfo
|
||||
public string Path { get; init; }
|
||||
public int PortHttp { get; init; }
|
||||
public int PortGrpc { get; init; }
|
||||
public string Fingerprint { get; init; }
|
||||
public string ApiToken { get; init; }
|
||||
}
|
||||
@ -9,7 +9,7 @@ public sealed partial class RustService
|
||||
try
|
||||
{
|
||||
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(45));
|
||||
var response = await this.http.GetFromJsonAsync<QdrantInfo>("/system/qdrant/port", this.jsonRustSerializerOptions, cts.Token);
|
||||
var response = await this.http.GetFromJsonAsync<QdrantInfo>("/system/qdrant/info", this.jsonRustSerializerOptions, cts.Token);
|
||||
return response;
|
||||
}
|
||||
catch (Exception e)
|
||||
@ -20,6 +20,8 @@ public sealed partial class RustService
|
||||
Path = string.Empty,
|
||||
PortHttp = 0,
|
||||
PortGrpc = 0,
|
||||
Fingerprint = string.Empty,
|
||||
ApiToken = string.Empty,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@ -39,6 +39,7 @@ pdfium-render = "0.8.34"
|
||||
sys-locale = "0.3.2"
|
||||
cfg-if = "1.0.1"
|
||||
pptx-to-md = "0.4.0"
|
||||
tempfile = "3.8"
|
||||
|
||||
# Fixes security vulnerability downstream, where the upstream is not fixed yet:
|
||||
url = "2.5"
|
||||
|
||||
@ -332,10 +332,10 @@ telemetry_disabled: true
|
||||
# Required if either service.enable_tls or cluster.p2p.enable_tls is true.
|
||||
tls:
|
||||
# Server certificate chain file
|
||||
cert: ./tls/cert.pem
|
||||
# cert: ./tls/cert.pem
|
||||
|
||||
# Server private key file
|
||||
key: ./tls/key.pem
|
||||
# key: ./tls/key.pem
|
||||
|
||||
# Certificate authority certificate file.
|
||||
# This certificate will be used to validate the certificates
|
||||
|
||||
@ -1,21 +1,5 @@
|
||||
use log::info;
|
||||
use once_cell::sync::Lazy;
|
||||
use rand::{RngCore, SeedableRng};
|
||||
use rocket::http::Status;
|
||||
use rocket::Request;
|
||||
use rocket::request::FromRequest;
|
||||
|
||||
/// The API token used to authenticate requests.
|
||||
pub static API_TOKEN: Lazy<APIToken> = Lazy::new(|| {
|
||||
let mut token = [0u8; 32];
|
||||
let mut rng = rand_chacha::ChaChaRng::from_os_rng();
|
||||
rng.fill_bytes(&mut token);
|
||||
|
||||
let token = APIToken::from_bytes(token.to_vec());
|
||||
info!("API token was generated successfully.");
|
||||
|
||||
token
|
||||
});
|
||||
use rand_chacha::ChaChaRng;
|
||||
|
||||
/// The API token data structure used to authenticate requests.
|
||||
pub struct APIToken {
|
||||
@ -34,7 +18,7 @@ impl APIToken {
|
||||
}
|
||||
|
||||
/// Creates a new API token from a hexadecimal text.
|
||||
fn from_hex_text(hex_text: &str) -> Self {
|
||||
pub fn from_hex_text(hex_text: &str) -> Self {
|
||||
APIToken {
|
||||
hex_text: hex_text.to_string(),
|
||||
}
|
||||
@ -45,40 +29,14 @@ impl APIToken {
|
||||
}
|
||||
|
||||
/// Validates the received token against the valid token.
|
||||
fn validate(&self, received_token: &Self) -> bool {
|
||||
pub fn validate(&self, received_token: &Self) -> bool {
|
||||
received_token.to_hex_text() == self.to_hex_text()
|
||||
}
|
||||
}
|
||||
|
||||
/// The request outcome type used to handle API token requests.
|
||||
type RequestOutcome<R, T> = rocket::request::Outcome<R, T>;
|
||||
|
||||
/// The request outcome implementation for the API token.
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for APIToken {
|
||||
type Error = APITokenError;
|
||||
|
||||
/// Handles the API token requests.
|
||||
async fn from_request(request: &'r Request<'_>) -> RequestOutcome<Self, Self::Error> {
|
||||
let token = request.headers().get_one("token");
|
||||
match token {
|
||||
Some(token) => {
|
||||
let received_token = APIToken::from_hex_text(token);
|
||||
if API_TOKEN.validate(&received_token) {
|
||||
RequestOutcome::Success(received_token)
|
||||
} else {
|
||||
RequestOutcome::Error((Status::Unauthorized, APITokenError::Invalid))
|
||||
}
|
||||
}
|
||||
|
||||
None => RequestOutcome::Error((Status::Unauthorized, APITokenError::Missing)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The API token error types.
|
||||
#[derive(Debug)]
|
||||
pub enum APITokenError {
|
||||
Missing,
|
||||
Invalid,
|
||||
pub fn generate_api_token() -> APIToken {
|
||||
let mut token = [0u8; 32];
|
||||
let mut rng = ChaChaRng::from_os_rng();
|
||||
rng.fill_bytes(&mut token);
|
||||
APIToken::from_bytes(token.to_vec())
|
||||
}
|
||||
@ -17,7 +17,7 @@ use crate::dotnet::stop_dotnet_server;
|
||||
use crate::environment::{is_prod, is_dev, CONFIG_DIRECTORY, DATA_DIRECTORY};
|
||||
use crate::log::switch_to_file_logging;
|
||||
use crate::pdfium::PDFIUM_LIB_PATH;
|
||||
use crate::qdrant::start_qdrant_server;
|
||||
use crate::qdrant::{start_qdrant_server, stop_qdrant_server};
|
||||
|
||||
/// The Tauri main window.
|
||||
static MAIN_WINDOW: Lazy<Mutex<Option<Window>>> = Lazy::new(|| Mutex::new(None));
|
||||
@ -174,6 +174,7 @@ pub fn start_tauri() {
|
||||
|
||||
RunEvent::ExitRequested { .. } => {
|
||||
warn!(Source = "Tauri"; "Run event: exit was requested.");
|
||||
stop_qdrant_server();
|
||||
}
|
||||
|
||||
RunEvent::Ready => {
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
use std::sync::OnceLock;
|
||||
use log::info;
|
||||
use rcgen::generate_simple_self_signed;
|
||||
use sha2::{Sha256, Digest};
|
||||
|
||||
/// The certificate used for the runtime API server.
|
||||
pub static CERTIFICATE: OnceLock<Vec<u8>> = OnceLock::new();
|
||||
|
||||
/// The private key used for the certificate of the runtime API server.
|
||||
pub static CERTIFICATE_PRIVATE_KEY: OnceLock<Vec<u8>> = OnceLock::new();
|
||||
|
||||
/// The fingerprint of the certificate used for the runtime API server.
|
||||
pub static CERTIFICATE_FINGERPRINT: OnceLock<String> = OnceLock::new();
|
||||
|
||||
/// Generates a TLS certificate for the runtime API server.
|
||||
pub fn generate_certificate() {
|
||||
|
||||
info!("Try to generate a TLS certificate for the runtime API server...");
|
||||
|
||||
let subject_alt_names = vec!["localhost".to_string()];
|
||||
let certificate_data = generate_simple_self_signed(subject_alt_names).unwrap();
|
||||
let certificate_binary_data = certificate_data.cert.der().to_vec();
|
||||
|
||||
let certificate_fingerprint = Sha256::digest(certificate_binary_data).to_vec();
|
||||
let certificate_fingerprint = certificate_fingerprint.iter().fold(String::new(), |mut result, byte| {
|
||||
result.push_str(&format!("{:02x}", byte));
|
||||
result
|
||||
});
|
||||
|
||||
let certificate_fingerprint = certificate_fingerprint.to_uppercase();
|
||||
|
||||
CERTIFICATE_FINGERPRINT.set(certificate_fingerprint.clone()).expect("Could not set the certificate fingerprint.");
|
||||
CERTIFICATE.set(certificate_data.cert.pem().as_bytes().to_vec()).expect("Could not set the certificate.");
|
||||
CERTIFICATE_PRIVATE_KEY.set(certificate_data.signing_key.serialize_pem().as_bytes().to_vec()).expect("Could not set the private key.");
|
||||
|
||||
info!("Certificate fingerprint: '{certificate_fingerprint}'.");
|
||||
info!("Done generating certificate for the runtime API server.");
|
||||
}
|
||||
22
runtime/src/certificate_factory.rs
Normal file
22
runtime/src/certificate_factory.rs
Normal file
@ -0,0 +1,22 @@
|
||||
use log::info;
|
||||
use rcgen::generate_simple_self_signed;
|
||||
use sha2::{Sha256, Digest};
|
||||
|
||||
pub fn generate_certificate() -> (Vec<u8>, Vec<u8>, String) {
|
||||
|
||||
let subject_alt_names = vec!["localhost".to_string()];
|
||||
let certificate_data = generate_simple_self_signed(subject_alt_names).unwrap();
|
||||
let certificate_binary_data = certificate_data.cert.der().to_vec();
|
||||
|
||||
let certificate_fingerprint = Sha256::digest(certificate_binary_data).to_vec();
|
||||
let certificate_fingerprint = certificate_fingerprint.iter().fold(String::new(), |mut result, byte| {
|
||||
result.push_str(&format!("{:02x}", byte));
|
||||
result
|
||||
});
|
||||
|
||||
let certificate_fingerprint = certificate_fingerprint.to_uppercase();
|
||||
|
||||
info!("Certificate fingerprint: '{certificate_fingerprint}'.");
|
||||
|
||||
(certificate_data.cert.pem().as_bytes().to_vec(), certificate_data.signing_key.serialize_pem().as_bytes().to_vec(), certificate_fingerprint.clone())
|
||||
}
|
||||
@ -7,9 +7,10 @@ use once_cell::sync::Lazy;
|
||||
use rocket::get;
|
||||
use tauri::api::process::{Command, CommandChild, CommandEvent};
|
||||
use tauri::Url;
|
||||
use crate::api_token::{APIToken, API_TOKEN};
|
||||
use crate::api_token::APIToken;
|
||||
use crate::runtime_api_token::API_TOKEN;
|
||||
use crate::app_window::change_location_to;
|
||||
use crate::certificate::CERTIFICATE_FINGERPRINT;
|
||||
use crate::runtime_certificate::CERTIFICATE_FINGERPRINT;
|
||||
use crate::encryption::ENCRYPTION;
|
||||
use crate::environment::is_dev;
|
||||
use crate::network::get_available_port;
|
||||
|
||||
@ -8,9 +8,11 @@ pub mod app_window;
|
||||
pub mod secret;
|
||||
pub mod clipboard;
|
||||
pub mod runtime_api;
|
||||
pub mod certificate;
|
||||
pub mod runtime_certificate;
|
||||
pub mod file_data;
|
||||
pub mod metadata;
|
||||
pub mod pdfium;
|
||||
pub mod pandoc;
|
||||
pub mod qdrant;
|
||||
pub mod qdrant;
|
||||
pub mod certificate_factory;
|
||||
pub mod runtime_api_token;
|
||||
@ -6,7 +6,7 @@ extern crate core;
|
||||
|
||||
use log::{info, warn};
|
||||
use mindwork_ai_studio::app_window::start_tauri;
|
||||
use mindwork_ai_studio::certificate::{generate_certificate};
|
||||
use mindwork_ai_studio::runtime_certificate::{generate_runtime_certificate};
|
||||
use mindwork_ai_studio::dotnet::start_dotnet_server;
|
||||
use mindwork_ai_studio::environment::is_dev;
|
||||
use mindwork_ai_studio::log::init_logging;
|
||||
@ -46,7 +46,7 @@ async fn main() {
|
||||
info!("Running in production mode.");
|
||||
}
|
||||
|
||||
generate_certificate();
|
||||
generate_runtime_certificate();
|
||||
start_runtime_api();
|
||||
|
||||
if is_dev() {
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use log::{debug, error, info, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use rocket::get;
|
||||
@ -9,6 +11,9 @@ use rocket::serde::Serialize;
|
||||
use tauri::api::process::{Command, CommandChild, CommandEvent};
|
||||
use crate::api_token::{APIToken};
|
||||
use crate::environment::DATA_DIRECTORY;
|
||||
use crate::certificate_factory::generate_certificate;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::{TempDir, Builder};
|
||||
|
||||
// Qdrant server process started in a separate process and can communicate
|
||||
// via HTTP or gRPC with the .NET server and the runtime process
|
||||
@ -23,26 +28,38 @@ static QDRANT_SERVER_PORT_GRPC: Lazy<u16> = Lazy::new(|| {
|
||||
crate::network::get_available_port().unwrap_or(6334)
|
||||
});
|
||||
|
||||
pub static CERTIFICATE_FINGERPRINT: OnceLock<String> = OnceLock::new();
|
||||
static API_TOKEN: Lazy<APIToken> = Lazy::new(|| {
|
||||
crate::api_token::generate_api_token()
|
||||
});
|
||||
|
||||
static TMPDIR: Lazy<Mutex<Option<TempDir>>> = Lazy::new(|| Mutex::new(None));
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ProvideQdrantInfo {
|
||||
path: String,
|
||||
port_http: u16,
|
||||
port_grpc: u16,
|
||||
fingerprint: String,
|
||||
api_token: String,
|
||||
}
|
||||
|
||||
#[get("/system/qdrant/port")]
|
||||
#[get("/system/qdrant/info")]
|
||||
pub fn qdrant_port(_token: APIToken) -> Json<ProvideQdrantInfo> {
|
||||
return Json(ProvideQdrantInfo {
|
||||
path: Path::new(DATA_DIRECTORY.get().unwrap()).join("databases").join("qdrant").to_str().unwrap().to_string(),
|
||||
port_http: *QDRANT_SERVER_PORT_HTTP,
|
||||
port_grpc: *QDRANT_SERVER_PORT_GRPC,
|
||||
fingerprint: CERTIFICATE_FINGERPRINT.get().expect("Certificate fingerprint not available").to_string(),
|
||||
api_token: API_TOKEN.to_hex_text().to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
/// Starts the Qdrant server in a separate process.
|
||||
pub fn start_qdrant_server() {
|
||||
|
||||
let base_path = DATA_DIRECTORY.get().unwrap();
|
||||
let base_path = DATA_DIRECTORY.get().unwrap();
|
||||
let (cert_path, key_path) =create_temp_tls_files(Path::new(base_path).join("databases").join("qdrant")).unwrap();
|
||||
|
||||
let storage_path = Path::new(base_path).join("databases").join("qdrant").join("storage").to_str().unwrap().to_string();
|
||||
let snapshot_path = Path::new(base_path).join("databases").join("qdrant").join("snapshots").to_str().unwrap().to_string();
|
||||
@ -54,6 +71,10 @@ pub fn start_qdrant_server() {
|
||||
(String::from("QDRANT_INIT_FILE_PATH"), init_path),
|
||||
(String::from("QDRANT__STORAGE__STORAGE_PATH"), storage_path),
|
||||
(String::from("QDRANT__STORAGE__SNAPSHOTS_PATH"), snapshot_path),
|
||||
(String::from("QDRANT__TLS__CERT"), cert_path.to_str().unwrap().to_string()),
|
||||
(String::from("QDRANT__TLS__KEY"), key_path.to_str().unwrap().to_string()),
|
||||
(String::from("QDRANT__SERVICE__ENABLE_TLS"), "true".to_string()),
|
||||
(String::from("QDRANT__SERVICE__API_KEY"), API_TOKEN.to_hex_text().to_string()),
|
||||
]);
|
||||
|
||||
let server_spawn_clone = QDRANT_SERVER.clone();
|
||||
@ -97,13 +118,51 @@ pub fn start_qdrant_server() {
|
||||
|
||||
/// Stops the Qdrant server process.
|
||||
pub fn stop_qdrant_server() {
|
||||
drop_tmpdir();
|
||||
if let Some(server_process) = QDRANT_SERVER.lock().unwrap().take() {
|
||||
let server_kill_result = server_process.kill();
|
||||
match server_kill_result {
|
||||
Ok(_) => info!("Qdrant server process was stopped."),
|
||||
Err(e) => error!("Failed to stop Qdrant server process: {e}."),
|
||||
Ok(_) => warn!(Source = "Qdrant"; "Qdrant server process was stopped."),
|
||||
Err(e) => error!(Source = "Qdrant"; "Failed to stop Qdrant server process: {e}."),
|
||||
}
|
||||
} else {
|
||||
warn!("Qdrant server process was not started or is already stopped.");
|
||||
warn!(Source = "Qdrant"; "Qdrant server process was not started or is already stopped.");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_temp_tls_files(path: PathBuf) -> Result<(PathBuf, PathBuf), Box<dyn std::error::Error>> {
|
||||
let (certificate, cert_private_key, cert_fingerprint) = generate_certificate();
|
||||
|
||||
let temp_dir = init_tmpdir_in(path);
|
||||
|
||||
let cert_path = temp_dir.join("cert.pem");
|
||||
let key_path = temp_dir.join("key.pem");
|
||||
|
||||
let mut cert_file = File::create(&cert_path)?;
|
||||
cert_file.write_all(&*certificate)?;
|
||||
|
||||
let mut key_file = File::create(&key_path)?;
|
||||
key_file.write_all(&*cert_private_key)?;
|
||||
|
||||
CERTIFICATE_FINGERPRINT.set(cert_fingerprint).expect("Could not set the certificate fingerprint.");
|
||||
|
||||
Ok((cert_path, key_path))
|
||||
}
|
||||
|
||||
pub fn init_tmpdir_in<P: AsRef<Path>>(path: P) -> PathBuf {
|
||||
let mut guard = TMPDIR.lock().unwrap();
|
||||
let dir = guard.get_or_insert_with(|| {
|
||||
Builder::new()
|
||||
.prefix("cert-")
|
||||
.tempdir_in(path)
|
||||
.expect("failed to create tempdir")
|
||||
});
|
||||
|
||||
dir.path().to_path_buf()
|
||||
}
|
||||
|
||||
pub fn drop_tmpdir() {
|
||||
let mut guard = TMPDIR.lock().unwrap();
|
||||
*guard = None;
|
||||
warn!(Source = "Qdrant"; "Temporary directory for TLS was dropped.");
|
||||
}
|
||||
@ -3,7 +3,7 @@ use once_cell::sync::Lazy;
|
||||
use rocket::config::Shutdown;
|
||||
use rocket::figment::Figment;
|
||||
use rocket::routes;
|
||||
use crate::certificate::{CERTIFICATE, CERTIFICATE_PRIVATE_KEY};
|
||||
use crate::runtime_certificate::{CERTIFICATE, CERTIFICATE_PRIVATE_KEY};
|
||||
use crate::environment::is_dev;
|
||||
use crate::network::get_available_port;
|
||||
|
||||
|
||||
40
runtime/src/runtime_api_token.rs
Normal file
40
runtime/src/runtime_api_token.rs
Normal file
@ -0,0 +1,40 @@
|
||||
use once_cell::sync::Lazy;
|
||||
use rocket::http::Status;
|
||||
use rocket::Request;
|
||||
use rocket::request::FromRequest;
|
||||
use crate::api_token::{generate_api_token, APIToken};
|
||||
|
||||
pub static API_TOKEN: Lazy<APIToken> = Lazy::new(|| generate_api_token());
|
||||
|
||||
/// The request outcome type used to handle API token requests.
|
||||
type RequestOutcome<R, T> = rocket::request::Outcome<R, T>;
|
||||
|
||||
/// The request outcome implementation for the API token.
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for APIToken {
|
||||
type Error = APITokenError;
|
||||
|
||||
/// Handles the API token requests.
|
||||
async fn from_request(request: &'r Request<'_>) -> RequestOutcome<Self, Self::Error> {
|
||||
let token = request.headers().get_one("token");
|
||||
match token {
|
||||
Some(token) => {
|
||||
let received_token = APIToken::from_hex_text(token);
|
||||
if API_TOKEN.validate(&received_token) {
|
||||
RequestOutcome::Success(received_token)
|
||||
} else {
|
||||
RequestOutcome::Error((Status::Unauthorized, APITokenError::Invalid))
|
||||
}
|
||||
}
|
||||
|
||||
None => RequestOutcome::Error((Status::Unauthorized, APITokenError::Missing)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The API token error types.
|
||||
#[derive(Debug)]
|
||||
pub enum APITokenError {
|
||||
Missing,
|
||||
Invalid,
|
||||
}
|
||||
26
runtime/src/runtime_certificate.rs
Normal file
26
runtime/src/runtime_certificate.rs
Normal file
@ -0,0 +1,26 @@
|
||||
use std::sync::OnceLock;
|
||||
use log::info;
|
||||
use crate::certificate_factory::generate_certificate;
|
||||
|
||||
/// The certificate used for the runtime API server.
|
||||
pub static CERTIFICATE: OnceLock<Vec<u8>> = OnceLock::new();
|
||||
|
||||
/// The private key used for the certificate of the runtime API server.
|
||||
pub static CERTIFICATE_PRIVATE_KEY: OnceLock<Vec<u8>> = OnceLock::new();
|
||||
|
||||
/// The fingerprint of the certificate used for the runtime API server.
|
||||
pub static CERTIFICATE_FINGERPRINT: OnceLock<String> = OnceLock::new();
|
||||
|
||||
/// Generates a TLS certificate for the runtime API server.
|
||||
pub fn generate_runtime_certificate() {
|
||||
|
||||
info!("Try to generate a TLS certificate for the runtime API server...");
|
||||
|
||||
let (certificate, cer_private_key, cer_fingerprint) = generate_certificate();
|
||||
|
||||
CERTIFICATE_FINGERPRINT.set(cer_fingerprint).expect("Could not set the certificate fingerprint.");
|
||||
CERTIFICATE.set(certificate).expect("Could not set the certificate.");
|
||||
CERTIFICATE_PRIVATE_KEY.set(cer_private_key).expect("Could not set the private key.");
|
||||
|
||||
info!("Done generating certificate for the runtime API server.");
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user