using System.Net;
using AIStudio.Chat;
using AIStudio.Settings;
using RustService = AIStudio.Tools.RustService;
namespace AIStudio.Provider;
///
/// The base class for all providers.
///
public abstract class BaseProvider : IProvider, ISecretId
{
///
/// The HTTP client to use it for all requests.
///
protected readonly HttpClient httpClient = new();
///
/// The logger to use.
///
protected readonly ILogger logger;
static BaseProvider()
{
RUST_SERVICE = Program.RUST_SERVICE;
ENCRYPTION = Program.ENCRYPTION;
}
protected static readonly RustService RUST_SERVICE;
protected static readonly Encryption ENCRYPTION;
///
/// Constructor for the base provider.
///
/// The base URL for the provider.
/// The logger service to use.
protected BaseProvider(string url, ILogger loggerService)
{
this.logger = loggerService;
// Set the base URL:
this.httpClient.BaseAddress = new(url);
}
#region Handling of IProvider, which all providers must implement
///
public abstract string Id { get; }
///
public abstract string InstanceName { get; set; }
///
public abstract IAsyncEnumerable StreamChatCompletion(Model chatModel, ChatThread chatThread, SettingsManager settingsManager, CancellationToken token = default);
///
public abstract IAsyncEnumerable StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, CancellationToken token = default);
///
public abstract Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default);
///
public abstract Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default);
///
public abstract Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default);
#endregion
#region Implementation of ISecretId
public string SecretId => this.Id;
public string SecretName => this.InstanceName;
#endregion
///
/// Sends a request and handles rate limiting by exponential backoff.
///
/// A function that builds the request.
/// The cancellation token.
/// The status object of the request.
protected async Task SendRequest(Func> requestBuilder, CancellationToken token = default)
{
const int MAX_RETRIES = 6;
const double RETRY_DELAY_SECONDS = 4;
var retry = 0;
var response = default(HttpResponseMessage);
var errorMessage = string.Empty;
while (retry++ < MAX_RETRIES)
{
using var request = await requestBuilder();
// Send the request with the ResponseHeadersRead option.
// This allows us to read the stream as soon as the headers are received.
// This is important because we want to stream the responses.
var nextResponse = await this.httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, token);
if (nextResponse.IsSuccessStatusCode)
{
response = nextResponse;
break;
}
if(nextResponse.StatusCode is HttpStatusCode.BadRequest)
{
this.logger.LogError($"Failed request with status code {nextResponse.StatusCode} (message = '{nextResponse.ReasonPhrase}').");
errorMessage = nextResponse.ReasonPhrase;
break;
}
errorMessage = nextResponse.ReasonPhrase;
var timeSeconds = Math.Pow(RETRY_DELAY_SECONDS, retry + 1);
if(timeSeconds > 90)
timeSeconds = 90;
this.logger.LogDebug($"Failed request with status code {nextResponse.StatusCode} (message = '{errorMessage}'). Retrying in {timeSeconds:0.00} seconds.");
await Task.Delay(TimeSpan.FromSeconds(timeSeconds), token);
}
if(retry >= MAX_RETRIES || !string.IsNullOrWhiteSpace(errorMessage))
return new HttpRateLimitedStreamResult(false, true, errorMessage ?? $"Failed after {MAX_RETRIES} retries; no provider message available", response);
return new HttpRateLimitedStreamResult(true, false, string.Empty, response);
}
}