diff --git a/app/MindWork AI Studio/Program.cs b/app/MindWork AI Studio/Program.cs index bddcae3e..a29155c1 100644 --- a/app/MindWork AI Studio/Program.cs +++ b/app/MindWork AI Studio/Program.cs @@ -18,6 +18,7 @@ internal sealed class Program { public static RustService RUST_SERVICE = null!; public static Encryption ENCRYPTION = null!; + public static string API_TOKEN = null!; public static async Task Main(string[] args) { @@ -52,6 +53,15 @@ internal sealed class Program return; } + var apiToken = Environment.GetEnvironmentVariable("AI_STUDIO_API_TOKEN"); + if(string.IsNullOrWhiteSpace(apiToken)) + { + Console.WriteLine("Error: The AI_STUDIO_API_TOKEN environment variable is not set."); + return; + } + + API_TOKEN = apiToken; + var rustApiPort = args[0]; using var rust = new RustService(rustApiPort, certificateFingerprint); var appPort = await rust.GetAppPort(); diff --git a/app/MindWork AI Studio/Tools/HttpRequestHeadersExtensions.cs b/app/MindWork AI Studio/Tools/HttpRequestHeadersExtensions.cs new file mode 100644 index 00000000..7a3fc122 --- /dev/null +++ b/app/MindWork AI Studio/Tools/HttpRequestHeadersExtensions.cs @@ -0,0 +1,18 @@ +using System.Net.Http.Headers; + +namespace AIStudio.Tools; + +public static class HttpRequestHeadersExtensions +{ + private static readonly string API_TOKEN; + + static HttpRequestHeadersExtensions() + { + API_TOKEN = Program.API_TOKEN; + } + + public static void AddApiToken(this HttpRequestHeaders headers) + { + headers.Add("token", API_TOKEN); + } +} \ No newline at end of file diff --git a/app/MindWork AI Studio/Tools/RustService.cs b/app/MindWork AI Studio/Tools/RustService.cs index f382575e..1acd7de9 100644 --- a/app/MindWork AI Studio/Tools/RustService.cs +++ b/app/MindWork AI Studio/Tools/RustService.cs @@ -48,6 +48,8 @@ public sealed class RustService : IDisposable DefaultRequestVersion = Version.Parse("2.0"), DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher, }; + + this.http.DefaultRequestHeaders.AddApiToken(); } public void SetLogger(ILogger logService) @@ -99,8 +101,10 @@ public sealed class RustService : IDisposable return currentCertificateFingerprint == this.certificateFingerprint; } }); + initialHttp.DefaultRequestVersion = Version.Parse("2.0"); initialHttp.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher; + initialHttp.DefaultRequestHeaders.AddApiToken(); try { diff --git a/runtime/src/main.rs b/runtime/src/main.rs index cf280b22..1286da80 100644 --- a/runtime/src/main.rs +++ b/runtime/src/main.rs @@ -32,16 +32,22 @@ use rcgen::generate_simple_self_signed; use rocket::figment::Figment; use rocket::{data, get, post, routes, Data, Request}; use rocket::config::{Shutdown}; -use rocket::data::{Outcome, ToByteUnit}; +use rocket::data::{ToByteUnit}; use rocket::http::Status; +use rocket::request::{FromRequest}; use rocket::serde::json::Json; use sha2::{Sha256, Sha512, Digest}; use tauri::updater::UpdateResponse; use tokio::io::AsyncReadExt; type Aes256CbcEnc = cbc::Encryptor; + type Aes256CbcDec = cbc::Decryptor; +type DataOutcome<'r, T> = data::Outcome<'r, T>; + +type RequestOutcome = rocket::request::Outcome; + // The .NET server is started in a separate process and communicates with this // runtime process via IPC. However, we do net start the .NET server in // the development environment. @@ -88,6 +94,13 @@ static ENCRYPTION: Lazy = Lazy::new(|| { Encryption::new(&secret_key, &secret_key_salt).unwrap() }); +static API_TOKEN: Lazy = Lazy::new(|| { + let mut token = [0u8; 32]; + let mut rng = rand_chacha::ChaChaRng::from_entropy(); + rng.fill_bytes(&mut token); + APIToken::from_bytes(token.to_vec()) +}); + static DATA_DIRECTORY: OnceLock = OnceLock::new(); static CONFIG_DIRECTORY: OnceLock = OnceLock::new(); @@ -229,6 +242,13 @@ async fn main() { let secret_password = BASE64_STANDARD.encode(ENCRYPTION.secret_password); let secret_key_salt = BASE64_STANDARD.encode(ENCRYPTION.secret_key_salt); + let dotnet_server_environment = HashMap::from_iter([ + (String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password), + (String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt), + (String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint), + (String::from("AI_STUDIO_API_TOKEN"), API_TOKEN.to_hex_text().to_string()), + ]); + info!("Secret password for the IPC channel was generated successfully."); info!("Try to start the .NET server..."); let server_spawn_clone = DOTNET_SERVER.clone(); @@ -248,13 +268,7 @@ async fn main() { // We provide the runtime API server port to the .NET server: .args(["run", "--project", "../app/MindWork AI Studio", "--", format!("{api_port}").as_str()]) - // Provide the secret password & salt for the IPC channel to the .NET server by using - // an environment variable. We must use a HashMap for this: - .envs(HashMap::from_iter([ - (String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password), - (String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt), - (String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint), - ])) + .envs(dotnet_server_environment) .spawn() .expect("Failed to spawn .NET server process.") } @@ -266,13 +280,7 @@ async fn main() { // Provide the runtime API server port to the .NET server: .args([format!("{api_port}").as_str()]) - // Provide the secret password & salt for the IPC channel to the .NET server by using - // an environment variable. We must use a HashMap for this: - .envs(HashMap::from_iter([ - (String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password), - (String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt), - (String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint), - ])) + .envs(dotnet_server_environment) .spawn() .expect("Failed to spawn .NET server process.") } @@ -428,6 +436,62 @@ async fn main() { } } +struct APIToken{ + hex_text: String, +} + +impl APIToken { + fn from_bytes(bytes: Vec) -> Self { + APIToken { + hex_text: bytes.iter().fold(String::new(), |mut result, byte| { + result.push_str(&format!("{:02x}", byte)); + result + }), + } + } + + fn from_hex_text(hex_text: &str) -> Self { + APIToken { + hex_text: hex_text.to_string(), + } + } + + fn to_hex_text(&self) -> &str { + self.hex_text.as_str() + } + + fn validate(&self, received_token: &Self) -> bool { + received_token.to_hex_text() == self.to_hex_text() + } +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for APIToken { + type Error = APITokenError; + + async fn from_request(request: &'r Request<'_>) -> RequestOutcome { + let token = request.headers().get_one("token"); + match token { + Some(token) => { + let received_token = APIToken::from_hex_text(token); + if API_TOKEN.validate(&received_token) { + RequestOutcome::Success(received_token) + } else { + RequestOutcome::Error((Status::Unauthorized, APITokenError::Invalid)) + } + } + + None => RequestOutcome::Error((Status::Unauthorized, APITokenError::Missing)), + } + } +} + +#[derive(Debug)] +enum APITokenError { + Missing, + Invalid, +} + // // Data structure for iterating over key-value pairs of log messages. // @@ -615,30 +679,30 @@ impl fmt::Display for EncryptedText { #[rocket::async_trait] impl<'r> data::FromData<'r> for EncryptedText { type Error = String; - async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> { + async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> DataOutcome<'r, Self> { let content_type = req.content_type(); if content_type.map_or(true, |ct| !ct.is_text()) { - return Outcome::Forward((data, Status::Ok)); + return DataOutcome::Forward((data, Status::Ok)); } let mut stream = data.open(2.mebibytes()); let mut body = String::new(); if let Err(e) = stream.read_to_string(&mut body).await { - return Outcome::Error((Status::InternalServerError, format!("Failed to read data: {}", e))); + return DataOutcome::Error((Status::InternalServerError, format!("Failed to read data: {}", e))); } - Outcome::Success(EncryptedText(body)) + DataOutcome::Success(EncryptedText(body)) } } #[get("/system/dotnet/port")] -fn dotnet_port() -> String { +fn dotnet_port(_token: APIToken) -> String { let dotnet_server_port = *DOTNET_SERVER_PORT; format!("{dotnet_server_port}") } #[get("/system/directories/data")] -fn get_data_directory() -> String { +fn get_data_directory(_token: APIToken) -> String { match DATA_DIRECTORY.get() { Some(data_directory) => data_directory.clone(), None => String::from(""), @@ -646,7 +710,7 @@ fn get_data_directory() -> String { } #[get("/system/directories/config")] -fn get_config_directory() -> String { +fn get_config_directory(_token: APIToken) -> String { match CONFIG_DIRECTORY.get() { Some(config_directory) => config_directory.clone(), None => String::from(""), @@ -654,7 +718,7 @@ fn get_config_directory() -> String { } #[get("/system/dotnet/ready")] -async fn dotnet_ready() { +async fn dotnet_ready(_token: APIToken) { let main_window_spawn_clone = &MAIN_WINDOW; let dotnet_server_port = *DOTNET_SERVER_PORT; let url = match Url::parse(format!("http://localhost:{dotnet_server_port}").as_str()) @@ -724,7 +788,7 @@ fn stop_servers() { } #[get("/updates/check")] -async fn check_for_update() -> Json { +async fn check_for_update(_token: APIToken) -> Json { let app_handle = MAIN_WINDOW.lock().unwrap().as_ref().unwrap().app_handle(); let response = app_handle.updater().check().await; match response { @@ -777,7 +841,7 @@ struct CheckUpdateResponse { } #[get("/updates/install")] -async fn install_update() { +async fn install_update(_token: APIToken) { let cloned_response_option = CHECK_UPDATE_RESPONSE.lock().unwrap().clone(); match cloned_response_option { Some(update_response) => { @@ -791,7 +855,7 @@ async fn install_update() { } #[post("/secrets/store", data = "")] -fn store_secret(request: Json) -> Json { +fn store_secret(_token: APIToken, request: Json) -> Json { let user_name = request.user_name.as_str(); let decrypted_text = match ENCRYPTION.decrypt(&request.secret) { Ok(text) => text, @@ -840,7 +904,7 @@ struct StoreSecretResponse { } #[post("/secrets/get", data = "")] -fn get_secret(request: Json) -> Json { +fn get_secret(_token: APIToken, request: Json) -> Json { let user_name = request.user_name.as_str(); let service = format!("mindwork-ai-studio::{}", request.destination); let entry = Entry::new(service.as_str(), user_name).unwrap(); @@ -894,7 +958,7 @@ struct RequestedSecret { } #[post("/secrets/delete", data = "")] -fn delete_secret(request: Json) -> Json { +fn delete_secret(_token: APIToken, request: Json) -> Json { let user_name = request.user_name.as_str(); let service = format!("mindwork-ai-studio::{}", request.destination); let entry = Entry::new(service.as_str(), user_name).unwrap(); @@ -938,7 +1002,7 @@ struct DeleteSecretResponse { } #[post("/clipboard/set", data = "")] -fn set_clipboard(encrypted_text: EncryptedText) -> Json { +fn set_clipboard(_token: APIToken, encrypted_text: EncryptedText) -> Json { // Decrypt this text first: let decrypted_text = match ENCRYPTION.decrypt(&encrypted_text) {