mirror of
https://github.com/MindWorkAI/AI-Studio.git
synced 2025-04-28 17:59:46 +00:00
Implemented a mandatory & secret API token for the runtime API
This commit is contained in:
parent
f929789535
commit
a19c1a0b7e
@ -18,6 +18,7 @@ internal sealed class Program
|
||||
{
|
||||
public static RustService RUST_SERVICE = null!;
|
||||
public static Encryption ENCRYPTION = null!;
|
||||
public static string API_TOKEN = null!;
|
||||
|
||||
public static async Task Main(string[] args)
|
||||
{
|
||||
@ -52,6 +53,15 @@ internal sealed class Program
|
||||
return;
|
||||
}
|
||||
|
||||
var apiToken = Environment.GetEnvironmentVariable("AI_STUDIO_API_TOKEN");
|
||||
if(string.IsNullOrWhiteSpace(apiToken))
|
||||
{
|
||||
Console.WriteLine("Error: The AI_STUDIO_API_TOKEN environment variable is not set.");
|
||||
return;
|
||||
}
|
||||
|
||||
API_TOKEN = apiToken;
|
||||
|
||||
var rustApiPort = args[0];
|
||||
using var rust = new RustService(rustApiPort, certificateFingerprint);
|
||||
var appPort = await rust.GetAppPort();
|
||||
|
18
app/MindWork AI Studio/Tools/HttpRequestHeadersExtensions.cs
Normal file
18
app/MindWork AI Studio/Tools/HttpRequestHeadersExtensions.cs
Normal file
@ -0,0 +1,18 @@
|
||||
using System.Net.Http.Headers;
|
||||
|
||||
namespace AIStudio.Tools;
|
||||
|
||||
public static class HttpRequestHeadersExtensions
|
||||
{
|
||||
private static readonly string API_TOKEN;
|
||||
|
||||
static HttpRequestHeadersExtensions()
|
||||
{
|
||||
API_TOKEN = Program.API_TOKEN;
|
||||
}
|
||||
|
||||
public static void AddApiToken(this HttpRequestHeaders headers)
|
||||
{
|
||||
headers.Add("token", API_TOKEN);
|
||||
}
|
||||
}
|
@ -48,6 +48,8 @@ public sealed class RustService : IDisposable
|
||||
DefaultRequestVersion = Version.Parse("2.0"),
|
||||
DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher,
|
||||
};
|
||||
|
||||
this.http.DefaultRequestHeaders.AddApiToken();
|
||||
}
|
||||
|
||||
public void SetLogger(ILogger<RustService> logService)
|
||||
@ -99,8 +101,10 @@ public sealed class RustService : IDisposable
|
||||
return currentCertificateFingerprint == this.certificateFingerprint;
|
||||
}
|
||||
});
|
||||
|
||||
initialHttp.DefaultRequestVersion = Version.Parse("2.0");
|
||||
initialHttp.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrHigher;
|
||||
initialHttp.DefaultRequestHeaders.AddApiToken();
|
||||
|
||||
try
|
||||
{
|
||||
|
@ -32,16 +32,22 @@ use rcgen::generate_simple_self_signed;
|
||||
use rocket::figment::Figment;
|
||||
use rocket::{data, get, post, routes, Data, Request};
|
||||
use rocket::config::{Shutdown};
|
||||
use rocket::data::{Outcome, ToByteUnit};
|
||||
use rocket::data::{ToByteUnit};
|
||||
use rocket::http::Status;
|
||||
use rocket::request::{FromRequest};
|
||||
use rocket::serde::json::Json;
|
||||
use sha2::{Sha256, Sha512, Digest};
|
||||
use tauri::updater::UpdateResponse;
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
type Aes256CbcEnc = cbc::Encryptor<aes::Aes256>;
|
||||
|
||||
type Aes256CbcDec = cbc::Decryptor<aes::Aes256>;
|
||||
|
||||
type DataOutcome<'r, T> = data::Outcome<'r, T>;
|
||||
|
||||
type RequestOutcome<R, T> = rocket::request::Outcome<R, T>;
|
||||
|
||||
// The .NET server is started in a separate process and communicates with this
|
||||
// runtime process via IPC. However, we do net start the .NET server in
|
||||
// the development environment.
|
||||
@ -88,6 +94,13 @@ static ENCRYPTION: Lazy<Encryption> = Lazy::new(|| {
|
||||
Encryption::new(&secret_key, &secret_key_salt).unwrap()
|
||||
});
|
||||
|
||||
static API_TOKEN: Lazy<APIToken> = Lazy::new(|| {
|
||||
let mut token = [0u8; 32];
|
||||
let mut rng = rand_chacha::ChaChaRng::from_entropy();
|
||||
rng.fill_bytes(&mut token);
|
||||
APIToken::from_bytes(token.to_vec())
|
||||
});
|
||||
|
||||
static DATA_DIRECTORY: OnceLock<String> = OnceLock::new();
|
||||
|
||||
static CONFIG_DIRECTORY: OnceLock<String> = OnceLock::new();
|
||||
@ -229,6 +242,13 @@ async fn main() {
|
||||
let secret_password = BASE64_STANDARD.encode(ENCRYPTION.secret_password);
|
||||
let secret_key_salt = BASE64_STANDARD.encode(ENCRYPTION.secret_key_salt);
|
||||
|
||||
let dotnet_server_environment = HashMap::from_iter([
|
||||
(String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password),
|
||||
(String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt),
|
||||
(String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint),
|
||||
(String::from("AI_STUDIO_API_TOKEN"), API_TOKEN.to_hex_text().to_string()),
|
||||
]);
|
||||
|
||||
info!("Secret password for the IPC channel was generated successfully.");
|
||||
info!("Try to start the .NET server...");
|
||||
let server_spawn_clone = DOTNET_SERVER.clone();
|
||||
@ -248,13 +268,7 @@ async fn main() {
|
||||
// We provide the runtime API server port to the .NET server:
|
||||
.args(["run", "--project", "../app/MindWork AI Studio", "--", format!("{api_port}").as_str()])
|
||||
|
||||
// Provide the secret password & salt for the IPC channel to the .NET server by using
|
||||
// an environment variable. We must use a HashMap for this:
|
||||
.envs(HashMap::from_iter([
|
||||
(String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password),
|
||||
(String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt),
|
||||
(String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint),
|
||||
]))
|
||||
.envs(dotnet_server_environment)
|
||||
.spawn()
|
||||
.expect("Failed to spawn .NET server process.")
|
||||
}
|
||||
@ -266,13 +280,7 @@ async fn main() {
|
||||
// Provide the runtime API server port to the .NET server:
|
||||
.args([format!("{api_port}").as_str()])
|
||||
|
||||
// Provide the secret password & salt for the IPC channel to the .NET server by using
|
||||
// an environment variable. We must use a HashMap for this:
|
||||
.envs(HashMap::from_iter([
|
||||
(String::from("AI_STUDIO_SECRET_PASSWORD"), secret_password),
|
||||
(String::from("AI_STUDIO_SECRET_KEY_SALT"), secret_key_salt),
|
||||
(String::from("AI_STUDIO_CERTIFICATE_FINGERPRINT"), certificate_fingerprint),
|
||||
]))
|
||||
.envs(dotnet_server_environment)
|
||||
.spawn()
|
||||
.expect("Failed to spawn .NET server process.")
|
||||
}
|
||||
@ -428,6 +436,62 @@ async fn main() {
|
||||
}
|
||||
}
|
||||
|
||||
struct APIToken{
|
||||
hex_text: String,
|
||||
}
|
||||
|
||||
impl APIToken {
|
||||
fn from_bytes(bytes: Vec<u8>) -> Self {
|
||||
APIToken {
|
||||
hex_text: bytes.iter().fold(String::new(), |mut result, byte| {
|
||||
result.push_str(&format!("{:02x}", byte));
|
||||
result
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_hex_text(hex_text: &str) -> Self {
|
||||
APIToken {
|
||||
hex_text: hex_text.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_hex_text(&self) -> &str {
|
||||
self.hex_text.as_str()
|
||||
}
|
||||
|
||||
fn validate(&self, received_token: &Self) -> bool {
|
||||
received_token.to_hex_text() == self.to_hex_text()
|
||||
}
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for APIToken {
|
||||
type Error = APITokenError;
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> RequestOutcome<Self, Self::Error> {
|
||||
let token = request.headers().get_one("token");
|
||||
match token {
|
||||
Some(token) => {
|
||||
let received_token = APIToken::from_hex_text(token);
|
||||
if API_TOKEN.validate(&received_token) {
|
||||
RequestOutcome::Success(received_token)
|
||||
} else {
|
||||
RequestOutcome::Error((Status::Unauthorized, APITokenError::Invalid))
|
||||
}
|
||||
}
|
||||
|
||||
None => RequestOutcome::Error((Status::Unauthorized, APITokenError::Missing)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum APITokenError {
|
||||
Missing,
|
||||
Invalid,
|
||||
}
|
||||
|
||||
//
|
||||
// Data structure for iterating over key-value pairs of log messages.
|
||||
//
|
||||
@ -615,30 +679,30 @@ impl fmt::Display for EncryptedText {
|
||||
#[rocket::async_trait]
|
||||
impl<'r> data::FromData<'r> for EncryptedText {
|
||||
type Error = String;
|
||||
async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> {
|
||||
async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> DataOutcome<'r, Self> {
|
||||
let content_type = req.content_type();
|
||||
if content_type.map_or(true, |ct| !ct.is_text()) {
|
||||
return Outcome::Forward((data, Status::Ok));
|
||||
return DataOutcome::Forward((data, Status::Ok));
|
||||
}
|
||||
|
||||
let mut stream = data.open(2.mebibytes());
|
||||
let mut body = String::new();
|
||||
if let Err(e) = stream.read_to_string(&mut body).await {
|
||||
return Outcome::Error((Status::InternalServerError, format!("Failed to read data: {}", e)));
|
||||
return DataOutcome::Error((Status::InternalServerError, format!("Failed to read data: {}", e)));
|
||||
}
|
||||
|
||||
Outcome::Success(EncryptedText(body))
|
||||
DataOutcome::Success(EncryptedText(body))
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/system/dotnet/port")]
|
||||
fn dotnet_port() -> String {
|
||||
fn dotnet_port(_token: APIToken) -> String {
|
||||
let dotnet_server_port = *DOTNET_SERVER_PORT;
|
||||
format!("{dotnet_server_port}")
|
||||
}
|
||||
|
||||
#[get("/system/directories/data")]
|
||||
fn get_data_directory() -> String {
|
||||
fn get_data_directory(_token: APIToken) -> String {
|
||||
match DATA_DIRECTORY.get() {
|
||||
Some(data_directory) => data_directory.clone(),
|
||||
None => String::from(""),
|
||||
@ -646,7 +710,7 @@ fn get_data_directory() -> String {
|
||||
}
|
||||
|
||||
#[get("/system/directories/config")]
|
||||
fn get_config_directory() -> String {
|
||||
fn get_config_directory(_token: APIToken) -> String {
|
||||
match CONFIG_DIRECTORY.get() {
|
||||
Some(config_directory) => config_directory.clone(),
|
||||
None => String::from(""),
|
||||
@ -654,7 +718,7 @@ fn get_config_directory() -> String {
|
||||
}
|
||||
|
||||
#[get("/system/dotnet/ready")]
|
||||
async fn dotnet_ready() {
|
||||
async fn dotnet_ready(_token: APIToken) {
|
||||
let main_window_spawn_clone = &MAIN_WINDOW;
|
||||
let dotnet_server_port = *DOTNET_SERVER_PORT;
|
||||
let url = match Url::parse(format!("http://localhost:{dotnet_server_port}").as_str())
|
||||
@ -724,7 +788,7 @@ fn stop_servers() {
|
||||
}
|
||||
|
||||
#[get("/updates/check")]
|
||||
async fn check_for_update() -> Json<CheckUpdateResponse> {
|
||||
async fn check_for_update(_token: APIToken) -> Json<CheckUpdateResponse> {
|
||||
let app_handle = MAIN_WINDOW.lock().unwrap().as_ref().unwrap().app_handle();
|
||||
let response = app_handle.updater().check().await;
|
||||
match response {
|
||||
@ -777,7 +841,7 @@ struct CheckUpdateResponse {
|
||||
}
|
||||
|
||||
#[get("/updates/install")]
|
||||
async fn install_update() {
|
||||
async fn install_update(_token: APIToken) {
|
||||
let cloned_response_option = CHECK_UPDATE_RESPONSE.lock().unwrap().clone();
|
||||
match cloned_response_option {
|
||||
Some(update_response) => {
|
||||
@ -791,7 +855,7 @@ async fn install_update() {
|
||||
}
|
||||
|
||||
#[post("/secrets/store", data = "<request>")]
|
||||
fn store_secret(request: Json<StoreSecret>) -> Json<StoreSecretResponse> {
|
||||
fn store_secret(_token: APIToken, request: Json<StoreSecret>) -> Json<StoreSecretResponse> {
|
||||
let user_name = request.user_name.as_str();
|
||||
let decrypted_text = match ENCRYPTION.decrypt(&request.secret) {
|
||||
Ok(text) => text,
|
||||
@ -840,7 +904,7 @@ struct StoreSecretResponse {
|
||||
}
|
||||
|
||||
#[post("/secrets/get", data = "<request>")]
|
||||
fn get_secret(request: Json<RequestSecret>) -> Json<RequestedSecret> {
|
||||
fn get_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<RequestedSecret> {
|
||||
let user_name = request.user_name.as_str();
|
||||
let service = format!("mindwork-ai-studio::{}", request.destination);
|
||||
let entry = Entry::new(service.as_str(), user_name).unwrap();
|
||||
@ -894,7 +958,7 @@ struct RequestedSecret {
|
||||
}
|
||||
|
||||
#[post("/secrets/delete", data = "<request>")]
|
||||
fn delete_secret(request: Json<RequestSecret>) -> Json<DeleteSecretResponse> {
|
||||
fn delete_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<DeleteSecretResponse> {
|
||||
let user_name = request.user_name.as_str();
|
||||
let service = format!("mindwork-ai-studio::{}", request.destination);
|
||||
let entry = Entry::new(service.as_str(), user_name).unwrap();
|
||||
@ -938,7 +1002,7 @@ struct DeleteSecretResponse {
|
||||
}
|
||||
|
||||
#[post("/clipboard/set", data = "<encrypted_text>")]
|
||||
fn set_clipboard(encrypted_text: EncryptedText) -> Json<SetClipboardResponse> {
|
||||
fn set_clipboard(_token: APIToken, encrypted_text: EncryptedText) -> Json<SetClipboardResponse> {
|
||||
|
||||
// Decrypt this text first:
|
||||
let decrypted_text = match ENCRYPTION.decrypt(&encrypted_text) {
|
||||
|
Loading…
Reference in New Issue
Block a user