Added method to get embedding models of a certain provider

This commit is contained in:
Thorsten Sommer 2024-12-01 11:55:16 +01:00
parent 340f0ef5cd
commit c9ce1ac468
Signed by: tsommer
GPG Key ID: 371BBA77A02C0108
9 changed files with 97 additions and 29 deletions

View File

@ -162,13 +162,17 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap
}.AsEnumerable());
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc />
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
return Task.FromResult(Enumerable.Empty<Model>());
}
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc />
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
return Task.FromResult(Enumerable.Empty<Model>());
}
#endregion
}

View File

@ -155,6 +155,12 @@ public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.firew
{
return Task.FromResult(Enumerable.Empty<Model>());
}
/// <inheritdoc />
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
return Task.FromResult(Enumerable.Empty<Model>());
}
#endregion
}

View File

@ -146,9 +146,15 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc />
public Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
public async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
return this.LoadModels(token, apiKeyProvisional);
var modelResponse = await this.LoadModels(token, apiKeyProvisional);
if(modelResponse == default)
return [];
return modelResponse.Models.Where(model =>
model.Name.StartsWith("models/gemini-", StringComparison.InvariantCultureIgnoreCase))
.Select(n => new Provider.Model(n.Name.Replace("models/", string.Empty), n.DisplayName));
}
/// <inheritdoc />
@ -157,9 +163,20 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
return Task.FromResult(Enumerable.Empty<Provider.Model>());
}
public async Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
var modelResponse = await this.LoadModels(token, apiKeyProvisional);
if(modelResponse == default)
return [];
return modelResponse.Models.Where(model =>
model.Name.StartsWith("models/text-embedding-", StringComparison.InvariantCultureIgnoreCase))
.Select(n => new Provider.Model(n.Name.Replace("models/", string.Empty), n.DisplayName));
}
#endregion
private async Task<IEnumerable<Provider.Model>> LoadModels(CancellationToken token, string? apiKeyProvisional = null)
private async Task<ModelsResponse> LoadModels(CancellationToken token, string? apiKeyProvisional = null)
{
var secretKey = apiKeyProvisional switch
{
@ -170,19 +187,17 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
_ => null,
}
};
if (secretKey is null)
return [];
return default;
var request = new HttpRequestMessage(HttpMethod.Get, $"models?key={secretKey}");
var response = await this.httpClient.SendAsync(request, token);
if(!response.IsSuccessStatusCode)
return [];
return default;
var modelResponse = await response.Content.ReadFromJsonAsync<ModelsResponse>(token);
return modelResponse.Models.Where(model =>
model.Name.StartsWith("models/gemini-", StringComparison.InvariantCultureIgnoreCase))
.Select(n => new Provider.Model(n.Name.Replace("models/", string.Empty), n.DisplayName));
return modelResponse;
}
}

View File

@ -158,6 +158,12 @@ public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/o
{
return Task.FromResult<IEnumerable<Model>>(Array.Empty<Model>());
}
/// <inheritdoc />
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
return Task.FromResult(Enumerable.Empty<Model>());
}
#endregion

View File

@ -53,4 +53,12 @@ public interface IProvider
/// <param name="token">The cancellation token.</param>
/// <returns>The list of image models.</returns>
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default);
/// <summary>
/// Load all possible embedding models that can be used with this provider.
/// </summary>
/// <param name="apiKeyProvisional">The provisional API key to use. Useful when the user is adding a new provider. When null, the stored API key is used.</param>
/// <param name="token">The cancellation token.</param>
/// <returns>The list of embedding models.</returns>
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default);
}

View File

@ -148,6 +148,37 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
/// <inheritdoc />
public async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
var modelResponse = await this.LoadModelList(apiKeyProvisional, token);
if(modelResponse == default)
return [];
return modelResponse.Data.Where(n =>
!n.Id.StartsWith("code", StringComparison.InvariantCulture) &&
!n.Id.Contains("embed", StringComparison.InvariantCulture))
.Select(n => new Provider.Model(n.Id, null));
}
/// <inheritdoc />
public async Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
var modelResponse = await this.LoadModelList(apiKeyProvisional, token);
if(modelResponse == default)
return [];
return modelResponse.Data.Where(n => n.Id.Contains("embed", StringComparison.InvariantCulture))
.Select(n => new Provider.Model(n.Id, null));
}
/// <inheritdoc />
public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
return Task.FromResult(Enumerable.Empty<Provider.Model>());
}
#endregion
private async Task<ModelsResponse> LoadModelList(string? apiKeyProvisional, CancellationToken token)
{
var secretKey = apiKeyProvisional switch
{
@ -160,29 +191,16 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
};
if (secretKey is null)
return [];
return default;
var request = new HttpRequestMessage(HttpMethod.Get, "models");
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", secretKey);
var response = await this.httpClient.SendAsync(request, token);
if(!response.IsSuccessStatusCode)
return [];
return default;
var modelResponse = await response.Content.ReadFromJsonAsync<ModelsResponse>(token);
return modelResponse.Data.Where(n =>
!n.Id.StartsWith("code", StringComparison.InvariantCulture) &&
!n.Id.Contains("embed", StringComparison.InvariantCulture))
.Select(n => new Provider.Model(n.Id, null));
return modelResponse;
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc />
public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
return Task.FromResult(Enumerable.Empty<Provider.Model>());
}
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
#endregion
}

View File

@ -15,6 +15,8 @@ public class NoProvider : IProvider
public Task<IEnumerable<Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
public async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatChatThread, [EnumeratorCancellation] CancellationToken token = default)
{

View File

@ -161,6 +161,12 @@ public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.o
{
return this.LoadModels(["dall-e-"], token, apiKeyProvisional);
}
/// <inheritdoc />
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
return this.LoadModels(["text-embedding-"], token, apiKeyProvisional);
}
#endregion

View File

@ -200,13 +200,16 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide
}
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc />
public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
return Task.FromResult(Enumerable.Empty<Provider.Model>());
}
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
public Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
return Task.FromResult(Enumerable.Empty<Provider.Model>());
}
#endregion
}