Allow the use of an API key for self-hosted ollama instances (#156)

This commit is contained in:
Thorsten Sommer 2024-10-07 13:26:25 +02:00 committed by GitHub
parent 776fa8ac58
commit 37e113af0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 134 additions and 78 deletions

View File

@ -13,6 +13,7 @@ public partial class Changelog
public static readonly Log[] LOGS = public static readonly Log[] LOGS =
[ [
new (188, "v0.9.13, build 188 (2024-10-07 11:18 UTC)", "v0.9.13.md"),
new (187, "v0.9.12, build 187 (2024-09-15 20:49 UTC)", "v0.9.12.md"), new (187, "v0.9.12, build 187 (2024-09-15 20:49 UTC)", "v0.9.12.md"),
new (186, "v0.9.11, build 186 (2024-09-15 10:33 UTC)", "v0.9.11.md"), new (186, "v0.9.11, build 186 (2024-09-15 10:33 UTC)", "v0.9.11.md"),
new (185, "v0.9.10, build 185 (2024-09-12 20:52 UTC)", "v0.9.10.md"), new (185, "v0.9.10, build 185 (2024-09-12 20:52 UTC)", "v0.9.10.md"),

View File

@ -19,7 +19,7 @@
<MudTextField <MudTextField
T="string" T="string"
@bind-Text="@this.dataAPIKey" @bind-Text="@this.dataAPIKey"
Label="API Key" Label="@this.APIKeyText"
Disabled="@(!this.NeedAPIKey)" Disabled="@(!this.NeedAPIKey)"
Class="mb-3" Class="mb-3"
Adornment="Adornment.Start" Adornment="Adornment.Start"

View File

@ -133,7 +133,7 @@ public partial class ProviderDialog : ComponentBase
// //
// We cannot load the API key for self-hosted providers: // We cannot load the API key for self-hosted providers:
// //
if (this.DataLLMProvider is LLMProviders.SELF_HOSTED) if (this.DataLLMProvider is LLMProviders.SELF_HOSTED && this.DataHost is not Host.OLLAMA)
{ {
await this.ReloadModels(); await this.ReloadModels();
await base.OnInitializedAsync(); await base.OnInitializedAsync();
@ -149,7 +149,7 @@ public partial class ProviderDialog : ComponentBase
} }
// Load the API key: // Load the API key:
var requestedSecret = await this.RustService.GetAPIKey(provider); var requestedSecret = await this.RustService.GetAPIKey(provider, isTrying: this.DataLLMProvider is LLMProviders.SELF_HOSTED);
if(requestedSecret.Success) if(requestedSecret.Success)
{ {
this.dataAPIKey = await requestedSecret.Secret.Decrypt(this.encryption); this.dataAPIKey = await requestedSecret.Secret.Decrypt(this.encryption);
@ -159,8 +159,15 @@ public partial class ProviderDialog : ComponentBase
} }
else else
{ {
this.dataAPIKeyStorageIssue = $"Failed to load the API key from the operating system. The message was: {requestedSecret.Issue}. You might ignore this message and provide the API key again."; this.dataAPIKey = string.Empty;
await this.form.Validate(); if (this.DataLLMProvider is not LLMProviders.SELF_HOSTED)
{
this.dataAPIKeyStorageIssue = $"Failed to load the API key from the operating system. The message was: {requestedSecret.Issue}. You might ignore this message and provide the API key again.";
await this.form.Validate();
}
// We still try to load the models. Some local hosts don't need an API key:
await this.ReloadModels();
} }
} }
@ -192,7 +199,7 @@ public partial class ProviderDialog : ComponentBase
// Use the data model to store the provider. // Use the data model to store the provider.
// We just return this data to the parent component: // We just return this data to the parent component:
var addedProviderSettings = this.CreateProviderSettings(); var addedProviderSettings = this.CreateProviderSettings();
if (addedProviderSettings.UsedLLMProvider != LLMProviders.SELF_HOSTED) if (!string.IsNullOrWhiteSpace(this.dataAPIKey))
{ {
// We need to instantiate the provider to store the API key: // We need to instantiate the provider to store the API key:
var provider = addedProviderSettings.CreateProvider(this.Logger); var provider = addedProviderSettings.CreateProvider(this.Logger);
@ -363,10 +370,17 @@ public partial class ProviderDialog : ComponentBase
LLMProviders.ANTHROPIC => true, LLMProviders.ANTHROPIC => true,
LLMProviders.FIREWORKS => true, LLMProviders.FIREWORKS => true,
LLMProviders.SELF_HOSTED => this.DataHost is Host.OLLAMA,
_ => false, _ => false,
}; };
private string APIKeyText => this.DataLLMProvider switch
{
LLMProviders.SELF_HOSTED => "(Optional) API Key",
_ => "API Key",
};
private bool NeedHostname => this.DataLLMProvider switch private bool NeedHostname => this.DataLLMProvider switch
{ {
LLMProviders.SELF_HOSTED => true, LLMProviders.SELF_HOSTED => true,

View File

@ -1,3 +1,4 @@
using System.Net.Http.Headers;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
@ -23,6 +24,9 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<string> StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) public async IAsyncEnumerable<string> StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default)
{ {
// Get the API key:
var requestedSecret = await RUST_SERVICE.GetAPIKey(this, isTrying: true);
// Prepare the system prompt: // Prepare the system prompt:
var systemPrompt = new Message var systemPrompt = new Message
{ {
@ -62,68 +66,83 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide
MaxTokens = -1, MaxTokens = -1,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
// Build the HTTP post request: StreamReader? streamReader = default;
var request = new HttpRequestMessage(HttpMethod.Post, provider.Host.ChatURL()); try
// Set the content:
request.Content = new StringContent(providerChatRequest, Encoding.UTF8, "application/json");
// Send the request with the ResponseHeadersRead option.
// This allows us to read the stream as soon as the headers are received.
// This is important because we want to stream the responses.
var response = await this.httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, token);
// Open the response stream:
var providerStream = await response.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
var streamReader = new StreamReader(providerStream);
// Read the stream, line by line:
while(!streamReader.EndOfStream)
{ {
// Check if the token is canceled: // Build the HTTP post request:
if(token.IsCancellationRequested) var request = new HttpRequestMessage(HttpMethod.Post, provider.Host.ChatURL());
yield break;
// Read the next line: // Set the authorization header:
var line = await streamReader.ReadLineAsync(token); if (requestedSecret.Success)
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION));
// Skip empty lines: // Set the content:
if(string.IsNullOrWhiteSpace(line)) request.Content = new StringContent(providerChatRequest, Encoding.UTF8, "application/json");
continue;
// Skip lines that do not start with "data: ". Regard // Send the request with the ResponseHeadersRead option.
// to the specification, we only want to read the data lines: // This allows us to read the stream as soon as the headers are received.
if(!line.StartsWith("data: ", StringComparison.InvariantCulture)) // This is important because we want to stream the responses.
continue; var response = await this.httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, token);
// Check if the line is the end of the stream: // Open the response stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture)) var providerStream = await response.Content.ReadAsStreamAsync(token);
yield break;
ResponseStreamLine providerResponse; // Add a stream reader to read the stream, line by line:
try streamReader = new StreamReader(providerStream);
}
catch(Exception e)
{
this.logger.LogError($"Failed to stream chat completion from self-hosted provider '{this.InstanceName}': {e.Message}");
}
if (streamReader is not null)
{
// Read the stream, line by line:
while (!streamReader.EndOfStream)
{ {
// We know that the line starts with "data: ". Hence, we can // Check if the token is canceled:
// skip the first 6 characters to get the JSON data after that. if (token.IsCancellationRequested)
var jsonData = line[6..]; yield break;
// Deserialize the JSON data: // Read the next line:
providerResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS); var line = await streamReader.ReadLineAsync(token);
// Skip empty lines:
if (string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if (!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine providerResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
providerResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if (providerResponse == default || providerResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return providerResponse.Choices[0].Delta.Content;
} }
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(providerResponse == default || providerResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return providerResponse.Choices[0].Delta.Content;
} }
} }
@ -149,7 +168,21 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide
case Host.LM_STUDIO: case Host.LM_STUDIO:
case Host.OLLAMA: case Host.OLLAMA:
var secretKey = apiKeyProvisional switch
{
not null => apiKeyProvisional,
_ => await RUST_SERVICE.GetAPIKey(this, isTrying: true) switch
{
{ Success: true } result => await result.Secret.Decrypt(ENCRYPTION),
_ => null,
}
};
var lmStudioRequest = new HttpRequestMessage(HttpMethod.Get, "models"); var lmStudioRequest = new HttpRequestMessage(HttpMethod.Get, "models");
if(secretKey is not null)
lmStudioRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", apiKeyProvisional);
var lmStudioResponse = await this.httpClient.SendAsync(lmStudioRequest, token); var lmStudioResponse = await this.httpClient.SendAsync(lmStudioRequest, token);
if(!lmStudioResponse.IsSuccessStatusCode) if(!lmStudioResponse.IsSuccessStatusCode)
return []; return [];

View File

@ -1,3 +1,3 @@
namespace AIStudio.Tools.Rust; namespace AIStudio.Tools.Rust;
public readonly record struct SelectSecretRequest(string Destination, string UserName); public readonly record struct SelectSecretRequest(string Destination, string UserName, bool IsTrying);

View File

@ -258,19 +258,21 @@ public sealed class RustService : IDisposable
/// Try to get the API key for the given provider. /// Try to get the API key for the given provider.
/// </summary> /// </summary>
/// <param name="provider">The provider to get the API key for.</param> /// <param name="provider">The provider to get the API key for.</param>
/// <param name="isTrying">Indicates if we are trying to get the API key. In that case, we don't log errors.</param>
/// <returns>The requested secret.</returns> /// <returns>The requested secret.</returns>
public async Task<RequestedSecret> GetAPIKey(IProvider provider) public async Task<RequestedSecret> GetAPIKey(IProvider provider, bool isTrying = false)
{ {
var secretRequest = new SelectSecretRequest($"provider::{provider.Id}::{provider.InstanceName}::api_key", Environment.UserName); var secretRequest = new SelectSecretRequest($"provider::{provider.Id}::{provider.InstanceName}::api_key", Environment.UserName, isTrying);
var result = await this.http.PostAsJsonAsync("/secrets/get", secretRequest, this.jsonRustSerializerOptions); var result = await this.http.PostAsJsonAsync("/secrets/get", secretRequest, this.jsonRustSerializerOptions);
if (!result.IsSuccessStatusCode) if (!result.IsSuccessStatusCode)
{ {
this.logger!.LogError($"Failed to get the API key for provider '{provider.Id}' due to an API issue: '{result.StatusCode}'"); if(!isTrying)
this.logger!.LogError($"Failed to get the API key for provider '{provider.Id}' due to an API issue: '{result.StatusCode}'");
return new RequestedSecret(false, new EncryptedText(string.Empty), "Failed to get the API key due to an API issue."); return new RequestedSecret(false, new EncryptedText(string.Empty), "Failed to get the API key due to an API issue.");
} }
var secret = await result.Content.ReadFromJsonAsync<RequestedSecret>(this.jsonRustSerializerOptions); var secret = await result.Content.ReadFromJsonAsync<RequestedSecret>(this.jsonRustSerializerOptions);
if (!secret.Success) if (!secret.Success && !isTrying)
this.logger!.LogError($"Failed to get the API key for provider '{provider.Id}': '{secret.Issue}'"); this.logger!.LogError($"Failed to get the API key for provider '{provider.Id}': '{secret.Issue}'");
return secret; return secret;
@ -307,7 +309,7 @@ public sealed class RustService : IDisposable
/// <returns>The delete secret response.</returns> /// <returns>The delete secret response.</returns>
public async Task<DeleteSecretResponse> DeleteAPIKey(IProvider provider) public async Task<DeleteSecretResponse> DeleteAPIKey(IProvider provider)
{ {
var request = new SelectSecretRequest($"provider::{provider.Id}::{provider.InstanceName}::api_key", Environment.UserName); var request = new SelectSecretRequest($"provider::{provider.Id}::{provider.InstanceName}::api_key", Environment.UserName, false);
var result = await this.http.PostAsJsonAsync("/secrets/delete", request, this.jsonRustSerializerOptions); var result = await this.http.PostAsJsonAsync("/secrets/delete", request, this.jsonRustSerializerOptions);
if (!result.IsSuccessStatusCode) if (!result.IsSuccessStatusCode)
{ {

View File

@ -0,0 +1,2 @@
# v0.9.13, build 188 (2024-10-07 11:18 UTC)
- Allow the use of an API key for self-hosted `ollama` instances. Useful when using `ollama` with, e.g., Open WebUI.

View File

@ -1,9 +1,9 @@
0.9.12 0.9.13
2024-09-15 20:49:12 UTC 2024-10-07 11:18:05 UTC
187 188
8.0.108 (commit 665a05cea7) 8.0.108 (commit 665a05cea7)
8.0.8 (commit 08338fcaa5) 8.0.8 (commit 08338fcaa5)
1.81.0 (commit eeb90cda1) 1.81.0 (commit eeb90cda1)
7.8.0 7.8.0
1.7.1 1.7.1
8715054dda6, release 580ca9850b1, release

2
runtime/Cargo.lock generated
View File

@ -2130,7 +2130,7 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]] [[package]]
name = "mindwork-ai-studio" name = "mindwork-ai-studio"
version = "0.9.12" version = "0.9.13"
dependencies = [ dependencies = [
"aes", "aes",
"arboard", "arboard",

View File

@ -1,6 +1,6 @@
[package] [package]
name = "mindwork-ai-studio" name = "mindwork-ai-studio"
version = "0.9.12" version = "0.9.13"
edition = "2021" edition = "2021"
description = "MindWork AI Studio" description = "MindWork AI Studio"
authors = ["Thorsten Sommer"] authors = ["Thorsten Sommer"]

View File

@ -966,7 +966,10 @@ fn get_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<RequestedS
}, },
Err(e) => { Err(e) => {
error!(Source = "Secret Store"; "Failed to retrieve secret for '{service}' and user '{user_name}': {e}."); if !request.is_trying {
error!(Source = "Secret Store"; "Failed to retrieve secret for '{service}' and user '{user_name}': {e}.");
}
Json(RequestedSecret { Json(RequestedSecret {
success: false, success: false,
secret: EncryptedText::new(String::from("")), secret: EncryptedText::new(String::from("")),
@ -980,6 +983,7 @@ fn get_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<RequestedS
struct RequestSecret { struct RequestSecret {
destination: String, destination: String,
user_name: String, user_name: String,
is_trying: bool,
} }
#[derive(Serialize)] #[derive(Serialize)]

View File

@ -6,7 +6,7 @@
}, },
"package": { "package": {
"productName": "MindWork AI Studio", "productName": "MindWork AI Studio",
"version": "0.9.12" "version": "0.9.13"
}, },
"tauri": { "tauri": {
"allowlist": { "allowlist": {