Implemented a mandatory & secret API token for the runtime API

This commit is contained in:
Thorsten Sommer 2024-09-01 18:59:25 +02:00
parent f929789535
commit a19c1a0b7e
Signed by: tsommer
GPG Key ID: 371BBA77A02C0108
4 changed files with 125 additions and 29 deletions

View File

@ -18,6 +18,7 @@ internal sealed class Program
{ {
public static RustService RUST_SERVICE = null!; public static RustService RUST_SERVICE = null!;
public static Encryption ENCRYPTION = null!; public static Encryption ENCRYPTION = null!;
public static string API_TOKEN = null!;
public static async Task Main(string[] args) public static async Task Main(string[] args)
{ {
@ -52,6 +53,15 @@ internal sealed class Program
return; return;
} }
var apiToken = Environment.GetEnvironmentVariable("AI_STUDIO_API_TOKEN");
if(string.IsNullOrWhiteSpace(apiToken))
{
Console.WriteLine("Error: The AI_STUDIO_API_TOKEN environment variable is not set.");
return;
}
API_TOKEN = apiToken;
var rustApiPort = args[0]; var rustApiPort = args[0];
using var rust = new RustService(rustApiPort, certificateFingerprint); using var rust = new RustService(rustApiPort, certificateFingerprint);
var appPort = await rust.GetAppPort(); var appPort = await rust.GetAppPort();

View File

@ -0,0 +1,18 @@
using System.Net.Http.Headers;
namespace AIStudio.Tools;
public static class HttpRequestHeadersExtensions
{
private static readonly string API_TOKEN;
static HttpRequestHeadersExtensions()
{
API_TOKEN = Program.API_TOKEN;
}
public static void AddApiToken(this HttpRequestHeaders headers)
{
headers.Add("token", API_TOKEN);
}
}

View File

@ -48,6 +48,8 @@ public sealed class RustService : IDisposable
DefaultRequestVersion = Version.Parse("2.0"), DefaultRequestVersion = Version.Parse("2.0"),
DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher, DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher,
}; };
this.http.DefaultRequestHeaders.AddApiToken();
} }
public void SetLogger(ILogger<RustService> logService) public void SetLogger(ILogger<RustService> logService)
@ -99,8 +101,10 @@ public sealed class RustService : IDisposable
return currentCertificateFingerprint == this.certificateFingerprint; return currentCertificateFingerprint == this.certificateFingerprint;
} }
}); });
initialHttp.DefaultRequestVersion = Version.Parse("2.0"); initialHttp.DefaultRequestVersion = Version.Parse("2.0");
initialHttp.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher; initialHttp.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher;
initialHttp.DefaultRequestHeaders.AddApiToken();
try try
{ {

View File

@ -32,16 +32,22 @@ use rcgen::generate_simple_self_signed;
use rocket::figment::Figment; use rocket::figment::Figment;
use rocket::{data, get, post, routes, Data, Request}; use rocket::{data, get, post, routes, Data, Request};
use rocket::config::{Shutdown}; use rocket::config::{Shutdown};
use rocket::data::{Outcome, ToByteUnit}; use rocket::data::{ToByteUnit};
use rocket::http::Status; use rocket::http::Status;
use rocket::request::{FromRequest};
use rocket::serde::json::Json; use rocket::serde::json::Json;
use sha2::{Sha256, Sha512, Digest}; use sha2::{Sha256, Sha512, Digest};
use tauri::updater::UpdateResponse; use tauri::updater::UpdateResponse;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
type Aes256CbcEnc = cbc::Encryptor<aes::Aes256>; type Aes256CbcEnc = cbc::Encryptor<aes::Aes256>;
type Aes256CbcDec = cbc::Decryptor<aes::Aes256>; type Aes256CbcDec = cbc::Decryptor<aes::Aes256>;
type DataOutcome<'r, T> = data::Outcome<'r, T>;
type RequestOutcome<R, T> = rocket::request::Outcome<R, T>;
// The .NET server is started in a separate process and communicates with this // The .NET server is started in a separate process and communicates with this
// runtime process via IPC. However, we do net start the .NET server in // runtime process via IPC. However, we do net start the .NET server in
// the development environment. // the development environment.
@ -88,6 +94,13 @@ static ENCRYPTION: Lazy<Encryption> = Lazy::new(|| {
Encryption::new(&secret_key, &secret_key_salt).unwrap() Encryption::new(&secret_key, &secret_key_salt).unwrap()
}); });
static API_TOKEN: Lazy<APIToken> = Lazy::new(|| {
let mut token = [0u8; 32];
let mut rng = rand_chacha::ChaChaRng::from_entropy();
rng.fill_bytes(&mut token);
APIToken::from_bytes(token.to_vec())
});
static DATA_DIRECTORY: OnceLock<String> = OnceLock::new(); static DATA_DIRECTORY: OnceLock<String> = OnceLock::new();
static CONFIG_DIRECTORY: OnceLock<String> = OnceLock::new(); static CONFIG_DIRECTORY: OnceLock<String> = OnceLock::new();
@ -229,6 +242,13 @@ async fn main() {
let secret_password = BASE64_STANDARD.encode(ENCRYPTION.secret_password); let secret_password = BASE64_STANDARD.encode(ENCRYPTION.secret_password);
let secret_key_salt = BASE64_STANDARD.encode(ENCRYPTION.secret_key_salt); let secret_key_salt = BASE64_STANDARD.encode(ENCRYPTION.secret_key_salt);
let dotnet_server_environment = HashMap::from_iter([
(String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password),
(String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt),
(String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint),
(String::from("AI_STUDIO_API_TOKEN"), API_TOKEN.to_hex_text().to_string()),
]);
info!("Secret password for the IPC channel was generated successfully."); info!("Secret password for the IPC channel was generated successfully.");
info!("Try to start the .NET server..."); info!("Try to start the .NET server...");
let server_spawn_clone = DOTNET_SERVER.clone(); let server_spawn_clone = DOTNET_SERVER.clone();
@ -248,13 +268,7 @@ async fn main() {
// We provide the runtime API server port to the .NET server: // We provide the runtime API server port to the .NET server:
.args(["run", "--project", "../app/MindWork AI Studio", "--", format!("{api_port}").as_str()]) .args(["run", "--project", "../app/MindWork AI Studio", "--", format!("{api_port}").as_str()])
// Provide the secret password & salt for the IPC channel to the .NET server by using .envs(dotnet_server_environment)
// an environment variable. We must use a HashMap for this:
.envs(HashMap::from_iter([
(String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password),
(String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt),
(String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint),
]))
.spawn() .spawn()
.expect("Failed to spawn .NET server process.") .expect("Failed to spawn .NET server process.")
} }
@ -266,13 +280,7 @@ async fn main() {
// Provide the runtime API server port to the .NET server: // Provide the runtime API server port to the .NET server:
.args([format!("{api_port}").as_str()]) .args([format!("{api_port}").as_str()])
// Provide the secret password & salt for the IPC channel to the .NET server by using .envs(dotnet_server_environment)
// an environment variable. We must use a HashMap for this:
.envs(HashMap::from_iter([
(String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password),
(String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt),
(String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint),
]))
.spawn() .spawn()
.expect("Failed to spawn .NET server process.") .expect("Failed to spawn .NET server process.")
} }
@ -428,6 +436,62 @@ async fn main() {
} }
} }
struct APIToken{
hex_text: String,
}
impl APIToken {
fn from_bytes(bytes: Vec<u8>) -> Self {
APIToken {
hex_text: bytes.iter().fold(String::new(), |mut result, byte| {
result.push_str(&format!("{:02x}", byte));
result
}),
}
}
fn from_hex_text(hex_text: &str) -> Self {
APIToken {
hex_text: hex_text.to_string(),
}
}
fn to_hex_text(&self) -> &str {
self.hex_text.as_str()
}
fn validate(&self, received_token: &Self) -> bool {
received_token.to_hex_text() == self.to_hex_text()
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for APIToken {
type Error = APITokenError;
async fn from_request(request: &'r Request<'_>) -> RequestOutcome<Self, Self::Error> {
let token = request.headers().get_one("token");
match token {
Some(token) => {
let received_token = APIToken::from_hex_text(token);
if API_TOKEN.validate(&received_token) {
RequestOutcome::Success(received_token)
} else {
RequestOutcome::Error((Status::Unauthorized, APITokenError::Invalid))
}
}
None => RequestOutcome::Error((Status::Unauthorized, APITokenError::Missing)),
}
}
}
#[derive(Debug)]
enum APITokenError {
Missing,
Invalid,
}
// //
// Data structure for iterating over key-value pairs of log messages. // Data structure for iterating over key-value pairs of log messages.
// //
@ -615,30 +679,30 @@ impl fmt::Display for EncryptedText {
#[rocket::async_trait] #[rocket::async_trait]
impl<'r> data::FromData<'r> for EncryptedText { impl<'r> data::FromData<'r> for EncryptedText {
type Error = String; type Error = String;
async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> { async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> DataOutcome<'r, Self> {
let content_type = req.content_type(); let content_type = req.content_type();
if content_type.map_or(true, |ct| !ct.is_text()) { if content_type.map_or(true, |ct| !ct.is_text()) {
return Outcome::Forward((data, Status::Ok)); return DataOutcome::Forward((data, Status::Ok));
} }
let mut stream = data.open(2.mebibytes()); let mut stream = data.open(2.mebibytes());
let mut body = String::new(); let mut body = String::new();
if let Err(e) = stream.read_to_string(&mut body).await { if let Err(e) = stream.read_to_string(&mut body).await {
return Outcome::Error((Status::InternalServerError, format!("Failed to read data: {}", e))); return DataOutcome::Error((Status::InternalServerError, format!("Failed to read data: {}", e)));
} }
Outcome::Success(EncryptedText(body)) DataOutcome::Success(EncryptedText(body))
} }
} }
#[get("/system/dotnet/port")] #[get("/system/dotnet/port")]
fn dotnet_port() -> String { fn dotnet_port(_token: APIToken) -> String {
let dotnet_server_port = *DOTNET_SERVER_PORT; let dotnet_server_port = *DOTNET_SERVER_PORT;
format!("{dotnet_server_port}") format!("{dotnet_server_port}")
} }
#[get("/system/directories/data")] #[get("/system/directories/data")]
fn get_data_directory() -> String { fn get_data_directory(_token: APIToken) -> String {
match DATA_DIRECTORY.get() { match DATA_DIRECTORY.get() {
Some(data_directory) => data_directory.clone(), Some(data_directory) => data_directory.clone(),
None => String::from(""), None => String::from(""),
@ -646,7 +710,7 @@ fn get_data_directory() -> String {
} }
#[get("/system/directories/config")] #[get("/system/directories/config")]
fn get_config_directory() -> String { fn get_config_directory(_token: APIToken) -> String {
match CONFIG_DIRECTORY.get() { match CONFIG_DIRECTORY.get() {
Some(config_directory) => config_directory.clone(), Some(config_directory) => config_directory.clone(),
None => String::from(""), None => String::from(""),
@ -654,7 +718,7 @@ fn get_config_directory() -> String {
} }
#[get("/system/dotnet/ready")] #[get("/system/dotnet/ready")]
async fn dotnet_ready() { async fn dotnet_ready(_token: APIToken) {
let main_window_spawn_clone = &MAIN_WINDOW; let main_window_spawn_clone = &MAIN_WINDOW;
let dotnet_server_port = *DOTNET_SERVER_PORT; let dotnet_server_port = *DOTNET_SERVER_PORT;
let url = match Url::parse(format!("http://localhost:{dotnet_server_port}").as_str()) let url = match Url::parse(format!("http://localhost:{dotnet_server_port}").as_str())
@ -724,7 +788,7 @@ fn stop_servers() {
} }
#[get("/updates/check")] #[get("/updates/check")]
async fn check_for_update() -> Json<CheckUpdateResponse> { async fn check_for_update(_token: APIToken) -> Json<CheckUpdateResponse> {
let app_handle = MAIN_WINDOW.lock().unwrap().as_ref().unwrap().app_handle(); let app_handle = MAIN_WINDOW.lock().unwrap().as_ref().unwrap().app_handle();
let response = app_handle.updater().check().await; let response = app_handle.updater().check().await;
match response { match response {
@ -777,7 +841,7 @@ struct CheckUpdateResponse {
} }
#[get("/updates/install")] #[get("/updates/install")]
async fn install_update() { async fn install_update(_token: APIToken) {
let cloned_response_option = CHECK_UPDATE_RESPONSE.lock().unwrap().clone(); let cloned_response_option = CHECK_UPDATE_RESPONSE.lock().unwrap().clone();
match cloned_response_option { match cloned_response_option {
Some(update_response) => { Some(update_response) => {
@ -791,7 +855,7 @@ async fn install_update() {
} }
#[post("/secrets/store", data = "<request>")] #[post("/secrets/store", data = "<request>")]
fn store_secret(request: Json<StoreSecret>) -> Json<StoreSecretResponse> { fn store_secret(_token: APIToken, request: Json<StoreSecret>) -> Json<StoreSecretResponse> {
let user_name = request.user_name.as_str(); let user_name = request.user_name.as_str();
let decrypted_text = match ENCRYPTION.decrypt(&request.secret) { let decrypted_text = match ENCRYPTION.decrypt(&request.secret) {
Ok(text) => text, Ok(text) => text,
@ -840,7 +904,7 @@ struct StoreSecretResponse {
} }
#[post("/secrets/get", data = "<request>")] #[post("/secrets/get", data = "<request>")]
fn get_secret(request: Json<RequestSecret>) -> Json<RequestedSecret> { fn get_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<RequestedSecret> {
let user_name = request.user_name.as_str(); let user_name = request.user_name.as_str();
let service = format!("mindwork-ai-studio::{}", request.destination); let service = format!("mindwork-ai-studio::{}", request.destination);
let entry = Entry::new(service.as_str(), user_name).unwrap(); let entry = Entry::new(service.as_str(), user_name).unwrap();
@ -894,7 +958,7 @@ struct RequestedSecret {
} }
#[post("/secrets/delete", data = "<request>")] #[post("/secrets/delete", data = "<request>")]
fn delete_secret(request: Json<RequestSecret>) -> Json<DeleteSecretResponse> { fn delete_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<DeleteSecretResponse> {
let user_name = request.user_name.as_str(); let user_name = request.user_name.as_str();
let service = format!("mindwork-ai-studio::{}", request.destination); let service = format!("mindwork-ai-studio::{}", request.destination);
let entry = Entry::new(service.as_str(), user_name).unwrap(); let entry = Entry::new(service.as_str(), user_name).unwrap();
@ -938,7 +1002,7 @@ struct DeleteSecretResponse {
} }
#[post("/clipboard/set", data = "<encrypted_text>")] #[post("/clipboard/set", data = "<encrypted_text>")]
fn set_clipboard(encrypted_text: EncryptedText) -> Json<SetClipboardResponse> { fn set_clipboard(_token: APIToken, encrypted_text: EncryptedText) -> Json<SetClipboardResponse> {
// Decrypt this text first: // Decrypt this text first:
let decrypted_text = match ENCRYPTION.decrypt(&encrypted_text) { let decrypted_text = match ENCRYPTION.decrypt(&encrypted_text) {