2025-01-01 19:11:42 +00:00
using System.Net ;
2024-12-03 14:24:40 +00:00
using AIStudio.Chat ;
2025-01-02 13:50:54 +00:00
using AIStudio.Settings ;
2024-12-03 14:24:40 +00:00
2024-09-01 18:10:03 +00:00
using RustService = AIStudio . Tools . RustService ;
2024-05-04 08:59:13 +00:00
namespace AIStudio.Provider ;
/// <summary>
/// The base class for all providers.
/// </summary>
2024-12-03 14:24:40 +00:00
public abstract class BaseProvider : IProvider , ISecretId
2024-05-04 08:59:13 +00:00
{
/// <summary>
2024-12-03 14:24:40 +00:00
/// The HTTP client to use it for all requests.
2024-05-04 08:59:13 +00:00
/// </summary>
protected readonly HttpClient httpClient = new ( ) ;
2024-09-01 18:10:03 +00:00
/// <summary>
/// The logger to use.
/// </summary>
protected readonly ILogger logger ;
static BaseProvider ( )
{
RUST_SERVICE = Program . RUST_SERVICE ;
ENCRYPTION = Program . ENCRYPTION ;
}
protected static readonly RustService RUST_SERVICE ;
protected static readonly Encryption ENCRYPTION ;
2024-05-04 08:59:13 +00:00
/// <summary>
/// Constructor for the base provider.
/// </summary>
/// <param name="url">The base URL for the provider.</param>
2024-09-01 18:10:03 +00:00
/// <param name="loggerService">The logger service to use.</param>
protected BaseProvider ( string url , ILogger loggerService )
2024-05-04 08:59:13 +00:00
{
2024-09-01 18:10:03 +00:00
this . logger = loggerService ;
2024-05-04 08:59:13 +00:00
// Set the base URL:
this . httpClient . BaseAddress = new ( url ) ;
}
2024-12-03 14:24:40 +00:00
#region Handling of IProvider , which all providers must implement
/// <inheritdoc />
public abstract string Id { get ; }
/// <inheritdoc />
public abstract string InstanceName { get ; set ; }
/// <inheritdoc />
2025-01-02 13:50:54 +00:00
public abstract IAsyncEnumerable < string > StreamChatCompletion ( Model chatModel , ChatThread chatThread , SettingsManager settingsManager , CancellationToken token = default ) ;
2024-12-03 14:24:40 +00:00
/// <inheritdoc />
public abstract IAsyncEnumerable < ImageURL > StreamImageCompletion ( Model imageModel , string promptPositive , string promptNegative = FilterOperator . String . Empty , ImageURL referenceImageURL = default , CancellationToken token = default ) ;
/// <inheritdoc />
public abstract Task < IEnumerable < Model > > GetTextModels ( string? apiKeyProvisional = null , CancellationToken token = default ) ;
/// <inheritdoc />
public abstract Task < IEnumerable < Model > > GetImageModels ( string? apiKeyProvisional = null , CancellationToken token = default ) ;
/// <inheritdoc />
public abstract Task < IEnumerable < Model > > GetEmbeddingModels ( string? apiKeyProvisional = null , CancellationToken token = default ) ;
#endregion
#region Implementation of ISecretId
public string SecretId = > this . Id ;
public string SecretName = > this . InstanceName ;
#endregion
2025-01-01 14:49:27 +00:00
/// <summary>
/// Sends a request and handles rate limiting by exponential backoff.
/// </summary>
/// <param name="requestBuilder">A function that builds the request.</param>
/// <param name="token">The cancellation token.</param>
/// <returns>The status object of the request.</returns>
protected async Task < HttpRateLimitedStreamResult > SendRequest ( Func < Task < HttpRequestMessage > > requestBuilder , CancellationToken token = default )
{
const int MAX_RETRIES = 6 ;
const double RETRY_DELAY_SECONDS = 4 ;
var retry = 0 ;
var response = default ( HttpResponseMessage ) ;
var errorMessage = string . Empty ;
while ( retry + + < MAX_RETRIES )
{
using var request = await requestBuilder ( ) ;
// Send the request with the ResponseHeadersRead option.
// This allows us to read the stream as soon as the headers are received.
// This is important because we want to stream the responses.
var nextResponse = await this . httpClient . SendAsync ( request , HttpCompletionOption . ResponseHeadersRead , token ) ;
if ( nextResponse . IsSuccessStatusCode )
{
response = nextResponse ;
break ;
}
2025-01-01 19:11:42 +00:00
if ( nextResponse . StatusCode is HttpStatusCode . BadRequest )
{
this . logger . LogError ( $"Failed request with status code {nextResponse.StatusCode} (message = '{nextResponse.ReasonPhrase}')." ) ;
errorMessage = nextResponse . ReasonPhrase ;
break ;
}
2025-01-01 14:49:27 +00:00
errorMessage = nextResponse . ReasonPhrase ;
var timeSeconds = Math . Pow ( RETRY_DELAY_SECONDS , retry + 1 ) ;
if ( timeSeconds > 90 )
timeSeconds = 90 ;
this . logger . LogDebug ( $"Failed request with status code {nextResponse.StatusCode} (message = '{errorMessage}'). Retrying in {timeSeconds:0.00} seconds." ) ;
await Task . Delay ( TimeSpan . FromSeconds ( timeSeconds ) , token ) ;
}
2025-01-01 19:11:42 +00:00
if ( retry > = MAX_RETRIES | | ! string . IsNullOrWhiteSpace ( errorMessage ) )
2025-01-01 14:49:27 +00:00
return new HttpRateLimitedStreamResult ( false , true , errorMessage ? ? $"Failed after {MAX_RETRIES} retries; no provider message available" , response ) ;
return new HttpRateLimitedStreamResult ( true , false , string . Empty , response ) ;
}
2024-05-04 08:59:13 +00:00
}