Refactored network code (#253)

This commit is contained in:
Thorsten Sommer 2025-01-04 14:11:32 +01:00 committed by GitHub
parent a54c0bdbbf
commit 33a2728644
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 224 additions and 740 deletions

View File

@ -10,11 +10,6 @@ namespace AIStudio.Provider.Anthropic;
public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://api.anthropic.com/v1/", logger) public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://api.anthropic.com/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider #region Implementation of IProvider
public override string Id => LLMProviders.ANTHROPIC.ToName(); public override string Id => LLMProviders.ANTHROPIC.ToName();
@ -60,117 +55,24 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap
Stream = true, Stream = true,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
StreamReader? streamReader = null; async Task<HttpRequestMessage> RequestBuilder()
try
{ {
async Task<HttpRequestMessage> RequestBuilder() // Build the HTTP post request:
{ var request = new HttpRequestMessage(HttpMethod.Post, "messages");
// Build the HTTP post request:
var request = new HttpRequestMessage(HttpMethod.Post, "messages");
// Set the authorization header: // Set the authorization header:
request.Headers.Add("x-api-key", await requestedSecret.Secret.Decrypt(ENCRYPTION)); request.Headers.Add("x-api-key", await requestedSecret.Secret.Decrypt(ENCRYPTION));
// Set the Anthropic version: // Set the Anthropic version:
request.Headers.Add("anthropic-version", "2023-06-01"); request.Headers.Add("anthropic-version", "2023-06-01");
// Set the content: // Set the content:
request.Content = new StringContent(chatRequest, Encoding.UTF8, "application/json"); request.Content = new StringContent(chatRequest, Encoding.UTF8, "application/json");
return request; return request;
}
// Send the request using exponential backoff:
var responseData = await this.SendRequest(RequestBuilder, token);
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"Anthropic chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var stream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(stream);
}
catch (Exception e)
{
this.logger.LogError($"Failed to stream chat completion from Anthropic '{this.InstanceName}': {e.Message}");
} }
if (streamReader is null) await foreach (var content in this.StreamChatCompletionInternal<ResponseStreamLine>("Anthropic", RequestBuilder, token))
yield break; yield return content;
// Read the stream, line by line:
while(true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from Anthropic '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if(token.IsCancellationRequested)
yield break;
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from Anthropic '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if(string.IsNullOrWhiteSpace(line))
continue;
// Check for the end of the stream:
if(line.StartsWith("event: message_stop", StringComparison.InvariantCulture))
yield break;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if(!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Ignore any type except "content_block_delta":
if(!line.Contains("\"content_block_delta\"", StringComparison.InvariantCulture))
continue;
ResponseStreamLine anthropicResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
anthropicResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(anthropicResponse == default || string.IsNullOrWhiteSpace(anthropicResponse.Delta.Text))
continue;
// Yield the response:
yield return anthropicResponse.Delta.Text;
}
} }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

View File

@ -7,7 +7,14 @@ namespace AIStudio.Provider.Anthropic;
/// <param name="Type">The type of the response line.</param> /// <param name="Type">The type of the response line.</param>
/// <param name="Index">The index of the response line.</param> /// <param name="Index">The index of the response line.</param>
/// <param name="Delta">The delta of the response line.</param> /// <param name="Delta">The delta of the response line.</param>
public readonly record struct ResponseStreamLine(string Type, int Index, Delta Delta); public readonly record struct ResponseStreamLine(string Type, int Index, Delta Delta) : IResponseStreamLine
{
/// <inheritdoc />
public bool ContainsContent() => this != default && !string.IsNullOrWhiteSpace(this.Delta.Text);
/// <inheritdoc />
public string GetContent() => this.Delta.Text;
}
/// <summary> /// <summary>
/// The delta object of a response line. /// The delta object of a response line.

View File

@ -1,4 +1,6 @@
using System.Net; using System.Net;
using System.Runtime.CompilerServices;
using System.Text.Json;
using AIStudio.Chat; using AIStudio.Chat;
using AIStudio.Settings; using AIStudio.Settings;
@ -31,6 +33,11 @@ public abstract class BaseProvider : IProvider, ISecretId
protected static readonly RustService RUST_SERVICE; protected static readonly RustService RUST_SERVICE;
protected static readonly Encryption ENCRYPTION; protected static readonly Encryption ENCRYPTION;
protected static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
/// <summary> /// <summary>
/// Constructor for the base provider. /// Constructor for the base provider.
@ -127,4 +134,105 @@ public abstract class BaseProvider : IProvider, ISecretId
return new HttpRateLimitedStreamResult(true, false, string.Empty, response); return new HttpRateLimitedStreamResult(true, false, string.Empty, response);
} }
protected async IAsyncEnumerable<string> StreamChatCompletionInternal<T>(string providerName, Func<Task<HttpRequestMessage>> requestBuilder, [EnumeratorCancellation] CancellationToken token = default) where T : struct, IResponseStreamLine
{
StreamReader? streamReader = null;
try
{
// Send the request using exponential backoff:
var responseData = await this.SendRequest(requestBuilder, token);
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"The {providerName} chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var providerStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(providerStream);
}
catch(Exception e)
{
this.logger.LogError($"Failed to stream chat completion from {providerName} '{this.InstanceName}': {e.Message}");
}
if (streamReader is null)
yield break;
// Read the stream, line by line:
while (true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from {providerName} '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if (token.IsCancellationRequested)
{
this.logger.LogWarning($"The user canceled the chat completion for {providerName} '{this.InstanceName}'.");
streamReader.Close();
yield break;
}
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from {providerName} '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if (string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if (!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
T providerResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
providerResponse = JsonSerializer.Deserialize<T>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if (!providerResponse.ContainsContent())
continue;
// Yield the response:
yield return providerResponse.GetContent();
}
streamReader.Dispose();
}
} }

View File

@ -10,11 +10,6 @@ namespace AIStudio.Provider.Fireworks;
public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.fireworks.ai/inference/v1/", logger) public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.fireworks.ai/inference/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider #region Implementation of IProvider
/// <inheritdoc /> /// <inheritdoc />
@ -69,110 +64,21 @@ public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.firew
Stream = true, Stream = true,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
StreamReader? streamReader = null; async Task<HttpRequestMessage> RequestBuilder()
try
{ {
async Task<HttpRequestMessage> RequestBuilder() // Build the HTTP post request:
{ var request = new HttpRequestMessage(HttpMethod.Post, "chat/completions");
// Build the HTTP post request:
var request = new HttpRequestMessage(HttpMethod.Post, "chat/completions");
// Set the authorization header: // Set the authorization header:
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION)); request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION));
// Set the content: // Set the content:
request.Content = new StringContent(fireworksChatRequest, Encoding.UTF8, "application/json"); request.Content = new StringContent(fireworksChatRequest, Encoding.UTF8, "application/json");
return request; return request;
}
// Send the request using exponential backoff:
var responseData = await this.SendRequest(RequestBuilder, token);
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"Fireworks chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var fireworksStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(fireworksStream);
}
catch (Exception e)
{
this.logger.LogError($"Failed to stream chat completion from Fireworks '{this.InstanceName}': {e.Message}");
} }
if (streamReader is null) await foreach (var content in this.StreamChatCompletionInternal<ResponseStreamLine>("Fireworks", RequestBuilder, token))
yield break; yield return content;
// Read the stream, line by line:
while(true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from Fireworks '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if(token.IsCancellationRequested)
yield break;
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from Fireworks '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if(string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if(!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine fireworksResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
fireworksResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(fireworksResponse == default || fireworksResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return fireworksResponse.Choices[0].Delta.Content;
}
} }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

View File

@ -8,7 +8,14 @@ namespace AIStudio.Provider.Fireworks;
/// <param name="Created">The timestamp of the response.</param> /// <param name="Created">The timestamp of the response.</param>
/// <param name="Model">The model used for the response.</param> /// <param name="Model">The model used for the response.</param>
/// <param name="Choices">The choices made by the AI.</param> /// <param name="Choices">The choices made by the AI.</param>
public readonly record struct ResponseStreamLine(string Id, string Object, uint Created, string Model, IList<Choice> Choices); public readonly record struct ResponseStreamLine(string Id, string Object, uint Created, string Model, IList<Choice> Choices) : IResponseStreamLine
{
/// <inheritdoc />
public bool ContainsContent() => this != default && this.Choices.Count > 0;
/// <inheritdoc />
public string GetContent() => this.Choices[0].Delta.Content;
}
/// <summary> /// <summary>
/// Data model for a choice made by the AI. /// Data model for a choice made by the AI.

View File

@ -11,11 +11,6 @@ namespace AIStudio.Provider.Google;
public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativelanguage.googleapis.com/v1beta/", logger) public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativelanguage.googleapis.com/v1beta/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider #region Implementation of IProvider
/// <inheritdoc /> /// <inheritdoc />
@ -70,110 +65,21 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
Stream = true, Stream = true,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
StreamReader? streamReader = null; async Task<HttpRequestMessage> RequestBuilder()
try
{ {
async Task<HttpRequestMessage> RequestBuilder() // Build the HTTP post request:
{ var request = new HttpRequestMessage(HttpMethod.Post, "chat/completions");
// Build the HTTP post request:
var request = new HttpRequestMessage(HttpMethod.Post, "chat/completions");
// Set the authorization header: // Set the authorization header:
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION)); request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION));
// Set the content: // Set the content:
request.Content = new StringContent(geminiChatRequest, Encoding.UTF8, "application/json"); request.Content = new StringContent(geminiChatRequest, Encoding.UTF8, "application/json");
return request; return request;
}
// Send the request using exponential backoff:
var responseData = await this.SendRequest(RequestBuilder, token);
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"Google chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var geminiStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(geminiStream);
}
catch (Exception e)
{
this.logger.LogError($"Failed to stream chat completion from Google '{this.InstanceName}': {e.Message}");
} }
if (streamReader is null) await foreach (var content in this.StreamChatCompletionInternal<ResponseStreamLine>("Google", RequestBuilder, token))
yield break; yield return content;
// Read the stream, line by line:
while(true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from Google '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if(token.IsCancellationRequested)
yield break;
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from Google '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if(string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if(!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine geminiResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
geminiResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(geminiResponse == default || geminiResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return geminiResponse.Choices[0].Delta.Content;
}
} }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

View File

@ -11,11 +11,6 @@ namespace AIStudio.Provider.Groq;
public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/openai/v1/", logger) public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/openai/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider #region Implementation of IProvider
/// <inheritdoc /> /// <inheritdoc />
@ -72,110 +67,21 @@ public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/o
Stream = true, Stream = true,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
StreamReader? streamReader = null; async Task<HttpRequestMessage> RequestBuilder()
try
{ {
async Task<HttpRequestMessage> RequestBuilder() // Build the HTTP post request:
{ var request = new HttpRequestMessage(HttpMethod.Post, "chat/completions");
// Build the HTTP post request:
var request = new HttpRequestMessage(HttpMethod.Post, "chat/completions");
// Set the authorization header: // Set the authorization header:
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION)); request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION));
// Set the content: // Set the content:
request.Content = new StringContent(groqChatRequest, Encoding.UTF8, "application/json"); request.Content = new StringContent(groqChatRequest, Encoding.UTF8, "application/json");
return request; return request;
}
// Send the request using exponential backoff:
var responseData = await this.SendRequest(RequestBuilder, token);
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"Groq chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var groqStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(groqStream);
}
catch (Exception e)
{
this.logger.LogError($"Failed to stream chat completion from Groq '{this.InstanceName}': {e.Message}");
} }
if (streamReader is null) await foreach (var content in this.StreamChatCompletionInternal<ResponseStreamLine>("Groq", RequestBuilder, token))
yield break; yield return content;
// Read the stream, line by line:
while(true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from Groq '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if(token.IsCancellationRequested)
yield break;
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from Groq '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if(string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if(!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine groqResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
groqResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(groqResponse == default || groqResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return groqResponse.Choices[0].Delta.Content;
}
} }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

View File

@ -0,0 +1,16 @@
namespace AIStudio.Provider;
public interface IResponseStreamLine
{
/// <summary>
/// Checks if the response line contains any content.
/// </summary>
/// <returns>True when the response line contains content, false otherwise.</returns>
public bool ContainsContent();
/// <summary>
/// Gets the content of the response line.
/// </summary>
/// <returns>The content of the response line.</returns>
public string GetContent();
}

View File

@ -11,11 +11,6 @@ namespace AIStudio.Provider.Mistral;
public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.mistral.ai/v1/", logger) public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.mistral.ai/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider #region Implementation of IProvider
public override string Id => LLMProviders.MISTRAL.ToName(); public override string Id => LLMProviders.MISTRAL.ToName();
@ -71,110 +66,21 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
SafePrompt = false, SafePrompt = false,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
StreamReader? streamReader = null; async Task<HttpRequestMessage> RequestBuilder()
try
{ {
async Task<HttpRequestMessage> RequestBuilder() // Build the HTTP post request:
{ var request = new HttpRequestMessage(HttpMethod.Post, "chat/completions");
// Build the HTTP post request:
var request = new HttpRequestMessage(HttpMethod.Post, "chat/completions");
// Set the authorization header: // Set the authorization header:
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION)); request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION));
// Set the content: // Set the content:
request.Content = new StringContent(mistralChatRequest, Encoding.UTF8, "application/json"); request.Content = new StringContent(mistralChatRequest, Encoding.UTF8, "application/json");
return request; return request;
}
// Send the request using exponential backoff:
var responseData = await this.SendRequest(RequestBuilder, token);
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"Mistral chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var mistralStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(mistralStream);
}
catch (Exception e)
{
this.logger.LogError($"Failed to stream chat completion from Mistral '{this.InstanceName}': {e.Message}");
} }
if (streamReader is null) await foreach (var content in this.StreamChatCompletionInternal<ResponseStreamLine>("Mistral", RequestBuilder, token))
yield break; yield return content;
// Read the stream, line by line:
while(true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from Mistral '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if(token.IsCancellationRequested)
yield break;
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from Mistral '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if(string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if(!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine mistralResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
mistralResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(mistralResponse == default || mistralResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return mistralResponse.Choices[0].Delta.Content;
}
} }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

View File

@ -13,11 +13,6 @@ namespace AIStudio.Provider.OpenAI;
/// </summary> /// </summary>
public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.openai.com/v1/", logger) public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.openai.com/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider #region Implementation of IProvider
/// <inheritdoc /> /// <inheritdoc />
@ -99,110 +94,21 @@ public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.o
Stream = true, Stream = true,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
StreamReader? streamReader = null; async Task<HttpRequestMessage> RequestBuilder()
try
{ {
async Task<HttpRequestMessage> RequestBuilder() // Build the HTTP post request:
{ var request = new HttpRequestMessage(HttpMethod.Post, "chat/completions");
// Build the HTTP post request:
var request = new HttpRequestMessage(HttpMethod.Post, "chat/completions");
// Set the authorization header: // Set the authorization header:
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION)); request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION));
// Set the content: // Set the content:
request.Content = new StringContent(openAIChatRequest, Encoding.UTF8, "application/json"); request.Content = new StringContent(openAIChatRequest, Encoding.UTF8, "application/json");
return request; return request;
}
// Send the request using exponential backoff:
var responseData = await this.SendRequest(RequestBuilder, token);
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"OpenAI chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var openAIStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(openAIStream);
}
catch (Exception e)
{
this.logger.LogError($"Failed to stream chat completion from OpenAI '{this.InstanceName}': {e.Message}");
} }
if (streamReader is null) await foreach (var content in this.StreamChatCompletionInternal<ResponseStreamLine>("OpenAI", RequestBuilder, token))
yield break; yield return content;
// Read the stream, line by line:
while(true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from OpenAI '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if(token.IsCancellationRequested)
yield break;
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from OpenAI '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if(string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if(!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine openAIResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
openAIResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(openAIResponse == default || openAIResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return openAIResponse.Choices[0].Delta.Content;
}
} }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

View File

@ -9,7 +9,14 @@ namespace AIStudio.Provider.OpenAI;
/// <param name="Model">The model used for the response.</param> /// <param name="Model">The model used for the response.</param>
/// <param name="SystemFingerprint">The system fingerprint; together with the seed, this allows you to reproduce the response.</param> /// <param name="SystemFingerprint">The system fingerprint; together with the seed, this allows you to reproduce the response.</param>
/// <param name="Choices">The choices made by the AI.</param> /// <param name="Choices">The choices made by the AI.</param>
public readonly record struct ResponseStreamLine(string Id, string Object, uint Created, string Model, string SystemFingerprint, IList<Choice> Choices); public readonly record struct ResponseStreamLine(string Id, string Object, uint Created, string Model, string SystemFingerprint, IList<Choice> Choices) : IResponseStreamLine
{
/// <inheritdoc />
public bool ContainsContent() => this != default && this.Choices.Count > 0;
/// <inheritdoc />
public string GetContent() => this.Choices[0].Delta.Content;
}
/// <summary> /// <summary>
/// Data model for a choice made by the AI. /// Data model for a choice made by the AI.

View File

@ -11,11 +11,6 @@ namespace AIStudio.Provider.SelfHosted;
public sealed class ProviderSelfHosted(ILogger logger, Host host, string hostname) : BaseProvider($"{hostname}{host.BaseURL()}", logger) public sealed class ProviderSelfHosted(ILogger logger, Host host, string hostname) : BaseProvider($"{hostname}{host.BaseURL()}", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider #region Implementation of IProvider
public override string Id => LLMProviders.SELF_HOSTED.ToName(); public override string Id => LLMProviders.SELF_HOSTED.ToName();
@ -67,113 +62,22 @@ public sealed class ProviderSelfHosted(ILogger logger, Host host, string hostnam
MaxTokens = -1, MaxTokens = -1,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
StreamReader? streamReader = null; async Task<HttpRequestMessage> RequestBuilder()
try
{ {
async Task<HttpRequestMessage> RequestBuilder() // Build the HTTP post request:
{ var request = new HttpRequestMessage(HttpMethod.Post, host.ChatURL());
// Build the HTTP post request:
var request = new HttpRequestMessage(HttpMethod.Post, host.ChatURL());
// Set the authorization header: // Set the authorization header:
if (requestedSecret.Success) if (requestedSecret.Success)
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION)); request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION));
// Set the content: // Set the content:
request.Content = new StringContent(providerChatRequest, Encoding.UTF8, "application/json"); request.Content = new StringContent(providerChatRequest, Encoding.UTF8, "application/json");
return request; return request;
}
// Send the request using exponential backoff:
var responseData = await this.SendRequest(RequestBuilder, token);
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"Self-hosted provider's chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var providerStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(providerStream);
}
catch(Exception e)
{
this.logger.LogError($"Failed to stream chat completion from self-hosted provider '{this.InstanceName}': {e.Message}");
}
if (streamReader is null)
yield break;
// Read the stream, line by line:
while (true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from self-hosted provider '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if (token.IsCancellationRequested)
yield break;
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from self-hosted provider '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if (string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if (!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine providerResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
providerResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if (providerResponse == default || providerResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return providerResponse.Choices[0].Delta.Content;
} }
streamReader.Dispose(); await foreach (var content in this.StreamChatCompletionInternal<ResponseStreamLine>("self-hosted provider", RequestBuilder, token))
yield return content;
} }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

View File

@ -0,0 +1,3 @@
# v0.9.25, build 200 (2025-01-xx xx:xx UTC)
- Improved the stop generation button behavior to ensure that the AI stops generating content immediately (which will save compute time, energy and financial resources).
- Restructured the streaming network code to be centralized out of the individual providers. This will allow for easier maintenance and updates in the future.