mirror of
				https://github.com/MindWorkAI/AI-Studio.git
				synced 2025-11-04 01:40:21 +00:00 
			
		
		
		
	Added method to get embedding models of a certain provider
This commit is contained in:
		
							parent
							
								
									340f0ef5cd
								
							
						
					
					
						commit
						c9ce1ac468
					
				@ -162,13 +162,17 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap
 | 
				
			|||||||
        }.AsEnumerable());
 | 
					        }.AsEnumerable());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
 | 
					 | 
				
			||||||
    /// <inheritdoc />
 | 
					    /// <inheritdoc />
 | 
				
			||||||
    public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
					    public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        return Task.FromResult(Enumerable.Empty<Model>());
 | 
					        return Task.FromResult(Enumerable.Empty<Model>());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
 | 
					    
 | 
				
			||||||
 | 
					    /// <inheritdoc />
 | 
				
			||||||
 | 
					    public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        return Task.FromResult(Enumerable.Empty<Model>());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #endregion
 | 
					    #endregion
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -156,5 +156,11 @@ public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.firew
 | 
				
			|||||||
        return Task.FromResult(Enumerable.Empty<Model>());
 | 
					        return Task.FromResult(Enumerable.Empty<Model>());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
					    /// <inheritdoc />
 | 
				
			||||||
 | 
					    public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        return Task.FromResult(Enumerable.Empty<Model>());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #endregion
 | 
					    #endregion
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -146,9 +146,15 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
 | 
				
			|||||||
    #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
 | 
					    #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /// <inheritdoc />
 | 
					    /// <inheritdoc />
 | 
				
			||||||
    public Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
					    public async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        return this.LoadModels(token, apiKeyProvisional);
 | 
					        var modelResponse = await this.LoadModels(token, apiKeyProvisional);
 | 
				
			||||||
 | 
					        if(modelResponse == default)
 | 
				
			||||||
 | 
					            return [];
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return modelResponse.Models.Where(model =>
 | 
				
			||||||
 | 
					                model.Name.StartsWith("models/gemini-", StringComparison.InvariantCultureIgnoreCase))
 | 
				
			||||||
 | 
					            .Select(n => new Provider.Model(n.Name.Replace("models/", string.Empty), n.DisplayName));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /// <inheritdoc />
 | 
					    /// <inheritdoc />
 | 
				
			||||||
@ -157,9 +163,20 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
 | 
				
			|||||||
        return Task.FromResult(Enumerable.Empty<Provider.Model>());
 | 
					        return Task.FromResult(Enumerable.Empty<Provider.Model>());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public async Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        var modelResponse = await this.LoadModels(token, apiKeyProvisional);
 | 
				
			||||||
 | 
					        if(modelResponse == default)
 | 
				
			||||||
 | 
					            return [];
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return modelResponse.Models.Where(model =>
 | 
				
			||||||
 | 
					                model.Name.StartsWith("models/text-embedding-", StringComparison.InvariantCultureIgnoreCase))
 | 
				
			||||||
 | 
					            .Select(n => new Provider.Model(n.Name.Replace("models/", string.Empty), n.DisplayName));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #endregion
 | 
					    #endregion
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private async Task<IEnumerable<Provider.Model>> LoadModels(CancellationToken token, string? apiKeyProvisional = null)
 | 
					    private async Task<ModelsResponse> LoadModels(CancellationToken token, string? apiKeyProvisional = null)
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        var secretKey = apiKeyProvisional switch
 | 
					        var secretKey = apiKeyProvisional switch
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
@ -172,17 +189,15 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
 | 
				
			|||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (secretKey is null)
 | 
					        if (secretKey is null)
 | 
				
			||||||
            return [];
 | 
					            return default;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        var request = new HttpRequestMessage(HttpMethod.Get, $"models?key={secretKey}");
 | 
					        var request = new HttpRequestMessage(HttpMethod.Get, $"models?key={secretKey}");
 | 
				
			||||||
        var response = await this.httpClient.SendAsync(request, token);
 | 
					        var response = await this.httpClient.SendAsync(request, token);
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        if(!response.IsSuccessStatusCode)
 | 
					        if(!response.IsSuccessStatusCode)
 | 
				
			||||||
            return [];
 | 
					            return default;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        var modelResponse = await response.Content.ReadFromJsonAsync<ModelsResponse>(token);
 | 
					        var modelResponse = await response.Content.ReadFromJsonAsync<ModelsResponse>(token);
 | 
				
			||||||
        return modelResponse.Models.Where(model =>
 | 
					        return modelResponse;
 | 
				
			||||||
            model.Name.StartsWith("models/gemini-", StringComparison.InvariantCultureIgnoreCase))
 | 
					 | 
				
			||||||
            .Select(n => new Provider.Model(n.Name.Replace("models/", string.Empty), n.DisplayName));
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -159,6 +159,12 @@ public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/o
 | 
				
			|||||||
        return Task.FromResult<IEnumerable<Model>>(Array.Empty<Model>());
 | 
					        return Task.FromResult<IEnumerable<Model>>(Array.Empty<Model>());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
					    /// <inheritdoc />
 | 
				
			||||||
 | 
					    public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        return Task.FromResult(Enumerable.Empty<Model>());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #endregion
 | 
					    #endregion
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private async Task<IEnumerable<Model>> LoadModels(CancellationToken token, string? apiKeyProvisional = null)
 | 
					    private async Task<IEnumerable<Model>> LoadModels(CancellationToken token, string? apiKeyProvisional = null)
 | 
				
			||||||
 | 
				
			|||||||
@ -53,4 +53,12 @@ public interface IProvider
 | 
				
			|||||||
    /// <param name="token">The cancellation token.</param>
 | 
					    /// <param name="token">The cancellation token.</param>
 | 
				
			||||||
    /// <returns>The list of image models.</returns>
 | 
					    /// <returns>The list of image models.</returns>
 | 
				
			||||||
    public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default);
 | 
					    public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /// <summary>
 | 
				
			||||||
 | 
					    /// Load all possible embedding models that can be used with this provider.
 | 
				
			||||||
 | 
					    /// </summary>
 | 
				
			||||||
 | 
					    /// <param name="apiKeyProvisional">The provisional API key to use. Useful when the user is adding a new provider. When null, the stored API key is used.</param>
 | 
				
			||||||
 | 
					    /// <param name="token">The cancellation token.</param>
 | 
				
			||||||
 | 
					    /// <returns>The list of embedding models.</returns>
 | 
				
			||||||
 | 
					    public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -148,6 +148,37 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    /// <inheritdoc />
 | 
					    /// <inheritdoc />
 | 
				
			||||||
    public async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
					    public async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        var modelResponse = await this.LoadModelList(apiKeyProvisional, token);
 | 
				
			||||||
 | 
					        if(modelResponse == default)
 | 
				
			||||||
 | 
					            return [];
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return modelResponse.Data.Where(n => 
 | 
				
			||||||
 | 
					            !n.Id.StartsWith("code", StringComparison.InvariantCulture) &&
 | 
				
			||||||
 | 
					            !n.Id.Contains("embed", StringComparison.InvariantCulture))
 | 
				
			||||||
 | 
					            .Select(n => new Provider.Model(n.Id, null));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    /// <inheritdoc />
 | 
				
			||||||
 | 
					    public async Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        var modelResponse = await this.LoadModelList(apiKeyProvisional, token);
 | 
				
			||||||
 | 
					        if(modelResponse == default)
 | 
				
			||||||
 | 
					            return [];
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return modelResponse.Data.Where(n => n.Id.Contains("embed", StringComparison.InvariantCulture))
 | 
				
			||||||
 | 
					            .Select(n => new Provider.Model(n.Id, null));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    /// <inheritdoc />
 | 
				
			||||||
 | 
					    public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        return Task.FromResult(Enumerable.Empty<Provider.Model>());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #endregion
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    private async Task<ModelsResponse> LoadModelList(string? apiKeyProvisional, CancellationToken token)
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        var secretKey = apiKeyProvisional switch
 | 
					        var secretKey = apiKeyProvisional switch
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
@ -160,29 +191,16 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
 | 
				
			|||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (secretKey is null)
 | 
					        if (secretKey is null)
 | 
				
			||||||
            return [];
 | 
					            return default;
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        var request = new HttpRequestMessage(HttpMethod.Get, "models");
 | 
					        var request = new HttpRequestMessage(HttpMethod.Get, "models");
 | 
				
			||||||
        request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", secretKey);
 | 
					        request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", secretKey);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        var response = await this.httpClient.SendAsync(request, token);
 | 
					        var response = await this.httpClient.SendAsync(request, token);
 | 
				
			||||||
        if(!response.IsSuccessStatusCode)
 | 
					        if(!response.IsSuccessStatusCode)
 | 
				
			||||||
            return [];
 | 
					            return default;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        var modelResponse = await response.Content.ReadFromJsonAsync<ModelsResponse>(token);
 | 
					        var modelResponse = await response.Content.ReadFromJsonAsync<ModelsResponse>(token);
 | 
				
			||||||
        return modelResponse.Data.Where(n => 
 | 
					        return modelResponse;
 | 
				
			||||||
            !n.Id.StartsWith("code", StringComparison.InvariantCulture) &&
 | 
					 | 
				
			||||||
            !n.Id.Contains("embed", StringComparison.InvariantCulture))
 | 
					 | 
				
			||||||
            .Select(n => new Provider.Model(n.Id, null));
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
 | 
					 | 
				
			||||||
    /// <inheritdoc />
 | 
					 | 
				
			||||||
    public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
        return Task.FromResult(Enumerable.Empty<Provider.Model>());
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    #endregion
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -16,6 +16,8 @@ public class NoProvider : IProvider
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
 | 
					    public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
					    public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatChatThread, [EnumeratorCancellation] CancellationToken token = default)
 | 
					    public async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatChatThread, [EnumeratorCancellation] CancellationToken token = default)
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        await Task.FromResult(0);
 | 
					        await Task.FromResult(0);
 | 
				
			||||||
 | 
				
			|||||||
@ -162,6 +162,12 @@ public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.o
 | 
				
			|||||||
        return this.LoadModels(["dall-e-"], token, apiKeyProvisional);
 | 
					        return this.LoadModels(["dall-e-"], token, apiKeyProvisional);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
					    /// <inheritdoc />
 | 
				
			||||||
 | 
					    public Task<IEnumerable<Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        return this.LoadModels(["text-embedding-"], token, apiKeyProvisional);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #endregion
 | 
					    #endregion
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private async Task<IEnumerable<Model>> LoadModels(string[] prefixes, CancellationToken token, string? apiKeyProvisional = null)
 | 
					    private async Task<IEnumerable<Model>> LoadModels(string[] prefixes, CancellationToken token, string? apiKeyProvisional = null)
 | 
				
			||||||
 | 
				
			|||||||
@ -200,13 +200,16 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
 | 
					 | 
				
			||||||
    /// <inheritdoc />
 | 
					    /// <inheritdoc />
 | 
				
			||||||
    public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
					    public Task<IEnumerable<Provider.Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        return Task.FromResult(Enumerable.Empty<Provider.Model>());
 | 
					        return Task.FromResult(Enumerable.Empty<Provider.Model>());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
 | 
					
 | 
				
			||||||
 | 
					    public Task<IEnumerable<Provider.Model>> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        return Task.FromResult(Enumerable.Empty<Provider.Model>());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #endregion
 | 
					    #endregion
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user