Allow the use of an API key for self-hosted ollama instances (#156)

This commit is contained in:
Thorsten Sommer 2024-10-07 13:26:25 +02:00 committed by GitHub
parent 776fa8ac58
commit 37e113af0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 134 additions and 78 deletions

View File

@ -13,6 +13,7 @@ public partial class Changelog
public static readonly Log[] LOGS =
[
new (188, "v0.9.13, build 188 (2024-10-07 11:18 UTC)", "v0.9.13.md"),
new (187, "v0.9.12, build 187 (2024-09-15 20:49 UTC)", "v0.9.12.md"),
new (186, "v0.9.11, build 186 (2024-09-15 10:33 UTC)", "v0.9.11.md"),
new (185, "v0.9.10, build 185 (2024-09-12 20:52 UTC)", "v0.9.10.md"),

View File

@ -19,7 +19,7 @@
<MudTextField
T="string"
@bind-Text="@this.dataAPIKey"
Label="API Key"
Label="@this.APIKeyText"
Disabled="@(!this.NeedAPIKey)"
Class="mb-3"
Adornment="Adornment.Start"

View File

@ -133,7 +133,7 @@ public partial class ProviderDialog : ComponentBase
//
// We cannot load the API key for self-hosted providers:
//
if (this.DataLLMProvider is LLMProviders.SELF_HOSTED)
if (this.DataLLMProvider is LLMProviders.SELF_HOSTED && this.DataHost is not Host.OLLAMA)
{
await this.ReloadModels();
await base.OnInitializedAsync();
@ -149,7 +149,7 @@ public partial class ProviderDialog : ComponentBase
}
// Load the API key:
var requestedSecret = await this.RustService.GetAPIKey(provider);
var requestedSecret = await this.RustService.GetAPIKey(provider, isTrying: this.DataLLMProvider is LLMProviders.SELF_HOSTED);
if(requestedSecret.Success)
{
this.dataAPIKey = await requestedSecret.Secret.Decrypt(this.encryption);
@ -159,8 +159,15 @@ public partial class ProviderDialog : ComponentBase
}
else
{
this.dataAPIKeyStorageIssue = $"Failed to load the API key from the operating system. The message was: {requestedSecret.Issue}. You might ignore this message and provide the API key again.";
await this.form.Validate();
this.dataAPIKey = string.Empty;
if (this.DataLLMProvider is not LLMProviders.SELF_HOSTED)
{
this.dataAPIKeyStorageIssue = $"Failed to load the API key from the operating system. The message was: {requestedSecret.Issue}. You might ignore this message and provide the API key again.";
await this.form.Validate();
}
// We still try to load the models. Some local hosts don't need an API key:
await this.ReloadModels();
}
}
@ -192,7 +199,7 @@ public partial class ProviderDialog : ComponentBase
// Use the data model to store the provider.
// We just return this data to the parent component:
var addedProviderSettings = this.CreateProviderSettings();
if (addedProviderSettings.UsedLLMProvider != LLMProviders.SELF_HOSTED)
if (!string.IsNullOrWhiteSpace(this.dataAPIKey))
{
// We need to instantiate the provider to store the API key:
var provider = addedProviderSettings.CreateProvider(this.Logger);
@ -363,9 +370,16 @@ public partial class ProviderDialog : ComponentBase
LLMProviders.ANTHROPIC => true,
LLMProviders.FIREWORKS => true,
LLMProviders.SELF_HOSTED => this.DataHost is Host.OLLAMA,
_ => false,
};
private string APIKeyText => this.DataLLMProvider switch
{
LLMProviders.SELF_HOSTED => "(Optional) API Key",
_ => "API Key",
};
private bool NeedHostname => this.DataLLMProvider switch
{

View File

@ -1,3 +1,4 @@
using System.Net.Http.Headers;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
@ -23,6 +24,9 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide
/// <inheritdoc />
public async IAsyncEnumerable<string> StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default)
{
// Get the API key:
var requestedSecret = await RUST_SERVICE.GetAPIKey(this, isTrying: true);
// Prepare the system prompt:
var systemPrompt = new Message
{
@ -62,68 +66,83 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide
MaxTokens = -1,
}, JSON_SERIALIZER_OPTIONS);
// Build the HTTP post request:
var request = new HttpRequestMessage(HttpMethod.Post, provider.Host.ChatURL());
// Set the content:
request.Content = new StringContent(providerChatRequest, Encoding.UTF8, "application/json");
// Send the request with the ResponseHeadersRead option.
// This allows us to read the stream as soon as the headers are received.
// This is important because we want to stream the responses.
var response = await this.httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, token);
// Open the response stream:
var providerStream = await response.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
var streamReader = new StreamReader(providerStream);
// Read the stream, line by line:
while(!streamReader.EndOfStream)
StreamReader? streamReader = default;
try
{
// Check if the token is canceled:
if(token.IsCancellationRequested)
yield break;
// Read the next line:
var line = await streamReader.ReadLineAsync(token);
// Skip empty lines:
if(string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if(!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Build the HTTP post request:
var request = new HttpRequestMessage(HttpMethod.Post, provider.Host.ChatURL());
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
// Set the authorization header:
if (requestedSecret.Success)
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION));
ResponseStreamLine providerResponse;
try
// Set the content:
request.Content = new StringContent(providerChatRequest, Encoding.UTF8, "application/json");
// Send the request with the ResponseHeadersRead option.
// This allows us to read the stream as soon as the headers are received.
// This is important because we want to stream the responses.
var response = await this.httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, token);
// Open the response stream:
var providerStream = await response.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(providerStream);
}
catch(Exception e)
{
this.logger.LogError($"Failed to stream chat completion from self-hosted provider '{this.InstanceName}': {e.Message}");
}
if (streamReader is not null)
{
// Read the stream, line by line:
while (!streamReader.EndOfStream)
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
providerResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
// Check if the token is canceled:
if (token.IsCancellationRequested)
yield break;
// Read the next line:
var line = await streamReader.ReadLineAsync(token);
// Skip empty lines:
if (string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if (!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine providerResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
providerResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if (providerResponse == default || providerResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return providerResponse.Choices[0].Delta.Content;
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(providerResponse == default || providerResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return providerResponse.Choices[0].Delta.Content;
}
}
@ -149,7 +168,21 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide
case Host.LM_STUDIO:
case Host.OLLAMA:
var secretKey = apiKeyProvisional switch
{
not null => apiKeyProvisional,
_ => await RUST_SERVICE.GetAPIKey(this, isTrying: true) switch
{
{ Success: true } result => await result.Secret.Decrypt(ENCRYPTION),
_ => null,
}
};
var lmStudioRequest = new HttpRequestMessage(HttpMethod.Get, "models");
if(secretKey is not null)
lmStudioRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", apiKeyProvisional);
var lmStudioResponse = await this.httpClient.SendAsync(lmStudioRequest, token);
if(!lmStudioResponse.IsSuccessStatusCode)
return [];

View File

@ -1,3 +1,3 @@
namespace AIStudio.Tools.Rust;
public readonly record struct SelectSecretRequest(string Destination, string UserName);
public readonly record struct SelectSecretRequest(string Destination, string UserName, bool IsTrying);

View File

@ -253,24 +253,26 @@ public sealed class RustService : IDisposable
throw;
}
}
/// <summary>
/// Try to get the API key for the given provider.
/// </summary>
/// <param name="provider">The provider to get the API key for.</param>
/// <param name="isTrying">Indicates if we are trying to get the API key. In that case, we don't log errors.</param>
/// <returns>The requested secret.</returns>
public async Task<RequestedSecret> GetAPIKey(IProvider provider)
public async Task<RequestedSecret> GetAPIKey(IProvider provider, bool isTrying = false)
{
var secretRequest = new SelectSecretRequest($"provider::{provider.Id}::{provider.InstanceName}::api_key", Environment.UserName);
var secretRequest = new SelectSecretRequest($"provider::{provider.Id}::{provider.InstanceName}::api_key", Environment.UserName, isTrying);
var result = await this.http.PostAsJsonAsync("/secrets/get", secretRequest, this.jsonRustSerializerOptions);
if (!result.IsSuccessStatusCode)
{
this.logger!.LogError($"Failed to get the API key for provider '{provider.Id}' due to an API issue: '{result.StatusCode}'");
if(!isTrying)
this.logger!.LogError($"Failed to get the API key for provider '{provider.Id}' due to an API issue: '{result.StatusCode}'");
return new RequestedSecret(false, new EncryptedText(string.Empty), "Failed to get the API key due to an API issue.");
}
var secret = await result.Content.ReadFromJsonAsync<RequestedSecret>(this.jsonRustSerializerOptions);
if (!secret.Success)
if (!secret.Success && !isTrying)
this.logger!.LogError($"Failed to get the API key for provider '{provider.Id}': '{secret.Issue}'");
return secret;
@ -307,7 +309,7 @@ public sealed class RustService : IDisposable
/// <returns>The delete secret response.</returns>
public async Task<DeleteSecretResponse> DeleteAPIKey(IProvider provider)
{
var request = new SelectSecretRequest($"provider::{provider.Id}::{provider.InstanceName}::api_key", Environment.UserName);
var request = new SelectSecretRequest($"provider::{provider.Id}::{provider.InstanceName}::api_key", Environment.UserName, false);
var result = await this.http.PostAsJsonAsync("/secrets/delete", request, this.jsonRustSerializerOptions);
if (!result.IsSuccessStatusCode)
{

View File

@ -0,0 +1,2 @@
# v0.9.13, build 188 (2024-10-07 11:18 UTC)
- Allow the use of an API key for self-hosted `ollama` instances. Useful when using `ollama` with, e.g., Open WebUI.

View File

@ -1,9 +1,9 @@
0.9.12
2024-09-15 20:49:12 UTC
187
0.9.13
2024-10-07 11:18:05 UTC
188
8.0.108 (commit 665a05cea7)
8.0.8 (commit 08338fcaa5)
1.81.0 (commit eeb90cda1)
7.8.0
1.7.1
8715054dda6, release
580ca9850b1, release

2
runtime/Cargo.lock generated
View File

@ -2130,7 +2130,7 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mindwork-ai-studio"
version = "0.9.12"
version = "0.9.13"
dependencies = [
"aes",
"arboard",

View File

@ -1,6 +1,6 @@
[package]
name = "mindwork-ai-studio"
version = "0.9.12"
version = "0.9.13"
edition = "2021"
description = "MindWork AI Studio"
authors = ["Thorsten Sommer"]

View File

@ -966,7 +966,10 @@ fn get_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<RequestedS
},
Err(e) => {
error!(Source = "Secret Store"; "Failed to retrieve secret for '{service}' and user '{user_name}': {e}.");
if !request.is_trying {
error!(Source = "Secret Store"; "Failed to retrieve secret for '{service}' and user '{user_name}': {e}.");
}
Json(RequestedSecret {
success: false,
secret: EncryptedText::new(String::from("")),
@ -980,6 +983,7 @@ fn get_secret(_token: APIToken, request: Json<RequestSecret>) -> Json<RequestedS
struct RequestSecret {
destination: String,
user_name: String,
is_trying: bool,
}
#[derive(Serialize)]

View File

@ -6,7 +6,7 @@
},
"package": {
"productName": "MindWork AI Studio",
"version": "0.9.12"
"version": "0.9.13"
},
"tauri": {
"allowlist": {