Refactored secret handling

This commit is contained in:
Thorsten Sommer 2024-12-02 19:51:10 +01:00
parent 311ad20114
commit dc1f9efd1f
Signed by: tsommer
GPG Key ID: 371BBA77A02C0108
12 changed files with 168 additions and 87 deletions

View File

@ -7,7 +7,7 @@ using AIStudio.Provider.OpenAI;
namespace AIStudio.Provider.Anthropic; namespace AIStudio.Provider.Anthropic;
public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://api.anthropic.com/v1/", logger), IProvider public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://api.anthropic.com/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{ {
@ -16,12 +16,12 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap
#region Implementation of IProvider #region Implementation of IProvider
public string Id => "Anthropic"; public override string Id => LLMProviders.ANTHROPIC.ToName();
public string InstanceName { get; set; } = "Anthropic"; public override string InstanceName { get; set; } = "Anthropic";
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) public override async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default)
{ {
// Get the API key: // Get the API key:
var requestedSecret = await RUST_SERVICE.GetAPIKey(this); var requestedSecret = await RUST_SERVICE.GetAPIKey(this);
@ -136,14 +136,14 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<ImageURL> StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) public override async IAsyncEnumerable<ImageURL> StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default)
{ {
yield break; yield break;
} }
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return Task.FromResult(new[] return Task.FromResult(new[]
{ {
@ -163,13 +163,13 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap
} }
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return Task.FromResult(Enumerable.Empty<Model>()); return Task.FromResult(Enumerable.Empty<Model>());
} }
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return Task.FromResult(Enumerable.Empty<Model>()); return Task.FromResult(Enumerable.Empty<Model>());
} }

View File

@ -1,3 +1,5 @@
using AIStudio.Chat;
using RustService = AIStudio.Tools.RustService; using RustService = AIStudio.Tools.RustService;
namespace AIStudio.Provider; namespace AIStudio.Provider;
@ -5,7 +7,7 @@ namespace AIStudio.Provider;
/// <summary> /// <summary>
/// The base class for all providers. /// The base class for all providers.
/// </summary> /// </summary>
public abstract class BaseProvider public abstract class BaseProvider : IProvider, ISecretId
{ {
/// <summary> /// <summary>
/// The HTTP client to use it for all requests. /// The HTTP client to use it for all requests.
@ -39,4 +41,37 @@ public abstract class BaseProvider
// Set the base URL: // Set the base URL:
this.httpClient.BaseAddress = new(url); this.httpClient.BaseAddress = new(url);
} }
#region Handling of IProvider, which all providers must implement
/// <inheritdoc />
public abstract string Id { get; }
/// <inheritdoc />
public abstract string InstanceName { get; set; }
/// <inheritdoc />
public abstract IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatThread, CancellationToken token = default);
/// <inheritdoc />
public abstract IAsyncEnumerable<ImageURL> StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, CancellationToken token = default);
/// <inheritdoc />
public abstract Task<IEnumerable<Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default);
/// <inheritdoc />
public abstract Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default);
/// <inheritdoc />
public abstract Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default);
#endregion
#region Implementation of ISecretId
public string SecretId => this.Id;
public string SecretName => this.InstanceName;
#endregion
} }

View File

@ -7,7 +7,7 @@ using AIStudio.Chat;
namespace AIStudio.Provider.Fireworks; namespace AIStudio.Provider.Fireworks;
public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.fireworks.ai/inference/v1/", logger), IProvider public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.fireworks.ai/inference/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{ {
@ -17,13 +17,13 @@ public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.firew
#region Implementation of IProvider #region Implementation of IProvider
/// <inheritdoc /> /// <inheritdoc />
public string Id => "Fireworks.ai"; public override string Id => LLMProviders.FIREWORKS.ToName();
/// <inheritdoc /> /// <inheritdoc />
public string InstanceName { get; set; } = "Fireworks.ai"; public override string InstanceName { get; set; } = "Fireworks.ai";
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) public override async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default)
{ {
// Get the API key: // Get the API key:
var requestedSecret = await RUST_SERVICE.GetAPIKey(this); var requestedSecret = await RUST_SERVICE.GetAPIKey(this);
@ -138,26 +138,26 @@ public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.firew
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<ImageURL> StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) public override async IAsyncEnumerable<ImageURL> StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default)
{ {
yield break; yield break;
} }
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return Task.FromResult(Enumerable.Empty<Model>()); return Task.FromResult(Enumerable.Empty<Model>());
} }
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return Task.FromResult(Enumerable.Empty<Model>()); return Task.FromResult(Enumerable.Empty<Model>());
} }
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return Task.FromResult(Enumerable.Empty<Model>()); return Task.FromResult(Enumerable.Empty<Model>());
} }

View File

@ -8,7 +8,7 @@ using AIStudio.Provider.OpenAI;
namespace AIStudio.Provider.Google; namespace AIStudio.Provider.Google;
public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativelanguage.googleapis.com/v1beta/", logger), IProvider public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativelanguage.googleapis.com/v1beta/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{ {
@ -18,13 +18,13 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
#region Implementation of IProvider #region Implementation of IProvider
/// <inheritdoc /> /// <inheritdoc />
public string Id => "Google"; public override string Id => LLMProviders.GOOGLE.ToName();
/// <inheritdoc /> /// <inheritdoc />
public string InstanceName { get; set; } = "Google Gemini"; public override string InstanceName { get; set; } = "Google Gemini";
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<string> StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) public override async IAsyncEnumerable<string> StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default)
{ {
// Get the API key: // Get the API key:
var requestedSecret = await RUST_SERVICE.GetAPIKey(this); var requestedSecret = await RUST_SERVICE.GetAPIKey(this);
@ -139,14 +139,14 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<ImageURL> StreamImageCompletion(Provider.Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) public override async IAsyncEnumerable<ImageURL> StreamImageCompletion(Provider.Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default)
{ {
yield break; yield break;
} }
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc /> /// <inheritdoc />
public async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) public override async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
var modelResponse = await this.LoadModels(token, apiKeyProvisional); var modelResponse = await this.LoadModels(token, apiKeyProvisional);
if(modelResponse == default) if(modelResponse == default)
@ -158,12 +158,12 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
} }
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return Task.FromResult(Enumerable.Empty<Provider.Model>()); return Task.FromResult(Enumerable.Empty<Provider.Model>());
} }
public async Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) public override async Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
var modelResponse = await this.LoadModels(token, apiKeyProvisional); var modelResponse = await this.LoadModels(token, apiKeyProvisional);
if(modelResponse == default) if(modelResponse == default)

View File

@ -8,7 +8,7 @@ using AIStudio.Provider.OpenAI;
namespace AIStudio.Provider.Groq; namespace AIStudio.Provider.Groq;
public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/openai/v1/", logger), IProvider public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/openai/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{ {
@ -18,13 +18,13 @@ public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/o
#region Implementation of IProvider #region Implementation of IProvider
/// <inheritdoc /> /// <inheritdoc />
public string Id => "Groq"; public override string Id => LLMProviders.GROQ.ToName();
/// <inheritdoc /> /// <inheritdoc />
public string InstanceName { get; set; } = "Groq"; public override string InstanceName { get; set; } = "Groq";
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) public override async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default)
{ {
// Get the API key: // Get the API key:
var requestedSecret = await RUST_SERVICE.GetAPIKey(this); var requestedSecret = await RUST_SERVICE.GetAPIKey(this);
@ -141,26 +141,26 @@ public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/o
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<ImageURL> StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) public override async IAsyncEnumerable<ImageURL> StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default)
{ {
yield break; yield break;
} }
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return this.LoadModels(token, apiKeyProvisional); return this.LoadModels(token, apiKeyProvisional);
} }
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return Task.FromResult<IEnumerable<Model>>(Array.Empty<Model>()); return Task.FromResult<IEnumerable<Model>>(Array.Empty<Model>());
} }
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return Task.FromResult(Enumerable.Empty<Model>()); return Task.FromResult(Enumerable.Empty<Model>());
} }

View File

@ -103,20 +103,36 @@ public static class LLMProvidersExtensions
/// <param name="logger">The logger to use.</param> /// <param name="logger">The logger to use.</param>
/// <returns>The provider instance.</returns> /// <returns>The provider instance.</returns>
public static IProvider CreateProvider(this Settings.Provider providerSettings, ILogger logger) public static IProvider CreateProvider(this Settings.Provider providerSettings, ILogger logger)
{
return providerSettings.UsedLLMProvider.CreateProvider(providerSettings.InstanceName, providerSettings.Host, providerSettings.Hostname, logger);
}
/// <summary>
/// Creates a new provider instance based on the embedding provider value.
/// </summary>
/// <param name="embeddingProviderSettings">The embedding provider settings.</param>
/// <param name="logger">The logger to use.</param>
/// <returns>The provider instance.</returns>
public static IProvider CreateProvider(this EmbeddingProvider embeddingProviderSettings, ILogger logger)
{
return embeddingProviderSettings.UsedLLMProvider.CreateProvider(embeddingProviderSettings.Name, embeddingProviderSettings.Host, embeddingProviderSettings.Hostname, logger);
}
private static IProvider CreateProvider(this LLMProviders provider, string instanceName, Host host, string hostname, ILogger logger)
{ {
try try
{ {
return providerSettings.UsedLLMProvider switch return provider switch
{ {
LLMProviders.OPEN_AI => new ProviderOpenAI(logger) { InstanceName = providerSettings.InstanceName }, LLMProviders.OPEN_AI => new ProviderOpenAI(logger) { InstanceName = instanceName },
LLMProviders.ANTHROPIC => new ProviderAnthropic(logger) { InstanceName = providerSettings.InstanceName }, LLMProviders.ANTHROPIC => new ProviderAnthropic(logger) { InstanceName = instanceName },
LLMProviders.MISTRAL => new ProviderMistral(logger) { InstanceName = providerSettings.InstanceName }, LLMProviders.MISTRAL => new ProviderMistral(logger) { InstanceName = instanceName },
LLMProviders.GOOGLE => new ProviderGoogle(logger) { InstanceName = providerSettings.InstanceName }, LLMProviders.GOOGLE => new ProviderGoogle(logger) { InstanceName = instanceName },
LLMProviders.GROQ => new ProviderGroq(logger) { InstanceName = providerSettings.InstanceName }, LLMProviders.GROQ => new ProviderGroq(logger) { InstanceName = instanceName },
LLMProviders.FIREWORKS => new ProviderFireworks(logger) { InstanceName = providerSettings.InstanceName }, LLMProviders.FIREWORKS => new ProviderFireworks(logger) { InstanceName = instanceName },
LLMProviders.SELF_HOSTED => new ProviderSelfHosted(logger, providerSettings) { InstanceName = providerSettings.InstanceName }, LLMProviders.SELF_HOSTED => new ProviderSelfHosted(logger, host, hostname) { InstanceName = instanceName },
_ => new NoProvider(), _ => new NoProvider(),
}; };

View File

@ -8,7 +8,7 @@ using AIStudio.Provider.OpenAI;
namespace AIStudio.Provider.Mistral; namespace AIStudio.Provider.Mistral;
public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.mistral.ai/v1/", logger), IProvider public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.mistral.ai/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{ {
@ -17,12 +17,12 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
#region Implementation of IProvider #region Implementation of IProvider
public string Id => "Mistral"; public override string Id => LLMProviders.MISTRAL.ToName();
public string InstanceName { get; set; } = "Mistral"; public override string InstanceName { get; set; } = "Mistral";
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<string> StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) public override async IAsyncEnumerable<string> StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default)
{ {
// Get the API key: // Get the API key:
var requestedSecret = await RUST_SERVICE.GetAPIKey(this); var requestedSecret = await RUST_SERVICE.GetAPIKey(this);
@ -140,14 +140,14 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<ImageURL> StreamImageCompletion(Provider.Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) public override async IAsyncEnumerable<ImageURL> StreamImageCompletion(Provider.Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default)
{ {
yield break; yield break;
} }
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc /> /// <inheritdoc />
public async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) public override async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
var modelResponse = await this.LoadModelList(apiKeyProvisional, token); var modelResponse = await this.LoadModelList(apiKeyProvisional, token);
if(modelResponse == default) if(modelResponse == default)
@ -160,7 +160,7 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
} }
/// <inheritdoc /> /// <inheritdoc />
public async Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) public override async Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
var modelResponse = await this.LoadModelList(apiKeyProvisional, token); var modelResponse = await this.LoadModelList(apiKeyProvisional, token);
if(modelResponse == default) if(modelResponse == default)
@ -171,7 +171,7 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
} }
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return Task.FromResult(Enumerable.Empty<Provider.Model>()); return Task.FromResult(Enumerable.Empty<Provider.Model>());
} }

View File

@ -10,7 +10,7 @@ namespace AIStudio.Provider.OpenAI;
/// <summary> /// <summary>
/// The OpenAI provider. /// The OpenAI provider.
/// </summary> /// </summary>
public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.openai.com/v1/", logger), IProvider public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.openai.com/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{ {
@ -20,13 +20,13 @@ public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.o
#region Implementation of IProvider #region Implementation of IProvider
/// <inheritdoc /> /// <inheritdoc />
public string Id => "OpenAI"; public override string Id => LLMProviders.OPEN_AI.ToName();
/// <inheritdoc /> /// <inheritdoc />
public string InstanceName { get; set; } = "OpenAI"; public override string InstanceName { get; set; } = "OpenAI";
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) public override async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default)
{ {
// Get the API key: // Get the API key:
var requestedSecret = await RUST_SERVICE.GetAPIKey(this); var requestedSecret = await RUST_SERVICE.GetAPIKey(this);
@ -144,26 +144,26 @@ public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.o
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<ImageURL> StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) public override async IAsyncEnumerable<ImageURL> StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default)
{ {
yield break; yield break;
} }
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return this.LoadModels(["gpt-", "o1-"], token, apiKeyProvisional); return this.LoadModels(["gpt-", "o1-"], token, apiKeyProvisional);
} }
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return this.LoadModels(["dall-e-"], token, apiKeyProvisional); return this.LoadModels(["dall-e-"], token, apiKeyProvisional);
} }
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return this.LoadModels(["text-embedding-"], token, apiKeyProvisional); return this.LoadModels(["text-embedding-"], token, apiKeyProvisional);
} }

View File

@ -8,7 +8,7 @@ using AIStudio.Provider.OpenAI;
namespace AIStudio.Provider.SelfHosted; namespace AIStudio.Provider.SelfHosted;
public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provider) : BaseProvider($"{provider.Hostname}{provider.Host.BaseURL()}", logger), IProvider public sealed class ProviderSelfHosted(ILogger logger, Host host, string hostname) : BaseProvider($"{hostname}{host.BaseURL()}", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{ {
@ -17,12 +17,12 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide
#region Implementation of IProvider #region Implementation of IProvider
public string Id => "Self-hosted"; public override string Id => LLMProviders.SELF_HOSTED.ToName();
public string InstanceName { get; set; } = "Self-hosted"; public override string InstanceName { get; set; } = "Self-hosted";
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<string> StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) public override async IAsyncEnumerable<string> StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default)
{ {
// Get the API key: // Get the API key:
var requestedSecret = await RUST_SERVICE.GetAPIKey(this, isTrying: true); var requestedSecret = await RUST_SERVICE.GetAPIKey(this, isTrying: true);
@ -70,7 +70,7 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide
try try
{ {
// Build the HTTP post request: // Build the HTTP post request:
var request = new HttpRequestMessage(HttpMethod.Post, provider.Host.ChatURL()); var request = new HttpRequestMessage(HttpMethod.Post, host.ChatURL());
// Set the authorization header: // Set the authorization header:
if (requestedSecret.Success) if (requestedSecret.Success)
@ -148,18 +148,18 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<ImageURL> StreamImageCompletion(Provider.Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) public override async IAsyncEnumerable<ImageURL> StreamImageCompletion(Provider.Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default)
{ {
yield break; yield break;
} }
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
public async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) public override async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
try try
{ {
switch (provider.Host) switch (host)
{ {
case Host.LLAMACPP: case Host.LLAMACPP:
// Right now, llama.cpp only supports one model. // Right now, llama.cpp only supports one model.
@ -201,12 +201,12 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide
} }
/// <inheritdoc /> /// <inheritdoc />
public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return Task.FromResult(Enumerable.Empty<Provider.Model>()); return Task.FromResult(Enumerable.Empty<Provider.Model>());
} }
public Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) public override Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
return Task.FromResult(Enumerable.Empty<Provider.Model>()); return Task.FromResult(Enumerable.Empty<Provider.Model>());
} }

View File

@ -1,3 +1,5 @@
using System.Text.Json.Serialization;
using AIStudio.Provider; using AIStudio.Provider;
using Host = AIStudio.Provider.SelfHosted.Host; using Host = AIStudio.Provider.SelfHosted.Host;
@ -22,7 +24,7 @@ public readonly record struct Provider(
Model Model, Model Model,
bool IsSelfHosted = false, bool IsSelfHosted = false,
string Hostname = "http://localhost:1234", string Hostname = "http://localhost:1234",
Host Host = Host.NONE) Host Host = Host.NONE) : ISecretId
{ {
#region Overrides of ValueType #region Overrides of ValueType
@ -40,4 +42,16 @@ public readonly record struct Provider(
} }
#endregion #endregion
#region Implementation of ISecretId
/// <inheritdoc />
[JsonIgnore]
public string SecretId => this.Id;
/// <inheritdoc />
[JsonIgnore]
public string SecretName => this.InstanceName;
#endregion
} }

View File

@ -0,0 +1,17 @@
namespace AIStudio.Tools;
/// <summary>
/// Represents an interface defining a secret identifier.
/// </summary>
public interface ISecretId
{
/// <summary>
/// The unique ID of the secret.
/// </summary>
public string SecretId { get; }
/// <summary>
/// The instance name of the secret.
/// </summary>
public string SecretName { get; }
}

View File

@ -1,7 +1,6 @@
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text.Json; using System.Text.Json;
using AIStudio.Provider;
using AIStudio.Tools.Rust; using AIStudio.Tools.Rust;
// ReSharper disable NotAccessedPositionalProperty.Local // ReSharper disable NotAccessedPositionalProperty.Local
@ -255,71 +254,71 @@ public sealed class RustService : IDisposable
} }
/// <summary> /// <summary>
/// Try to get the API key for the given provider. /// Try to get the API key for the given secret ID.
/// </summary> /// </summary>
/// <param name="provider">The provider to get the API key for.</param> /// <param name="secretId">The secret ID to get the API key for.</param>
/// <param name="isTrying">Indicates if we are trying to get the API key. In that case, we don't log errors.</param> /// <param name="isTrying">Indicates if we are trying to get the API key. In that case, we don't log errors.</param>
/// <returns>The requested secret.</returns> /// <returns>The requested secret.</returns>
public async Task<RequestedSecret> GetAPIKey(IProvider provider, bool isTrying = false) public async Task<RequestedSecret> GetAPIKey(ISecretId secretId, bool isTrying = false)
{ {
var secretRequest = new SelectSecretRequest($"provider::{provider.Id}::{provider.InstanceName}::api_key", Environment.UserName, isTrying); var secretRequest = new SelectSecretRequest($"provider::{secretId.SecretId}::{secretId.SecretName}::api_key", Environment.UserName, isTrying);
var result = await this.http.PostAsJsonAsync("/secrets/get", secretRequest, this.jsonRustSerializerOptions); var result = await this.http.PostAsJsonAsync("/secrets/get", secretRequest, this.jsonRustSerializerOptions);
if (!result.IsSuccessStatusCode) if (!result.IsSuccessStatusCode)
{ {
if(!isTrying) if(!isTrying)
this.logger!.LogError($"Failed to get the API key for provider '{provider.Id}' due to an API issue: '{result.StatusCode}'"); this.logger!.LogError($"Failed to get the API key for secret ID '{secretId.SecretId}' due to an API issue: '{result.StatusCode}'");
return new RequestedSecret(false, new EncryptedText(string.Empty), "Failed to get the API key due to an API issue."); return new RequestedSecret(false, new EncryptedText(string.Empty), "Failed to get the API key due to an API issue.");
} }
var secret = await result.Content.ReadFromJsonAsync<RequestedSecret>(this.jsonRustSerializerOptions); var secret = await result.Content.ReadFromJsonAsync<RequestedSecret>(this.jsonRustSerializerOptions);
if (!secret.Success && !isTrying) if (!secret.Success && !isTrying)
this.logger!.LogError($"Failed to get the API key for provider '{provider.Id}': '{secret.Issue}'"); this.logger!.LogError($"Failed to get the API key for secret ID '{secretId.SecretId}': '{secret.Issue}'");
return secret; return secret;
} }
/// <summary> /// <summary>
/// Try to store the API key for the given provider. /// Try to store the API key for the given secret ID.
/// </summary> /// </summary>
/// <param name="provider">The provider to store the API key for.</param> /// <param name="secretId">The secret ID to store the API key for.</param>
/// <param name="key">The API key to store.</param> /// <param name="key">The API key to store.</param>
/// <returns>The store secret response.</returns> /// <returns>The store secret response.</returns>
public async Task<StoreSecretResponse> SetAPIKey(IProvider provider, string key) public async Task<StoreSecretResponse> SetAPIKey(ISecretId secretId, string key)
{ {
var encryptedKey = await this.encryptor!.Encrypt(key); var encryptedKey = await this.encryptor!.Encrypt(key);
var request = new StoreSecretRequest($"provider::{provider.Id}::{provider.InstanceName}::api_key", Environment.UserName, encryptedKey); var request = new StoreSecretRequest($"provider::{secretId.SecretId}::{secretId.SecretName}::api_key", Environment.UserName, encryptedKey);
var result = await this.http.PostAsJsonAsync("/secrets/store", request, this.jsonRustSerializerOptions); var result = await this.http.PostAsJsonAsync("/secrets/store", request, this.jsonRustSerializerOptions);
if (!result.IsSuccessStatusCode) if (!result.IsSuccessStatusCode)
{ {
this.logger!.LogError($"Failed to store the API key for provider '{provider.Id}' due to an API issue: '{result.StatusCode}'"); this.logger!.LogError($"Failed to store the API key for secret ID '{secretId.SecretId}' due to an API issue: '{result.StatusCode}'");
return new StoreSecretResponse(false, "Failed to get the API key due to an API issue."); return new StoreSecretResponse(false, "Failed to get the API key due to an API issue.");
} }
var state = await result.Content.ReadFromJsonAsync<StoreSecretResponse>(this.jsonRustSerializerOptions); var state = await result.Content.ReadFromJsonAsync<StoreSecretResponse>(this.jsonRustSerializerOptions);
if (!state.Success) if (!state.Success)
this.logger!.LogError($"Failed to store the API key for provider '{provider.Id}': '{state.Issue}'"); this.logger!.LogError($"Failed to store the API key for secret ID '{secretId.SecretId}': '{state.Issue}'");
return state; return state;
} }
/// <summary> /// <summary>
/// Tries to delete the API key for the given provider. /// Tries to delete the API key for the given secret ID.
/// </summary> /// </summary>
/// <param name="provider">The provider to delete the API key for.</param> /// <param name="secretId">The secret ID to delete the API key for.</param>
/// <returns>The delete secret response.</returns> /// <returns>The delete secret response.</returns>
public async Task<DeleteSecretResponse> DeleteAPIKey(IProvider provider) public async Task<DeleteSecretResponse> DeleteAPIKey(ISecretId secretId)
{ {
var request = new SelectSecretRequest($"provider::{provider.Id}::{provider.InstanceName}::api_key", Environment.UserName, false); var request = new SelectSecretRequest($"provider::{secretId.SecretId}::{secretId.SecretName}::api_key", Environment.UserName, false);
var result = await this.http.PostAsJsonAsync("/secrets/delete", request, this.jsonRustSerializerOptions); var result = await this.http.PostAsJsonAsync("/secrets/delete", request, this.jsonRustSerializerOptions);
if (!result.IsSuccessStatusCode) if (!result.IsSuccessStatusCode)
{ {
this.logger!.LogError($"Failed to delete the API key for provider '{provider.Id}' due to an API issue: '{result.StatusCode}'"); this.logger!.LogError($"Failed to delete the API key for secret ID '{secretId.SecretId}' due to an API issue: '{result.StatusCode}'");
return new DeleteSecretResponse{Success = false, WasEntryFound = false, Issue = "Failed to delete the API key due to an API issue."}; return new DeleteSecretResponse{Success = false, WasEntryFound = false, Issue = "Failed to delete the API key due to an API issue."};
} }
var state = await result.Content.ReadFromJsonAsync<DeleteSecretResponse>(this.jsonRustSerializerOptions); var state = await result.Content.ReadFromJsonAsync<DeleteSecretResponse>(this.jsonRustSerializerOptions);
if (!state.Success) if (!state.Success)
this.logger!.LogError($"Failed to delete the API key for provider '{provider.Id}': '{state.Issue}'"); this.logger!.LogError($"Failed to delete the API key for secret ID '{secretId.SecretId}': '{state.Issue}'");
return state; return state;
} }