Migrated IPC server to use axum

This commit is contained in:
Thorsten Sommer 2026-05-12 19:17:49 +02:00
parent f69186f7a9
commit 978b261c13
Signed by untrusted user who does not match committer: tsommer
GPG Key ID: 371BBA77A02C0108
15 changed files with 387 additions and 914 deletions

794
runtime/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -25,7 +25,9 @@ async-stream = "0.3.6"
flexi_logger = "0.31.8"
log = { version = "0.4.29", features = ["kv"] }
once_cell = "1.21.4"
rocket = { version = "0.5.1", features = ["json", "tls"] }
axum = { version = "0.8.9", features = ["http2", "json", "query", "tokio"] }
axum-server = { version = "0.8.0", features = ["tls-rustls"] }
rustls = { version = "0.23.28", default-features = false, features = ["aws_lc_rs"] }
rand = "0.10.1"
rand_chacha = "0.10.0"
base64 = "0.22.1"
@ -46,7 +48,6 @@ strum_macros = "0.28.0"
sysinfo = "0.38.4"
# Fixes security vulnerability downstream, where the upstream is not fixed yet:
time = "0.3.47" # -> Rocket
bytes = "1.11.1" # -> almost every dependency
[target.'cfg(target_os = "linux")'.dependencies]

View File

@ -1,13 +1,16 @@
use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Mutex;
use std::time::Duration;
use async_stream::stream;
use axum::body::Body;
use axum::http::header::CONTENT_TYPE;
use axum::response::{IntoResponse, Response};
use axum::Json;
use bytes::Bytes;
use log::{debug, error, info, trace, warn};
use once_cell::sync::Lazy;
use rocket::{get, post};
use rocket::response::stream::TextStream;
use rocket::serde::json::Json;
use rocket::serde::Serialize;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use strum_macros::Display;
use tauri::{DragDropEvent,RunEvent, Manager, WindowEvent, generate_context};
use tauri::path::PathResolver;
@ -256,8 +259,7 @@ fn should_open_in_system_browser<R: tauri::Runtime>(webview: &tauri::Webview<R>,
/// When the client disconnects, the stream is closed. But we try to not lose events in between.
/// The client is expected to reconnect automatically when the connection is closed and continue
/// listening for events.
#[get("/events")]
pub async fn get_event_stream(_token: APIToken) -> TextStream![String] {
pub async fn get_event_stream(_token: APIToken) -> Response {
// Get the lock to the event broadcast sender:
let event_broadcast_lock = EVENT_BROADCAST.lock().unwrap();
@ -269,8 +271,7 @@ pub async fn get_event_stream(_token: APIToken) -> TextStream![String] {
// Drop the lock to allow other access to the sender:
drop(event_broadcast_lock);
// Create the event stream:
TextStream! {
let stream = stream! {
loop {
// Wait at most 3 seconds for an event:
match time::timeout(Duration::from_secs(3), event_receiver.recv()).await {
@ -281,11 +282,11 @@ pub async fn get_event_stream(_token: APIToken) -> TextStream![String] {
// is serialized as a single line so that the client can parse it
// correctly:
let event_json = serde_json::to_string(&event).unwrap();
yield event_json;
yield Ok::<Bytes, Infallible>(Bytes::from(event_json));
// The client expects a newline after each event because we are using
// a method to read the stream line-by-line:
yield "\n".to_string();
yield Ok::<Bytes, Infallible>(Bytes::from("\n"));
},
// Case: we lagged behind and missed some events
@ -305,15 +306,17 @@ pub async fn get_event_stream(_token: APIToken) -> TextStream![String] {
// Again, we have to serialize the event as a single line:
let event_json = serde_json::to_string(&ping_event).unwrap();
yield event_json;
yield Ok::<Bytes, Infallible>(Bytes::from(event_json));
// The client expects a newline after each event because we are using
// a method to read the stream line-by-line:
yield "\n".to_string();
yield Ok::<Bytes, Infallible>(Bytes::from("\n"));
},
}
}
}
};
([(CONTENT_TYPE, "application/jsonl")], Body::from_stream(stream)).into_response()
}
/// Data structure representing a Tauri event for our event API.
@ -428,7 +431,6 @@ pub async fn change_location_to(url: &str) {
}
/// Checks for updates.
#[get("/updates/check")]
pub async fn check_for_update(_token: APIToken) -> Json<CheckUpdateResponse> {
if is_dev() {
warn!(Source = "Updater"; "The app is running in development mode; skipping update check.");
@ -514,7 +516,6 @@ pub struct CheckUpdateResponse {
}
/// Installs the update.
#[get("/updates/install")]
pub async fn install_update(_token: APIToken) {
if is_dev() {
warn!(Source = "Updater"; "The app is running in development mode; skipping update installation.");
@ -623,8 +624,7 @@ fn register_shortcut_with_callback<R: tauri::Runtime>(
}
/// Requests a controlled shutdown of the entire desktop application.
#[post("/app/exit")]
pub fn exit_app(_token: APIToken) -> Json<AppExitResponse> {
pub async fn exit_app(_token: APIToken) -> Json<AppExitResponse> {
let app_handle = {
let main_window_lock = MAIN_WINDOW.lock().unwrap();
match main_window_lock.as_ref() {
@ -653,8 +653,7 @@ pub fn exit_app(_token: APIToken) -> Json<AppExitResponse> {
/// Registers or updates a global shortcut. If the shortcut string is empty,
/// the existing shortcut for that name will be unregistered.
#[post("/shortcuts/register", data = "<payload>")]
pub fn register_shortcut(_token: APIToken, payload: Json<RegisterShortcutRequest>) -> Json<ShortcutResponse> {
pub async fn register_shortcut(_token: APIToken, payload: Json<RegisterShortcutRequest>) -> Json<ShortcutResponse> {
let id = payload.id;
let new_shortcut = payload.shortcut.clone();
@ -761,8 +760,7 @@ pub struct ShortcutValidationResponse {
/// Validates a shortcut string without registering it.
/// Checks if the shortcut syntax is valid and if it
/// conflicts with existing shortcuts.
#[post("/shortcuts/validate", data = "<payload>")]
pub fn validate_shortcut(_token: APIToken, payload: Json<ValidateShortcutRequest>) -> Json<ShortcutValidationResponse> {
pub async fn validate_shortcut(_token: APIToken, payload: Json<ValidateShortcutRequest>) -> Json<ShortcutValidationResponse> {
let shortcut = payload.shortcut.clone();
// Empty shortcuts are always valid (means "disabled"):
@ -816,8 +814,7 @@ pub fn validate_shortcut(_token: APIToken, payload: Json<ValidateShortcutRequest
/// The shortcuts remain in our internal map, so they can be re-registered on resume.
/// This is useful when opening a dialog to configure shortcuts, so the user can
/// press the current shortcut to re-enter it without triggering the action.
#[post("/shortcuts/suspend")]
pub fn suspend_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
pub async fn suspend_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
// Get the main window to access the global shortcut manager:
let main_window_lock = MAIN_WINDOW.lock().unwrap();
let main_window = match main_window_lock.as_ref() {
@ -853,8 +850,7 @@ pub fn suspend_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
}
/// Resumes shortcut processing by re-registering all shortcuts with the OS.
#[post("/shortcuts/resume")]
pub fn resume_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
pub async fn resume_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
// Get the main window to access the global shortcut manager:
let main_window_lock = MAIN_WINDOW.lock().unwrap();
let main_window = match main_window_lock.as_ref() {
@ -954,36 +950,6 @@ fn validate_shortcut_syntax(shortcut: &str) -> bool {
has_key
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tauri_localhost_is_tauri_asset_url() {
let https_url = tauri::Url::parse("https://tauri.localhost/index.html").unwrap();
let http_url = tauri::Url::parse("http://tauri.localhost/index.html").unwrap();
assert!(is_tauri_asset_url(&https_url));
assert!(is_tauri_asset_url(&http_url));
}
#[test]
fn localhost_app_url_is_not_tauri_asset_url() {
let url = tauri::Url::parse("http://localhost:12345/").unwrap();
assert!(!is_tauri_asset_url(&url));
assert!(is_local_http_url(&url));
}
#[test]
fn external_url_is_not_internal_url() {
let url = tauri::Url::parse("https://example.com/").unwrap();
assert!(!is_tauri_asset_url(&url));
assert!(!is_local_http_url(&url));
}
}
fn set_pdfium_path<R: tauri::Runtime>(path_resolver: &PathResolver<R>) {
let resource_dir = match path_resolver.resource_dir() {
Ok(path) => path,
@ -1012,3 +978,33 @@ fn set_pdfium_path<R: tauri::Runtime>(path_resolver: &PathResolver<R>) {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tauri_localhost_is_tauri_asset_url() {
let https_url = tauri::Url::parse("https://tauri.localhost/index.html").unwrap();
let http_url = tauri::Url::parse("http://tauri.localhost/index.html").unwrap();
assert!(is_tauri_asset_url(&https_url));
assert!(is_tauri_asset_url(&http_url));
}
#[test]
fn localhost_app_url_is_not_tauri_asset_url() {
let url = tauri::Url::parse("http://localhost:12345/").unwrap();
assert!(!is_tauri_asset_url(&url));
assert!(is_local_http_url(&url));
}
#[test]
fn external_url_is_not_internal_url() {
let url = tauri::Url::parse("https://example.com/").unwrap();
assert!(!is_tauri_asset_url(&url));
assert!(!is_local_http_url(&url));
}
}

View File

@ -1,14 +1,13 @@
use arboard::Clipboard;
use log::{debug, error};
use rocket::post;
use rocket::serde::json::Json;
use axum::Json;
use serde::Serialize;
use crate::api_token::APIToken;
use crate::encryption::{EncryptedText, ENCRYPTION};
/// Sets the clipboard text to the provided encrypted text.
#[post("/clipboard/set", data = "<encrypted_text>")]
pub fn set_clipboard(_token: APIToken, encrypted_text: EncryptedText) -> Json<SetClipboardResponse> {
pub async fn set_clipboard(_token: APIToken, encrypted_text: String) -> Json<SetClipboardResponse> {
let encrypted_text = EncryptedText::new(encrypted_text);
// Decrypt this text first:
let decrypted_text = match ENCRYPTION.decrypt(&encrypted_text) {

View File

@ -5,7 +5,6 @@ use base64::Engine;
use base64::prelude::BASE64_STANDARD;
use log::{error, info, warn};
use once_cell::sync::Lazy;
use rocket::get;
use tauri::Url;
use tauri_plugin_shell::process::{CommandChild, CommandEvent};
use tauri_plugin_shell::ShellExt;
@ -89,8 +88,7 @@ fn sanitize_stdout_line(line: &str) -> String {
/// Returns the desired port of the .NET server. Our .NET app calls this endpoint to get
/// the port where the .NET server should listen to.
#[get("/system/dotnet/port")]
pub fn dotnet_port(_token: APIToken) -> String {
pub async fn dotnet_port(_token: APIToken) -> String {
let dotnet_server_port = *DOTNET_SERVER_PORT;
format!("{dotnet_server_port}")
}
@ -179,7 +177,6 @@ pub fn start_dotnet_server<R: tauri::Runtime>(app_handle: tauri::AppHandle<R>) {
}
/// This endpoint is called by the .NET server to signal that the server is ready.
#[get("/system/dotnet/ready")]
pub async fn dotnet_ready(_token: APIToken) {
// We create a manual scope for the lock to be released as soon as possible.

View File

@ -9,19 +9,13 @@ use once_cell::sync::Lazy;
use pbkdf2::pbkdf2;
use rand::rngs::SysRng;
use rand::{Rng, SeedableRng};
use rocket::{data, Data, Request};
use rocket::data::ToByteUnit;
use rocket::http::Status;
use rocket::serde::{Deserialize, Serialize};
use serde::{Deserialize, Serialize};
use sha2::Sha512;
use tokio::io::AsyncReadExt;
type Aes256CbcEnc = cbc::Encryptor<aes::Aes256>;
type Aes256CbcDec = cbc::Decryptor<aes::Aes256>;
type DataOutcome<'r, T> = data::Outcome<'r, T>;
/// The encryption instance used for the IPC channel.
pub static ENCRYPTION: Lazy<Encryption> = Lazy::new(|| {
//
@ -170,27 +164,4 @@ impl fmt::Display for EncryptedText {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "**********")
}
}
/// Use Case: When we receive encrypted text from the client as body (e.g., in a POST request).
/// We must interpret the body as EncryptedText.
#[rocket::async_trait]
impl<'r> data::FromData<'r> for EncryptedText {
type Error = String;
/// Parses the data as EncryptedText.
async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> DataOutcome<'r, Self> {
let content_type = req.content_type();
if content_type.map_or(true, |ct| !ct.is_text()) {
return DataOutcome::Forward((data, Status::Ok));
}
let mut stream = data.open(2.mebibytes());
let mut body = String::new();
if let Err(e) = stream.read_to_string(&mut body).await {
return DataOutcome::Error((Status::InternalServerError, format!("Failed to read data: {}", e)));
}
DataOutcome::Success(EncryptedText(body))
}
}

View File

@ -1,7 +1,6 @@
use crate::api_token::APIToken;
use axum::Json;
use log::{debug, info, warn};
use rocket::get;
use rocket::serde::json::Json;
use serde::Serialize;
use std::collections::{HashMap, HashSet};
use std::env;
@ -29,8 +28,7 @@ pub static CONFIG_DIRECTORY: OnceLock<String> = OnceLock::new();
static USER_LANGUAGE: OnceLock<String> = OnceLock::new();
/// Returns the config directory.
#[get("/system/directories/config")]
pub fn get_config_directory(_token: APIToken) -> String {
pub async fn get_config_directory(_token: APIToken) -> String {
match CONFIG_DIRECTORY.get() {
Some(config_directory) => config_directory.clone(),
None => String::from(""),
@ -38,8 +36,7 @@ pub fn get_config_directory(_token: APIToken) -> String {
}
/// Returns the data directory.
#[get("/system/directories/data")]
pub fn get_data_directory(_token: APIToken) -> String {
pub async fn get_data_directory(_token: APIToken) -> String {
match DATA_DIRECTORY.get() {
Some(data_directory) => data_directory.clone(),
None => String::from(""),
@ -150,8 +147,7 @@ fn detect_user_language() -> (String, LanguageDetectionSource) {
)
}
#[get("/system/language")]
pub fn read_user_language(_token: APIToken) -> String {
pub async fn read_user_language(_token: APIToken) -> String {
USER_LANGUAGE
.get_or_init(|| {
let (user_language, source) = detect_user_language();
@ -194,8 +190,7 @@ struct EnterpriseSourceData {
encryption_secret: String,
}
#[get("/system/enterprise/config/id")]
pub fn read_enterprise_env_config_id(_token: APIToken) -> String {
pub async fn read_enterprise_env_config_id(_token: APIToken) -> String {
debug!("Trying to read the effective enterprise configuration ID.");
resolve_effective_enterprise_config_source()
.configs
@ -205,8 +200,7 @@ pub fn read_enterprise_env_config_id(_token: APIToken) -> String {
.unwrap_or_default()
}
#[get("/system/enterprise/config/server")]
pub fn read_enterprise_env_config_server_url(_token: APIToken) -> String {
pub async fn read_enterprise_env_config_server_url(_token: APIToken) -> String {
debug!("Trying to read the effective enterprise configuration server URL.");
resolve_effective_enterprise_config_source()
.configs
@ -216,15 +210,13 @@ pub fn read_enterprise_env_config_server_url(_token: APIToken) -> String {
.unwrap_or_default()
}
#[get("/system/enterprise/config/encryption_secret")]
pub fn read_enterprise_env_config_encryption_secret(_token: APIToken) -> String {
pub async fn read_enterprise_env_config_encryption_secret(_token: APIToken) -> String {
debug!("Trying to read the effective enterprise configuration encryption secret.");
resolve_effective_enterprise_secret_source().encryption_secret
}
/// Returns all enterprise configurations from the effective source.
#[get("/system/enterprise/configs")]
pub fn read_enterprise_configs(_token: APIToken) -> Json<Vec<EnterpriseConfig>> {
pub async fn read_enterprise_configs(_token: APIToken) -> Json<Vec<EnterpriseConfig>> {
info!("Trying to read the effective enterprise configurations.");
Json(resolve_effective_enterprise_config_source().configs)
}

View File

@ -1,7 +1,7 @@
use log::{error, info};
use rocket::post;
use rocket::serde::{Deserialize, Serialize};
use rocket::serde::json::Json;
use axum::extract::Query;
use axum::Json;
use serde::{Deserialize, Serialize};
use tauri_plugin_dialog::{DialogExt, FileDialogBuilder};
use crate::api_token::APIToken;
use crate::app_window::MAIN_WINDOW;
@ -11,6 +11,11 @@ pub struct PreviousDirectory {
path: String,
}
#[derive(Deserialize)]
pub struct SelectDirectoryQuery {
title: String,
}
#[derive(Clone, Deserialize)]
pub struct FileTypeFilter {
filter_name: String,
@ -61,10 +66,9 @@ pub struct PreviousFile {
}
/// Let the user select a directory.
#[post("/select/directory?<title>", data = "<previous_directory>")]
pub fn select_directory(
pub async fn select_directory(
_token: APIToken,
title: &str,
Query(query): Query<SelectDirectoryQuery>,
previous_directory: Option<Json<PreviousDirectory>>,
) -> Json<DirectorySelectionResponse> {
let main_window_lock = MAIN_WINDOW.lock().unwrap();
@ -79,7 +83,7 @@ pub fn select_directory(
}
};
let mut dialog = main_window.dialog().file().set_parent(main_window).set_title(title);
let mut dialog = main_window.dialog().file().set_parent(main_window).set_title(&query.title);
if let Some(previous) = previous_directory {
dialog = dialog.set_directory(previous.path.clone());
}
@ -118,8 +122,7 @@ pub fn select_directory(
}
/// Let the user select a file.
#[post("/select/file", data = "<payload>")]
pub fn select_file(
pub async fn select_file(
_token: APIToken,
payload: Json<SelectFileOptions>,
) -> Json<FileSelectionResponse> {
@ -178,8 +181,7 @@ pub fn select_file(
}
/// Let the user select some files.
#[post("/select/files", data = "<payload>")]
pub fn select_files(
pub async fn select_files(
_token: APIToken,
payload: Json<SelectFileOptions>,
) -> Json<FilesSelectionResponse> {
@ -229,8 +231,7 @@ pub fn select_files(
}
}
#[post("/save/file", data = "<payload>")]
pub fn save_file(_token: APIToken, payload: Json<SaveFileOptions>) -> Json<FileSaveResponse> {
pub async fn save_file(_token: APIToken, payload: Json<SaveFileOptions>) -> Json<FileSaveResponse> {
// Create a new file dialog builder:
let file_dialog = MAIN_WINDOW
.lock()

View File

@ -1,19 +1,18 @@
use std::cmp::min;
use std::convert::Infallible;
use crate::api_token::APIToken;
use crate::pandoc::PandocProcessBuilder;
use crate::pdfium::PdfiumInit;
use async_stream::stream;
use axum::extract::Query;
use axum::response::sse::{Event, Sse};
use base64::{engine::general_purpose, Engine as _};
use calamine::{open_workbook_auto, Reader};
use file_format::{FileFormat, Kind};
use futures::{Stream, StreamExt};
use pdfium_render::prelude::Pdfium;
use pptx_to_md::{ImageHandlingMode, ParserConfig, PptxContainer};
use rocket::get;
use rocket::response::stream::{Event, EventStream};
use rocket::serde::Serialize;
use rocket::tokio::select;
use rocket::Shutdown;
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::pin::Pin;
use log::{debug, error};
@ -82,39 +81,45 @@ const IMAGE_SEGMENT_SIZE_IN_CHARS: usize = 8_192; // equivalent to ~ 5500 token
type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
type ChunkStream = Pin<Box<dyn Stream<Item = Result<Chunk>> + Send>>;
#[get("/retrieval/fs/extract?<path>&<stream_id>&<extract_images>")]
pub async fn extract_data(_token: APIToken, path: String, stream_id: String, extract_images: bool, mut end: Shutdown) -> EventStream![] {
EventStream! {
let stream_result = stream_data(&path, extract_images).await;
let id_ref = &stream_id;
#[derive(Deserialize)]
pub struct ExtractDataQuery {
path: String,
stream_id: String,
extract_images: bool,
}
pub async fn extract_data(
_token: APIToken,
Query(query): Query<ExtractDataQuery>,
) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
let stream = stream! {
let stream_result = stream_data(&query.path, query.extract_images).await;
let id_ref = &query.stream_id;
match stream_result {
Ok(mut stream) => {
loop {
let chunk = select! {
chunk = stream.next() => match chunk {
Some(Ok(mut chunk)) => {
chunk.set_stream_id(id_ref);
chunk
},
Some(Err(e)) => {
yield Event::json(&format!("Error: {e}"));
break;
},
None => break,
while let Some(chunk) = stream.next().await {
match chunk {
Ok(mut chunk) => {
chunk.set_stream_id(id_ref);
yield Ok(Event::default().json_data(&chunk).unwrap_or_else(|e| Event::default().data(format!("Error: {e}"))));
},
_ = &mut end => break,
};
yield Event::json(&chunk);
Err(e) => {
yield Ok(Event::default().json_data(format!("Error: {e}")).unwrap_or_else(|_| Event::default().data(format!("Error: {e}"))));
break;
},
}
}
},
Err(e) => {
yield Event::json(&format!("Error starting stream: {e}"));
yield Ok(Event::default().json_data(format!("Error starting stream: {e}")).unwrap_or_else(|_| Event::default().data(format!("Error starting stream: {e}"))));
}
}
}
};
Sse::new(stream)
}
async fn stream_data(file_path: &str, extract_images: bool) -> Result<ChunkStream> {

View File

@ -8,9 +8,8 @@ use flexi_logger::{DeferredNow, Duplicate, FileSpec, Logger, LoggerHandle};
use flexi_logger::writers::FileLogWriter;
use log::{kv, Level};
use log::kv::{Key, Value, VisitSource};
use rocket::{get, post};
use rocket::serde::json::Json;
use rocket::serde::{Deserialize, Serialize};
use axum::Json;
use serde::{Deserialize, Serialize};
use crate::api_token::APIToken;
use crate::environment::is_dev;
@ -34,14 +33,17 @@ pub fn init_logging() {
false => log_config.push_str("info, "),
};
// Set the log level for the Rocket library:
log_config.push_str("rocket=info, ");
// Set the log level for the Rocket server:
log_config.push_str("rocket::server=warn, ");
// Set the log level for the Reqwest library:
log_config.push_str("reqwest::async_impl::client=info");
// Keep noisy HTTP/TLS internals at info level even in development builds:
log_config.push_str("h2=info, ");
log_config.push_str("hyper=info, ");
log_config.push_str("hyper_util=info, ");
log_config.push_str("axum=info, ");
log_config.push_str("axum_server=info, ");
log_config.push_str("tower=info, ");
log_config.push_str("tower_http=info, ");
log_config.push_str("rustls=info, ");
log_config.push_str("tokio_rustls=info, ");
log_config.push_str("reqwest=info");
// Configure the initial filename. On Unix systems, the file should start
// with a dot to be hidden.
@ -224,7 +226,6 @@ fn file_logger_format(
write!(w, "{}", &record.args())
}
#[get("/log/paths")]
pub async fn get_log_paths(_token: APIToken) -> Json<LogPathsResponse> {
Json(LogPathsResponse {
log_startup_path: LOG_STARTUP_PATH.get().expect("No startup log path was set").clone(),
@ -269,9 +270,7 @@ fn log_with_level(
}
/// Logs an event from the .NET server.
#[post("/log/event", data = "<event>")]
pub fn log_event(_token: APIToken, event: Json<LogEvent>) -> Json<LogEventResponse> {
let event = event.into_inner();
pub async fn log_event(_token: APIToken, Json(event): Json<LogEvent>) -> Json<LogEventResponse> {
let level = parse_dotnet_log_level(&event.level);
let message = event.message.as_str();
let category = event.category.as_str();

View File

@ -1,7 +1,6 @@
// Prevents an additional console window on Windows in release, DO NOT REMOVE!!
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
extern crate rocket;
extern crate core;
use log::{info, warn};
@ -12,7 +11,6 @@ use mindwork_ai_studio::log::init_logging;
use mindwork_ai_studio::metadata::MetaData;
use mindwork_ai_studio::runtime_api::start_runtime_api;
#[tokio::main]
async fn main() {
let metadata = MetaData::init_from_string(include_str!("../../metadata.txt"));

View File

@ -7,9 +7,8 @@ use std::path::Path;
use std::sync::{Arc, Mutex, OnceLock};
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
use rocket::get;
use rocket::serde::json::Json;
use rocket::serde::Serialize;
use axum::Json;
use serde::Serialize;
use crate::api_token::{APIToken};
use crate::environment::{is_dev, DATA_DIRECTORY};
use crate::certificate_factory::generate_certificate;
@ -70,8 +69,7 @@ pub struct ProvideQdrantInfo {
unavailable_reason: Option<String>,
}
#[get("/system/qdrant/info")]
pub fn qdrant_port(_token: APIToken) -> Json<ProvideQdrantInfo> {
pub async fn qdrant_port(_token: APIToken) -> Json<ProvideQdrantInfo> {
let status = QDRANT_STATUS.lock().unwrap();
let is_available = status.is_available;
let unavailable_reason = status.unavailable_reason.clone();

View File

@ -1,12 +1,16 @@
use log::info;
use once_cell::sync::Lazy;
use rocket::config::Shutdown;
use rocket::figment::Figment;
use rocket::routes;
use axum::routing::{get, post};
use axum::Router;
use axum_server::tls_rustls::RustlsConfig;
use std::net::SocketAddr;
use std::sync::Once;
use crate::runtime_certificate::{CERTIFICATE, CERTIFICATE_PRIVATE_KEY};
use crate::environment::is_dev;
use crate::network::get_available_port;
static RUSTLS_CRYPTO_PROVIDER_INIT: Once = Once::new();
/// The port used for the runtime API server. In the development environment, we use a fixed
/// port, in the production environment we use the next available port. This differentiation
/// is necessary because we cannot communicate the port to the .NET server in the development
@ -24,109 +28,55 @@ pub static API_SERVER_PORT: Lazy<u16> = Lazy::new(|| {
pub fn start_runtime_api() {
let api_port = *API_SERVER_PORT;
info!("Try to start the API server on 'http://localhost:{api_port}'...");
// Get the shutdown configuration:
let shutdown = create_shutdown();
// Configure the runtime API server:
let figment = Figment::from(rocket::Config::release_default())
let app = Router::new()
.route("/system/dotnet/port", get(crate::dotnet::dotnet_port))
.route("/system/dotnet/ready", get(crate::dotnet::dotnet_ready))
.route("/system/qdrant/info", get(crate::qdrant::qdrant_port))
.route("/clipboard/set", post(crate::clipboard::set_clipboard))
.route("/events", get(crate::app_window::get_event_stream))
.route("/updates/check", get(crate::app_window::check_for_update))
.route("/updates/install", get(crate::app_window::install_update))
.route("/app/exit", post(crate::app_window::exit_app))
.route("/select/directory", post(crate::file_actions::select_directory))
.route("/select/file", post(crate::file_actions::select_file))
.route("/select/files", post(crate::file_actions::select_files))
.route("/save/file", post(crate::file_actions::save_file))
.route("/secrets/get", post(crate::secret::get_secret))
.route("/secrets/store", post(crate::secret::store_secret))
.route("/secrets/delete", post(crate::secret::delete_secret))
.route("/system/directories/config", get(crate::environment::get_config_directory))
.route("/system/directories/data", get(crate::environment::get_data_directory))
.route("/system/language", get(crate::environment::read_user_language))
.route("/system/enterprise/config/id", get(crate::environment::read_enterprise_env_config_id))
.route("/system/enterprise/config/server", get(crate::environment::read_enterprise_env_config_server_url))
.route("/system/enterprise/config/encryption_secret", get(crate::environment::read_enterprise_env_config_encryption_secret))
.route("/system/enterprise/configs", get(crate::environment::read_enterprise_configs))
.route("/retrieval/fs/extract", get(crate::file_data::extract_data))
.route("/log/paths", get(crate::log::get_log_paths))
.route("/log/event", post(crate::log::log_event))
.route("/shortcuts/register", post(crate::app_window::register_shortcut))
.route("/shortcuts/validate", post(crate::app_window::validate_shortcut))
.route("/shortcuts/suspend", post(crate::app_window::suspend_shortcuts))
.route("/shortcuts/resume", post(crate::app_window::resume_shortcuts));
// We use the next available port which was determined before:
.merge(("port", api_port))
// The runtime API server should be accessible only from the local machine:
.merge(("address", "127.0.0.1"))
// We do not want to use the Ctrl+C signal to stop the server:
.merge(("ctrlc", false))
// Set a name for the server:
.merge(("ident", "AI Studio Runtime API"))
// Set the maximum number of workers and blocking threads:
.merge(("workers", 3))
.merge(("max_blocking", 12))
// No colors and emojis in the log output:
.merge(("cli_colors", false))
// Read the TLS certificate and key from the generated certificate data in-memory:
.merge(("tls.certs", CERTIFICATE.get().unwrap()))
.merge(("tls.key", CERTIFICATE_PRIVATE_KEY.get().unwrap()))
// Set the shutdown configuration:
.merge(("shutdown", shutdown));
//
// Start the runtime API server in a separate thread. This is necessary
// because the server is blocking, and we need to run the Tauri app in
// parallel:
//
tauri::async_runtime::spawn(async move {
rocket::custom(figment)
.mount("/", routes![
crate::dotnet::dotnet_port,
crate::dotnet::dotnet_ready,
crate::qdrant::qdrant_port,
crate::clipboard::set_clipboard,
crate::app_window::get_event_stream,
crate::app_window::check_for_update,
crate::app_window::install_update,
crate::app_window::exit_app,
crate::file_actions::select_directory,
crate::file_actions::select_file,
crate::file_actions::select_files,
crate::file_actions::save_file,
crate::secret::get_secret,
crate::secret::store_secret,
crate::secret::delete_secret,
crate::environment::get_data_directory,
crate::environment::get_config_directory,
crate::environment::read_user_language,
crate::environment::read_enterprise_env_config_id,
crate::environment::read_enterprise_env_config_server_url,
crate::environment::read_enterprise_env_config_encryption_secret,
crate::environment::read_enterprise_configs,
crate::file_data::extract_data,
crate::log::get_log_paths,
crate::log::log_event,
crate::app_window::register_shortcut,
crate::app_window::validate_shortcut,
crate::app_window::suspend_shortcuts,
crate::app_window::resume_shortcuts,
])
.ignite().await.unwrap()
.launch().await.unwrap();
install_rustls_crypto_provider();
let cert = CERTIFICATE.get().unwrap().clone();
let key = CERTIFICATE_PRIVATE_KEY.get().unwrap().clone();
let tls_config = RustlsConfig::from_pem(cert, key).await.unwrap();
let addr = SocketAddr::from(([127, 0, 0, 1], api_port));
axum_server::bind_rustls(addr, tls_config)
.serve(app.into_make_service())
.await
.unwrap();
});
}
fn create_shutdown() -> Shutdown {
//
// Create a shutdown configuration, depending on the operating system:
//
#[cfg(unix)]
{
use std::collections::HashSet;
let mut shutdown = Shutdown {
// We do not want to use the Ctrl+C signal to stop the server:
ctrlc: false,
// Everything else is set to default for now:
..Shutdown::default()
};
shutdown.signals = HashSet::new();
shutdown
}
#[cfg(windows)]
{
Shutdown {
// We do not want to use the Ctrl+C signal to stop the server:
ctrlc: false,
// Everything else is set to default for now:
..Shutdown::default()
}
}
fn install_rustls_crypto_provider() {
RUSTLS_CRYPTO_PROVIDER_INIT.call_once(|| {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
});
}

View File

@ -1,33 +1,29 @@
use once_cell::sync::Lazy;
use rocket::http::Status;
use rocket::Request;
use rocket::request::FromRequest;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::http::StatusCode;
use crate::api_token::{generate_api_token, APIToken};
pub static API_TOKEN: Lazy<APIToken> = Lazy::new(|| generate_api_token());
pub static API_TOKEN: Lazy<APIToken> = Lazy::new(generate_api_token);
/// The request outcome type used to handle API token requests.
type RequestOutcome<R, T> = rocket::request::Outcome<R, T>;
impl<S> FromRequestParts<S> for APIToken
where
S: Send + Sync,
{
type Rejection = StatusCode;
/// The request outcome implementation for the API token.
#[rocket::async_trait]
impl<'r> FromRequest<'r> for APIToken {
type Error = APITokenError;
/// Handles the API token requests.
async fn from_request(request: &'r Request<'_>) -> RequestOutcome<Self, Self::Error> {
let token = request.headers().get_one("token");
match token {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
match parts.headers.get("token").and_then(|value| value.to_str().ok()) {
Some(token) => {
let received_token = APIToken::from_hex_text(token);
if API_TOKEN.validate(&received_token) {
RequestOutcome::Success(received_token)
Ok(received_token)
} else {
RequestOutcome::Error((Status::Unauthorized, APITokenError::Invalid))
Err(StatusCode::UNAUTHORIZED)
}
}
None => RequestOutcome::Error((Status::Unauthorized, APITokenError::Missing)),
None => Err(StatusCode::UNAUTHORIZED),
}
}
}

View File

@ -1,15 +1,13 @@
use keyring::Entry;
use log::{error, info, warn};
use rocket::post;
use rocket::serde::json::Json;
use axum::Json;
use serde::{Deserialize, Serialize};
use keyring::error::Error::NoEntry;
use crate::api_token::APIToken;
use crate::encryption::{EncryptedText, ENCRYPTION};
/// Stores a secret in the secret store using the operating system's keyring.
#[post("/secrets/store", data = "<request>")]
pub fn store_secret(_token: APIToken, request: Json<StoreSecret>) -> Json<StoreSecretResponse> {
pub async fn store_secret(_token: APIToken, request: Json<StoreSecret>) -> Json<StoreSecretResponse> {
let user_name = request.user_name.as_str();
let decrypted_text = match ENCRYPTION.decrypt(&request.secret) {
Ok(text) => text,
@ -60,8 +58,7 @@ pub struct StoreSecretResponse {
}
/// Retrieves a secret from the secret store using the operating system's keyring.
#[post("/secrets/get", data = "<request>")]
pub fn get_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<RequestedSecret> {
pub async fn get_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<RequestedSecret> {
let user_name = request.user_name.as_str();
let service = format!("mindwork-ai-studio::{}", request.destination);
let entry = Entry::new(service.as_str(), user_name).unwrap();
@ -121,8 +118,7 @@ pub struct RequestedSecret {
}
/// Deletes a secret from the secret store using the operating system's keyring.
#[post("/secrets/delete", data = "<request>")]
pub fn delete_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<DeleteSecretResponse> {
pub async fn delete_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<DeleteSecretResponse> {
let user_name = request.user_name.as_str();
let service = format!("mindwork-ai-studio::{}", request.destination);
let entry = Entry::new(service.as_str(), user_name).unwrap();