mirror of
https://github.com/MindWorkAI/AI-Studio.git
synced 2025-04-28 17:39:47 +00:00
Added method to get embedding models of a certain provider
This commit is contained in:
parent
340f0ef5cd
commit
c9ce1ac468
@ -162,13 +162,17 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap
|
||||
}.AsEnumerable());
|
||||
}
|
||||
|
||||
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
|
||||
/// <inheritdoc />
|
||||
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||
{
|
||||
return Task.FromResult(Enumerable.Empty<Model>());
|
||||
}
|
||||
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||
{
|
||||
return Task.FromResult(Enumerable.Empty<Model>());
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
@ -155,6 +155,12 @@ public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.firew
|
||||
{
|
||||
return Task.FromResult(Enumerable.Empty<Model>());
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||
{
|
||||
return Task.FromResult(Enumerable.Empty<Model>());
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
@ -146,9 +146,15 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
|
||||
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||
public async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||
{
|
||||
return this.LoadModels(token, apiKeyProvisional);
|
||||
var modelResponse = await this.LoadModels(token, apiKeyProvisional);
|
||||
if(modelResponse == default)
|
||||
return [];
|
||||
|
||||
return modelResponse.Models.Where(model =>
|
||||
model.Name.StartsWith("models/gemini-", StringComparison.InvariantCultureIgnoreCase))
|
||||
.Select(n => new Provider.Model(n.Name.Replace("models/", string.Empty), n.DisplayName));
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
@ -157,9 +163,20 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
|
||||
return Task.FromResult(Enumerable.Empty<Provider.Model>());
|
||||
}
|
||||
|
||||
public async Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||
{
|
||||
var modelResponse = await this.LoadModels(token, apiKeyProvisional);
|
||||
if(modelResponse == default)
|
||||
return [];
|
||||
|
||||
return modelResponse.Models.Where(model =>
|
||||
model.Name.StartsWith("models/text-embedding-", StringComparison.InvariantCultureIgnoreCase))
|
||||
.Select(n => new Provider.Model(n.Name.Replace("models/", string.Empty), n.DisplayName));
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
private async Task<IEnumerable<Provider.Model>> LoadModels(CancellationToken token, string? apiKeyProvisional = null)
|
||||
private async Task<ModelsResponse> LoadModels(CancellationToken token, string? apiKeyProvisional = null)
|
||||
{
|
||||
var secretKey = apiKeyProvisional switch
|
||||
{
|
||||
@ -170,19 +187,17 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
|
||||
_ => null,
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
if (secretKey is null)
|
||||
return [];
|
||||
return default;
|
||||
|
||||
var request = new HttpRequestMessage(HttpMethod.Get, $"models?key={secretKey}");
|
||||
var response = await this.httpClient.SendAsync(request, token);
|
||||
|
||||
if(!response.IsSuccessStatusCode)
|
||||
return [];
|
||||
return default;
|
||||
|
||||
var modelResponse = await response.Content.ReadFromJsonAsync<ModelsResponse>(token);
|
||||
return modelResponse.Models.Where(model =>
|
||||
model.Name.StartsWith("models/gemini-", StringComparison.InvariantCultureIgnoreCase))
|
||||
.Select(n => new Provider.Model(n.Name.Replace("models/", string.Empty), n.DisplayName));
|
||||
return modelResponse;
|
||||
}
|
||||
}
|
@ -158,6 +158,12 @@ public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/o
|
||||
{
|
||||
return Task.FromResult<IEnumerable<Model>>(Array.Empty<Model>());
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||
{
|
||||
return Task.FromResult(Enumerable.Empty<Model>());
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
|
@ -53,4 +53,12 @@ public interface IProvider
|
||||
/// <param name="token">The cancellation token.</param>
|
||||
/// <returns>The list of image models.</returns>
|
||||
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default);
|
||||
|
||||
/// <summary>
|
||||
/// Load all possible embedding models that can be used with this provider.
|
||||
/// </summary>
|
||||
/// <param name="apiKeyProvisional">The provisional API key to use. Useful when the user is adding a new provider. When null, the stored API key is used.</param>
|
||||
/// <param name="token">The cancellation token.</param>
|
||||
/// <returns>The list of embedding models.</returns>
|
||||
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default);
|
||||
}
|
@ -148,6 +148,37 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||
{
|
||||
var modelResponse = await this.LoadModelList(apiKeyProvisional, token);
|
||||
if(modelResponse == default)
|
||||
return [];
|
||||
|
||||
return modelResponse.Data.Where(n =>
|
||||
!n.Id.StartsWith("code", StringComparison.InvariantCulture) &&
|
||||
!n.Id.Contains("embed", StringComparison.InvariantCulture))
|
||||
.Select(n => new Provider.Model(n.Id, null));
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||
{
|
||||
var modelResponse = await this.LoadModelList(apiKeyProvisional, token);
|
||||
if(modelResponse == default)
|
||||
return [];
|
||||
|
||||
return modelResponse.Data.Where(n => n.Id.Contains("embed", StringComparison.InvariantCulture))
|
||||
.Select(n => new Provider.Model(n.Id, null));
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||
{
|
||||
return Task.FromResult(Enumerable.Empty<Provider.Model>());
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
private async Task<ModelsResponse> LoadModelList(string? apiKeyProvisional, CancellationToken token)
|
||||
{
|
||||
var secretKey = apiKeyProvisional switch
|
||||
{
|
||||
@ -160,29 +191,16 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
|
||||
};
|
||||
|
||||
if (secretKey is null)
|
||||
return [];
|
||||
return default;
|
||||
|
||||
var request = new HttpRequestMessage(HttpMethod.Get, "models");
|
||||
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", secretKey);
|
||||
|
||||
var response = await this.httpClient.SendAsync(request, token);
|
||||
if(!response.IsSuccessStatusCode)
|
||||
return [];
|
||||
return default;
|
||||
|
||||
var modelResponse = await response.Content.ReadFromJsonAsync<ModelsResponse>(token);
|
||||
return modelResponse.Data.Where(n =>
|
||||
!n.Id.StartsWith("code", StringComparison.InvariantCulture) &&
|
||||
!n.Id.Contains("embed", StringComparison.InvariantCulture))
|
||||
.Select(n => new Provider.Model(n.Id, null));
|
||||
return modelResponse;
|
||||
}
|
||||
|
||||
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
|
||||
/// <inheritdoc />
|
||||
public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||
{
|
||||
return Task.FromResult(Enumerable.Empty<Provider.Model>());
|
||||
}
|
||||
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
|
||||
|
||||
#endregion
|
||||
}
|
@ -15,6 +15,8 @@ public class NoProvider : IProvider
|
||||
public Task<IEnumerable<Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
|
||||
|
||||
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
|
||||
|
||||
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
|
||||
|
||||
public async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatChatThread, [EnumeratorCancellation] CancellationToken token = default)
|
||||
{
|
||||
|
@ -161,6 +161,12 @@ public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.o
|
||||
{
|
||||
return this.LoadModels(["dall-e-"], token, apiKeyProvisional);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||
{
|
||||
return this.LoadModels(["text-embedding-"], token, apiKeyProvisional);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
|
@ -200,13 +200,16 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide
|
||||
}
|
||||
}
|
||||
|
||||
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
|
||||
/// <inheritdoc />
|
||||
public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||
{
|
||||
return Task.FromResult(Enumerable.Empty<Provider.Model>());
|
||||
}
|
||||
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
|
||||
|
||||
public Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||
{
|
||||
return Task.FromResult(Enumerable.Empty<Provider.Model>());
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
Loading…
Reference in New Issue
Block a user