mirror of
https://github.com/MindWorkAI/AI-Studio.git
synced 2025-04-28 21:19:47 +00:00
Added method to get embedding models of a certain provider
This commit is contained in:
parent
340f0ef5cd
commit
c9ce1ac468
@ -162,13 +162,17 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap
|
|||||||
}.AsEnumerable());
|
}.AsEnumerable());
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
|
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||||
{
|
{
|
||||||
return Task.FromResult(Enumerable.Empty<Model>());
|
return Task.FromResult(Enumerable.Empty<Model>());
|
||||||
}
|
}
|
||||||
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
|
|
||||||
|
/// <inheritdoc />
|
||||||
|
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||||
|
{
|
||||||
|
return Task.FromResult(Enumerable.Empty<Model>());
|
||||||
|
}
|
||||||
|
|
||||||
#endregion
|
#endregion
|
||||||
}
|
}
|
@ -155,6 +155,12 @@ public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.firew
|
|||||||
{
|
{
|
||||||
return Task.FromResult(Enumerable.Empty<Model>());
|
return Task.FromResult(Enumerable.Empty<Model>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <inheritdoc />
|
||||||
|
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||||
|
{
|
||||||
|
return Task.FromResult(Enumerable.Empty<Model>());
|
||||||
|
}
|
||||||
|
|
||||||
#endregion
|
#endregion
|
||||||
}
|
}
|
@ -146,9 +146,15 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
|
|||||||
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
|
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
|
||||||
|
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
public async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||||
{
|
{
|
||||||
return this.LoadModels(token, apiKeyProvisional);
|
var modelResponse = await this.LoadModels(token, apiKeyProvisional);
|
||||||
|
if(modelResponse == default)
|
||||||
|
return [];
|
||||||
|
|
||||||
|
return modelResponse.Models.Where(model =>
|
||||||
|
model.Name.StartsWith("models/gemini-", StringComparison.InvariantCultureIgnoreCase))
|
||||||
|
.Select(n => new Provider.Model(n.Name.Replace("models/", string.Empty), n.DisplayName));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
@ -157,9 +163,20 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
|
|||||||
return Task.FromResult(Enumerable.Empty<Provider.Model>());
|
return Task.FromResult(Enumerable.Empty<Provider.Model>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||||
|
{
|
||||||
|
var modelResponse = await this.LoadModels(token, apiKeyProvisional);
|
||||||
|
if(modelResponse == default)
|
||||||
|
return [];
|
||||||
|
|
||||||
|
return modelResponse.Models.Where(model =>
|
||||||
|
model.Name.StartsWith("models/text-embedding-", StringComparison.InvariantCultureIgnoreCase))
|
||||||
|
.Select(n => new Provider.Model(n.Name.Replace("models/", string.Empty), n.DisplayName));
|
||||||
|
}
|
||||||
|
|
||||||
#endregion
|
#endregion
|
||||||
|
|
||||||
private async Task<IEnumerable<Provider.Model>> LoadModels(CancellationToken token, string? apiKeyProvisional = null)
|
private async Task<ModelsResponse> LoadModels(CancellationToken token, string? apiKeyProvisional = null)
|
||||||
{
|
{
|
||||||
var secretKey = apiKeyProvisional switch
|
var secretKey = apiKeyProvisional switch
|
||||||
{
|
{
|
||||||
@ -170,19 +187,17 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
|
|||||||
_ => null,
|
_ => null,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if (secretKey is null)
|
if (secretKey is null)
|
||||||
return [];
|
return default;
|
||||||
|
|
||||||
var request = new HttpRequestMessage(HttpMethod.Get, $"models?key={secretKey}");
|
var request = new HttpRequestMessage(HttpMethod.Get, $"models?key={secretKey}");
|
||||||
var response = await this.httpClient.SendAsync(request, token);
|
var response = await this.httpClient.SendAsync(request, token);
|
||||||
|
|
||||||
if(!response.IsSuccessStatusCode)
|
if(!response.IsSuccessStatusCode)
|
||||||
return [];
|
return default;
|
||||||
|
|
||||||
var modelResponse = await response.Content.ReadFromJsonAsync<ModelsResponse>(token);
|
var modelResponse = await response.Content.ReadFromJsonAsync<ModelsResponse>(token);
|
||||||
return modelResponse.Models.Where(model =>
|
return modelResponse;
|
||||||
model.Name.StartsWith("models/gemini-", StringComparison.InvariantCultureIgnoreCase))
|
|
||||||
.Select(n => new Provider.Model(n.Name.Replace("models/", string.Empty), n.DisplayName));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -158,6 +158,12 @@ public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/o
|
|||||||
{
|
{
|
||||||
return Task.FromResult<IEnumerable<Model>>(Array.Empty<Model>());
|
return Task.FromResult<IEnumerable<Model>>(Array.Empty<Model>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <inheritdoc />
|
||||||
|
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||||
|
{
|
||||||
|
return Task.FromResult(Enumerable.Empty<Model>());
|
||||||
|
}
|
||||||
|
|
||||||
#endregion
|
#endregion
|
||||||
|
|
||||||
|
@ -53,4 +53,12 @@ public interface IProvider
|
|||||||
/// <param name="token">The cancellation token.</param>
|
/// <param name="token">The cancellation token.</param>
|
||||||
/// <returns>The list of image models.</returns>
|
/// <returns>The list of image models.</returns>
|
||||||
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default);
|
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Load all possible embedding models that can be used with this provider.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="apiKeyProvisional">The provisional API key to use. Useful when the user is adding a new provider. When null, the stored API key is used.</param>
|
||||||
|
/// <param name="token">The cancellation token.</param>
|
||||||
|
/// <returns>The list of embedding models.</returns>
|
||||||
|
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default);
|
||||||
}
|
}
|
@ -148,6 +148,37 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
|
|||||||
|
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
public async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||||
|
{
|
||||||
|
var modelResponse = await this.LoadModelList(apiKeyProvisional, token);
|
||||||
|
if(modelResponse == default)
|
||||||
|
return [];
|
||||||
|
|
||||||
|
return modelResponse.Data.Where(n =>
|
||||||
|
!n.Id.StartsWith("code", StringComparison.InvariantCulture) &&
|
||||||
|
!n.Id.Contains("embed", StringComparison.InvariantCulture))
|
||||||
|
.Select(n => new Provider.Model(n.Id, null));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <inheritdoc />
|
||||||
|
public async Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||||
|
{
|
||||||
|
var modelResponse = await this.LoadModelList(apiKeyProvisional, token);
|
||||||
|
if(modelResponse == default)
|
||||||
|
return [];
|
||||||
|
|
||||||
|
return modelResponse.Data.Where(n => n.Id.Contains("embed", StringComparison.InvariantCulture))
|
||||||
|
.Select(n => new Provider.Model(n.Id, null));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <inheritdoc />
|
||||||
|
public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||||
|
{
|
||||||
|
return Task.FromResult(Enumerable.Empty<Provider.Model>());
|
||||||
|
}
|
||||||
|
|
||||||
|
#endregion
|
||||||
|
|
||||||
|
private async Task<ModelsResponse> LoadModelList(string? apiKeyProvisional, CancellationToken token)
|
||||||
{
|
{
|
||||||
var secretKey = apiKeyProvisional switch
|
var secretKey = apiKeyProvisional switch
|
||||||
{
|
{
|
||||||
@ -160,29 +191,16 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
|
|||||||
};
|
};
|
||||||
|
|
||||||
if (secretKey is null)
|
if (secretKey is null)
|
||||||
return [];
|
return default;
|
||||||
|
|
||||||
var request = new HttpRequestMessage(HttpMethod.Get, "models");
|
var request = new HttpRequestMessage(HttpMethod.Get, "models");
|
||||||
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", secretKey);
|
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", secretKey);
|
||||||
|
|
||||||
var response = await this.httpClient.SendAsync(request, token);
|
var response = await this.httpClient.SendAsync(request, token);
|
||||||
if(!response.IsSuccessStatusCode)
|
if(!response.IsSuccessStatusCode)
|
||||||
return [];
|
return default;
|
||||||
|
|
||||||
var modelResponse = await response.Content.ReadFromJsonAsync<ModelsResponse>(token);
|
var modelResponse = await response.Content.ReadFromJsonAsync<ModelsResponse>(token);
|
||||||
return modelResponse.Data.Where(n =>
|
return modelResponse;
|
||||||
!n.Id.StartsWith("code", StringComparison.InvariantCulture) &&
|
|
||||||
!n.Id.Contains("embed", StringComparison.InvariantCulture))
|
|
||||||
.Select(n => new Provider.Model(n.Id, null));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
|
|
||||||
/// <inheritdoc />
|
|
||||||
public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
|
||||||
{
|
|
||||||
return Task.FromResult(Enumerable.Empty<Provider.Model>());
|
|
||||||
}
|
|
||||||
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
|
|
||||||
|
|
||||||
#endregion
|
|
||||||
}
|
}
|
@ -15,6 +15,8 @@ public class NoProvider : IProvider
|
|||||||
public Task<IEnumerable<Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
|
public Task<IEnumerable<Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
|
||||||
|
|
||||||
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
|
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
|
||||||
|
|
||||||
|
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
|
||||||
|
|
||||||
public async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatChatThread, [EnumeratorCancellation] CancellationToken token = default)
|
public async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatChatThread, [EnumeratorCancellation] CancellationToken token = default)
|
||||||
{
|
{
|
||||||
|
@ -161,6 +161,12 @@ public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.o
|
|||||||
{
|
{
|
||||||
return this.LoadModels(["dall-e-"], token, apiKeyProvisional);
|
return this.LoadModels(["dall-e-"], token, apiKeyProvisional);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <inheritdoc />
|
||||||
|
public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||||
|
{
|
||||||
|
return this.LoadModels(["text-embedding-"], token, apiKeyProvisional);
|
||||||
|
}
|
||||||
|
|
||||||
#endregion
|
#endregion
|
||||||
|
|
||||||
|
@ -200,13 +200,16 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
|
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||||
{
|
{
|
||||||
return Task.FromResult(Enumerable.Empty<Provider.Model>());
|
return Task.FromResult(Enumerable.Empty<Provider.Model>());
|
||||||
}
|
}
|
||||||
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
|
|
||||||
|
public Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
|
||||||
|
{
|
||||||
|
return Task.FromResult(Enumerable.Empty<Provider.Model>());
|
||||||
|
}
|
||||||
|
|
||||||
#endregion
|
#endregion
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user