using System.Net;
using System.Net.Http.Headers;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using AIStudio.Chat;
using AIStudio.Provider.Anthropic;
using AIStudio.Provider.OpenAI;
using AIStudio.Provider.SelfHosted;
using AIStudio.Settings;
using AIStudio.Tools.ToolCallingSystem;
using AIStudio.Tools.MIME;
using AIStudio.Tools.PluginSystem;
using AIStudio.Tools.Rust;
using AIStudio.Tools.Services;
using Microsoft.Extensions.DependencyInjection;
using Host = AIStudio.Provider.SelfHosted.Host;
namespace AIStudio.Provider;
///
/// The base class for all providers.
///
public abstract class BaseProvider : IProvider, ISecretId
{
private static string TB(string fallbackEN) => I18N.I.T(fallbackEN, typeof(BaseProvider).Namespace, nameof(BaseProvider));
///
/// The HTTP client to use it for all requests.
///
protected readonly HttpClient httpClient = new();
///
/// The logger to use.
///
private readonly ILogger logger;
static BaseProvider()
{
RUST_SERVICE = Program.RUST_SERVICE;
ENCRYPTION = Program.ENCRYPTION;
}
protected static readonly RustService RUST_SERVICE;
protected static readonly Encryption ENCRYPTION;
protected static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
Converters =
{
new JsonStringEnumConverter(JsonNamingPolicy.SnakeCaseLower),
new AnnotationConverter(),
new MessageBaseConverter(),
new SubContentConverter(),
new SubContentImageSourceConverter(),
new SubContentImageUrlConverter(),
},
AllowTrailingCommas = false
};
///
/// Constructor for the base provider.
///
/// The provider enum value.
/// The base URL for the provider.
/// The logger to use.
protected BaseProvider(LLMProviders provider, string url, ILogger logger)
{
this.logger = logger;
this.Provider = provider;
// Set the base URL:
this.httpClient.BaseAddress = new(url);
}
#region Handling of IProvider, which all providers must implement
///
public LLMProviders Provider { get; }
///
public abstract string Id { get; }
///
public abstract string InstanceName { get; set; }
///
public string AdditionalJsonApiParameters { get; init; } = string.Empty;
///
public abstract IAsyncEnumerable StreamChatCompletion(Model chatModel, ChatThread chatThread, SettingsManager settingsManager, CancellationToken token = default);
///
public abstract IAsyncEnumerable StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, CancellationToken token = default);
///
public abstract Task TranscribeAudioAsync(Model transcriptionModel, string audioFilePath, SettingsManager settingsManager, CancellationToken token = default);
///
public abstract Task>> EmbedTextAsync(Model embeddingModel, SettingsManager settingsManager, CancellationToken token = default, params List texts);
///
public abstract Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default);
///
public abstract Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default);
///
public abstract Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default);
///
public abstract Task> GetTranscriptionModels(string? apiKeyProvisional = null, CancellationToken token = default);
#endregion
///
/// Whether this provider was imported from an enterprise configuration plugin.
///
public bool IsEnterpriseConfiguration { get; init; }
#region Implementation of ISecretId
public string SecretId => this.IsEnterpriseConfiguration ? $"{ISecretId.ENTERPRISE_KEY_PREFIX}::{this.Id}" : this.Id;
public string SecretName => this.InstanceName;
#endregion
///
/// Sends a request and handles rate limiting by exponential backoff.
///
/// A function that builds the request.
/// The cancellation token.
/// The status object of the request.
private async Task SendRequest(Func> requestBuilder, CancellationToken token = default)
{
const int MAX_RETRIES = 6;
const double RETRY_DELAY_SECONDS = 4;
var retry = 0;
var response = default(HttpResponseMessage);
var errorMessage = string.Empty;
while (retry++ < MAX_RETRIES)
{
using var request = await requestBuilder();
//
// Send the request with the ResponseHeadersRead option.
// This allows us to read the stream as soon as the headers are received.
// This is important because we want to stream the responses.
//
// Please notice: We do not dispose the response here. The caller is responsible
// for disposing the response object. This is important because the response
// object is used to read the stream.
var nextResponse = await this.httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, token);
if (nextResponse.IsSuccessStatusCode)
{
response = nextResponse;
break;
}
var errorBody = await nextResponse.Content.ReadAsStringAsync(token);
if (nextResponse.StatusCode is HttpStatusCode.Forbidden)
{
await MessageBus.INSTANCE.SendError(new(Icons.Material.Filled.Block, string.Format(TB("We tried to communicate with the LLM provider '{0}' (type={1}). You might not be able to use this provider from your location. The provider message is: '{2}'"), this.InstanceName, this.Provider, nextResponse.ReasonPhrase)));
this.logger.LogError("Failed request with status code {ResponseStatusCode} (message = '{ResponseReasonPhrase}', error body = '{ErrorBody}').", nextResponse.StatusCode, nextResponse.ReasonPhrase, errorBody);
errorMessage = nextResponse.ReasonPhrase;
break;
}
if(nextResponse.StatusCode is HttpStatusCode.BadRequest)
{
// Check if the error body contains "context" and "token" (case-insensitive),
// which indicates that the context window is likely exceeded:
if(errorBody.Contains("context", StringComparison.InvariantCultureIgnoreCase) &&
errorBody.Contains("token", StringComparison.InvariantCultureIgnoreCase))
{
await MessageBus.INSTANCE.SendError(new(Icons.Material.Filled.CloudOff, string.Format(TB("We tried to communicate with the LLM provider '{0}' (type={1}). The data of the chat, including all file attachments, is probably too large for the selected model and provider. The provider message is: '{2}'"), this.InstanceName, this.Provider, nextResponse.ReasonPhrase)));
}
else
{
await MessageBus.INSTANCE.SendError(new(Icons.Material.Filled.CloudOff, string.Format(TB("We tried to communicate with the LLM provider '{0}' (type={1}). The required message format might be changed. The provider message is: '{2}'"), this.InstanceName, this.Provider, nextResponse.ReasonPhrase)));
}
this.logger.LogError("Failed request with status code {ResponseStatusCode} (message = '{ResponseReasonPhrase}', error body = '{ErrorBody}').", nextResponse.StatusCode, nextResponse.ReasonPhrase, errorBody);
errorMessage = nextResponse.ReasonPhrase;
break;
}
if(nextResponse.StatusCode is HttpStatusCode.NotFound)
{
await MessageBus.INSTANCE.SendError(new(Icons.Material.Filled.CloudOff, string.Format(TB("We tried to communicate with the LLM provider '{0}' (type={1}). Something was not found. The provider message is: '{2}'"), this.InstanceName, this.Provider, nextResponse.ReasonPhrase)));
this.logger.LogError("Failed request with status code {ResponseStatusCode} (message = '{ResponseReasonPhrase}', error body = '{ErrorBody}').", nextResponse.StatusCode, nextResponse.ReasonPhrase, errorBody);
errorMessage = nextResponse.ReasonPhrase;
break;
}
if(nextResponse.StatusCode is HttpStatusCode.Unauthorized)
{
await MessageBus.INSTANCE.SendError(new(Icons.Material.Filled.Key, string.Format(TB("We tried to communicate with the LLM provider '{0}' (type={1}). The API key might be invalid. The provider message is: '{2}'"), this.InstanceName, this.Provider, nextResponse.ReasonPhrase)));
this.logger.LogError("Failed request with status code {ResponseStatusCode} (message = '{ResponseReasonPhrase}', error body = '{ErrorBody}').", nextResponse.StatusCode, nextResponse.ReasonPhrase, errorBody);
errorMessage = nextResponse.ReasonPhrase;
break;
}
if(nextResponse.StatusCode is HttpStatusCode.InternalServerError)
{
await MessageBus.INSTANCE.SendError(new(Icons.Material.Filled.CloudOff, string.Format(TB("We tried to communicate with the LLM provider '{0}' (type={1}). The server might be down or having issues. The provider message is: '{2}'"), this.InstanceName, this.Provider, nextResponse.ReasonPhrase)));
this.logger.LogError("Failed request with status code {ResponseStatusCode} (message = '{ResponseReasonPhrase}', error body = '{ErrorBody}').", nextResponse.StatusCode, nextResponse.ReasonPhrase, errorBody);
errorMessage = nextResponse.ReasonPhrase;
break;
}
if(nextResponse.StatusCode is HttpStatusCode.ServiceUnavailable)
{
await MessageBus.INSTANCE.SendError(new(Icons.Material.Filled.CloudOff, string.Format(TB("We tried to communicate with the LLM provider '{0}' (type={1}). The provider is overloaded. The message is: '{2}'"), this.InstanceName, this.Provider, nextResponse.ReasonPhrase)));
this.logger.LogError("Failed request with status code {ResponseStatusCode} (message = '{ResponseReasonPhrase}', error body = '{ErrorBody}').", nextResponse.StatusCode, nextResponse.ReasonPhrase, errorBody);
errorMessage = nextResponse.ReasonPhrase;
break;
}
errorMessage = nextResponse.ReasonPhrase;
var timeSeconds = Math.Pow(RETRY_DELAY_SECONDS, retry + 1);
if(timeSeconds > 90)
timeSeconds = 90;
this.logger.LogDebug("Failed request with status code {ResponseStatusCode} (message = '{ErrorMessage}'). Retrying in {TimeSeconds:0.00} seconds.", nextResponse.StatusCode, errorMessage, timeSeconds);
await Task.Delay(TimeSpan.FromSeconds(timeSeconds), token);
}
if(retry >= MAX_RETRIES || !string.IsNullOrWhiteSpace(errorMessage))
{
await MessageBus.INSTANCE.SendError(new DataErrorMessage(Icons.Material.Filled.CloudOff, string.Format(TB("We tried to communicate with the LLM provider '{0}' (type={1}). Even after {2} retries, there were some problems with the request. The provider message is: '{3}'."), this.InstanceName, this.Provider, MAX_RETRIES, errorMessage)));
return new HttpRateLimitedStreamResult(false, true, errorMessage ?? $"Failed after {MAX_RETRIES} retries; no provider message available", response);
}
return new HttpRateLimitedStreamResult(true, false, string.Empty, response);
}
///
/// Streams the chat completion from the provider using the Chat Completion API.
///
/// The name of the provider.
/// A function that builds the request.
/// The cancellation token to use.
/// The type of the delta lines inside the stream.
/// The type of the annotation lines inside the stream.
/// The stream of content chunks.
protected async IAsyncEnumerable StreamChatCompletionInternal(string providerName, Func> requestBuilder, [EnumeratorCancellation] CancellationToken token = default) where TDelta : IResponseStreamLine where TAnnotation : IAnnotationStreamLine
{
// Check if annotations are supported:
var annotationSupported = typeof(TAnnotation) != typeof(NoResponsesAnnotationStreamLine) && typeof(TAnnotation) != typeof(NoChatCompletionAnnotationStreamLine);
StreamReader? streamReader = null;
try
{
// Send the request using exponential backoff:
var responseData = await this.SendRequest(requestBuilder, token);
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"The {providerName} chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var providerStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(providerStream);
}
catch(Exception e)
{
await MessageBus.INSTANCE.SendError(new(Icons.Material.Filled.Stream, string.Format(TB("Tried to communicate with the LLM provider '{0}'. There were some problems with the request. The provider message is: '{1}'"), this.InstanceName, e.Message)));
this.logger.LogError($"Failed to stream chat completion from {providerName} '{this.InstanceName}': {e.Message}");
}
if (streamReader is null)
yield break;
//
// Read the stream, line by line:
//
while (true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
await MessageBus.INSTANCE.SendError(new(Icons.Material.Filled.Stream, string.Format(TB("Tried to stream the LLM provider '{0}' answer. There were some problems with the stream. The message is: '{1}'"), this.InstanceName, e.Message)));
this.logger.LogWarning($"Failed to read the end-of-stream state from {providerName} '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if (token.IsCancellationRequested)
{
this.logger.LogWarning($"The user canceled the chat completion for {providerName} '{this.InstanceName}'.");
streamReader.Close();
yield break;
}
//
// Read the next line:
//
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
await MessageBus.INSTANCE.SendError(new(Icons.Material.Filled.Stream, string.Format(TB("Tried to stream the LLM provider '{0}' answer. Was not able to read the stream. The message is: '{1}'"), this.InstanceName, e.Message)));
this.logger.LogError($"Failed to read the stream from {providerName} '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if (string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if (!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
//
// Process annotation lines:
//
if (annotationSupported && line.Contains("""
"annotations":[
""", StringComparison.InvariantCulture))
{
TAnnotation? providerResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
providerResponse = JsonSerializer.Deserialize(jsonData, JSON_SERIALIZER_OPTIONS);
if (providerResponse is null)
continue;
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if (!providerResponse.ContainsSources())
continue;
// Yield the response:
yield return new(string.Empty, providerResponse.GetSources());
}
//
// Process delta lines:
//
else
{
TDelta? providerResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
providerResponse = JsonSerializer.Deserialize(jsonData, JSON_SERIALIZER_OPTIONS);
if (providerResponse is null)
continue;
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if (!providerResponse.ContainsContent())
continue;
// Yield the response:
yield return providerResponse.GetContent();
}
}
streamReader.Dispose();
}
///
/// Streams the chat completion from the provider using the Responses API.
///
/// The name of the provider.
/// A function that builds the request.
/// The cancellation token to use.
/// The type of the delta lines inside the stream.
/// The type of the annotation lines inside the stream.
/// The stream of content chunks.
protected async IAsyncEnumerable StreamResponsesInternal(string providerName, Func> requestBuilder, [EnumeratorCancellation] CancellationToken token = default) where TDelta : IResponseStreamLine where TAnnotation : IAnnotationStreamLine
{
// Check if annotations are supported:
var annotationSupported = typeof(TAnnotation) != typeof(NoResponsesAnnotationStreamLine) && typeof(TAnnotation) != typeof(NoChatCompletionAnnotationStreamLine);
StreamReader? streamReader = null;
try
{
// Send the request using exponential backoff:
var responseData = await this.SendRequest(requestBuilder, token);
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"The {providerName} responses call failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var providerStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(providerStream);
}
catch(Exception e)
{
await MessageBus.INSTANCE.SendError(new(Icons.Material.Filled.Stream, string.Format(TB("Tried to communicate with the LLM provider '{0}'. There were some problems with the request. The provider message is: '{1}'"), this.InstanceName, e.Message)));
this.logger.LogError($"Failed to stream responses from {providerName} '{this.InstanceName}': {e.Message}");
}
if (streamReader is null)
yield break;
//
// Read the stream, line by line:
//
while (true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
await MessageBus.INSTANCE.SendError(new(Icons.Material.Filled.Stream, string.Format(TB("Tried to stream the LLM provider '{0}' answer. There were some problems with the stream. The message is: '{1}'"), this.InstanceName, e.Message)));
this.logger.LogWarning($"Failed to read the end-of-stream state from {providerName} '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if (token.IsCancellationRequested)
{
this.logger.LogWarning($"The user canceled the responses for {providerName} '{this.InstanceName}'.");
streamReader.Close();
yield break;
}
//
// Read the next line:
//
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
await MessageBus.INSTANCE.SendError(new(Icons.Material.Filled.Stream, string.Format(TB("Tried to stream the LLM provider '{0}' answer. Was not able to read the stream. The message is: '{1}'"), this.InstanceName, e.Message)));
this.logger.LogError($"Failed to read the stream from {providerName} '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if (string.IsNullOrWhiteSpace(line))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("event: response.completed", StringComparison.InvariantCulture))
yield break;
//
// Find delta lines:
//
if (line.StartsWith("""
data: {"type":"response.output_text.delta"
""", StringComparison.InvariantCulture))
{
TDelta? providerResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
providerResponse = JsonSerializer.Deserialize(jsonData, JSON_SERIALIZER_OPTIONS);
if (providerResponse is null)
continue;
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if (!providerResponse.ContainsContent())
continue;
// Yield the response:
yield return providerResponse.GetContent();
}
//
// Find annotation added lines:
//
else if (annotationSupported && line.StartsWith(
"""
data: {"type":"response.output_text.annotation.added"
""", StringComparison.InvariantCulture))
{
TAnnotation? providerResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
providerResponse = JsonSerializer.Deserialize(jsonData, JSON_SERIALIZER_OPTIONS);
if (providerResponse is null)
continue;
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if (!providerResponse.ContainsSources())
continue;
// Yield the response:
yield return new(string.Empty, providerResponse.GetSources());
}
}
streamReader.Dispose();
}
///
/// Streams the chat completion from an OpenAI-compatible provider using the Chat Completion API.
///
/// The provider name for logging and error reporting.
/// The selected chat model.
/// The current chat thread.
/// The settings manager.
/// Builds the provider-specific base messages.
/// Builds the provider-specific request body.
/// The secret store type.
/// Whether the API key is optional.
/// The system prompt role to use.
/// The request path, relative to the provider base URL.
/// Optional additional headers to add.
/// The cancellation token.
/// The delta stream line type.
/// The annotation stream line type.
/// The streamed content chunks.
protected async IAsyncEnumerable StreamOpenAICompatibleChatCompletion(
string providerName,
Model chatModel,
ChatThread chatThread,
SettingsManager settingsManager,
Func>> messagesFactory,
Func, IDictionary, bool, IList