Added Groq provider (#200)

This commit is contained in:
Thorsten Sommer 2024-11-09 20:13:14 +01:00 committed by GitHub
parent 119100bbce
commit 2ae40c59a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 224 additions and 2 deletions

View File

@ -3,5 +3,6 @@
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=LLM/@EntryIndexedValue">LLM</s:String> <s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=LLM/@EntryIndexedValue">LLM</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=LM/@EntryIndexedValue">LM</s:String> <s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=LM/@EntryIndexedValue">LM</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=MSG/@EntryIndexedValue">MSG</s:String> <s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=MSG/@EntryIndexedValue">MSG</s:String>
<s:Boolean x:Key="/Default/UserDictionary/Words/=groq/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=ollama/@EntryIndexedValue">True</s:Boolean> <s:Boolean x:Key="/Default/UserDictionary/Words/=ollama/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=tauri_0027s/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary> <s:Boolean x:Key="/Default/UserDictionary/Words/=tauri_0027s/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>

View File

@ -358,6 +358,7 @@ public partial class ProviderDialog : ComponentBase
LLMProviders.MISTRAL => true, LLMProviders.MISTRAL => true,
LLMProviders.ANTHROPIC => true, LLMProviders.ANTHROPIC => true,
LLMProviders.GROQ => true,
LLMProviders.FIREWORKS => true, LLMProviders.FIREWORKS => true,
_ => false, _ => false,
@ -369,7 +370,9 @@ public partial class ProviderDialog : ComponentBase
LLMProviders.MISTRAL => true, LLMProviders.MISTRAL => true,
LLMProviders.ANTHROPIC => true, LLMProviders.ANTHROPIC => true,
LLMProviders.GROQ => true,
LLMProviders.FIREWORKS => true, LLMProviders.FIREWORKS => true,
LLMProviders.SELF_HOSTED => this.DataHost is Host.OLLAMA, LLMProviders.SELF_HOSTED => this.DataHost is Host.OLLAMA,
_ => false, _ => false,
@ -411,7 +414,8 @@ public partial class ProviderDialog : ComponentBase
LLMProviders.OPEN_AI => "https://platform.openai.com/signup", LLMProviders.OPEN_AI => "https://platform.openai.com/signup",
LLMProviders.MISTRAL => "https://console.mistral.ai/", LLMProviders.MISTRAL => "https://console.mistral.ai/",
LLMProviders.ANTHROPIC => "https://console.anthropic.com/dashboard", LLMProviders.ANTHROPIC => "https://console.anthropic.com/dashboard",
LLMProviders.GROQ => "https://console.groq.com/",
LLMProviders.FIREWORKS => "https://fireworks.ai/login", LLMProviders.FIREWORKS => "https://fireworks.ai/login",
_ => string.Empty, _ => string.Empty,

View File

@ -131,6 +131,7 @@ public partial class Settings : ComponentBase, IMessageBusReceiver, IDisposable
LLMProviders.OPEN_AI => true, LLMProviders.OPEN_AI => true,
LLMProviders.MISTRAL => true, LLMProviders.MISTRAL => true,
LLMProviders.ANTHROPIC => true, LLMProviders.ANTHROPIC => true,
LLMProviders.GROQ => true,
LLMProviders.FIREWORKS => true, LLMProviders.FIREWORKS => true,
_ => false, _ => false,
@ -141,6 +142,7 @@ public partial class Settings : ComponentBase, IMessageBusReceiver, IDisposable
LLMProviders.OPEN_AI => "https://platform.openai.com/usage", LLMProviders.OPEN_AI => "https://platform.openai.com/usage",
LLMProviders.MISTRAL => "https://console.mistral.ai/usage/", LLMProviders.MISTRAL => "https://console.mistral.ai/usage/",
LLMProviders.ANTHROPIC => "https://console.anthropic.com/settings/plans", LLMProviders.ANTHROPIC => "https://console.anthropic.com/settings/plans",
LLMProviders.GROQ => "https://console.groq.com/settings/usage",
LLMProviders.FIREWORKS => "https://fireworks.ai/account/billing", LLMProviders.FIREWORKS => "https://fireworks.ai/account/billing",
_ => string.Empty, _ => string.Empty,

View File

@ -0,0 +1,17 @@
using AIStudio.Provider.OpenAI;
namespace AIStudio.Provider.Groq;
/// <summary>
/// The Groq chat request model.
/// </summary>
/// <param name="Model">Which model to use for chat completion.</param>
/// <param name="Messages">The chat messages.</param>
/// <param name="Stream">Whether to stream the chat completion.</param>
/// <param name="Seed">The seed for the chat completion.</param>
public readonly record struct ChatRequest(
string Model,
IList<Message> Messages,
bool Stream,
int Seed
);

View File

@ -0,0 +1,191 @@
using System.Net.Http.Headers;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using AIStudio.Chat;
using AIStudio.Provider.OpenAI;
namespace AIStudio.Provider.Groq;
public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/openai/v1/", logger), IProvider
{
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider
/// <inheritdoc />
public string Id => "Groq";
/// <inheritdoc />
public string InstanceName { get; set; } = "Groq";
/// <inheritdoc />
public async IAsyncEnumerable<string> StreamChatCompletion(Model chatModel, ChatThread chatThread, [EnumeratorCancellation] CancellationToken token = default)
{
// Get the API key:
var requestedSecret = await RUST_SERVICE.GetAPIKey(this);
if(!requestedSecret.Success)
yield break;
// Prepare the system prompt:
var systemPrompt = new Message
{
Role = "system",
Content = chatThread.SystemPrompt,
};
// Prepare the OpenAI HTTP chat request:
var groqChatRequest = JsonSerializer.Serialize(new ChatRequest
{
Model = chatModel.Id,
// Build the messages:
// - First of all the system prompt
// - Then none-empty user and AI messages
Messages = [systemPrompt, ..chatThread.Blocks.Where(n => n.ContentType is ContentType.TEXT && !string.IsNullOrWhiteSpace((n.Content as ContentText)?.Text)).Select(n => new Message
{
Role = n.Role switch
{
ChatRole.USER => "user",
ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant",
ChatRole.SYSTEM => "system",
_ => "user",
},
Content = n.Content switch
{
ContentText text => text.Text,
_ => string.Empty,
}
}).ToList()],
Seed = chatThread.Seed,
// Right now, we only support streaming completions:
Stream = true,
}, JSON_SERIALIZER_OPTIONS);
// Build the HTTP post request:
var request = new HttpRequestMessage(HttpMethod.Post, "chat/completions");
// Set the authorization header:
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION));
// Set the content:
request.Content = new StringContent(groqChatRequest, Encoding.UTF8, "application/json");
// Send the request with the ResponseHeadersRead option.
// This allows us to read the stream as soon as the headers are received.
// This is important because we want to stream the responses.
var response = await this.httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, token);
// Open the response stream:
var groqStream = await response.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
var streamReader = new StreamReader(groqStream);
// Read the stream, line by line:
while(!streamReader.EndOfStream)
{
// Check if the token is canceled:
if(token.IsCancellationRequested)
yield break;
// Read the next line:
var line = await streamReader.ReadLineAsync(token);
// Skip empty lines:
if(string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if(!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine groqResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
groqResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(groqResponse == default || groqResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return groqResponse.Choices[0].Delta.Content;
}
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc />
public async IAsyncEnumerable<ImageURL> StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default)
{
yield break;
}
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc />
public Task<IEnumerable<Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
return this.LoadModels(token, apiKeyProvisional);
}
/// <inheritdoc />
public Task<IEnumerable<Model>> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
return Task.FromResult<IEnumerable<Model>>(Array.Empty<Model>());
}
#endregion
private async Task<IEnumerable<Model>> LoadModels(CancellationToken token, string? apiKeyProvisional = null)
{
var secretKey = apiKeyProvisional switch
{
not null => apiKeyProvisional,
_ => await RUST_SERVICE.GetAPIKey(this) switch
{
{ Success: true } result => await result.Secret.Decrypt(ENCRYPTION),
_ => null,
}
};
if (secretKey is null)
return [];
var request = new HttpRequestMessage(HttpMethod.Get, "models");
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", secretKey);
var response = await this.httpClient.SendAsync(request, token);
if(!response.IsSuccessStatusCode)
return [];
var modelResponse = await response.Content.ReadFromJsonAsync<ModelsResponse>(token);
return modelResponse.Data.Where(n =>
!n.Id.StartsWith("whisper-", StringComparison.InvariantCultureIgnoreCase) &&
!n.Id.StartsWith("distil-", StringComparison.InvariantCultureIgnoreCase));
}
}

View File

@ -12,6 +12,7 @@ public enum LLMProviders
MISTRAL = 3, MISTRAL = 3,
FIREWORKS = 5, FIREWORKS = 5,
GROQ = 6,
SELF_HOSTED = 4, SELF_HOSTED = 4,
} }

View File

@ -1,5 +1,6 @@
using AIStudio.Provider.Anthropic; using AIStudio.Provider.Anthropic;
using AIStudio.Provider.Fireworks; using AIStudio.Provider.Fireworks;
using AIStudio.Provider.Groq;
using AIStudio.Provider.Mistral; using AIStudio.Provider.Mistral;
using AIStudio.Provider.OpenAI; using AIStudio.Provider.OpenAI;
using AIStudio.Provider.SelfHosted; using AIStudio.Provider.SelfHosted;
@ -22,6 +23,7 @@ public static class LLMProvidersExtensions
LLMProviders.ANTHROPIC => "Anthropic", LLMProviders.ANTHROPIC => "Anthropic",
LLMProviders.MISTRAL => "Mistral", LLMProviders.MISTRAL => "Mistral",
LLMProviders.GROQ => "Groq",
LLMProviders.FIREWORKS => "Fireworks.ai", LLMProviders.FIREWORKS => "Fireworks.ai",
LLMProviders.SELF_HOSTED => "Self-hosted", LLMProviders.SELF_HOSTED => "Self-hosted",
@ -48,6 +50,8 @@ public static class LLMProvidersExtensions
"https://openai.com/enterprise-privacy/" "https://openai.com/enterprise-privacy/"
).WithLevel(settingsManager.GetConfiguredConfidenceLevel(llmProvider)), ).WithLevel(settingsManager.GetConfiguredConfidenceLevel(llmProvider)),
LLMProviders.GROQ => Confidence.USA_NO_TRAINING.WithRegion("America, U.S.").WithSources("https://wow.groq.com/terms-of-use/").WithLevel(settingsManager.GetConfiguredConfidenceLevel(llmProvider)),
LLMProviders.ANTHROPIC => Confidence.USA_NO_TRAINING.WithRegion("America, U.S.").WithSources("https://www.anthropic.com/legal/commercial-terms").WithLevel(settingsManager.GetConfiguredConfidenceLevel(llmProvider)), LLMProviders.ANTHROPIC => Confidence.USA_NO_TRAINING.WithRegion("America, U.S.").WithSources("https://www.anthropic.com/legal/commercial-terms").WithLevel(settingsManager.GetConfiguredConfidenceLevel(llmProvider)),
LLMProviders.MISTRAL => Confidence.GDPR_NO_TRAINING.WithRegion("Europe, France").WithSources("https://mistral.ai/terms/#terms-of-service-la-plateforme").WithLevel(settingsManager.GetConfiguredConfidenceLevel(llmProvider)), LLMProviders.MISTRAL => Confidence.GDPR_NO_TRAINING.WithRegion("Europe, France").WithSources("https://mistral.ai/terms/#terms-of-service-la-plateforme").WithLevel(settingsManager.GetConfiguredConfidenceLevel(llmProvider)),
@ -73,6 +77,7 @@ public static class LLMProvidersExtensions
LLMProviders.ANTHROPIC => new ProviderAnthropic(logger) { InstanceName = providerSettings.InstanceName }, LLMProviders.ANTHROPIC => new ProviderAnthropic(logger) { InstanceName = providerSettings.InstanceName },
LLMProviders.MISTRAL => new ProviderMistral(logger) { InstanceName = providerSettings.InstanceName }, LLMProviders.MISTRAL => new ProviderMistral(logger) { InstanceName = providerSettings.InstanceName },
LLMProviders.GROQ => new ProviderGroq(logger) { InstanceName = providerSettings.InstanceName },
LLMProviders.FIREWORKS => new ProviderFireworks(logger) { InstanceName = providerSettings.InstanceName }, LLMProviders.FIREWORKS => new ProviderFireworks(logger) { InstanceName = providerSettings.InstanceName },
LLMProviders.SELF_HOSTED => new ProviderSelfHosted(logger, providerSettings) { InstanceName = providerSettings.InstanceName }, LLMProviders.SELF_HOSTED => new ProviderSelfHosted(logger, providerSettings) { InstanceName = providerSettings.InstanceName },

View File

@ -1,2 +1,3 @@
# v0.9.18, build 193 (2024-11-xx xx:xx UTC) # v0.9.18, build 193 (2024-11-xx xx:xx UTC)
- Added new Anthropic model `claude-3-5-heiku-20241022` as well as the alias `claude-3-5-heiku-latest`. - Added new Anthropic model `claude-3-5-heiku-20241022` as well as the alias `claude-3-5-heiku-latest`.
- Added [Groq](https://console.groq.com/) as a new provider option.