From 826d15729b50d11ca754e0f790dd82234aec13b1 Mon Sep 17 00:00:00 2001 From: Thorsten Sommer Date: Sat, 9 Nov 2024 22:04:00 +0100 Subject: [PATCH] Added the Google provider (#201) --- .../Dialogs/ProviderDialog.razor.cs | 5 +- .../Pages/Settings.razor.cs | 2 + .../Provider/Anthropic/ProviderAnthropic.cs | 18 +- .../Provider/Google/ChatRequest.cs | 15 ++ .../Provider/Google/Model.cs | 3 + .../Provider/Google/ModelsResponse.cs | 7 + .../Provider/Google/ProviderGoogle.cs | 188 ++++++++++++++++++ .../Provider/LLMProviders.cs | 1 + .../Provider/LLMProvidersExtensions.cs | 5 + .../Provider/Mistral/ProviderMistral.cs | 2 +- app/MindWork AI Studio/Provider/Model.cs | 26 ++- .../Provider/SelfHosted/ProviderSelfHosted.cs | 4 +- .../wwwroot/changelog/v0.9.18.md | 4 +- 13 files changed, 264 insertions(+), 16 deletions(-) create mode 100644 app/MindWork AI Studio/Provider/Google/ChatRequest.cs create mode 100644 app/MindWork AI Studio/Provider/Google/Model.cs create mode 100644 app/MindWork AI Studio/Provider/Google/ModelsResponse.cs create mode 100644 app/MindWork AI Studio/Provider/Google/ProviderGoogle.cs diff --git a/app/MindWork AI Studio/Dialogs/ProviderDialog.razor.cs b/app/MindWork AI Studio/Dialogs/ProviderDialog.razor.cs index 80090a2..c791606 100644 --- a/app/MindWork AI Studio/Dialogs/ProviderDialog.razor.cs +++ b/app/MindWork AI Studio/Dialogs/ProviderDialog.razor.cs @@ -105,7 +105,7 @@ public partial class ProviderDialog : ComponentBase Id = this.DataId, InstanceName = this.DataInstanceName, UsedLLMProvider = this.DataLLMProvider, - Model = this.DataLLMProvider is LLMProviders.FIREWORKS ? new Model(this.dataManuallyModel) : this.DataModel, + Model = this.DataLLMProvider is LLMProviders.FIREWORKS ? new Model(this.dataManuallyModel, null) : this.DataModel, IsSelfHosted = this.DataLLMProvider is LLMProviders.SELF_HOSTED, Hostname = this.DataHostname.EndsWith('/') ? this.DataHostname[..^1] : this.DataHostname, Host = this.DataHost, @@ -357,6 +357,7 @@ public partial class ProviderDialog : ComponentBase LLMProviders.OPEN_AI => true, LLMProviders.MISTRAL => true, LLMProviders.ANTHROPIC => true, + LLMProviders.GOOGLE => true, LLMProviders.GROQ => true, LLMProviders.FIREWORKS => true, @@ -369,6 +370,7 @@ public partial class ProviderDialog : ComponentBase LLMProviders.OPEN_AI => true, LLMProviders.MISTRAL => true, LLMProviders.ANTHROPIC => true, + LLMProviders.GOOGLE => true, LLMProviders.GROQ => true, LLMProviders.FIREWORKS => true, @@ -414,6 +416,7 @@ public partial class ProviderDialog : ComponentBase LLMProviders.OPEN_AI => "https://platform.openai.com/signup", LLMProviders.MISTRAL => "https://console.mistral.ai/", LLMProviders.ANTHROPIC => "https://console.anthropic.com/dashboard", + LLMProviders.GOOGLE => "https://console.cloud.google.com/", LLMProviders.GROQ => "https://console.groq.com/", LLMProviders.FIREWORKS => "https://fireworks.ai/login", diff --git a/app/MindWork AI Studio/Pages/Settings.razor.cs b/app/MindWork AI Studio/Pages/Settings.razor.cs index 58298a0..8a2e109 100644 --- a/app/MindWork AI Studio/Pages/Settings.razor.cs +++ b/app/MindWork AI Studio/Pages/Settings.razor.cs @@ -133,6 +133,7 @@ public partial class Settings : ComponentBase, IMessageBusReceiver, IDisposable LLMProviders.ANTHROPIC => true, LLMProviders.GROQ => true, LLMProviders.FIREWORKS => true, + LLMProviders.GOOGLE => true, _ => false, }; @@ -143,6 +144,7 @@ public partial class Settings : ComponentBase, IMessageBusReceiver, IDisposable LLMProviders.MISTRAL => "https://console.mistral.ai/usage/", LLMProviders.ANTHROPIC => "https://console.anthropic.com/settings/plans", LLMProviders.GROQ => "https://console.groq.com/settings/usage", + LLMProviders.GOOGLE => "https://console.cloud.google.com/billing", LLMProviders.FIREWORKS => "https://fireworks.ai/account/billing", _ => string.Empty, diff --git a/app/MindWork AI Studio/Provider/Anthropic/ProviderAnthropic.cs b/app/MindWork AI Studio/Provider/Anthropic/ProviderAnthropic.cs index 04c0f5c..8c7c9d7 100644 --- a/app/MindWork AI Studio/Provider/Anthropic/ProviderAnthropic.cs +++ b/app/MindWork AI Studio/Provider/Anthropic/ProviderAnthropic.cs @@ -147,18 +147,18 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap { return Task.FromResult(new[] { - new Model("claude-3-5-sonnet-latest"), - new Model("claude-3-5-sonnet-20240620"), - new Model("claude-3-5-sonnet-20241022"), + new Model("claude-3-5-sonnet-latest", "Claude 3.5 Sonnet (latest)"), + new Model("claude-3-5-sonnet-20240620", "Claude 3.5 Sonnet (20. June 2024)"), + new Model("claude-3-5-sonnet-20241022", "Claude 3.5 Sonnet (22. October 2024)"), - new Model("claude-3-5-haiku-latest"), - new Model("claude-3-5-heiku-20241022"), + new Model("claude-3-5-haiku-latest", "Claude 3.5 Haiku (latest)"), + new Model("claude-3-5-heiku-20241022", "Claude 3.5 Haiku (22. October 2024)"), - new Model("claude-3-opus-20240229"), - new Model("claude-3-opus-latest"), + new Model("claude-3-opus-20240229", "Claude 3.0 Opus (29. February 2024)"), + new Model("claude-3-opus-latest", "Claude 3.0 Opus (latest)"), - new Model("claude-3-sonnet-20240229"), - new Model("claude-3-haiku-20240307"), + new Model("claude-3-sonnet-20240229", "Claude 3.0 Sonnet (29. February 2024)"), + new Model("claude-3-haiku-20240307", "Claude 3.0 Haiku (7. March 2024)"), }.AsEnumerable()); } diff --git a/app/MindWork AI Studio/Provider/Google/ChatRequest.cs b/app/MindWork AI Studio/Provider/Google/ChatRequest.cs new file mode 100644 index 0000000..36b4abd --- /dev/null +++ b/app/MindWork AI Studio/Provider/Google/ChatRequest.cs @@ -0,0 +1,15 @@ +using AIStudio.Provider.OpenAI; + +namespace AIStudio.Provider.Google; + +/// +/// The Google chat request model. +/// +/// Which model to use for chat completion. +/// The chat messages. +/// Whether to stream the chat completion. +public readonly record struct ChatRequest( + string Model, + IList Messages, + bool Stream +); \ No newline at end of file diff --git a/app/MindWork AI Studio/Provider/Google/Model.cs b/app/MindWork AI Studio/Provider/Google/Model.cs new file mode 100644 index 0000000..f1a5328 --- /dev/null +++ b/app/MindWork AI Studio/Provider/Google/Model.cs @@ -0,0 +1,3 @@ +namespace AIStudio.Provider.Google; + +public readonly record struct Model(string Name, string DisplayName); \ No newline at end of file diff --git a/app/MindWork AI Studio/Provider/Google/ModelsResponse.cs b/app/MindWork AI Studio/Provider/Google/ModelsResponse.cs new file mode 100644 index 0000000..01cb81f --- /dev/null +++ b/app/MindWork AI Studio/Provider/Google/ModelsResponse.cs @@ -0,0 +1,7 @@ +namespace AIStudio.Provider.Google; + +/// +/// A data model for the response from the model endpoint. +/// +/// +public readonly record struct ModelsResponse(IList Models); \ No newline at end of file diff --git a/app/MindWork AI Studio/Provider/Google/ProviderGoogle.cs b/app/MindWork AI Studio/Provider/Google/ProviderGoogle.cs new file mode 100644 index 0000000..9e523dd --- /dev/null +++ b/app/MindWork AI Studio/Provider/Google/ProviderGoogle.cs @@ -0,0 +1,188 @@ +using System.Net.Http.Headers; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; + +using AIStudio.Chat; +using AIStudio.Provider.OpenAI; + +namespace AIStudio.Provider.Google; + +public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativelanguage.googleapis.com/v1beta/", logger), IProvider +{ + private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() + { + PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, + }; + + #region Implementation of IProvider + + /// + public string Id => "Google"; + + /// + public string InstanceName { get; set; } = "Google Gemini"; + + /// + public async IAsyncEnumerable StreamChatCompletion(Provider.Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default) + { + // Get the API key: + var requestedSecret = await RUST_SERVICE.GetAPIKey(this); + if(!requestedSecret.Success) + yield break; + + // Prepare the system prompt: + var systemPrompt = new Message + { + Role = "system", + Content = chatThread.SystemPrompt, + }; + + // Prepare the Google HTTP chat request: + var geminiChatRequest = JsonSerializer.Serialize(new ChatRequest + { + Model = chatModel.Id, + + // Build the messages: + // - First of all the system prompt + // - Then none-empty user and AI messages + Messages = [systemPrompt, ..chatThread.Blocks.Where(n => n.ContentType is ContentType.TEXT && !string.IsNullOrWhiteSpace((n.Content as ContentText)?.Text)).Select(n => new Message + { + Role = n.Role switch + { + ChatRole.USER => "user", + ChatRole.AI => "assistant", + ChatRole.AGENT => "assistant", + ChatRole.SYSTEM => "system", + + _ => "user", + }, + + Content = n.Content switch + { + ContentText text => text.Text, + _ => string.Empty, + } + }).ToList()], + + // Right now, we only support streaming completions: + Stream = true, + }, JSON_SERIALIZER_OPTIONS); + + // Build the HTTP post request: + var request = new HttpRequestMessage(HttpMethod.Post, "chat/completions"); + + // Set the authorization header: + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION)); + + // Set the content: + request.Content = new StringContent(geminiChatRequest, Encoding.UTF8, "application/json"); + + // Send the request with the ResponseHeadersRead option. + // This allows us to read the stream as soon as the headers are received. + // This is important because we want to stream the responses. + var response = await this.httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, token); + + // Open the response stream: + var geminiStream = await response.Content.ReadAsStreamAsync(token); + + // Add a stream reader to read the stream, line by line: + var streamReader = new StreamReader(geminiStream); + + // Read the stream, line by line: + while(!streamReader.EndOfStream) + { + // Check if the token is canceled: + if(token.IsCancellationRequested) + yield break; + + // Read the next line: + var line = await streamReader.ReadLineAsync(token); + + // Skip empty lines: + if(string.IsNullOrWhiteSpace(line)) + continue; + + // Skip lines that do not start with "data: ". Regard + // to the specification, we only want to read the data lines: + if(!line.StartsWith("data: ", StringComparison.InvariantCulture)) + continue; + + // Check if the line is the end of the stream: + if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture)) + yield break; + + ResponseStreamLine geminiResponse; + try + { + // We know that the line starts with "data: ". Hence, we can + // skip the first 6 characters to get the JSON data after that. + var jsonData = line[6..]; + + // Deserialize the JSON data: + geminiResponse = JsonSerializer.Deserialize(jsonData, JSON_SERIALIZER_OPTIONS); + } + catch + { + // Skip invalid JSON data: + continue; + } + + // Skip empty responses: + if(geminiResponse == default || geminiResponse.Choices.Count == 0) + continue; + + // Yield the response: + yield return geminiResponse.Choices[0].Delta.Content; + } + } + + #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + /// + public async IAsyncEnumerable StreamImageCompletion(Provider.Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) + { + yield break; + } + #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + + /// + public Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) + { + return this.LoadModels(token, apiKeyProvisional); + } + + /// + public Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) + { + return Task.FromResult(Enumerable.Empty()); + } + + #endregion + + private async Task> LoadModels(CancellationToken token, string? apiKeyProvisional = null) + { + var secretKey = apiKeyProvisional switch + { + not null => apiKeyProvisional, + _ => await RUST_SERVICE.GetAPIKey(this) switch + { + { Success: true } result => await result.Secret.Decrypt(ENCRYPTION), + _ => null, + } + }; + + if (secretKey is null) + return []; + + var request = new HttpRequestMessage(HttpMethod.Get, $"models?key={secretKey}"); + var response = await this.httpClient.SendAsync(request, token); + + if(!response.IsSuccessStatusCode) + return []; + + var modelResponse = await response.Content.ReadFromJsonAsync(token); + return modelResponse.Models.Where(model => + model.Name.StartsWith("models/gemini-", StringComparison.InvariantCultureIgnoreCase)) + .Select(n => new Provider.Model(n.Name.Replace("models/", string.Empty), n.DisplayName)); + } +} \ No newline at end of file diff --git a/app/MindWork AI Studio/Provider/LLMProviders.cs b/app/MindWork AI Studio/Provider/LLMProviders.cs index 375ca6e..24657cc 100644 --- a/app/MindWork AI Studio/Provider/LLMProviders.cs +++ b/app/MindWork AI Studio/Provider/LLMProviders.cs @@ -10,6 +10,7 @@ public enum LLMProviders OPEN_AI = 1, ANTHROPIC = 2, MISTRAL = 3, + GOOGLE = 7, FIREWORKS = 5, GROQ = 6, diff --git a/app/MindWork AI Studio/Provider/LLMProvidersExtensions.cs b/app/MindWork AI Studio/Provider/LLMProvidersExtensions.cs index f83c9c3..2dc84b5 100644 --- a/app/MindWork AI Studio/Provider/LLMProvidersExtensions.cs +++ b/app/MindWork AI Studio/Provider/LLMProvidersExtensions.cs @@ -1,5 +1,6 @@ using AIStudio.Provider.Anthropic; using AIStudio.Provider.Fireworks; +using AIStudio.Provider.Google; using AIStudio.Provider.Groq; using AIStudio.Provider.Mistral; using AIStudio.Provider.OpenAI; @@ -22,6 +23,7 @@ public static class LLMProvidersExtensions LLMProviders.OPEN_AI => "OpenAI", LLMProviders.ANTHROPIC => "Anthropic", LLMProviders.MISTRAL => "Mistral", + LLMProviders.GOOGLE => "Google", LLMProviders.GROQ => "Groq", LLMProviders.FIREWORKS => "Fireworks.ai", @@ -50,6 +52,8 @@ public static class LLMProvidersExtensions "https://openai.com/enterprise-privacy/" ).WithLevel(settingsManager.GetConfiguredConfidenceLevel(llmProvider)), + LLMProviders.GOOGLE => Confidence.USA_NO_TRAINING.WithRegion("America, U.S.").WithSources("https://ai.google.dev/gemini-api/terms").WithLevel(settingsManager.GetConfiguredConfidenceLevel(llmProvider)), + LLMProviders.GROQ => Confidence.USA_NO_TRAINING.WithRegion("America, U.S.").WithSources("https://wow.groq.com/terms-of-use/").WithLevel(settingsManager.GetConfiguredConfidenceLevel(llmProvider)), LLMProviders.ANTHROPIC => Confidence.USA_NO_TRAINING.WithRegion("America, U.S.").WithSources("https://www.anthropic.com/legal/commercial-terms").WithLevel(settingsManager.GetConfiguredConfidenceLevel(llmProvider)), @@ -76,6 +80,7 @@ public static class LLMProvidersExtensions LLMProviders.OPEN_AI => new ProviderOpenAI(logger) { InstanceName = providerSettings.InstanceName }, LLMProviders.ANTHROPIC => new ProviderAnthropic(logger) { InstanceName = providerSettings.InstanceName }, LLMProviders.MISTRAL => new ProviderMistral(logger) { InstanceName = providerSettings.InstanceName }, + LLMProviders.GOOGLE => new ProviderGoogle(logger) { InstanceName = providerSettings.InstanceName }, LLMProviders.GROQ => new ProviderGroq(logger) { InstanceName = providerSettings.InstanceName }, LLMProviders.FIREWORKS => new ProviderFireworks(logger) { InstanceName = providerSettings.InstanceName }, diff --git a/app/MindWork AI Studio/Provider/Mistral/ProviderMistral.cs b/app/MindWork AI Studio/Provider/Mistral/ProviderMistral.cs index 86d94e8..23296669 100644 --- a/app/MindWork AI Studio/Provider/Mistral/ProviderMistral.cs +++ b/app/MindWork AI Studio/Provider/Mistral/ProviderMistral.cs @@ -173,7 +173,7 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api. return modelResponse.Data.Where(n => !n.Id.StartsWith("code", StringComparison.InvariantCulture) && !n.Id.Contains("embed", StringComparison.InvariantCulture)) - .Select(n => new Provider.Model(n.Id)); + .Select(n => new Provider.Model(n.Id, null)); } #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously diff --git a/app/MindWork AI Studio/Provider/Model.cs b/app/MindWork AI Studio/Provider/Model.cs index af39709..ff66933 100644 --- a/app/MindWork AI Studio/Provider/Model.cs +++ b/app/MindWork AI Studio/Provider/Model.cs @@ -4,11 +4,33 @@ namespace AIStudio.Provider; /// The data model for the model to use. /// /// The model's ID. -public readonly record struct Model(string Id) +/// The model's display name. +public readonly record struct Model(string Id, string? DisplayName) { #region Overrides of ValueType - public override string ToString() => string.IsNullOrWhiteSpace(this.Id) ? "no model selected" : this.Id; + public override string ToString() + { + if(!string.IsNullOrWhiteSpace(this.DisplayName)) + return this.DisplayName; + + if(!string.IsNullOrWhiteSpace(this.Id)) + return this.Id; + + return "no model selected"; + } + + #endregion + + #region Implementation of IEquatable + + public bool Equals(Model? other) + { + if(other is null) + return false; + + return this.Id == other.Value.Id; + } #endregion } \ No newline at end of file diff --git a/app/MindWork AI Studio/Provider/SelfHosted/ProviderSelfHosted.cs b/app/MindWork AI Studio/Provider/SelfHosted/ProviderSelfHosted.cs index 9c595da..f50b34c 100644 --- a/app/MindWork AI Studio/Provider/SelfHosted/ProviderSelfHosted.cs +++ b/app/MindWork AI Studio/Provider/SelfHosted/ProviderSelfHosted.cs @@ -164,7 +164,7 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide case Host.LLAMACPP: // Right now, llama.cpp only supports one model. // There is no API to list the model(s). - return [ new Provider.Model("as configured by llama.cpp") ]; + return [ new Provider.Model("as configured by llama.cpp", null) ]; case Host.LM_STUDIO: case Host.OLLAMA: @@ -188,7 +188,7 @@ public sealed class ProviderSelfHosted(ILogger logger, Settings.Provider provide return []; var lmStudioModelResponse = await lmStudioResponse.Content.ReadFromJsonAsync(token); - return lmStudioModelResponse.Data.Select(n => new Provider.Model(n.Id)); + return lmStudioModelResponse.Data.Select(n => new Provider.Model(n.Id, null)); } return []; diff --git a/app/MindWork AI Studio/wwwroot/changelog/v0.9.18.md b/app/MindWork AI Studio/wwwroot/changelog/v0.9.18.md index bf26d20..1d51270 100644 --- a/app/MindWork AI Studio/wwwroot/changelog/v0.9.18.md +++ b/app/MindWork AI Studio/wwwroot/changelog/v0.9.18.md @@ -1,3 +1,5 @@ # v0.9.18, build 193 (2024-11-xx xx:xx UTC) - Added new Anthropic model `claude-3-5-heiku-20241022` as well as the alias `claude-3-5-heiku-latest`. -- Added [Groq](https://console.groq.com/) as a new provider option. \ No newline at end of file +- Added [Groq](https://console.groq.com/) as a provider option. +- Added Google Gemini as a provider option. +- Improved model display while configuring a provider when the provider supports display names. \ No newline at end of file