From dc1f9efd1fd27c077620c4018d39ae6a0db15e82 Mon Sep 17 00:00:00 2001 From: Thorsten Sommer Date: Mon, 2 Dec 2024 19:51:10 +0100 Subject: [PATCH] Refactored secret handling --- .../Provider/Anthropic/ProviderAnthropic.cs | 16 ++++---- .../Provider/BaseProvider.cs | 37 ++++++++++++++++++- .../Provider/Fireworks/ProviderFireworks.cs | 16 ++++---- .../Provider/Google/ProviderGoogle.cs | 16 ++++---- .../Provider/Groq/ProviderGroq.cs | 16 ++++---- .../Provider/LLMProvidersExtensions.cs | 32 ++++++++++++---- .../Provider/Mistral/ProviderMistral.cs | 16 ++++---- .../Provider/OpenAI/ProviderOpenAI.cs | 16 ++++---- .../Provider/SelfHosted/ProviderSelfHosted.cs | 20 +++++----- app/MindWork AI Studio/Settings/Provider.cs | 16 +++++++- app/MindWork AI Studio/Tools/ISecretId.cs | 17 +++++++++ app/MindWork AI Studio/Tools/RustService.cs | 37 +++++++++---------- 12 files changed, 168 insertions(+), 87 deletions(-) create mode 100644 app/MindWork AI Studio/Tools/ISecretId.cs diff --git a/app/MindWork AI Studio/Provider/Anthropic/ProviderAnthropic.cs b/app/MindWork AI Studio/Provider/Anthropic/ProviderAnthropic.cs index f0903ded..dc9767dc 100644 --- a/app/MindWork AI Studio/Provider/Anthropic/ProviderAnthropic.cs +++ b/app/MindWork AI Studio/Provider/Anthropic/ProviderAnthropic.cs @@ -7,7 +7,7 @@ using AIStudio.Provider.OpenAI; namespace AIStudio.Provider.Anthropic; -public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://api.anthropic.com/v1/", logger), IProvider +public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://api.anthropic.com/v1/", logger) { private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() { @@ -16,12 +16,12 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap #region Implementation of IProvider - public string Id => "Anthropic"; + public override string Id => LLMProviders.ANTHROPIC.ToName(); - public string InstanceName { get; set; } = "Anthropic"; + public override string InstanceName { get; set; } = "Anthropic"; /// - public async IAsyncEnumerable StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) + public override async IAsyncEnumerable StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) { // Get the API key: var requestedSecret = await RUST_SERVICE.GetAPIKey(this); @@ -136,14 +136,14 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously /// - public async IAsyncEnumerable StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) + public override async IAsyncEnumerable StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) { yield break; } #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously /// - public Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) { return Task.FromResult(new[] { @@ -163,13 +163,13 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap } /// - public Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) { return Task.FromResult(Enumerable.Empty()); } /// - public Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) { return Task.FromResult(Enumerable.Empty()); } diff --git a/app/MindWork AI Studio/Provider/BaseProvider.cs b/app/MindWork AI Studio/Provider/BaseProvider.cs index 9ae1e2a4..e52d0c20 100644 --- a/app/MindWork AI Studio/Provider/BaseProvider.cs +++ b/app/MindWork AI Studio/Provider/BaseProvider.cs @@ -1,3 +1,5 @@ +using AIStudio.Chat; + using RustService = AIStudio.Tools.RustService; namespace AIStudio.Provider; @@ -5,7 +7,7 @@ namespace AIStudio.Provider; /// /// The base class for all providers. /// -public abstract class BaseProvider +public abstract class BaseProvider : IProvider, ISecretId { /// /// The HTTP client to use it for all requests. @@ -39,4 +41,37 @@ public abstract class BaseProvider // Set the base URL: this.httpClient.BaseAddress = new(url); } + + #region Handling of IProvider, which all providers must implement + + /// + public abstract string Id { get; } + + /// + public abstract string InstanceName { get; set; } + + /// + public abstract IAsyncEnumerable StreamChatCompletion(Model chatModel, ChatThread chatThread, CancellationToken token = default); + + /// + public abstract IAsyncEnumerable StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, CancellationToken token = default); + + /// + public abstract Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default); + + /// + public abstract Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default); + + /// + public abstract Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default); + + #endregion + + #region Implementation of ISecretId + + public string SecretId => this.Id; + + public string SecretName => this.InstanceName; + + #endregion } \ No newline at end of file diff --git a/app/MindWork AI Studio/Provider/Fireworks/ProviderFireworks.cs b/app/MindWork AI Studio/Provider/Fireworks/ProviderFireworks.cs index c60e3fe6..709aad15 100644 --- a/app/MindWork AI Studio/Provider/Fireworks/ProviderFireworks.cs +++ b/app/MindWork AI Studio/Provider/Fireworks/ProviderFireworks.cs @@ -7,7 +7,7 @@ using AIStudio.Chat; namespace AIStudio.Provider.Fireworks; -public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.fireworks.ai/inference/v1/", logger), IProvider +public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.fireworks.ai/inference/v1/", logger) { private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() { @@ -17,13 +17,13 @@ public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.firew #region Implementation of IProvider /// - public string Id => "Fireworks.ai"; + public override string Id => LLMProviders.FIREWORKS.ToName(); /// - public string InstanceName { get; set; } = "Fireworks.ai"; + public override string InstanceName { get; set; } = "Fireworks.ai"; /// - public async IAsyncEnumerable StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) + public override async IAsyncEnumerable StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) { // Get the API key: var requestedSecret = await RUST_SERVICE.GetAPIKey(this); @@ -138,26 +138,26 @@ public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.firew #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously /// - public async IAsyncEnumerable StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) + public override async IAsyncEnumerable StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) { yield break; } #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously /// - public Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) { return Task.FromResult(Enumerable.Empty()); } /// - public Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) { return Task.FromResult(Enumerable.Empty()); } /// - public Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) { return Task.FromResult(Enumerable.Empty()); } diff --git a/app/MindWork AI Studio/Provider/Google/ProviderGoogle.cs b/app/MindWork AI Studio/Provider/Google/ProviderGoogle.cs index 226d4cb1..6ca6d923 100644 --- a/app/MindWork AI Studio/Provider/Google/ProviderGoogle.cs +++ b/app/MindWork AI Studio/Provider/Google/ProviderGoogle.cs @@ -8,7 +8,7 @@ using AIStudio.Provider.OpenAI; namespace AIStudio.Provider.Google; -public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativelanguage.googleapis.com/v1beta/", logger), IProvider +public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativelanguage.googleapis.com/v1beta/", logger) { private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() { @@ -18,13 +18,13 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela #region Implementation of IProvider /// - public string Id => "Google"; + public override string Id => LLMProviders.GOOGLE.ToName(); /// - public string InstanceName { get; set; } = "Google Gemini"; + public override string InstanceName { get; set; } = "Google Gemini"; /// - public async IAsyncEnumerable StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) + public override async IAsyncEnumerable StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) { // Get the API key: var requestedSecret = await RUST_SERVICE.GetAPIKey(this); @@ -139,14 +139,14 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously /// - public async IAsyncEnumerable StreamImageCompletion(Provider.Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) + public override async IAsyncEnumerable StreamImageCompletion(Provider.Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) { yield break; } #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously /// - public async Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override async Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) { var modelResponse = await this.LoadModels(token, apiKeyProvisional); if(modelResponse == default) @@ -158,12 +158,12 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela } /// - public Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) { return Task.FromResult(Enumerable.Empty()); } - public async Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override async Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) { var modelResponse = await this.LoadModels(token, apiKeyProvisional); if(modelResponse == default) diff --git a/app/MindWork AI Studio/Provider/Groq/ProviderGroq.cs b/app/MindWork AI Studio/Provider/Groq/ProviderGroq.cs index 703a449f..477f9a0f 100644 --- a/app/MindWork AI Studio/Provider/Groq/ProviderGroq.cs +++ b/app/MindWork AI Studio/Provider/Groq/ProviderGroq.cs @@ -8,7 +8,7 @@ using AIStudio.Provider.OpenAI; namespace AIStudio.Provider.Groq; -public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/openai/v1/", logger), IProvider +public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/openai/v1/", logger) { private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() { @@ -18,13 +18,13 @@ public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/o #region Implementation of IProvider /// - public string Id => "Groq"; + public override string Id => LLMProviders.GROQ.ToName(); /// - public string InstanceName { get; set; } = "Groq"; + public override string InstanceName { get; set; } = "Groq"; /// - public async IAsyncEnumerable StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) + public override async IAsyncEnumerable StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) { // Get the API key: var requestedSecret = await RUST_SERVICE.GetAPIKey(this); @@ -141,26 +141,26 @@ public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/o #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously /// - public async IAsyncEnumerable StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) + public override async IAsyncEnumerable StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) { yield break; } #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously /// - public Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) { return this.LoadModels(token, apiKeyProvisional); } /// - public Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) { return Task.FromResult>(Array.Empty()); } /// - public Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) { return Task.FromResult(Enumerable.Empty()); } diff --git a/app/MindWork AI Studio/Provider/LLMProvidersExtensions.cs b/app/MindWork AI Studio/Provider/LLMProvidersExtensions.cs index 725c16ad..cfb98438 100644 --- a/app/MindWork AI Studio/Provider/LLMProvidersExtensions.cs +++ b/app/MindWork AI Studio/Provider/LLMProvidersExtensions.cs @@ -103,20 +103,36 @@ public static class LLMProvidersExtensions /// The logger to use. /// The provider instance. public static IProvider CreateProvider(this Settings.Provider providerSettings, ILogger logger) + { + return providerSettings.UsedLLMProvider.CreateProvider(providerSettings.InstanceName, providerSettings.Host, providerSettings.Hostname, logger); + } + + /// + /// Creates a new provider instance based on the embedding provider value. + /// + /// The embedding provider settings. + /// The logger to use. + /// The provider instance. + public static IProvider CreateProvider(this EmbeddingProvider embeddingProviderSettings, ILogger logger) + { + return embeddingProviderSettings.UsedLLMProvider.CreateProvider(embeddingProviderSettings.Name, embeddingProviderSettings.Host, embeddingProviderSettings.Hostname, logger); + } + + private static IProvider CreateProvider(this LLMProviders provider, string instanceName, Host host, string hostname, ILogger logger) { try { - return providerSettings.UsedLLMProvider switch + return provider switch { - LLMProviders.OPEN_AI => new ProviderOpenAI(logger) { InstanceName = providerSettings.InstanceName }, - LLMProviders.ANTHROPIC => new ProviderAnthropic(logger) { InstanceName = providerSettings.InstanceName }, - LLMProviders.MISTRAL => new ProviderMistral(logger) { InstanceName = providerSettings.InstanceName }, - LLMProviders.GOOGLE => new ProviderGoogle(logger) { InstanceName = providerSettings.InstanceName }, + LLMProviders.OPEN_AI => new ProviderOpenAI(logger) { InstanceName = instanceName }, + LLMProviders.ANTHROPIC => new ProviderAnthropic(logger) { InstanceName = instanceName }, + LLMProviders.MISTRAL => new ProviderMistral(logger) { InstanceName = instanceName }, + LLMProviders.GOOGLE => new ProviderGoogle(logger) { InstanceName = instanceName }, - LLMProviders.GROQ => new ProviderGroq(logger) { InstanceName = providerSettings.InstanceName }, - LLMProviders.FIREWORKS => new ProviderFireworks(logger) { InstanceName = providerSettings.InstanceName }, + LLMProviders.GROQ => new ProviderGroq(logger) { InstanceName = instanceName }, + LLMProviders.FIREWORKS => new ProviderFireworks(logger) { InstanceName = instanceName }, - LLMProviders.SELF_HOSTED => new ProviderSelfHosted(logger, providerSettings) { InstanceName = providerSettings.InstanceName }, + LLMProviders.SELF_HOSTED => new ProviderSelfHosted(logger, host, hostname) { InstanceName = instanceName }, _ => new NoProvider(), }; diff --git a/app/MindWork AI Studio/Provider/Mistral/ProviderMistral.cs b/app/MindWork AI Studio/Provider/Mistral/ProviderMistral.cs index 6d2a10b9..633fa94b 100644 --- a/app/MindWork AI Studio/Provider/Mistral/ProviderMistral.cs +++ b/app/MindWork AI Studio/Provider/Mistral/ProviderMistral.cs @@ -8,7 +8,7 @@ using AIStudio.Provider.OpenAI; namespace AIStudio.Provider.Mistral; -public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.mistral.ai/v1/", logger), IProvider +public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.mistral.ai/v1/", logger) { private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() { @@ -17,12 +17,12 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api. #region Implementation of IProvider - public string Id => "Mistral"; + public override string Id => LLMProviders.MISTRAL.ToName(); - public string InstanceName { get; set; } = "Mistral"; + public override string InstanceName { get; set; } = "Mistral"; /// - public async IAsyncEnumerable StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) + public override async IAsyncEnumerable StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) { // Get the API key: var requestedSecret = await RUST_SERVICE.GetAPIKey(this); @@ -140,14 +140,14 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api. #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously /// - public async IAsyncEnumerable StreamImageCompletion(Provider.Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) + public override async IAsyncEnumerable StreamImageCompletion(Provider.Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) { yield break; } #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously /// - public async Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override async Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) { var modelResponse = await this.LoadModelList(apiKeyProvisional, token); if(modelResponse == default) @@ -160,7 +160,7 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api. } /// - public async Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override async Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) { var modelResponse = await this.LoadModelList(apiKeyProvisional, token); if(modelResponse == default) @@ -171,7 +171,7 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api. } /// - public Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) { return Task.FromResult(Enumerable.Empty()); } diff --git a/app/MindWork AI Studio/Provider/OpenAI/ProviderOpenAI.cs b/app/MindWork AI Studio/Provider/OpenAI/ProviderOpenAI.cs index bae64478..2f1c25ad 100644 --- a/app/MindWork AI Studio/Provider/OpenAI/ProviderOpenAI.cs +++ b/app/MindWork AI Studio/Provider/OpenAI/ProviderOpenAI.cs @@ -10,7 +10,7 @@ namespace AIStudio.Provider.OpenAI; /// /// The OpenAI provider. /// -public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.openai.com/v1/", logger), IProvider +public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.openai.com/v1/", logger) { private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() { @@ -20,13 +20,13 @@ public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.o #region Implementation of IProvider /// - public string Id => "OpenAI"; + public override string Id => LLMProviders.OPEN_AI.ToName(); /// - public string InstanceName { get; set; } = "OpenAI"; + public override string InstanceName { get; set; } = "OpenAI"; /// - public async IAsyncEnumerable StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) + public override async IAsyncEnumerable StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) { // Get the API key: var requestedSecret = await RUST_SERVICE.GetAPIKey(this); @@ -144,26 +144,26 @@ public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.o #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously /// - public async IAsyncEnumerable StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) + public override async IAsyncEnumerable StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) { yield break; } #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously /// - public Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) { return this.LoadModels(["gpt-", "o1-"], token, apiKeyProvisional); } /// - public Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) { return this.LoadModels(["dall-e-"], token, apiKeyProvisional); } /// - public Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) { return this.LoadModels(["text-embedding-"], token, apiKeyProvisional); } diff --git a/app/MindWork AI Studio/Provider/SelfHosted/ProviderSelfHosted.cs b/app/MindWork AI Studio/Provider/SelfHosted/ProviderSelfHosted.cs index e1bbfbae..8845b73d 100644 --- a/app/MindWork AI Studio/Provider/SelfHosted/ProviderSelfHosted.cs +++ b/app/MindWork AI Studio/Provider/SelfHosted/ProviderSelfHosted.cs @@ -8,7 +8,7 @@ using AIStudio.Provider.OpenAI; namespace AIStudio.Provider.SelfHosted; -public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provider) : BaseProvider($"{provider.Hostname}{provider.Host.BaseURL()}", logger), IProvider +public sealed class ProviderSelfHosted(ILogger logger, Host host, string hostname) : BaseProvider($"{hostname}{host.BaseURL()}", logger) { private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() { @@ -17,12 +17,12 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide #region Implementation of IProvider - public string Id => "Self-hosted"; + public override string Id => LLMProviders.SELF_HOSTED.ToName(); - public string InstanceName { get; set; } = "Self-hosted"; + public override string InstanceName { get; set; } = "Self-hosted"; /// - public async IAsyncEnumerable StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) + public override async IAsyncEnumerable StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) { // Get the API key: var requestedSecret = await RUST_SERVICE.GetAPIKey(this, isTrying: true); @@ -70,7 +70,7 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide try { // Build the HTTP post request: - var request = new HttpRequestMessage(HttpMethod.Post, provider.Host.ChatURL()); + var request = new HttpRequestMessage(HttpMethod.Post, host.ChatURL()); // Set the authorization header: if (requestedSecret.Success) @@ -148,18 +148,18 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously /// - public async IAsyncEnumerable StreamImageCompletion(Provider.Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) + public override async IAsyncEnumerable StreamImageCompletion(Provider.Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) { yield break; } #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously - public async Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override async Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) { try { - switch (provider.Host) + switch (host) { case Host.LLAMACPP: // Right now, llama.cpp only supports one model. @@ -201,12 +201,12 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide } /// - public Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) { return Task.FromResult(Enumerable.Empty()); } - public Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) + public override Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) { return Task.FromResult(Enumerable.Empty()); } diff --git a/app/MindWork AI Studio/Settings/Provider.cs b/app/MindWork AI Studio/Settings/Provider.cs index 6279cf75..b349016d 100644 --- a/app/MindWork AI Studio/Settings/Provider.cs +++ b/app/MindWork AI Studio/Settings/Provider.cs @@ -1,3 +1,5 @@ +using System.Text.Json.Serialization; + using AIStudio.Provider; using Host = AIStudio.Provider.SelfHosted.Host; @@ -22,7 +24,7 @@ public readonly record struct Provider( Model Model, bool IsSelfHosted = false, string Hostname = "http://localhost:1234", - Host Host = Host.NONE) + Host Host = Host.NONE) : ISecretId { #region Overrides of ValueType @@ -40,4 +42,16 @@ public readonly record struct Provider( } #endregion + + #region Implementation of ISecretId + + /// + [JsonIgnore] + public string SecretId => this.Id; + + /// + [JsonIgnore] + public string SecretName => this.InstanceName; + + #endregion } \ No newline at end of file diff --git a/app/MindWork AI Studio/Tools/ISecretId.cs b/app/MindWork AI Studio/Tools/ISecretId.cs new file mode 100644 index 00000000..c1198913 --- /dev/null +++ b/app/MindWork AI Studio/Tools/ISecretId.cs @@ -0,0 +1,17 @@ +namespace AIStudio.Tools; + +/// +/// Represents an interface defining a secret identifier. +/// +public interface ISecretId +{ + /// + /// The unique ID of the secret. + /// + public string SecretId { get; } + + /// + /// The instance name of the secret. + /// + public string SecretName { get; } +} \ No newline at end of file diff --git a/app/MindWork AI Studio/Tools/RustService.cs b/app/MindWork AI Studio/Tools/RustService.cs index bf40e0d6..a38cf5fb 100644 --- a/app/MindWork AI Studio/Tools/RustService.cs +++ b/app/MindWork AI Studio/Tools/RustService.cs @@ -1,7 +1,6 @@ using System.Security.Cryptography; using System.Text.Json; -using AIStudio.Provider; using AIStudio.Tools.Rust; // ReSharper disable NotAccessedPositionalProperty.Local @@ -255,71 +254,71 @@ public sealed class RustService : IDisposable } /// - /// Try to get the API key for the given provider. + /// Try to get the API key for the given secret ID. /// - /// The provider to get the API key for. + /// The secret ID to get the API key for. /// Indicates if we are trying to get the API key. In that case, we don't log errors. /// The requested secret. - public async Task GetAPIKey(IProvider provider, bool isTrying = false) + public async Task GetAPIKey(ISecretId secretId, bool isTrying = false) { - var secretRequest = new SelectSecretRequest($"provider::{provider.Id}::{provider.InstanceName}::api_key", Environment.UserName, isTrying); + var secretRequest = new SelectSecretRequest($"provider::{secretId.SecretId}::{secretId.SecretName}::api_key", Environment.UserName, isTrying); var result = await this.http.PostAsJsonAsync("/secrets/get", secretRequest, this.jsonRustSerializerOptions); if (!result.IsSuccessStatusCode) { if(!isTrying) - this.logger!.LogError($"Failed to get the API key for provider '{provider.Id}' due to an API issue: '{result.StatusCode}'"); + this.logger!.LogError($"Failed to get the API key for secret ID '{secretId.SecretId}' due to an API issue: '{result.StatusCode}'"); return new RequestedSecret(false, new EncryptedText(string.Empty), "Failed to get the API key due to an API issue."); } var secret = await result.Content.ReadFromJsonAsync(this.jsonRustSerializerOptions); if (!secret.Success && !isTrying) - this.logger!.LogError($"Failed to get the API key for provider '{provider.Id}': '{secret.Issue}'"); + this.logger!.LogError($"Failed to get the API key for secret ID '{secretId.SecretId}': '{secret.Issue}'"); return secret; } /// - /// Try to store the API key for the given provider. + /// Try to store the API key for the given secret ID. /// - /// The provider to store the API key for. + /// The secret ID to store the API key for. /// The API key to store. /// The store secret response. - public async Task SetAPIKey(IProvider provider, string key) + public async Task SetAPIKey(ISecretId secretId, string key) { var encryptedKey = await this.encryptor!.Encrypt(key); - var request = new StoreSecretRequest($"provider::{provider.Id}::{provider.InstanceName}::api_key", Environment.UserName, encryptedKey); + var request = new StoreSecretRequest($"provider::{secretId.SecretId}::{secretId.SecretName}::api_key", Environment.UserName, encryptedKey); var result = await this.http.PostAsJsonAsync("/secrets/store", request, this.jsonRustSerializerOptions); if (!result.IsSuccessStatusCode) { - this.logger!.LogError($"Failed to store the API key for provider '{provider.Id}' due to an API issue: '{result.StatusCode}'"); + this.logger!.LogError($"Failed to store the API key for secret ID '{secretId.SecretId}' due to an API issue: '{result.StatusCode}'"); return new StoreSecretResponse(false, "Failed to get the API key due to an API issue."); } var state = await result.Content.ReadFromJsonAsync(this.jsonRustSerializerOptions); if (!state.Success) - this.logger!.LogError($"Failed to store the API key for provider '{provider.Id}': '{state.Issue}'"); + this.logger!.LogError($"Failed to store the API key for secret ID '{secretId.SecretId}': '{state.Issue}'"); return state; } /// - /// Tries to delete the API key for the given provider. + /// Tries to delete the API key for the given secret ID. /// - /// The provider to delete the API key for. + /// The secret ID to delete the API key for. /// The delete secret response. - public async Task DeleteAPIKey(IProvider provider) + public async Task DeleteAPIKey(ISecretId secretId) { - var request = new SelectSecretRequest($"provider::{provider.Id}::{provider.InstanceName}::api_key", Environment.UserName, false); + var request = new SelectSecretRequest($"provider::{secretId.SecretId}::{secretId.SecretName}::api_key", Environment.UserName, false); var result = await this.http.PostAsJsonAsync("/secrets/delete", request, this.jsonRustSerializerOptions); if (!result.IsSuccessStatusCode) { - this.logger!.LogError($"Failed to delete the API key for provider '{provider.Id}' due to an API issue: '{result.StatusCode}'"); + this.logger!.LogError($"Failed to delete the API key for secret ID '{secretId.SecretId}' due to an API issue: '{result.StatusCode}'"); return new DeleteSecretResponse{Success = false, WasEntryFound = false, Issue = "Failed to delete the API key due to an API issue."}; } var state = await result.Content.ReadFromJsonAsync(this.jsonRustSerializerOptions); if (!state.Success) - this.logger!.LogError($"Failed to delete the API key for provider '{provider.Id}': '{state.Issue}'"); + this.logger!.LogError($"Failed to delete the API key for secret ID '{secretId.SecretId}': '{state.Issue}'"); return state; }