token count now adapts to chosen provider tokenizer

This commit is contained in:
PaulKoudelka 2026-04-14 16:27:10 +02:00
parent e07ca378d4
commit 41573406d5
10 changed files with 263336 additions and 39 deletions

View File

@ -994,6 +994,16 @@ public partial class ChatComponent : MSGComponentBase, IAsyncDisposable
this.tokenCount = "0"; this.tokenCount = "0";
return; return;
} }
var tokenizerResponse = await this.RustService.EnsureTokenizer(this.Provider);
if (tokenizerResponse is null)
return;
if (!tokenizerResponse.Value.Success)
{
this.Logger.LogWarning($"Failed to initialize the tokenizer for the provider: {tokenizerResponse.Value.Message}");
return;
}
var response = await this.RustService.GetTokenCount(this.inputField.Value); var response = await this.RustService.GetTokenCount(this.inputField.Value);
if (response is null) if (response is null)
return; return;

View File

@ -243,7 +243,7 @@ public partial class EmbeddingProviderDialog : MSGComponentBase, ISecretId
if (!this.dataIsValid) if (!this.dataIsValid)
return; return;
var response = await this.RustService.StoreTokenizer(this.DataName, this.dataEditingPreviousInstanceName, this.dataFilePath); var response = await this.RustService.StoreTokenizer("embedding_"+this.DataName, "embedding_"+this.dataEditingPreviousInstanceName, this.dataFilePath);
Console.WriteLine($"Response from Rust: {response.Message}"); Console.WriteLine($"Response from Rust: {response.Message}");
if (!response.Success) if (!response.Success)
{ {

View File

@ -268,7 +268,7 @@ public partial class ProviderDialog : MSGComponentBase, ISecretId
if (!this.dataIsValid) if (!this.dataIsValid)
return; return;
var tokenizerResponse = await this.RustService.StoreTokenizer(this.DataInstanceName, this.dataEditingPreviousInstanceName, this.dataFilePath); var tokenizerResponse = await this.RustService.StoreTokenizer("chat_"+this.DataInstanceName, "chat_"+this.dataEditingPreviousInstanceName, this.dataFilePath);
if (!tokenizerResponse.Success) if (!tokenizerResponse.Success)
{ {
this.dataCustomTokenizerValidationIssue = tokenizerResponse.Message; this.dataCustomTokenizerValidationIssue = tokenizerResponse.Message;

View File

@ -1,9 +1,14 @@
using AIStudio.Tools.Rust; using AIStudio.Provider;
using AIStudio.Tools.Rust;
namespace AIStudio.Tools.Services; namespace AIStudio.Tools.Services;
public sealed partial class RustService public sealed partial class RustService
{ {
private readonly SemaphoreSlim tokenizerLock = new(1, 1);
private string currentTokenizerPath = string.Empty;
private bool hasInitializedTokenizer;
public async Task<TokenizerResponse> ValidateTokenizer(string filePath) public async Task<TokenizerResponse> ValidateTokenizer(string filePath)
{ {
var result = await this.http.PostAsJsonAsync("/tokenizer/validate", new { var result = await this.http.PostAsJsonAsync("/tokenizer/validate", new {
@ -66,4 +71,57 @@ public sealed partial class RustService
return null; return null;
} }
} }
}
public async Task<TokenizerResponse?> SetTokenizer(string providerName, string path)
{
Console.WriteLine($"Setting a new tokenizer for '{providerName}'");
var result = await this.http.PostAsJsonAsync("/tokenizer/set", new {
file_path = path,
}, this.jsonRustSerializerOptions);
if (!result.IsSuccessStatusCode)
{
this.logger!.LogError($"Failed to set the tokenizer '{result.StatusCode}'");
return new TokenizerResponse{
Success = false,
Message = "An error occured while sending the path to the Rust framework for setting a tokenizer: "+result.StatusCode,
TokenCount = 0
};
}
return await result.Content.ReadFromJsonAsync<TokenizerResponse>(this.jsonRustSerializerOptions);
}
public Task<TokenizerResponse?> EnsureTokenizer(Settings.Provider provider)
{
return this.EnsureTokenizer(provider.InstanceName, provider.TokenizerPath);
}
public Task<TokenizerResponse?> EnsureTokenizer(IProvider provider)
{
return this.EnsureTokenizer(provider.InstanceName, provider.TokenizerPath);
}
private async Task<TokenizerResponse?> EnsureTokenizer(string providerName, string path)
{
await this.tokenizerLock.WaitAsync();
try
{
if (this.hasInitializedTokenizer && this.currentTokenizerPath == path)
return new TokenizerResponse(true, 0, "Success");
var response = await this.SetTokenizer(providerName, path);
if (response is { Success: true })
{
this.currentTokenizerPath = path;
this.hasInitializedTokenizer = true;
}
return response;
}
finally
{
this.tokenizerLock.Release();
}
}
}

View File

@ -91,6 +91,7 @@ public sealed partial class RustService : BackgroundService
{ {
this.http.Dispose(); this.http.Dispose();
this.userLanguageLock.Dispose(); this.userLanguageLock.Dispose();
this.tokenizerLock.Dispose();
base.Dispose(); base.Dispose();
} }

File diff suppressed because it is too large Load Diff

View File

@ -21,6 +21,7 @@ use crate::pdfium::PDFIUM_LIB_PATH;
use crate::qdrant::{cleanup_qdrant, start_qdrant_server, stop_qdrant_server}; use crate::qdrant::{cleanup_qdrant, start_qdrant_server, stop_qdrant_server};
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
use crate::dotnet::create_startup_env_file; use crate::dotnet::create_startup_env_file;
use crate::tokenizer::set_path_resolver;
/// The Tauri main window. /// The Tauri main window.
static MAIN_WINDOW: Lazy<Mutex<Option<Window>>> = Lazy::new(|| Mutex::new(None)); static MAIN_WINDOW: Lazy<Mutex<Option<Window>>> = Lazy::new(|| Mutex::new(None));
@ -117,6 +118,8 @@ pub fn start_tauri() {
cleanup_qdrant(); cleanup_qdrant();
start_qdrant_server(app.path_resolver()); start_qdrant_server(app.path_resolver());
set_path_resolver(app.path_resolver());
info!(Source = "Bootloader Tauri"; "Reconfigure the file logger to use the app data directory {data_path:?}"); info!(Source = "Bootloader Tauri"; "Reconfigure the file logger to use the app data directory {data_path:?}");
switch_to_file_logging(data_path).map_err(|e| error!("Failed to switch logging to file: {e}")).unwrap(); switch_to_file_logging(data_path).map_err(|e| error!("Failed to switch logging to file: {e}")).unwrap();
set_pdfium_path(app.path_resolver()); set_pdfium_path(app.path_resolver());

View File

@ -43,10 +43,6 @@ async fn main() {
info!("Running in production mode."); info!("Running in production mode.");
} }
if let Err(e) = init_tokenizer() {
warn!(Source = "Tokenizer"; "Error during the initialisation of the tokenizer: {}", e);
}
generate_runtime_certificate(); generate_runtime_certificate();
start_runtime_api(); start_runtime_api();

View File

@ -92,6 +92,7 @@ pub fn start_runtime_api() {
crate::tokenizer::token_count, crate::tokenizer::token_count,
crate::tokenizer::validate_tokenizer, crate::tokenizer::validate_tokenizer,
crate::tokenizer::store_tokenizer, crate::tokenizer::store_tokenizer,
crate::tokenizer::set_tokenizer,
crate::app_window::register_shortcut, crate::app_window::register_shortcut,
crate::app_window::validate_shortcut, crate::app_window::validate_shortcut,
crate::app_window::suspend_shortcuts, crate::app_window::suspend_shortcuts,

View File

@ -1,16 +1,20 @@
use std::fs; use std::fs;
use std::path::{PathBuf}; use std::path::PathBuf;
use std::sync::OnceLock; use std::sync::{OnceLock, RwLock};
use rocket::{post}; use log::warn;
use rocket::post;
use rocket::serde::json::Json; use rocket::serde::json::Json;
use rocket::serde::Serialize; use rocket::serde::Serialize;
use serde::Deserialize; use serde::Deserialize;
use tauri::PathResolver;
use tokenizers::Error; use tokenizers::Error;
use tokenizers::tokenizer::{Tokenizer, Error as TokenizerError}; use tokenizers::tokenizer::{Tokenizer, Error as TokenizerError};
use crate::api_token::APIToken; use crate::api_token::APIToken;
use crate::environment::{DATA_DIRECTORY}; use crate::environment::DATA_DIRECTORY;
static TOKENIZER: OnceLock<Tokenizer> = OnceLock::new(); static TOKENIZER: OnceLock<RwLock<Option<Tokenizer>>> = OnceLock::new();
static TOKENIZER_PATH_RESOLVER: OnceLock<PathResolver> = OnceLock::new();
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct SetTokenText { pub struct SetTokenText {
@ -25,7 +29,7 @@ pub struct TokenizerStorage {
} }
#[derive(Clone, Deserialize)] #[derive(Clone, Deserialize)]
pub struct TokenizerValidation { pub struct TokenizerPath {
file_path: String, file_path: String,
} }
@ -53,15 +57,36 @@ impl From<Result<usize, TokenizerError>> for TokenizerResponse {
} }
} }
pub fn init_tokenizer() -> Result<(), Error>{ pub fn set_path_resolver(path_resolver: PathResolver) {
let mut target_dir = PathBuf::from("target"); match TOKENIZER_PATH_RESOLVER.set(path_resolver) {
target_dir.push("tokenizers"); Ok(_) => (),
fs::create_dir_all(&target_dir)?; Err(e) => warn!(Source = "Tokenizer"; "Could not set the path resolver: {:?}", e),
}
}
let mut local_tokenizer_path = target_dir.clone(); fn tokenizer_state() -> &'static RwLock<Option<Tokenizer>> {
local_tokenizer_path.push("tokenizer.json"); TOKENIZER.get_or_init(|| RwLock::new(None))
}
pub fn init_tokenizer(path: &str) -> Result<(), Error> {
let tokenizer_path = if path.trim().is_empty() {
let relative_source_path = String::from("resources/tokenizers/tokenizer.json");
let path_resolver = TOKENIZER_PATH_RESOLVER
.get()
.ok_or_else(|| Error::from("Tokenizer path resolver is not initialized"))?;
path_resolver
.resolve_resource(relative_source_path)
.ok_or_else(|| Error::from("Failed to resolve default tokenizer resource path"))?
} else {
PathBuf::from(path)
};
let tokenizer = Tokenizer::from_file(tokenizer_path)?;
let mut tokenizer_guard = tokenizer_state()
.write()
.map_err(|_| Error::from("Tokenizer state lock is poisoned"))?;
*tokenizer_guard = Some(tokenizer);
TOKENIZER.set(Tokenizer::from_file(local_tokenizer_path)?).expect("Could not set the tokenizer.");
Ok(()) Ok(())
} }
@ -93,13 +118,13 @@ fn validate_tokenizer_at_path(path: &PathBuf) -> Result<usize, TokenizerError> {
if token_count == 0 { if token_count == 0 {
return Err(TokenizerError::from( return Err(TokenizerError::from(
"Tokenizer produced 0 tokens for test string. The tokenizer is likely invalid or misconfigured." "Tokenizer produced 0 tokens for test string. The tokenizer is likely invalid or misconfigured.",
)); ));
} }
if encoding.get_tokens().iter().any(|t| t.is_empty()) { if encoding.get_tokens().iter().any(|t| t.is_empty()) {
return Err(TokenizerError::from( return Err(TokenizerError::from(
"Tokenizer produced empty tokens. The tokenizer is invalid." "Tokenizer produced empty tokens. The tokenizer is invalid.",
)); ));
} }
@ -113,41 +138,48 @@ fn handle_tokenizer_store(payload: &TokenizerStorage) -> Result<String, std::io:
let base_path = PathBuf::from(data_dir).join("tokenizers"); let base_path = PathBuf::from(data_dir).join("tokenizers");
// Delete previous model if file_path is empty
if payload.file_path.trim().is_empty() { if payload.file_path.trim().is_empty() {
if payload.previous_model_id.trim().is_empty() { if payload.previous_model_id.trim().is_empty() {
return Ok(String::from("")); // Nothing to delete return Ok(String::from(""));
} }
let previous_path = base_path.join(&payload.previous_model_id); let previous_path = base_path.join(&payload.previous_model_id);
fs::remove_dir_all(previous_path)?; fs::remove_dir_all(previous_path)?;
return Ok(String::from("")); return Ok(String::from(""));
} }
// Copy file
let source_path = PathBuf::from(&payload.file_path); let source_path = PathBuf::from(&payload.file_path);
let source_name = source_path.file_name() let source_name = source_path
.file_name()
.and_then(|n| n.to_str()) .and_then(|n| n.to_str())
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid tokenizer file path"))?; .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid tokenizer file path"))?;
let model_path = &base_path.join(&payload.model_id); let model_path = &base_path.join(&payload.model_id);
let destination_path = &model_path.join(source_name); let destination_path = &model_path.join(source_name);
println!("source_path: {}, destination_path: {}", source_path.display(), destination_path.display()); println!(
"source_path: {}, destination_path: {}",
source_path.display(),
destination_path.display()
);
println!("equals {}", source_path.eq(destination_path)); println!("equals {}", source_path.eq(destination_path));
if !source_path.eq(destination_path) && model_path.exists() { if !source_path.eq(destination_path) && model_path.exists() {
fs::remove_dir_all(model_path)?; fs::remove_dir_all(model_path)?;
} }
fs::create_dir_all(model_path)?; fs::create_dir_all(model_path)?;
println!("Moving tokenizer file from {} to {}", source_path.display(), destination_path.display()); println!(
"Moving tokenizer file from {} to {}",
source_path.display(),
destination_path.display()
);
let previous_path = base_path.join(&payload.previous_model_id); let previous_path = base_path.join(&payload.previous_model_id);
// Delete previous tokenizer folder if specified if !payload.previous_model_id.trim().is_empty() && source_path.starts_with(&previous_path) {
if !payload.previous_model_id.trim().is_empty() && source_path.starts_with(&previous_path){
fs::rename(&source_path, &destination_path)?; fs::rename(&source_path, &destination_path)?;
if previous_path.exists() && !previous_path.eq(model_path) { if previous_path.exists() && !previous_path.eq(model_path) {
fs::remove_dir_all(previous_path)?; fs::remove_dir_all(previous_path)?;
} }
}else{ } else {
fs::copy( & source_path, & destination_path)?; fs::copy(&source_path, &destination_path)?;
} }
Ok(destination_path.to_str().unwrap().to_string()) Ok(destination_path.to_str().unwrap().to_string())
} }
@ -157,7 +189,11 @@ pub fn get_token_count(text: &str) -> Result<usize, TokenizerError> {
return Err(TokenizerError::from("Input text is empty")); return Err(TokenizerError::from("Input text is empty"));
} }
let tokenizer = TOKENIZER.get().cloned().ok_or_else(|| TokenizerError::from("Tokenizer not initialized"))?; let tokenizer = tokenizer_state()
.read()
.map_err(|_| TokenizerError::from("Tokenizer state lock is poisoned"))?
.clone()
.ok_or_else(|| TokenizerError::from("Tokenizer not initialized"))?;
let enc = tokenizer.encode(text, true)?; let enc = tokenizer.encode(text, true)?;
Ok(enc.len()) Ok(enc.len())
} }
@ -168,14 +204,17 @@ pub fn token_count(_token: APIToken, req: Json<SetTokenText>) -> Json<TokenizerR
} }
#[post("/tokenizer/validate", data = "<payload>")] #[post("/tokenizer/validate", data = "<payload>")]
pub fn validate_tokenizer(_token: APIToken, payload: Json<TokenizerValidation>) -> Json<TokenizerResponse>{ pub fn validate_tokenizer(_token: APIToken, payload: Json<TokenizerPath>) -> Json<TokenizerResponse> {
println!("Received tokenizer validation request: {}", payload.file_path); println!("Received tokenizer validation request: {}", payload.file_path);
Json(validate_tokenizer_at_path(&PathBuf::from(payload.file_path.clone())).into()) Json(validate_tokenizer_at_path(&PathBuf::from(payload.file_path.clone())).into())
} }
#[post("/tokenizer/store", data = "<payload>")] #[post("/tokenizer/store", data = "<payload>")]
pub fn store_tokenizer(_token: APIToken, payload: Json<TokenizerStorage>) -> Json<TokenizerResponse>{ pub fn store_tokenizer(_token: APIToken, payload: Json<TokenizerStorage>) -> Json<TokenizerResponse> {
println!("Received tokenizer store request: {}, {}, {}", payload.model_id, payload.previous_model_id, payload.file_path); println!(
"Received tokenizer store request: {}, {}, {}",
payload.model_id, payload.previous_model_id, payload.file_path
);
match handle_tokenizer_store(&payload) { match handle_tokenizer_store(&payload) {
Ok(dest_path) => Json(TokenizerResponse { Ok(dest_path) => Json(TokenizerResponse {
success: true, success: true,
@ -188,5 +227,20 @@ pub fn store_tokenizer(_token: APIToken, payload: Json<TokenizerStorage>) -> Jso
message: e.to_string(), message: e.to_string(),
}), }),
} }
}
} #[post("/tokenizer/set", data = "<payload>")]
pub fn set_tokenizer(_token: APIToken, payload: Json<TokenizerPath>) -> Json<TokenizerResponse> {
match init_tokenizer(&payload.file_path) {
Ok(_) => Json(TokenizerResponse {
success: true,
token_count: 0,
message: "Success".to_string(),
}),
Err(e) => Json(TokenizerResponse {
success: false,
token_count: 0,
message: e.to_string(),
}),
}
}