Migrated IPC server to use axum

This commit is contained in:
Thorsten Sommer 2026-05-12 19:17:49 +02:00
parent f69186f7a9
commit 978b261c13
Signed by untrusted user who does not match committer: tsommer
GPG Key ID: 371BBA77A02C0108
15 changed files with 387 additions and 914 deletions

794
runtime/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -25,7 +25,9 @@ async-stream = "0.3.6"
flexi_logger = "0.31.8" flexi_logger = "0.31.8"
log = { version = "0.4.29", features = ["kv"] } log = { version = "0.4.29", features = ["kv"] }
once_cell = "1.21.4" once_cell = "1.21.4"
rocket = { version = "0.5.1", features = ["json", "tls"] } axum = { version = "0.8.9", features = ["http2", "json", "query", "tokio"] }
axum-server = { version = "0.8.0", features = ["tls-rustls"] }
rustls = { version = "0.23.28", default-features = false, features = ["aws_lc_rs"] }
rand = "0.10.1" rand = "0.10.1"
rand_chacha = "0.10.0" rand_chacha = "0.10.0"
base64 = "0.22.1" base64 = "0.22.1"
@ -46,7 +48,6 @@ strum_macros = "0.28.0"
sysinfo = "0.38.4" sysinfo = "0.38.4"
# Fixes security vulnerability downstream, where the upstream is not fixed yet: # Fixes security vulnerability downstream, where the upstream is not fixed yet:
time = "0.3.47" # -> Rocket
bytes = "1.11.1" # -> almost every dependency bytes = "1.11.1" # -> almost every dependency
[target.'cfg(target_os = "linux")'.dependencies] [target.'cfg(target_os = "linux")'.dependencies]

View File

@ -1,13 +1,16 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Mutex; use std::sync::Mutex;
use std::time::Duration; use std::time::Duration;
use async_stream::stream;
use axum::body::Body;
use axum::http::header::CONTENT_TYPE;
use axum::response::{IntoResponse, Response};
use axum::Json;
use bytes::Bytes;
use log::{debug, error, info, trace, warn}; use log::{debug, error, info, trace, warn};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use rocket::{get, post}; use serde::{Deserialize, Serialize};
use rocket::response::stream::TextStream;
use rocket::serde::json::Json;
use rocket::serde::Serialize;
use serde::Deserialize;
use strum_macros::Display; use strum_macros::Display;
use tauri::{DragDropEvent,RunEvent, Manager, WindowEvent, generate_context}; use tauri::{DragDropEvent,RunEvent, Manager, WindowEvent, generate_context};
use tauri::path::PathResolver; use tauri::path::PathResolver;
@ -256,8 +259,7 @@ fn should_open_in_system_browser<R: tauri::Runtime>(webview: &tauri::Webview<R>,
/// When the client disconnects, the stream is closed. But we try to not lose events in between. /// When the client disconnects, the stream is closed. But we try to not lose events in between.
/// The client is expected to reconnect automatically when the connection is closed and continue /// The client is expected to reconnect automatically when the connection is closed and continue
/// listening for events. /// listening for events.
#[get("/events")] pub async fn get_event_stream(_token: APIToken) -> Response {
pub async fn get_event_stream(_token: APIToken) -> TextStream![String] {
// Get the lock to the event broadcast sender: // Get the lock to the event broadcast sender:
let event_broadcast_lock = EVENT_BROADCAST.lock().unwrap(); let event_broadcast_lock = EVENT_BROADCAST.lock().unwrap();
@ -269,8 +271,7 @@ pub async fn get_event_stream(_token: APIToken) -> TextStream![String] {
// Drop the lock to allow other access to the sender: // Drop the lock to allow other access to the sender:
drop(event_broadcast_lock); drop(event_broadcast_lock);
// Create the event stream: let stream = stream! {
TextStream! {
loop { loop {
// Wait at most 3 seconds for an event: // Wait at most 3 seconds for an event:
match time::timeout(Duration::from_secs(3), event_receiver.recv()).await { match time::timeout(Duration::from_secs(3), event_receiver.recv()).await {
@ -281,11 +282,11 @@ pub async fn get_event_stream(_token: APIToken) -> TextStream![String] {
// is serialized as a single line so that the client can parse it // is serialized as a single line so that the client can parse it
// correctly: // correctly:
let event_json = serde_json::to_string(&event).unwrap(); let event_json = serde_json::to_string(&event).unwrap();
yield event_json; yield Ok::<Bytes, Infallible>(Bytes::from(event_json));
// The client expects a newline after each event because we are using // The client expects a newline after each event because we are using
// a method to read the stream line-by-line: // a method to read the stream line-by-line:
yield "\n".to_string(); yield Ok::<Bytes, Infallible>(Bytes::from("\n"));
}, },
// Case: we lagged behind and missed some events // Case: we lagged behind and missed some events
@ -305,15 +306,17 @@ pub async fn get_event_stream(_token: APIToken) -> TextStream![String] {
// Again, we have to serialize the event as a single line: // Again, we have to serialize the event as a single line:
let event_json = serde_json::to_string(&ping_event).unwrap(); let event_json = serde_json::to_string(&ping_event).unwrap();
yield event_json; yield Ok::<Bytes, Infallible>(Bytes::from(event_json));
// The client expects a newline after each event because we are using // The client expects a newline after each event because we are using
// a method to read the stream line-by-line: // a method to read the stream line-by-line:
yield "\n".to_string(); yield Ok::<Bytes, Infallible>(Bytes::from("\n"));
}, },
} }
} }
} };
([(CONTENT_TYPE, "application/jsonl")], Body::from_stream(stream)).into_response()
} }
/// Data structure representing a Tauri event for our event API. /// Data structure representing a Tauri event for our event API.
@ -428,7 +431,6 @@ pub async fn change_location_to(url: &str) {
} }
/// Checks for updates. /// Checks for updates.
#[get("/updates/check")]
pub async fn check_for_update(_token: APIToken) -> Json<CheckUpdateResponse> { pub async fn check_for_update(_token: APIToken) -> Json<CheckUpdateResponse> {
if is_dev() { if is_dev() {
warn!(Source = "Updater"; "The app is running in development mode; skipping update check."); warn!(Source = "Updater"; "The app is running in development mode; skipping update check.");
@ -514,7 +516,6 @@ pub struct CheckUpdateResponse {
} }
/// Installs the update. /// Installs the update.
#[get("/updates/install")]
pub async fn install_update(_token: APIToken) { pub async fn install_update(_token: APIToken) {
if is_dev() { if is_dev() {
warn!(Source = "Updater"; "The app is running in development mode; skipping update installation."); warn!(Source = "Updater"; "The app is running in development mode; skipping update installation.");
@ -623,8 +624,7 @@ fn register_shortcut_with_callback<R: tauri::Runtime>(
} }
/// Requests a controlled shutdown of the entire desktop application. /// Requests a controlled shutdown of the entire desktop application.
#[post("/app/exit")] pub async fn exit_app(_token: APIToken) -> Json<AppExitResponse> {
pub fn exit_app(_token: APIToken) -> Json<AppExitResponse> {
let app_handle = { let app_handle = {
let main_window_lock = MAIN_WINDOW.lock().unwrap(); let main_window_lock = MAIN_WINDOW.lock().unwrap();
match main_window_lock.as_ref() { match main_window_lock.as_ref() {
@ -653,8 +653,7 @@ pub fn exit_app(_token: APIToken) -> Json<AppExitResponse> {
/// Registers or updates a global shortcut. If the shortcut string is empty, /// Registers or updates a global shortcut. If the shortcut string is empty,
/// the existing shortcut for that name will be unregistered. /// the existing shortcut for that name will be unregistered.
#[post("/shortcuts/register", data = "<payload>")] pub async fn register_shortcut(_token: APIToken, payload: Json<RegisterShortcutRequest>) -> Json<ShortcutResponse> {
pub fn register_shortcut(_token: APIToken, payload: Json<RegisterShortcutRequest>) -> Json<ShortcutResponse> {
let id = payload.id; let id = payload.id;
let new_shortcut = payload.shortcut.clone(); let new_shortcut = payload.shortcut.clone();
@ -761,8 +760,7 @@ pub struct ShortcutValidationResponse {
/// Validates a shortcut string without registering it. /// Validates a shortcut string without registering it.
/// Checks if the shortcut syntax is valid and if it /// Checks if the shortcut syntax is valid and if it
/// conflicts with existing shortcuts. /// conflicts with existing shortcuts.
#[post("/shortcuts/validate", data = "<payload>")] pub async fn validate_shortcut(_token: APIToken, payload: Json<ValidateShortcutRequest>) -> Json<ShortcutValidationResponse> {
pub fn validate_shortcut(_token: APIToken, payload: Json<ValidateShortcutRequest>) -> Json<ShortcutValidationResponse> {
let shortcut = payload.shortcut.clone(); let shortcut = payload.shortcut.clone();
// Empty shortcuts are always valid (means "disabled"): // Empty shortcuts are always valid (means "disabled"):
@ -816,8 +814,7 @@ pub fn validate_shortcut(_token: APIToken, payload: Json<ValidateShortcutRequest
/// The shortcuts remain in our internal map, so they can be re-registered on resume. /// The shortcuts remain in our internal map, so they can be re-registered on resume.
/// This is useful when opening a dialog to configure shortcuts, so the user can /// This is useful when opening a dialog to configure shortcuts, so the user can
/// press the current shortcut to re-enter it without triggering the action. /// press the current shortcut to re-enter it without triggering the action.
#[post("/shortcuts/suspend")] pub async fn suspend_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
pub fn suspend_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
// Get the main window to access the global shortcut manager: // Get the main window to access the global shortcut manager:
let main_window_lock = MAIN_WINDOW.lock().unwrap(); let main_window_lock = MAIN_WINDOW.lock().unwrap();
let main_window = match main_window_lock.as_ref() { let main_window = match main_window_lock.as_ref() {
@ -853,8 +850,7 @@ pub fn suspend_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
} }
/// Resumes shortcut processing by re-registering all shortcuts with the OS. /// Resumes shortcut processing by re-registering all shortcuts with the OS.
#[post("/shortcuts/resume")] pub async fn resume_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
pub fn resume_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
// Get the main window to access the global shortcut manager: // Get the main window to access the global shortcut manager:
let main_window_lock = MAIN_WINDOW.lock().unwrap(); let main_window_lock = MAIN_WINDOW.lock().unwrap();
let main_window = match main_window_lock.as_ref() { let main_window = match main_window_lock.as_ref() {
@ -954,36 +950,6 @@ fn validate_shortcut_syntax(shortcut: &str) -> bool {
has_key has_key
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tauri_localhost_is_tauri_asset_url() {
let https_url = tauri::Url::parse("https://tauri.localhost/index.html").unwrap();
let http_url = tauri::Url::parse("http://tauri.localhost/index.html").unwrap();
assert!(is_tauri_asset_url(&https_url));
assert!(is_tauri_asset_url(&http_url));
}
#[test]
fn localhost_app_url_is_not_tauri_asset_url() {
let url = tauri::Url::parse("http://localhost:12345/").unwrap();
assert!(!is_tauri_asset_url(&url));
assert!(is_local_http_url(&url));
}
#[test]
fn external_url_is_not_internal_url() {
let url = tauri::Url::parse("https://example.com/").unwrap();
assert!(!is_tauri_asset_url(&url));
assert!(!is_local_http_url(&url));
}
}
fn set_pdfium_path<R: tauri::Runtime>(path_resolver: &PathResolver<R>) { fn set_pdfium_path<R: tauri::Runtime>(path_resolver: &PathResolver<R>) {
let resource_dir = match path_resolver.resource_dir() { let resource_dir = match path_resolver.resource_dir() {
Ok(path) => path, Ok(path) => path,
@ -1012,3 +978,33 @@ fn set_pdfium_path<R: tauri::Runtime>(path_resolver: &PathResolver<R>) {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tauri_localhost_is_tauri_asset_url() {
let https_url = tauri::Url::parse("https://tauri.localhost/index.html").unwrap();
let http_url = tauri::Url::parse("http://tauri.localhost/index.html").unwrap();
assert!(is_tauri_asset_url(&https_url));
assert!(is_tauri_asset_url(&http_url));
}
#[test]
fn localhost_app_url_is_not_tauri_asset_url() {
let url = tauri::Url::parse("http://localhost:12345/").unwrap();
assert!(!is_tauri_asset_url(&url));
assert!(is_local_http_url(&url));
}
#[test]
fn external_url_is_not_internal_url() {
let url = tauri::Url::parse("https://example.com/").unwrap();
assert!(!is_tauri_asset_url(&url));
assert!(!is_local_http_url(&url));
}
}

View File

@ -1,14 +1,13 @@
use arboard::Clipboard; use arboard::Clipboard;
use log::{debug, error}; use log::{debug, error};
use rocket::post; use axum::Json;
use rocket::serde::json::Json;
use serde::Serialize; use serde::Serialize;
use crate::api_token::APIToken; use crate::api_token::APIToken;
use crate::encryption::{EncryptedText, ENCRYPTION}; use crate::encryption::{EncryptedText, ENCRYPTION};
/// Sets the clipboard text to the provided encrypted text. /// Sets the clipboard text to the provided encrypted text.
#[post("/clipboard/set", data = "<encrypted_text>")] pub async fn set_clipboard(_token: APIToken, encrypted_text: String) -> Json<SetClipboardResponse> {
pub fn set_clipboard(_token: APIToken, encrypted_text: EncryptedText) -> Json<SetClipboardResponse> { let encrypted_text = EncryptedText::new(encrypted_text);
// Decrypt this text first: // Decrypt this text first:
let decrypted_text = match ENCRYPTION.decrypt(&encrypted_text) { let decrypted_text = match ENCRYPTION.decrypt(&encrypted_text) {

View File

@ -5,7 +5,6 @@ use base64::Engine;
use base64::prelude::BASE64_STANDARD; use base64::prelude::BASE64_STANDARD;
use log::{error, info, warn}; use log::{error, info, warn};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use rocket::get;
use tauri::Url; use tauri::Url;
use tauri_plugin_shell::process::{CommandChild, CommandEvent}; use tauri_plugin_shell::process::{CommandChild, CommandEvent};
use tauri_plugin_shell::ShellExt; use tauri_plugin_shell::ShellExt;
@ -89,8 +88,7 @@ fn sanitize_stdout_line(line: &str) -> String {
/// Returns the desired port of the .NET server. Our .NET app calls this endpoint to get /// Returns the desired port of the .NET server. Our .NET app calls this endpoint to get
/// the port where the .NET server should listen to. /// the port where the .NET server should listen to.
#[get("/system/dotnet/port")] pub async fn dotnet_port(_token: APIToken) -> String {
pub fn dotnet_port(_token: APIToken) -> String {
let dotnet_server_port = *DOTNET_SERVER_PORT; let dotnet_server_port = *DOTNET_SERVER_PORT;
format!("{dotnet_server_port}") format!("{dotnet_server_port}")
} }
@ -179,7 +177,6 @@ pub fn start_dotnet_server<R: tauri::Runtime>(app_handle: tauri::AppHandle<R>) {
} }
/// This endpoint is called by the .NET server to signal that the server is ready. /// This endpoint is called by the .NET server to signal that the server is ready.
#[get("/system/dotnet/ready")]
pub async fn dotnet_ready(_token: APIToken) { pub async fn dotnet_ready(_token: APIToken) {
// We create a manual scope for the lock to be released as soon as possible. // We create a manual scope for the lock to be released as soon as possible.

View File

@ -9,19 +9,13 @@ use once_cell::sync::Lazy;
use pbkdf2::pbkdf2; use pbkdf2::pbkdf2;
use rand::rngs::SysRng; use rand::rngs::SysRng;
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};
use rocket::{data, Data, Request}; use serde::{Deserialize, Serialize};
use rocket::data::ToByteUnit;
use rocket::http::Status;
use rocket::serde::{Deserialize, Serialize};
use sha2::Sha512; use sha2::Sha512;
use tokio::io::AsyncReadExt;
type Aes256CbcEnc = cbc::Encryptor<aes::Aes256>; type Aes256CbcEnc = cbc::Encryptor<aes::Aes256>;
type Aes256CbcDec = cbc::Decryptor<aes::Aes256>; type Aes256CbcDec = cbc::Decryptor<aes::Aes256>;
type DataOutcome<'r, T> = data::Outcome<'r, T>;
/// The encryption instance used for the IPC channel. /// The encryption instance used for the IPC channel.
pub static ENCRYPTION: Lazy<Encryption> = Lazy::new(|| { pub static ENCRYPTION: Lazy<Encryption> = Lazy::new(|| {
// //
@ -170,27 +164,4 @@ impl fmt::Display for EncryptedText {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "**********") write!(f, "**********")
} }
}
/// Use Case: When we receive encrypted text from the client as body (e.g., in a POST request).
/// We must interpret the body as EncryptedText.
#[rocket::async_trait]
impl<'r> data::FromData<'r> for EncryptedText {
type Error = String;
/// Parses the data as EncryptedText.
async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> DataOutcome<'r, Self> {
let content_type = req.content_type();
if content_type.map_or(true, |ct| !ct.is_text()) {
return DataOutcome::Forward((data, Status::Ok));
}
let mut stream = data.open(2.mebibytes());
let mut body = String::new();
if let Err(e) = stream.read_to_string(&mut body).await {
return DataOutcome::Error((Status::InternalServerError, format!("Failed to read data: {}", e)));
}
DataOutcome::Success(EncryptedText(body))
}
} }

View File

@ -1,7 +1,6 @@
use crate::api_token::APIToken; use crate::api_token::APIToken;
use axum::Json;
use log::{debug, info, warn}; use log::{debug, info, warn};
use rocket::get;
use rocket::serde::json::Json;
use serde::Serialize; use serde::Serialize;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::env; use std::env;
@ -29,8 +28,7 @@ pub static CONFIG_DIRECTORY: OnceLock<String> = OnceLock::new();
static USER_LANGUAGE: OnceLock<String> = OnceLock::new(); static USER_LANGUAGE: OnceLock<String> = OnceLock::new();
/// Returns the config directory. /// Returns the config directory.
#[get("/system/directories/config")] pub async fn get_config_directory(_token: APIToken) -> String {
pub fn get_config_directory(_token: APIToken) -> String {
match CONFIG_DIRECTORY.get() { match CONFIG_DIRECTORY.get() {
Some(config_directory) => config_directory.clone(), Some(config_directory) => config_directory.clone(),
None => String::from(""), None => String::from(""),
@ -38,8 +36,7 @@ pub fn get_config_directory(_token: APIToken) -> String {
} }
/// Returns the data directory. /// Returns the data directory.
#[get("/system/directories/data")] pub async fn get_data_directory(_token: APIToken) -> String {
pub fn get_data_directory(_token: APIToken) -> String {
match DATA_DIRECTORY.get() { match DATA_DIRECTORY.get() {
Some(data_directory) => data_directory.clone(), Some(data_directory) => data_directory.clone(),
None => String::from(""), None => String::from(""),
@ -150,8 +147,7 @@ fn detect_user_language() -> (String, LanguageDetectionSource) {
) )
} }
#[get("/system/language")] pub async fn read_user_language(_token: APIToken) -> String {
pub fn read_user_language(_token: APIToken) -> String {
USER_LANGUAGE USER_LANGUAGE
.get_or_init(|| { .get_or_init(|| {
let (user_language, source) = detect_user_language(); let (user_language, source) = detect_user_language();
@ -194,8 +190,7 @@ struct EnterpriseSourceData {
encryption_secret: String, encryption_secret: String,
} }
#[get("/system/enterprise/config/id")] pub async fn read_enterprise_env_config_id(_token: APIToken) -> String {
pub fn read_enterprise_env_config_id(_token: APIToken) -> String {
debug!("Trying to read the effective enterprise configuration ID."); debug!("Trying to read the effective enterprise configuration ID.");
resolve_effective_enterprise_config_source() resolve_effective_enterprise_config_source()
.configs .configs
@ -205,8 +200,7 @@ pub fn read_enterprise_env_config_id(_token: APIToken) -> String {
.unwrap_or_default() .unwrap_or_default()
} }
#[get("/system/enterprise/config/server")] pub async fn read_enterprise_env_config_server_url(_token: APIToken) -> String {
pub fn read_enterprise_env_config_server_url(_token: APIToken) -> String {
debug!("Trying to read the effective enterprise configuration server URL."); debug!("Trying to read the effective enterprise configuration server URL.");
resolve_effective_enterprise_config_source() resolve_effective_enterprise_config_source()
.configs .configs
@ -216,15 +210,13 @@ pub fn read_enterprise_env_config_server_url(_token: APIToken) -> String {
.unwrap_or_default() .unwrap_or_default()
} }
#[get("/system/enterprise/config/encryption_secret")] pub async fn read_enterprise_env_config_encryption_secret(_token: APIToken) -> String {
pub fn read_enterprise_env_config_encryption_secret(_token: APIToken) -> String {
debug!("Trying to read the effective enterprise configuration encryption secret."); debug!("Trying to read the effective enterprise configuration encryption secret.");
resolve_effective_enterprise_secret_source().encryption_secret resolve_effective_enterprise_secret_source().encryption_secret
} }
/// Returns all enterprise configurations from the effective source. /// Returns all enterprise configurations from the effective source.
#[get("/system/enterprise/configs")] pub async fn read_enterprise_configs(_token: APIToken) -> Json<Vec<EnterpriseConfig>> {
pub fn read_enterprise_configs(_token: APIToken) -> Json<Vec<EnterpriseConfig>> {
info!("Trying to read the effective enterprise configurations."); info!("Trying to read the effective enterprise configurations.");
Json(resolve_effective_enterprise_config_source().configs) Json(resolve_effective_enterprise_config_source().configs)
} }

View File

@ -1,7 +1,7 @@
use log::{error, info}; use log::{error, info};
use rocket::post; use axum::extract::Query;
use rocket::serde::{Deserialize, Serialize}; use axum::Json;
use rocket::serde::json::Json; use serde::{Deserialize, Serialize};
use tauri_plugin_dialog::{DialogExt, FileDialogBuilder}; use tauri_plugin_dialog::{DialogExt, FileDialogBuilder};
use crate::api_token::APIToken; use crate::api_token::APIToken;
use crate::app_window::MAIN_WINDOW; use crate::app_window::MAIN_WINDOW;
@ -11,6 +11,11 @@ pub struct PreviousDirectory {
path: String, path: String,
} }
#[derive(Deserialize)]
pub struct SelectDirectoryQuery {
title: String,
}
#[derive(Clone, Deserialize)] #[derive(Clone, Deserialize)]
pub struct FileTypeFilter { pub struct FileTypeFilter {
filter_name: String, filter_name: String,
@ -61,10 +66,9 @@ pub struct PreviousFile {
} }
/// Let the user select a directory. /// Let the user select a directory.
#[post("/select/directory?<title>", data = "<previous_directory>")] pub async fn select_directory(
pub fn select_directory(
_token: APIToken, _token: APIToken,
title: &str, Query(query): Query<SelectDirectoryQuery>,
previous_directory: Option<Json<PreviousDirectory>>, previous_directory: Option<Json<PreviousDirectory>>,
) -> Json<DirectorySelectionResponse> { ) -> Json<DirectorySelectionResponse> {
let main_window_lock = MAIN_WINDOW.lock().unwrap(); let main_window_lock = MAIN_WINDOW.lock().unwrap();
@ -79,7 +83,7 @@ pub fn select_directory(
} }
}; };
let mut dialog = main_window.dialog().file().set_parent(main_window).set_title(title); let mut dialog = main_window.dialog().file().set_parent(main_window).set_title(&query.title);
if let Some(previous) = previous_directory { if let Some(previous) = previous_directory {
dialog = dialog.set_directory(previous.path.clone()); dialog = dialog.set_directory(previous.path.clone());
} }
@ -118,8 +122,7 @@ pub fn select_directory(
} }
/// Let the user select a file. /// Let the user select a file.
#[post("/select/file", data = "<payload>")] pub async fn select_file(
pub fn select_file(
_token: APIToken, _token: APIToken,
payload: Json<SelectFileOptions>, payload: Json<SelectFileOptions>,
) -> Json<FileSelectionResponse> { ) -> Json<FileSelectionResponse> {
@ -178,8 +181,7 @@ pub fn select_file(
} }
/// Let the user select some files. /// Let the user select some files.
#[post("/select/files", data = "<payload>")] pub async fn select_files(
pub fn select_files(
_token: APIToken, _token: APIToken,
payload: Json<SelectFileOptions>, payload: Json<SelectFileOptions>,
) -> Json<FilesSelectionResponse> { ) -> Json<FilesSelectionResponse> {
@ -229,8 +231,7 @@ pub fn select_files(
} }
} }
#[post("/save/file", data = "<payload>")] pub async fn save_file(_token: APIToken, payload: Json<SaveFileOptions>) -> Json<FileSaveResponse> {
pub fn save_file(_token: APIToken, payload: Json<SaveFileOptions>) -> Json<FileSaveResponse> {
// Create a new file dialog builder: // Create a new file dialog builder:
let file_dialog = MAIN_WINDOW let file_dialog = MAIN_WINDOW
.lock() .lock()

View File

@ -1,19 +1,18 @@
use std::cmp::min; use std::cmp::min;
use std::convert::Infallible;
use crate::api_token::APIToken; use crate::api_token::APIToken;
use crate::pandoc::PandocProcessBuilder; use crate::pandoc::PandocProcessBuilder;
use crate::pdfium::PdfiumInit; use crate::pdfium::PdfiumInit;
use async_stream::stream; use async_stream::stream;
use axum::extract::Query;
use axum::response::sse::{Event, Sse};
use base64::{engine::general_purpose, Engine as _}; use base64::{engine::general_purpose, Engine as _};
use calamine::{open_workbook_auto, Reader}; use calamine::{open_workbook_auto, Reader};
use file_format::{FileFormat, Kind}; use file_format::{FileFormat, Kind};
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use pdfium_render::prelude::Pdfium; use pdfium_render::prelude::Pdfium;
use pptx_to_md::{ImageHandlingMode, ParserConfig, PptxContainer}; use pptx_to_md::{ImageHandlingMode, ParserConfig, PptxContainer};
use rocket::get; use serde::{Deserialize, Serialize};
use rocket::response::stream::{Event, EventStream};
use rocket::serde::Serialize;
use rocket::tokio::select;
use rocket::Shutdown;
use std::path::Path; use std::path::Path;
use std::pin::Pin; use std::pin::Pin;
use log::{debug, error}; use log::{debug, error};
@ -82,39 +81,45 @@ const IMAGE_SEGMENT_SIZE_IN_CHARS: usize = 8_192; // equivalent to ~ 5500 token
type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>; type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
type ChunkStream = Pin<Box<dyn Stream<Item = Result<Chunk>> + Send>>; type ChunkStream = Pin<Box<dyn Stream<Item = Result<Chunk>> + Send>>;
#[get("/retrieval/fs/extract?<path>&<stream_id>&<extract_images>")] #[derive(Deserialize)]
pub async fn extract_data(_token: APIToken, path: String, stream_id: String, extract_images: bool, mut end: Shutdown) -> EventStream![] { pub struct ExtractDataQuery {
EventStream! { path: String,
let stream_result = stream_data(&path, extract_images).await; stream_id: String,
let id_ref = &stream_id; extract_images: bool,
}
pub async fn extract_data(
_token: APIToken,
Query(query): Query<ExtractDataQuery>,
) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
let stream = stream! {
let stream_result = stream_data(&query.path, query.extract_images).await;
let id_ref = &query.stream_id;
match stream_result { match stream_result {
Ok(mut stream) => { Ok(mut stream) => {
loop { while let Some(chunk) = stream.next().await {
let chunk = select! { match chunk {
chunk = stream.next() => match chunk { Ok(mut chunk) => {
Some(Ok(mut chunk)) => { chunk.set_stream_id(id_ref);
chunk.set_stream_id(id_ref); yield Ok(Event::default().json_data(&chunk).unwrap_or_else(|e| Event::default().data(format!("Error: {e}"))));
chunk
},
Some(Err(e)) => {
yield Event::json(&format!("Error: {e}"));
break;
},
None => break,
}, },
_ = &mut end => break,
}; Err(e) => {
yield Ok(Event::default().json_data(format!("Error: {e}")).unwrap_or_else(|_| Event::default().data(format!("Error: {e}"))));
yield Event::json(&chunk); break;
},
}
} }
}, },
Err(e) => { Err(e) => {
yield Event::json(&format!("Error starting stream: {e}")); yield Ok(Event::default().json_data(format!("Error starting stream: {e}")).unwrap_or_else(|_| Event::default().data(format!("Error starting stream: {e}"))));
} }
} }
} };
Sse::new(stream)
} }
async fn stream_data(file_path: &str, extract_images: bool) -> Result<ChunkStream> { async fn stream_data(file_path: &str, extract_images: bool) -> Result<ChunkStream> {

View File

@ -8,9 +8,8 @@ use flexi_logger::{DeferredNow, Duplicate, FileSpec, Logger, LoggerHandle};
use flexi_logger::writers::FileLogWriter; use flexi_logger::writers::FileLogWriter;
use log::{kv, Level}; use log::{kv, Level};
use log::kv::{Key, Value, VisitSource}; use log::kv::{Key, Value, VisitSource};
use rocket::{get, post}; use axum::Json;
use rocket::serde::json::Json; use serde::{Deserialize, Serialize};
use rocket::serde::{Deserialize, Serialize};
use crate::api_token::APIToken; use crate::api_token::APIToken;
use crate::environment::is_dev; use crate::environment::is_dev;
@ -34,14 +33,17 @@ pub fn init_logging() {
false => log_config.push_str("info, "), false => log_config.push_str("info, "),
}; };
// Set the log level for the Rocket library: // Keep noisy HTTP/TLS internals at info level even in development builds:
log_config.push_str("rocket=info, "); log_config.push_str("h2=info, ");
log_config.push_str("hyper=info, ");
// Set the log level for the Rocket server: log_config.push_str("hyper_util=info, ");
log_config.push_str("rocket::server=warn, "); log_config.push_str("axum=info, ");
log_config.push_str("axum_server=info, ");
// Set the log level for the Reqwest library: log_config.push_str("tower=info, ");
log_config.push_str("reqwest::async_impl::client=info"); log_config.push_str("tower_http=info, ");
log_config.push_str("rustls=info, ");
log_config.push_str("tokio_rustls=info, ");
log_config.push_str("reqwest=info");
// Configure the initial filename. On Unix systems, the file should start // Configure the initial filename. On Unix systems, the file should start
// with a dot to be hidden. // with a dot to be hidden.
@ -224,7 +226,6 @@ fn file_logger_format(
write!(w, "{}", &record.args()) write!(w, "{}", &record.args())
} }
#[get("/log/paths")]
pub async fn get_log_paths(_token: APIToken) -> Json<LogPathsResponse> { pub async fn get_log_paths(_token: APIToken) -> Json<LogPathsResponse> {
Json(LogPathsResponse { Json(LogPathsResponse {
log_startup_path: LOG_STARTUP_PATH.get().expect("No startup log path was set").clone(), log_startup_path: LOG_STARTUP_PATH.get().expect("No startup log path was set").clone(),
@ -269,9 +270,7 @@ fn log_with_level(
} }
/// Logs an event from the .NET server. /// Logs an event from the .NET server.
#[post("/log/event", data = "<event>")] pub async fn log_event(_token: APIToken, Json(event): Json<LogEvent>) -> Json<LogEventResponse> {
pub fn log_event(_token: APIToken, event: Json<LogEvent>) -> Json<LogEventResponse> {
let event = event.into_inner();
let level = parse_dotnet_log_level(&event.level); let level = parse_dotnet_log_level(&event.level);
let message = event.message.as_str(); let message = event.message.as_str();
let category = event.category.as_str(); let category = event.category.as_str();

View File

@ -1,7 +1,6 @@
// Prevents an additional console window on Windows in release, DO NOT REMOVE!! // Prevents an additional console window on Windows in release, DO NOT REMOVE!!
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
extern crate rocket;
extern crate core; extern crate core;
use log::{info, warn}; use log::{info, warn};
@ -12,7 +11,6 @@ use mindwork_ai_studio::log::init_logging;
use mindwork_ai_studio::metadata::MetaData; use mindwork_ai_studio::metadata::MetaData;
use mindwork_ai_studio::runtime_api::start_runtime_api; use mindwork_ai_studio::runtime_api::start_runtime_api;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let metadata = MetaData::init_from_string(include_str!("../../metadata.txt")); let metadata = MetaData::init_from_string(include_str!("../../metadata.txt"));

View File

@ -7,9 +7,8 @@ use std::path::Path;
use std::sync::{Arc, Mutex, OnceLock}; use std::sync::{Arc, Mutex, OnceLock};
use log::{debug, error, info, warn}; use log::{debug, error, info, warn};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use rocket::get; use axum::Json;
use rocket::serde::json::Json; use serde::Serialize;
use rocket::serde::Serialize;
use crate::api_token::{APIToken}; use crate::api_token::{APIToken};
use crate::environment::{is_dev, DATA_DIRECTORY}; use crate::environment::{is_dev, DATA_DIRECTORY};
use crate::certificate_factory::generate_certificate; use crate::certificate_factory::generate_certificate;
@ -70,8 +69,7 @@ pub struct ProvideQdrantInfo {
unavailable_reason: Option<String>, unavailable_reason: Option<String>,
} }
#[get("/system/qdrant/info")] pub async fn qdrant_port(_token: APIToken) -> Json<ProvideQdrantInfo> {
pub fn qdrant_port(_token: APIToken) -> Json<ProvideQdrantInfo> {
let status = QDRANT_STATUS.lock().unwrap(); let status = QDRANT_STATUS.lock().unwrap();
let is_available = status.is_available; let is_available = status.is_available;
let unavailable_reason = status.unavailable_reason.clone(); let unavailable_reason = status.unavailable_reason.clone();

View File

@ -1,12 +1,16 @@
use log::info; use log::info;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use rocket::config::Shutdown; use axum::routing::{get, post};
use rocket::figment::Figment; use axum::Router;
use rocket::routes; use axum_server::tls_rustls::RustlsConfig;
use std::net::SocketAddr;
use std::sync::Once;
use crate::runtime_certificate::{CERTIFICATE, CERTIFICATE_PRIVATE_KEY}; use crate::runtime_certificate::{CERTIFICATE, CERTIFICATE_PRIVATE_KEY};
use crate::environment::is_dev; use crate::environment::is_dev;
use crate::network::get_available_port; use crate::network::get_available_port;
static RUSTLS_CRYPTO_PROVIDER_INIT: Once = Once::new();
/// The port used for the runtime API server. In the development environment, we use a fixed /// The port used for the runtime API server. In the development environment, we use a fixed
/// port, in the production environment we use the next available port. This differentiation /// port, in the production environment we use the next available port. This differentiation
/// is necessary because we cannot communicate the port to the .NET server in the development /// is necessary because we cannot communicate the port to the .NET server in the development
@ -24,109 +28,55 @@ pub static API_SERVER_PORT: Lazy<u16> = Lazy::new(|| {
pub fn start_runtime_api() { pub fn start_runtime_api() {
let api_port = *API_SERVER_PORT; let api_port = *API_SERVER_PORT;
info!("Try to start the API server on 'http://localhost:{api_port}'..."); info!("Try to start the API server on 'http://localhost:{api_port}'...");
// Get the shutdown configuration:
let shutdown = create_shutdown();
// Configure the runtime API server: let app = Router::new()
let figment = Figment::from(rocket::Config::release_default()) .route("/system/dotnet/port", get(crate::dotnet::dotnet_port))
.route("/system/dotnet/ready", get(crate::dotnet::dotnet_ready))
.route("/system/qdrant/info", get(crate::qdrant::qdrant_port))
.route("/clipboard/set", post(crate::clipboard::set_clipboard))
.route("/events", get(crate::app_window::get_event_stream))
.route("/updates/check", get(crate::app_window::check_for_update))
.route("/updates/install", get(crate::app_window::install_update))
.route("/app/exit", post(crate::app_window::exit_app))
.route("/select/directory", post(crate::file_actions::select_directory))
.route("/select/file", post(crate::file_actions::select_file))
.route("/select/files", post(crate::file_actions::select_files))
.route("/save/file", post(crate::file_actions::save_file))
.route("/secrets/get", post(crate::secret::get_secret))
.route("/secrets/store", post(crate::secret::store_secret))
.route("/secrets/delete", post(crate::secret::delete_secret))
.route("/system/directories/config", get(crate::environment::get_config_directory))
.route("/system/directories/data", get(crate::environment::get_data_directory))
.route("/system/language", get(crate::environment::read_user_language))
.route("/system/enterprise/config/id", get(crate::environment::read_enterprise_env_config_id))
.route("/system/enterprise/config/server", get(crate::environment::read_enterprise_env_config_server_url))
.route("/system/enterprise/config/encryption_secret", get(crate::environment::read_enterprise_env_config_encryption_secret))
.route("/system/enterprise/configs", get(crate::environment::read_enterprise_configs))
.route("/retrieval/fs/extract", get(crate::file_data::extract_data))
.route("/log/paths", get(crate::log::get_log_paths))
.route("/log/event", post(crate::log::log_event))
.route("/shortcuts/register", post(crate::app_window::register_shortcut))
.route("/shortcuts/validate", post(crate::app_window::validate_shortcut))
.route("/shortcuts/suspend", post(crate::app_window::suspend_shortcuts))
.route("/shortcuts/resume", post(crate::app_window::resume_shortcuts));
// We use the next available port which was determined before:
.merge(("port", api_port))
// The runtime API server should be accessible only from the local machine:
.merge(("address", "127.0.0.1"))
// We do not want to use the Ctrl+C signal to stop the server:
.merge(("ctrlc", false))
// Set a name for the server:
.merge(("ident", "AI Studio Runtime API"))
// Set the maximum number of workers and blocking threads:
.merge(("workers", 3))
.merge(("max_blocking", 12))
// No colors and emojis in the log output:
.merge(("cli_colors", false))
// Read the TLS certificate and key from the generated certificate data in-memory:
.merge(("tls.certs", CERTIFICATE.get().unwrap()))
.merge(("tls.key", CERTIFICATE_PRIVATE_KEY.get().unwrap()))
// Set the shutdown configuration:
.merge(("shutdown", shutdown));
//
// Start the runtime API server in a separate thread. This is necessary
// because the server is blocking, and we need to run the Tauri app in
// parallel:
//
tauri::async_runtime::spawn(async move { tauri::async_runtime::spawn(async move {
rocket::custom(figment) install_rustls_crypto_provider();
.mount("/", routes![
crate::dotnet::dotnet_port, let cert = CERTIFICATE.get().unwrap().clone();
crate::dotnet::dotnet_ready, let key = CERTIFICATE_PRIVATE_KEY.get().unwrap().clone();
crate::qdrant::qdrant_port, let tls_config = RustlsConfig::from_pem(cert, key).await.unwrap();
crate::clipboard::set_clipboard, let addr = SocketAddr::from(([127, 0, 0, 1], api_port));
crate::app_window::get_event_stream,
crate::app_window::check_for_update, axum_server::bind_rustls(addr, tls_config)
crate::app_window::install_update, .serve(app.into_make_service())
crate::app_window::exit_app, .await
crate::file_actions::select_directory, .unwrap();
crate::file_actions::select_file,
crate::file_actions::select_files,
crate::file_actions::save_file,
crate::secret::get_secret,
crate::secret::store_secret,
crate::secret::delete_secret,
crate::environment::get_data_directory,
crate::environment::get_config_directory,
crate::environment::read_user_language,
crate::environment::read_enterprise_env_config_id,
crate::environment::read_enterprise_env_config_server_url,
crate::environment::read_enterprise_env_config_encryption_secret,
crate::environment::read_enterprise_configs,
crate::file_data::extract_data,
crate::log::get_log_paths,
crate::log::log_event,
crate::app_window::register_shortcut,
crate::app_window::validate_shortcut,
crate::app_window::suspend_shortcuts,
crate::app_window::resume_shortcuts,
])
.ignite().await.unwrap()
.launch().await.unwrap();
}); });
} }
fn create_shutdown() -> Shutdown { fn install_rustls_crypto_provider() {
// RUSTLS_CRYPTO_PROVIDER_INIT.call_once(|| {
// Create a shutdown configuration, depending on the operating system: let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
// });
#[cfg(unix)]
{
use std::collections::HashSet;
let mut shutdown = Shutdown {
// We do not want to use the Ctrl+C signal to stop the server:
ctrlc: false,
// Everything else is set to default for now:
..Shutdown::default()
};
shutdown.signals = HashSet::new();
shutdown
}
#[cfg(windows)]
{
Shutdown {
// We do not want to use the Ctrl+C signal to stop the server:
ctrlc: false,
// Everything else is set to default for now:
..Shutdown::default()
}
}
} }

View File

@ -1,33 +1,29 @@
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use rocket::http::Status; use axum::extract::FromRequestParts;
use rocket::Request; use axum::http::request::Parts;
use rocket::request::FromRequest; use axum::http::StatusCode;
use crate::api_token::{generate_api_token, APIToken}; use crate::api_token::{generate_api_token, APIToken};
pub static API_TOKEN: Lazy<APIToken> = Lazy::new(|| generate_api_token()); pub static API_TOKEN: Lazy<APIToken> = Lazy::new(generate_api_token);
/// The request outcome type used to handle API token requests. impl<S> FromRequestParts<S> for APIToken
type RequestOutcome<R, T> = rocket::request::Outcome<R, T>; where
S: Send + Sync,
{
type Rejection = StatusCode;
/// The request outcome implementation for the API token. async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
#[rocket::async_trait] match parts.headers.get("token").and_then(|value| value.to_str().ok()) {
impl<'r> FromRequest<'r> for APIToken {
type Error = APITokenError;
/// Handles the API token requests.
async fn from_request(request: &'r Request<'_>) -> RequestOutcome<Self, Self::Error> {
let token = request.headers().get_one("token");
match token {
Some(token) => { Some(token) => {
let received_token = APIToken::from_hex_text(token); let received_token = APIToken::from_hex_text(token);
if API_TOKEN.validate(&received_token) { if API_TOKEN.validate(&received_token) {
RequestOutcome::Success(received_token) Ok(received_token)
} else { } else {
RequestOutcome::Error((Status::Unauthorized, APITokenError::Invalid)) Err(StatusCode::UNAUTHORIZED)
} }
} }
None => RequestOutcome::Error((Status::Unauthorized, APITokenError::Missing)), None => Err(StatusCode::UNAUTHORIZED),
} }
} }
} }

View File

@ -1,15 +1,13 @@
use keyring::Entry; use keyring::Entry;
use log::{error, info, warn}; use log::{error, info, warn};
use rocket::post; use axum::Json;
use rocket::serde::json::Json;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use keyring::error::Error::NoEntry; use keyring::error::Error::NoEntry;
use crate::api_token::APIToken; use crate::api_token::APIToken;
use crate::encryption::{EncryptedText, ENCRYPTION}; use crate::encryption::{EncryptedText, ENCRYPTION};
/// Stores a secret in the secret store using the operating system's keyring. /// Stores a secret in the secret store using the operating system's keyring.
#[post("/secrets/store", data = "<request>")] pub async fn store_secret(_token: APIToken, request: Json<StoreSecret>) -> Json<StoreSecretResponse> {
pub fn store_secret(_token: APIToken, request: Json<StoreSecret>) -> Json<StoreSecretResponse> {
let user_name = request.user_name.as_str(); let user_name = request.user_name.as_str();
let decrypted_text = match ENCRYPTION.decrypt(&request.secret) { let decrypted_text = match ENCRYPTION.decrypt(&request.secret) {
Ok(text) => text, Ok(text) => text,
@ -60,8 +58,7 @@ pub struct StoreSecretResponse {
} }
/// Retrieves a secret from the secret store using the operating system's keyring. /// Retrieves a secret from the secret store using the operating system's keyring.
#[post("/secrets/get", data = "<request>")] pub async fn get_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<RequestedSecret> {
pub fn get_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<RequestedSecret> {
let user_name = request.user_name.as_str(); let user_name = request.user_name.as_str();
let service = format!("mindwork-ai-studio::{}", request.destination); let service = format!("mindwork-ai-studio::{}", request.destination);
let entry = Entry::new(service.as_str(), user_name).unwrap(); let entry = Entry::new(service.as_str(), user_name).unwrap();
@ -121,8 +118,7 @@ pub struct RequestedSecret {
} }
/// Deletes a secret from the secret store using the operating system's keyring. /// Deletes a secret from the secret store using the operating system's keyring.
#[post("/secrets/delete", data = "<request>")] pub async fn delete_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<DeleteSecretResponse> {
pub fn delete_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<DeleteSecretResponse> {
let user_name = request.user_name.as_str(); let user_name = request.user_name.as_str();
let service = format!("mindwork-ai-studio::{}", request.destination); let service = format!("mindwork-ai-studio::{}", request.destination);
let entry = Entry::new(service.as_str(), user_name).unwrap(); let entry = Entry::new(service.as_str(), user_name).unwrap();