mirror of
https://github.com/MindWorkAI/AI-Studio.git
synced 2026-05-13 09:14:12 +00:00
Migrated IPC server to use axum
This commit is contained in:
parent
f69186f7a9
commit
978b261c13
794
runtime/Cargo.lock
generated
794
runtime/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -25,7 +25,9 @@ async-stream = "0.3.6"
|
||||
flexi_logger = "0.31.8"
|
||||
log = { version = "0.4.29", features = ["kv"] }
|
||||
once_cell = "1.21.4"
|
||||
rocket = { version = "0.5.1", features = ["json", "tls"] }
|
||||
axum = { version = "0.8.9", features = ["http2", "json", "query", "tokio"] }
|
||||
axum-server = { version = "0.8.0", features = ["tls-rustls"] }
|
||||
rustls = { version = "0.23.28", default-features = false, features = ["aws_lc_rs"] }
|
||||
rand = "0.10.1"
|
||||
rand_chacha = "0.10.0"
|
||||
base64 = "0.22.1"
|
||||
@ -46,7 +48,6 @@ strum_macros = "0.28.0"
|
||||
sysinfo = "0.38.4"
|
||||
|
||||
# Fixes security vulnerability downstream, where the upstream is not fixed yet:
|
||||
time = "0.3.47" # -> Rocket
|
||||
bytes = "1.11.1" # -> almost every dependency
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
|
||||
@ -1,13 +1,16 @@
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
use async_stream::stream;
|
||||
use axum::body::Body;
|
||||
use axum::http::header::CONTENT_TYPE;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::Json;
|
||||
use bytes::Bytes;
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use rocket::{get, post};
|
||||
use rocket::response::stream::TextStream;
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::serde::Serialize;
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum_macros::Display;
|
||||
use tauri::{DragDropEvent,RunEvent, Manager, WindowEvent, generate_context};
|
||||
use tauri::path::PathResolver;
|
||||
@ -256,8 +259,7 @@ fn should_open_in_system_browser<R: tauri::Runtime>(webview: &tauri::Webview<R>,
|
||||
/// When the client disconnects, the stream is closed. But we try to not lose events in between.
|
||||
/// The client is expected to reconnect automatically when the connection is closed and continue
|
||||
/// listening for events.
|
||||
#[get("/events")]
|
||||
pub async fn get_event_stream(_token: APIToken) -> TextStream![String] {
|
||||
pub async fn get_event_stream(_token: APIToken) -> Response {
|
||||
// Get the lock to the event broadcast sender:
|
||||
let event_broadcast_lock = EVENT_BROADCAST.lock().unwrap();
|
||||
|
||||
@ -269,8 +271,7 @@ pub async fn get_event_stream(_token: APIToken) -> TextStream![String] {
|
||||
// Drop the lock to allow other access to the sender:
|
||||
drop(event_broadcast_lock);
|
||||
|
||||
// Create the event stream:
|
||||
TextStream! {
|
||||
let stream = stream! {
|
||||
loop {
|
||||
// Wait at most 3 seconds for an event:
|
||||
match time::timeout(Duration::from_secs(3), event_receiver.recv()).await {
|
||||
@ -281,11 +282,11 @@ pub async fn get_event_stream(_token: APIToken) -> TextStream![String] {
|
||||
// is serialized as a single line so that the client can parse it
|
||||
// correctly:
|
||||
let event_json = serde_json::to_string(&event).unwrap();
|
||||
yield event_json;
|
||||
yield Ok::<Bytes, Infallible>(Bytes::from(event_json));
|
||||
|
||||
// The client expects a newline after each event because we are using
|
||||
// a method to read the stream line-by-line:
|
||||
yield "\n".to_string();
|
||||
yield Ok::<Bytes, Infallible>(Bytes::from("\n"));
|
||||
},
|
||||
|
||||
// Case: we lagged behind and missed some events
|
||||
@ -305,15 +306,17 @@ pub async fn get_event_stream(_token: APIToken) -> TextStream![String] {
|
||||
|
||||
// Again, we have to serialize the event as a single line:
|
||||
let event_json = serde_json::to_string(&ping_event).unwrap();
|
||||
yield event_json;
|
||||
yield Ok::<Bytes, Infallible>(Bytes::from(event_json));
|
||||
|
||||
// The client expects a newline after each event because we are using
|
||||
// a method to read the stream line-by-line:
|
||||
yield "\n".to_string();
|
||||
yield Ok::<Bytes, Infallible>(Bytes::from("\n"));
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
([(CONTENT_TYPE, "application/jsonl")], Body::from_stream(stream)).into_response()
|
||||
}
|
||||
|
||||
/// Data structure representing a Tauri event for our event API.
|
||||
@ -428,7 +431,6 @@ pub async fn change_location_to(url: &str) {
|
||||
}
|
||||
|
||||
/// Checks for updates.
|
||||
#[get("/updates/check")]
|
||||
pub async fn check_for_update(_token: APIToken) -> Json<CheckUpdateResponse> {
|
||||
if is_dev() {
|
||||
warn!(Source = "Updater"; "The app is running in development mode; skipping update check.");
|
||||
@ -514,7 +516,6 @@ pub struct CheckUpdateResponse {
|
||||
}
|
||||
|
||||
/// Installs the update.
|
||||
#[get("/updates/install")]
|
||||
pub async fn install_update(_token: APIToken) {
|
||||
if is_dev() {
|
||||
warn!(Source = "Updater"; "The app is running in development mode; skipping update installation.");
|
||||
@ -623,8 +624,7 @@ fn register_shortcut_with_callback<R: tauri::Runtime>(
|
||||
}
|
||||
|
||||
/// Requests a controlled shutdown of the entire desktop application.
|
||||
#[post("/app/exit")]
|
||||
pub fn exit_app(_token: APIToken) -> Json<AppExitResponse> {
|
||||
pub async fn exit_app(_token: APIToken) -> Json<AppExitResponse> {
|
||||
let app_handle = {
|
||||
let main_window_lock = MAIN_WINDOW.lock().unwrap();
|
||||
match main_window_lock.as_ref() {
|
||||
@ -653,8 +653,7 @@ pub fn exit_app(_token: APIToken) -> Json<AppExitResponse> {
|
||||
|
||||
/// Registers or updates a global shortcut. If the shortcut string is empty,
|
||||
/// the existing shortcut for that name will be unregistered.
|
||||
#[post("/shortcuts/register", data = "<payload>")]
|
||||
pub fn register_shortcut(_token: APIToken, payload: Json<RegisterShortcutRequest>) -> Json<ShortcutResponse> {
|
||||
pub async fn register_shortcut(_token: APIToken, payload: Json<RegisterShortcutRequest>) -> Json<ShortcutResponse> {
|
||||
let id = payload.id;
|
||||
let new_shortcut = payload.shortcut.clone();
|
||||
|
||||
@ -761,8 +760,7 @@ pub struct ShortcutValidationResponse {
|
||||
/// Validates a shortcut string without registering it.
|
||||
/// Checks if the shortcut syntax is valid and if it
|
||||
/// conflicts with existing shortcuts.
|
||||
#[post("/shortcuts/validate", data = "<payload>")]
|
||||
pub fn validate_shortcut(_token: APIToken, payload: Json<ValidateShortcutRequest>) -> Json<ShortcutValidationResponse> {
|
||||
pub async fn validate_shortcut(_token: APIToken, payload: Json<ValidateShortcutRequest>) -> Json<ShortcutValidationResponse> {
|
||||
let shortcut = payload.shortcut.clone();
|
||||
|
||||
// Empty shortcuts are always valid (means "disabled"):
|
||||
@ -816,8 +814,7 @@ pub fn validate_shortcut(_token: APIToken, payload: Json<ValidateShortcutRequest
|
||||
/// The shortcuts remain in our internal map, so they can be re-registered on resume.
|
||||
/// This is useful when opening a dialog to configure shortcuts, so the user can
|
||||
/// press the current shortcut to re-enter it without triggering the action.
|
||||
#[post("/shortcuts/suspend")]
|
||||
pub fn suspend_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
|
||||
pub async fn suspend_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
|
||||
// Get the main window to access the global shortcut manager:
|
||||
let main_window_lock = MAIN_WINDOW.lock().unwrap();
|
||||
let main_window = match main_window_lock.as_ref() {
|
||||
@ -853,8 +850,7 @@ pub fn suspend_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
|
||||
}
|
||||
|
||||
/// Resumes shortcut processing by re-registering all shortcuts with the OS.
|
||||
#[post("/shortcuts/resume")]
|
||||
pub fn resume_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
|
||||
pub async fn resume_shortcuts(_token: APIToken) -> Json<ShortcutResponse> {
|
||||
// Get the main window to access the global shortcut manager:
|
||||
let main_window_lock = MAIN_WINDOW.lock().unwrap();
|
||||
let main_window = match main_window_lock.as_ref() {
|
||||
@ -954,36 +950,6 @@ fn validate_shortcut_syntax(shortcut: &str) -> bool {
|
||||
has_key
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn tauri_localhost_is_tauri_asset_url() {
|
||||
let https_url = tauri::Url::parse("https://tauri.localhost/index.html").unwrap();
|
||||
let http_url = tauri::Url::parse("http://tauri.localhost/index.html").unwrap();
|
||||
|
||||
assert!(is_tauri_asset_url(&https_url));
|
||||
assert!(is_tauri_asset_url(&http_url));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn localhost_app_url_is_not_tauri_asset_url() {
|
||||
let url = tauri::Url::parse("http://localhost:12345/").unwrap();
|
||||
|
||||
assert!(!is_tauri_asset_url(&url));
|
||||
assert!(is_local_http_url(&url));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn external_url_is_not_internal_url() {
|
||||
let url = tauri::Url::parse("https://example.com/").unwrap();
|
||||
|
||||
assert!(!is_tauri_asset_url(&url));
|
||||
assert!(!is_local_http_url(&url));
|
||||
}
|
||||
}
|
||||
|
||||
fn set_pdfium_path<R: tauri::Runtime>(path_resolver: &PathResolver<R>) {
|
||||
let resource_dir = match path_resolver.resource_dir() {
|
||||
Ok(path) => path,
|
||||
@ -1012,3 +978,33 @@ fn set_pdfium_path<R: tauri::Runtime>(path_resolver: &PathResolver<R>) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn tauri_localhost_is_tauri_asset_url() {
|
||||
let https_url = tauri::Url::parse("https://tauri.localhost/index.html").unwrap();
|
||||
let http_url = tauri::Url::parse("http://tauri.localhost/index.html").unwrap();
|
||||
|
||||
assert!(is_tauri_asset_url(&https_url));
|
||||
assert!(is_tauri_asset_url(&http_url));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn localhost_app_url_is_not_tauri_asset_url() {
|
||||
let url = tauri::Url::parse("http://localhost:12345/").unwrap();
|
||||
|
||||
assert!(!is_tauri_asset_url(&url));
|
||||
assert!(is_local_http_url(&url));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn external_url_is_not_internal_url() {
|
||||
let url = tauri::Url::parse("https://example.com/").unwrap();
|
||||
|
||||
assert!(!is_tauri_asset_url(&url));
|
||||
assert!(!is_local_http_url(&url));
|
||||
}
|
||||
}
|
||||
@ -1,14 +1,13 @@
|
||||
use arboard::Clipboard;
|
||||
use log::{debug, error};
|
||||
use rocket::post;
|
||||
use rocket::serde::json::Json;
|
||||
use axum::Json;
|
||||
use serde::Serialize;
|
||||
use crate::api_token::APIToken;
|
||||
use crate::encryption::{EncryptedText, ENCRYPTION};
|
||||
|
||||
/// Sets the clipboard text to the provided encrypted text.
|
||||
#[post("/clipboard/set", data = "<encrypted_text>")]
|
||||
pub fn set_clipboard(_token: APIToken, encrypted_text: EncryptedText) -> Json<SetClipboardResponse> {
|
||||
pub async fn set_clipboard(_token: APIToken, encrypted_text: String) -> Json<SetClipboardResponse> {
|
||||
let encrypted_text = EncryptedText::new(encrypted_text);
|
||||
|
||||
// Decrypt this text first:
|
||||
let decrypted_text = match ENCRYPTION.decrypt(&encrypted_text) {
|
||||
|
||||
@ -5,7 +5,6 @@ use base64::Engine;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use log::{error, info, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use rocket::get;
|
||||
use tauri::Url;
|
||||
use tauri_plugin_shell::process::{CommandChild, CommandEvent};
|
||||
use tauri_plugin_shell::ShellExt;
|
||||
@ -89,8 +88,7 @@ fn sanitize_stdout_line(line: &str) -> String {
|
||||
|
||||
/// Returns the desired port of the .NET server. Our .NET app calls this endpoint to get
|
||||
/// the port where the .NET server should listen to.
|
||||
#[get("/system/dotnet/port")]
|
||||
pub fn dotnet_port(_token: APIToken) -> String {
|
||||
pub async fn dotnet_port(_token: APIToken) -> String {
|
||||
let dotnet_server_port = *DOTNET_SERVER_PORT;
|
||||
format!("{dotnet_server_port}")
|
||||
}
|
||||
@ -179,7 +177,6 @@ pub fn start_dotnet_server<R: tauri::Runtime>(app_handle: tauri::AppHandle<R>) {
|
||||
}
|
||||
|
||||
/// This endpoint is called by the .NET server to signal that the server is ready.
|
||||
#[get("/system/dotnet/ready")]
|
||||
pub async fn dotnet_ready(_token: APIToken) {
|
||||
|
||||
// We create a manual scope for the lock to be released as soon as possible.
|
||||
|
||||
@ -9,19 +9,13 @@ use once_cell::sync::Lazy;
|
||||
use pbkdf2::pbkdf2;
|
||||
use rand::rngs::SysRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rocket::{data, Data, Request};
|
||||
use rocket::data::ToByteUnit;
|
||||
use rocket::http::Status;
|
||||
use rocket::serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::Sha512;
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
type Aes256CbcEnc = cbc::Encryptor<aes::Aes256>;
|
||||
|
||||
type Aes256CbcDec = cbc::Decryptor<aes::Aes256>;
|
||||
|
||||
type DataOutcome<'r, T> = data::Outcome<'r, T>;
|
||||
|
||||
/// The encryption instance used for the IPC channel.
|
||||
pub static ENCRYPTION: Lazy<Encryption> = Lazy::new(|| {
|
||||
//
|
||||
@ -170,27 +164,4 @@ impl fmt::Display for EncryptedText {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "**********")
|
||||
}
|
||||
}
|
||||
|
||||
/// Use Case: When we receive encrypted text from the client as body (e.g., in a POST request).
|
||||
/// We must interpret the body as EncryptedText.
|
||||
#[rocket::async_trait]
|
||||
impl<'r> data::FromData<'r> for EncryptedText {
|
||||
type Error = String;
|
||||
|
||||
/// Parses the data as EncryptedText.
|
||||
async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> DataOutcome<'r, Self> {
|
||||
let content_type = req.content_type();
|
||||
if content_type.map_or(true, |ct| !ct.is_text()) {
|
||||
return DataOutcome::Forward((data, Status::Ok));
|
||||
}
|
||||
|
||||
let mut stream = data.open(2.mebibytes());
|
||||
let mut body = String::new();
|
||||
if let Err(e) = stream.read_to_string(&mut body).await {
|
||||
return DataOutcome::Error((Status::InternalServerError, format!("Failed to read data: {}", e)));
|
||||
}
|
||||
|
||||
DataOutcome::Success(EncryptedText(body))
|
||||
}
|
||||
}
|
||||
@ -1,7 +1,6 @@
|
||||
use crate::api_token::APIToken;
|
||||
use axum::Json;
|
||||
use log::{debug, info, warn};
|
||||
use rocket::get;
|
||||
use rocket::serde::json::Json;
|
||||
use serde::Serialize;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::env;
|
||||
@ -29,8 +28,7 @@ pub static CONFIG_DIRECTORY: OnceLock<String> = OnceLock::new();
|
||||
static USER_LANGUAGE: OnceLock<String> = OnceLock::new();
|
||||
|
||||
/// Returns the config directory.
|
||||
#[get("/system/directories/config")]
|
||||
pub fn get_config_directory(_token: APIToken) -> String {
|
||||
pub async fn get_config_directory(_token: APIToken) -> String {
|
||||
match CONFIG_DIRECTORY.get() {
|
||||
Some(config_directory) => config_directory.clone(),
|
||||
None => String::from(""),
|
||||
@ -38,8 +36,7 @@ pub fn get_config_directory(_token: APIToken) -> String {
|
||||
}
|
||||
|
||||
/// Returns the data directory.
|
||||
#[get("/system/directories/data")]
|
||||
pub fn get_data_directory(_token: APIToken) -> String {
|
||||
pub async fn get_data_directory(_token: APIToken) -> String {
|
||||
match DATA_DIRECTORY.get() {
|
||||
Some(data_directory) => data_directory.clone(),
|
||||
None => String::from(""),
|
||||
@ -150,8 +147,7 @@ fn detect_user_language() -> (String, LanguageDetectionSource) {
|
||||
)
|
||||
}
|
||||
|
||||
#[get("/system/language")]
|
||||
pub fn read_user_language(_token: APIToken) -> String {
|
||||
pub async fn read_user_language(_token: APIToken) -> String {
|
||||
USER_LANGUAGE
|
||||
.get_or_init(|| {
|
||||
let (user_language, source) = detect_user_language();
|
||||
@ -194,8 +190,7 @@ struct EnterpriseSourceData {
|
||||
encryption_secret: String,
|
||||
}
|
||||
|
||||
#[get("/system/enterprise/config/id")]
|
||||
pub fn read_enterprise_env_config_id(_token: APIToken) -> String {
|
||||
pub async fn read_enterprise_env_config_id(_token: APIToken) -> String {
|
||||
debug!("Trying to read the effective enterprise configuration ID.");
|
||||
resolve_effective_enterprise_config_source()
|
||||
.configs
|
||||
@ -205,8 +200,7 @@ pub fn read_enterprise_env_config_id(_token: APIToken) -> String {
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
#[get("/system/enterprise/config/server")]
|
||||
pub fn read_enterprise_env_config_server_url(_token: APIToken) -> String {
|
||||
pub async fn read_enterprise_env_config_server_url(_token: APIToken) -> String {
|
||||
debug!("Trying to read the effective enterprise configuration server URL.");
|
||||
resolve_effective_enterprise_config_source()
|
||||
.configs
|
||||
@ -216,15 +210,13 @@ pub fn read_enterprise_env_config_server_url(_token: APIToken) -> String {
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
#[get("/system/enterprise/config/encryption_secret")]
|
||||
pub fn read_enterprise_env_config_encryption_secret(_token: APIToken) -> String {
|
||||
pub async fn read_enterprise_env_config_encryption_secret(_token: APIToken) -> String {
|
||||
debug!("Trying to read the effective enterprise configuration encryption secret.");
|
||||
resolve_effective_enterprise_secret_source().encryption_secret
|
||||
}
|
||||
|
||||
/// Returns all enterprise configurations from the effective source.
|
||||
#[get("/system/enterprise/configs")]
|
||||
pub fn read_enterprise_configs(_token: APIToken) -> Json<Vec<EnterpriseConfig>> {
|
||||
pub async fn read_enterprise_configs(_token: APIToken) -> Json<Vec<EnterpriseConfig>> {
|
||||
info!("Trying to read the effective enterprise configurations.");
|
||||
Json(resolve_effective_enterprise_config_source().configs)
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
use log::{error, info};
|
||||
use rocket::post;
|
||||
use rocket::serde::{Deserialize, Serialize};
|
||||
use rocket::serde::json::Json;
|
||||
use axum::extract::Query;
|
||||
use axum::Json;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tauri_plugin_dialog::{DialogExt, FileDialogBuilder};
|
||||
use crate::api_token::APIToken;
|
||||
use crate::app_window::MAIN_WINDOW;
|
||||
@ -11,6 +11,11 @@ pub struct PreviousDirectory {
|
||||
path: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SelectDirectoryQuery {
|
||||
title: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct FileTypeFilter {
|
||||
filter_name: String,
|
||||
@ -61,10 +66,9 @@ pub struct PreviousFile {
|
||||
}
|
||||
|
||||
/// Let the user select a directory.
|
||||
#[post("/select/directory?<title>", data = "<previous_directory>")]
|
||||
pub fn select_directory(
|
||||
pub async fn select_directory(
|
||||
_token: APIToken,
|
||||
title: &str,
|
||||
Query(query): Query<SelectDirectoryQuery>,
|
||||
previous_directory: Option<Json<PreviousDirectory>>,
|
||||
) -> Json<DirectorySelectionResponse> {
|
||||
let main_window_lock = MAIN_WINDOW.lock().unwrap();
|
||||
@ -79,7 +83,7 @@ pub fn select_directory(
|
||||
}
|
||||
};
|
||||
|
||||
let mut dialog = main_window.dialog().file().set_parent(main_window).set_title(title);
|
||||
let mut dialog = main_window.dialog().file().set_parent(main_window).set_title(&query.title);
|
||||
if let Some(previous) = previous_directory {
|
||||
dialog = dialog.set_directory(previous.path.clone());
|
||||
}
|
||||
@ -118,8 +122,7 @@ pub fn select_directory(
|
||||
}
|
||||
|
||||
/// Let the user select a file.
|
||||
#[post("/select/file", data = "<payload>")]
|
||||
pub fn select_file(
|
||||
pub async fn select_file(
|
||||
_token: APIToken,
|
||||
payload: Json<SelectFileOptions>,
|
||||
) -> Json<FileSelectionResponse> {
|
||||
@ -178,8 +181,7 @@ pub fn select_file(
|
||||
}
|
||||
|
||||
/// Let the user select some files.
|
||||
#[post("/select/files", data = "<payload>")]
|
||||
pub fn select_files(
|
||||
pub async fn select_files(
|
||||
_token: APIToken,
|
||||
payload: Json<SelectFileOptions>,
|
||||
) -> Json<FilesSelectionResponse> {
|
||||
@ -229,8 +231,7 @@ pub fn select_files(
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/save/file", data = "<payload>")]
|
||||
pub fn save_file(_token: APIToken, payload: Json<SaveFileOptions>) -> Json<FileSaveResponse> {
|
||||
pub async fn save_file(_token: APIToken, payload: Json<SaveFileOptions>) -> Json<FileSaveResponse> {
|
||||
// Create a new file dialog builder:
|
||||
let file_dialog = MAIN_WINDOW
|
||||
.lock()
|
||||
|
||||
@ -1,19 +1,18 @@
|
||||
use std::cmp::min;
|
||||
use std::convert::Infallible;
|
||||
use crate::api_token::APIToken;
|
||||
use crate::pandoc::PandocProcessBuilder;
|
||||
use crate::pdfium::PdfiumInit;
|
||||
use async_stream::stream;
|
||||
use axum::extract::Query;
|
||||
use axum::response::sse::{Event, Sse};
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use calamine::{open_workbook_auto, Reader};
|
||||
use file_format::{FileFormat, Kind};
|
||||
use futures::{Stream, StreamExt};
|
||||
use pdfium_render::prelude::Pdfium;
|
||||
use pptx_to_md::{ImageHandlingMode, ParserConfig, PptxContainer};
|
||||
use rocket::get;
|
||||
use rocket::response::stream::{Event, EventStream};
|
||||
use rocket::serde::Serialize;
|
||||
use rocket::tokio::select;
|
||||
use rocket::Shutdown;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
use std::pin::Pin;
|
||||
use log::{debug, error};
|
||||
@ -82,39 +81,45 @@ const IMAGE_SEGMENT_SIZE_IN_CHARS: usize = 8_192; // equivalent to ~ 5500 token
|
||||
type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
|
||||
type ChunkStream = Pin<Box<dyn Stream<Item = Result<Chunk>> + Send>>;
|
||||
|
||||
#[get("/retrieval/fs/extract?<path>&<stream_id>&<extract_images>")]
|
||||
pub async fn extract_data(_token: APIToken, path: String, stream_id: String, extract_images: bool, mut end: Shutdown) -> EventStream![] {
|
||||
EventStream! {
|
||||
let stream_result = stream_data(&path, extract_images).await;
|
||||
let id_ref = &stream_id;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ExtractDataQuery {
|
||||
path: String,
|
||||
stream_id: String,
|
||||
extract_images: bool,
|
||||
}
|
||||
|
||||
pub async fn extract_data(
|
||||
_token: APIToken,
|
||||
Query(query): Query<ExtractDataQuery>,
|
||||
) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
|
||||
let stream = stream! {
|
||||
let stream_result = stream_data(&query.path, query.extract_images).await;
|
||||
let id_ref = &query.stream_id;
|
||||
|
||||
match stream_result {
|
||||
Ok(mut stream) => {
|
||||
loop {
|
||||
let chunk = select! {
|
||||
chunk = stream.next() => match chunk {
|
||||
Some(Ok(mut chunk)) => {
|
||||
chunk.set_stream_id(id_ref);
|
||||
chunk
|
||||
},
|
||||
Some(Err(e)) => {
|
||||
yield Event::json(&format!("Error: {e}"));
|
||||
break;
|
||||
},
|
||||
None => break,
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(mut chunk) => {
|
||||
chunk.set_stream_id(id_ref);
|
||||
yield Ok(Event::default().json_data(&chunk).unwrap_or_else(|e| Event::default().data(format!("Error: {e}"))));
|
||||
},
|
||||
_ = &mut end => break,
|
||||
};
|
||||
|
||||
yield Event::json(&chunk);
|
||||
|
||||
Err(e) => {
|
||||
yield Ok(Event::default().json_data(format!("Error: {e}")).unwrap_or_else(|_| Event::default().data(format!("Error: {e}"))));
|
||||
break;
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
Err(e) => {
|
||||
yield Event::json(&format!("Error starting stream: {e}"));
|
||||
yield Ok(Event::default().json_data(format!("Error starting stream: {e}")).unwrap_or_else(|_| Event::default().data(format!("Error starting stream: {e}"))));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Sse::new(stream)
|
||||
}
|
||||
|
||||
async fn stream_data(file_path: &str, extract_images: bool) -> Result<ChunkStream> {
|
||||
|
||||
@ -8,9 +8,8 @@ use flexi_logger::{DeferredNow, Duplicate, FileSpec, Logger, LoggerHandle};
|
||||
use flexi_logger::writers::FileLogWriter;
|
||||
use log::{kv, Level};
|
||||
use log::kv::{Key, Value, VisitSource};
|
||||
use rocket::{get, post};
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::serde::{Deserialize, Serialize};
|
||||
use axum::Json;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::api_token::APIToken;
|
||||
use crate::environment::is_dev;
|
||||
|
||||
@ -34,14 +33,17 @@ pub fn init_logging() {
|
||||
false => log_config.push_str("info, "),
|
||||
};
|
||||
|
||||
// Set the log level for the Rocket library:
|
||||
log_config.push_str("rocket=info, ");
|
||||
|
||||
// Set the log level for the Rocket server:
|
||||
log_config.push_str("rocket::server=warn, ");
|
||||
|
||||
// Set the log level for the Reqwest library:
|
||||
log_config.push_str("reqwest::async_impl::client=info");
|
||||
// Keep noisy HTTP/TLS internals at info level even in development builds:
|
||||
log_config.push_str("h2=info, ");
|
||||
log_config.push_str("hyper=info, ");
|
||||
log_config.push_str("hyper_util=info, ");
|
||||
log_config.push_str("axum=info, ");
|
||||
log_config.push_str("axum_server=info, ");
|
||||
log_config.push_str("tower=info, ");
|
||||
log_config.push_str("tower_http=info, ");
|
||||
log_config.push_str("rustls=info, ");
|
||||
log_config.push_str("tokio_rustls=info, ");
|
||||
log_config.push_str("reqwest=info");
|
||||
|
||||
// Configure the initial filename. On Unix systems, the file should start
|
||||
// with a dot to be hidden.
|
||||
@ -224,7 +226,6 @@ fn file_logger_format(
|
||||
write!(w, "{}", &record.args())
|
||||
}
|
||||
|
||||
#[get("/log/paths")]
|
||||
pub async fn get_log_paths(_token: APIToken) -> Json<LogPathsResponse> {
|
||||
Json(LogPathsResponse {
|
||||
log_startup_path: LOG_STARTUP_PATH.get().expect("No startup log path was set").clone(),
|
||||
@ -269,9 +270,7 @@ fn log_with_level(
|
||||
}
|
||||
|
||||
/// Logs an event from the .NET server.
|
||||
#[post("/log/event", data = "<event>")]
|
||||
pub fn log_event(_token: APIToken, event: Json<LogEvent>) -> Json<LogEventResponse> {
|
||||
let event = event.into_inner();
|
||||
pub async fn log_event(_token: APIToken, Json(event): Json<LogEvent>) -> Json<LogEventResponse> {
|
||||
let level = parse_dotnet_log_level(&event.level);
|
||||
let message = event.message.as_str();
|
||||
let category = event.category.as_str();
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
// Prevents an additional console window on Windows in release, DO NOT REMOVE!!
|
||||
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
|
||||
|
||||
extern crate rocket;
|
||||
extern crate core;
|
||||
|
||||
use log::{info, warn};
|
||||
@ -12,7 +11,6 @@ use mindwork_ai_studio::log::init_logging;
|
||||
use mindwork_ai_studio::metadata::MetaData;
|
||||
use mindwork_ai_studio::runtime_api::start_runtime_api;
|
||||
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let metadata = MetaData::init_from_string(include_str!("../../metadata.txt"));
|
||||
|
||||
@ -7,9 +7,8 @@ use std::path::Path;
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use log::{debug, error, info, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use rocket::get;
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::serde::Serialize;
|
||||
use axum::Json;
|
||||
use serde::Serialize;
|
||||
use crate::api_token::{APIToken};
|
||||
use crate::environment::{is_dev, DATA_DIRECTORY};
|
||||
use crate::certificate_factory::generate_certificate;
|
||||
@ -70,8 +69,7 @@ pub struct ProvideQdrantInfo {
|
||||
unavailable_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[get("/system/qdrant/info")]
|
||||
pub fn qdrant_port(_token: APIToken) -> Json<ProvideQdrantInfo> {
|
||||
pub async fn qdrant_port(_token: APIToken) -> Json<ProvideQdrantInfo> {
|
||||
let status = QDRANT_STATUS.lock().unwrap();
|
||||
let is_available = status.is_available;
|
||||
let unavailable_reason = status.unavailable_reason.clone();
|
||||
|
||||
@ -1,12 +1,16 @@
|
||||
use log::info;
|
||||
use once_cell::sync::Lazy;
|
||||
use rocket::config::Shutdown;
|
||||
use rocket::figment::Figment;
|
||||
use rocket::routes;
|
||||
use axum::routing::{get, post};
|
||||
use axum::Router;
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Once;
|
||||
use crate::runtime_certificate::{CERTIFICATE, CERTIFICATE_PRIVATE_KEY};
|
||||
use crate::environment::is_dev;
|
||||
use crate::network::get_available_port;
|
||||
|
||||
static RUSTLS_CRYPTO_PROVIDER_INIT: Once = Once::new();
|
||||
|
||||
/// The port used for the runtime API server. In the development environment, we use a fixed
|
||||
/// port, in the production environment we use the next available port. This differentiation
|
||||
/// is necessary because we cannot communicate the port to the .NET server in the development
|
||||
@ -24,109 +28,55 @@ pub static API_SERVER_PORT: Lazy<u16> = Lazy::new(|| {
|
||||
pub fn start_runtime_api() {
|
||||
let api_port = *API_SERVER_PORT;
|
||||
info!("Try to start the API server on 'http://localhost:{api_port}'...");
|
||||
|
||||
// Get the shutdown configuration:
|
||||
let shutdown = create_shutdown();
|
||||
|
||||
// Configure the runtime API server:
|
||||
let figment = Figment::from(rocket::Config::release_default())
|
||||
let app = Router::new()
|
||||
.route("/system/dotnet/port", get(crate::dotnet::dotnet_port))
|
||||
.route("/system/dotnet/ready", get(crate::dotnet::dotnet_ready))
|
||||
.route("/system/qdrant/info", get(crate::qdrant::qdrant_port))
|
||||
.route("/clipboard/set", post(crate::clipboard::set_clipboard))
|
||||
.route("/events", get(crate::app_window::get_event_stream))
|
||||
.route("/updates/check", get(crate::app_window::check_for_update))
|
||||
.route("/updates/install", get(crate::app_window::install_update))
|
||||
.route("/app/exit", post(crate::app_window::exit_app))
|
||||
.route("/select/directory", post(crate::file_actions::select_directory))
|
||||
.route("/select/file", post(crate::file_actions::select_file))
|
||||
.route("/select/files", post(crate::file_actions::select_files))
|
||||
.route("/save/file", post(crate::file_actions::save_file))
|
||||
.route("/secrets/get", post(crate::secret::get_secret))
|
||||
.route("/secrets/store", post(crate::secret::store_secret))
|
||||
.route("/secrets/delete", post(crate::secret::delete_secret))
|
||||
.route("/system/directories/config", get(crate::environment::get_config_directory))
|
||||
.route("/system/directories/data", get(crate::environment::get_data_directory))
|
||||
.route("/system/language", get(crate::environment::read_user_language))
|
||||
.route("/system/enterprise/config/id", get(crate::environment::read_enterprise_env_config_id))
|
||||
.route("/system/enterprise/config/server", get(crate::environment::read_enterprise_env_config_server_url))
|
||||
.route("/system/enterprise/config/encryption_secret", get(crate::environment::read_enterprise_env_config_encryption_secret))
|
||||
.route("/system/enterprise/configs", get(crate::environment::read_enterprise_configs))
|
||||
.route("/retrieval/fs/extract", get(crate::file_data::extract_data))
|
||||
.route("/log/paths", get(crate::log::get_log_paths))
|
||||
.route("/log/event", post(crate::log::log_event))
|
||||
.route("/shortcuts/register", post(crate::app_window::register_shortcut))
|
||||
.route("/shortcuts/validate", post(crate::app_window::validate_shortcut))
|
||||
.route("/shortcuts/suspend", post(crate::app_window::suspend_shortcuts))
|
||||
.route("/shortcuts/resume", post(crate::app_window::resume_shortcuts));
|
||||
|
||||
// We use the next available port which was determined before:
|
||||
.merge(("port", api_port))
|
||||
|
||||
// The runtime API server should be accessible only from the local machine:
|
||||
.merge(("address", "127.0.0.1"))
|
||||
|
||||
// We do not want to use the Ctrl+C signal to stop the server:
|
||||
.merge(("ctrlc", false))
|
||||
|
||||
// Set a name for the server:
|
||||
.merge(("ident", "AI Studio Runtime API"))
|
||||
|
||||
// Set the maximum number of workers and blocking threads:
|
||||
.merge(("workers", 3))
|
||||
.merge(("max_blocking", 12))
|
||||
|
||||
// No colors and emojis in the log output:
|
||||
.merge(("cli_colors", false))
|
||||
|
||||
// Read the TLS certificate and key from the generated certificate data in-memory:
|
||||
.merge(("tls.certs", CERTIFICATE.get().unwrap()))
|
||||
.merge(("tls.key", CERTIFICATE_PRIVATE_KEY.get().unwrap()))
|
||||
|
||||
// Set the shutdown configuration:
|
||||
.merge(("shutdown", shutdown));
|
||||
|
||||
//
|
||||
// Start the runtime API server in a separate thread. This is necessary
|
||||
// because the server is blocking, and we need to run the Tauri app in
|
||||
// parallel:
|
||||
//
|
||||
tauri::async_runtime::spawn(async move {
|
||||
rocket::custom(figment)
|
||||
.mount("/", routes![
|
||||
crate::dotnet::dotnet_port,
|
||||
crate::dotnet::dotnet_ready,
|
||||
crate::qdrant::qdrant_port,
|
||||
crate::clipboard::set_clipboard,
|
||||
crate::app_window::get_event_stream,
|
||||
crate::app_window::check_for_update,
|
||||
crate::app_window::install_update,
|
||||
crate::app_window::exit_app,
|
||||
crate::file_actions::select_directory,
|
||||
crate::file_actions::select_file,
|
||||
crate::file_actions::select_files,
|
||||
crate::file_actions::save_file,
|
||||
crate::secret::get_secret,
|
||||
crate::secret::store_secret,
|
||||
crate::secret::delete_secret,
|
||||
crate::environment::get_data_directory,
|
||||
crate::environment::get_config_directory,
|
||||
crate::environment::read_user_language,
|
||||
crate::environment::read_enterprise_env_config_id,
|
||||
crate::environment::read_enterprise_env_config_server_url,
|
||||
crate::environment::read_enterprise_env_config_encryption_secret,
|
||||
crate::environment::read_enterprise_configs,
|
||||
crate::file_data::extract_data,
|
||||
crate::log::get_log_paths,
|
||||
crate::log::log_event,
|
||||
crate::app_window::register_shortcut,
|
||||
crate::app_window::validate_shortcut,
|
||||
crate::app_window::suspend_shortcuts,
|
||||
crate::app_window::resume_shortcuts,
|
||||
])
|
||||
.ignite().await.unwrap()
|
||||
.launch().await.unwrap();
|
||||
install_rustls_crypto_provider();
|
||||
|
||||
let cert = CERTIFICATE.get().unwrap().clone();
|
||||
let key = CERTIFICATE_PRIVATE_KEY.get().unwrap().clone();
|
||||
let tls_config = RustlsConfig::from_pem(cert, key).await.unwrap();
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], api_port));
|
||||
|
||||
axum_server::bind_rustls(addr, tls_config)
|
||||
.serve(app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
fn create_shutdown() -> Shutdown {
|
||||
//
|
||||
// Create a shutdown configuration, depending on the operating system:
|
||||
//
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::collections::HashSet;
|
||||
let mut shutdown = Shutdown {
|
||||
// We do not want to use the Ctrl+C signal to stop the server:
|
||||
ctrlc: false,
|
||||
|
||||
// Everything else is set to default for now:
|
||||
..Shutdown::default()
|
||||
};
|
||||
|
||||
shutdown.signals = HashSet::new();
|
||||
shutdown
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
{
|
||||
Shutdown {
|
||||
// We do not want to use the Ctrl+C signal to stop the server:
|
||||
ctrlc: false,
|
||||
|
||||
// Everything else is set to default for now:
|
||||
..Shutdown::default()
|
||||
}
|
||||
}
|
||||
fn install_rustls_crypto_provider() {
|
||||
RUSTLS_CRYPTO_PROVIDER_INIT.call_once(|| {
|
||||
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
|
||||
});
|
||||
}
|
||||
@ -1,33 +1,29 @@
|
||||
use once_cell::sync::Lazy;
|
||||
use rocket::http::Status;
|
||||
use rocket::Request;
|
||||
use rocket::request::FromRequest;
|
||||
use axum::extract::FromRequestParts;
|
||||
use axum::http::request::Parts;
|
||||
use axum::http::StatusCode;
|
||||
use crate::api_token::{generate_api_token, APIToken};
|
||||
|
||||
pub static API_TOKEN: Lazy<APIToken> = Lazy::new(|| generate_api_token());
|
||||
pub static API_TOKEN: Lazy<APIToken> = Lazy::new(generate_api_token);
|
||||
|
||||
/// The request outcome type used to handle API token requests.
|
||||
type RequestOutcome<R, T> = rocket::request::Outcome<R, T>;
|
||||
impl<S> FromRequestParts<S> for APIToken
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = StatusCode;
|
||||
|
||||
/// The request outcome implementation for the API token.
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for APIToken {
|
||||
type Error = APITokenError;
|
||||
|
||||
/// Handles the API token requests.
|
||||
async fn from_request(request: &'r Request<'_>) -> RequestOutcome<Self, Self::Error> {
|
||||
let token = request.headers().get_one("token");
|
||||
match token {
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
match parts.headers.get("token").and_then(|value| value.to_str().ok()) {
|
||||
Some(token) => {
|
||||
let received_token = APIToken::from_hex_text(token);
|
||||
if API_TOKEN.validate(&received_token) {
|
||||
RequestOutcome::Success(received_token)
|
||||
Ok(received_token)
|
||||
} else {
|
||||
RequestOutcome::Error((Status::Unauthorized, APITokenError::Invalid))
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
}
|
||||
|
||||
None => RequestOutcome::Error((Status::Unauthorized, APITokenError::Missing)),
|
||||
None => Err(StatusCode::UNAUTHORIZED),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,15 +1,13 @@
|
||||
use keyring::Entry;
|
||||
use log::{error, info, warn};
|
||||
use rocket::post;
|
||||
use rocket::serde::json::Json;
|
||||
use axum::Json;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use keyring::error::Error::NoEntry;
|
||||
use crate::api_token::APIToken;
|
||||
use crate::encryption::{EncryptedText, ENCRYPTION};
|
||||
|
||||
/// Stores a secret in the secret store using the operating system's keyring.
|
||||
#[post("/secrets/store", data = "<request>")]
|
||||
pub fn store_secret(_token: APIToken, request: Json<StoreSecret>) -> Json<StoreSecretResponse> {
|
||||
pub async fn store_secret(_token: APIToken, request: Json<StoreSecret>) -> Json<StoreSecretResponse> {
|
||||
let user_name = request.user_name.as_str();
|
||||
let decrypted_text = match ENCRYPTION.decrypt(&request.secret) {
|
||||
Ok(text) => text,
|
||||
@ -60,8 +58,7 @@ pub struct StoreSecretResponse {
|
||||
}
|
||||
|
||||
/// Retrieves a secret from the secret store using the operating system's keyring.
|
||||
#[post("/secrets/get", data = "<request>")]
|
||||
pub fn get_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<RequestedSecret> {
|
||||
pub async fn get_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<RequestedSecret> {
|
||||
let user_name = request.user_name.as_str();
|
||||
let service = format!("mindwork-ai-studio::{}", request.destination);
|
||||
let entry = Entry::new(service.as_str(), user_name).unwrap();
|
||||
@ -121,8 +118,7 @@ pub struct RequestedSecret {
|
||||
}
|
||||
|
||||
/// Deletes a secret from the secret store using the operating system's keyring.
|
||||
#[post("/secrets/delete", data = "<request>")]
|
||||
pub fn delete_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<DeleteSecretResponse> {
|
||||
pub async fn delete_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<DeleteSecretResponse> {
|
||||
let user_name = request.user_name.as_str();
|
||||
let service = format!("mindwork-ai-studio::{}", request.destination);
|
||||
let entry = Entry::new(service.as_str(), user_name).unwrap();
|
||||
|
||||
Loading…
Reference in New Issue
Block a user