Refactored network code (#253)

This commit is contained in:
Thorsten Sommer 2025-01-04 14:11:32 +01:00 committed by GitHub
parent a54c0bdbbf
commit 33a2728644
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 224 additions and 740 deletions

View File

@ -10,11 +10,6 @@ namespace AIStudio.Provider.Anthropic;
public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://api.anthropic.com/v1/", logger) public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://api.anthropic.com/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider #region Implementation of IProvider
public override string Id => LLMProviders.ANTHROPIC.ToName(); public override string Id => LLMProviders.ANTHROPIC.ToName();
@ -60,9 +55,6 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap
Stream = true, Stream = true,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
StreamReader? streamReader = null;
try
{
async Task<HttpRequestMessage> RequestBuilder() async Task<HttpRequestMessage> RequestBuilder()
{ {
// Build the HTTP post request: // Build the HTTP post request:
@ -79,98 +71,8 @@ public sealed class ProviderAnthropic(ILogger logger) : BaseProvider("https://ap
return request; return request;
} }
// Send the request using exponential backoff: await foreach (var content in this.StreamChatCompletionInternal<ResponseStreamLine>("Anthropic", RequestBuilder, token))
var responseData = await this.SendRequest(RequestBuilder, token); yield return content;
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"Anthropic chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var stream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(stream);
}
catch (Exception e)
{
this.logger.LogError($"Failed to stream chat completion from Anthropic '{this.InstanceName}': {e.Message}");
}
if (streamReader is null)
yield break;
// Read the stream, line by line:
while(true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from Anthropic '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if(token.IsCancellationRequested)
yield break;
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from Anthropic '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if(string.IsNullOrWhiteSpace(line))
continue;
// Check for the end of the stream:
if(line.StartsWith("event: message_stop", StringComparison.InvariantCulture))
yield break;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if(!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Ignore any type except "content_block_delta":
if(!line.Contains("\"content_block_delta\"", StringComparison.InvariantCulture))
continue;
ResponseStreamLine anthropicResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
anthropicResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(anthropicResponse == default || string.IsNullOrWhiteSpace(anthropicResponse.Delta.Text))
continue;
// Yield the response:
yield return anthropicResponse.Delta.Text;
}
} }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

View File

@ -7,7 +7,14 @@ namespace AIStudio.Provider.Anthropic;
/// <param name="Type">The type of the response line.</param> /// <param name="Type">The type of the response line.</param>
/// <param name="Index">The index of the response line.</param> /// <param name="Index">The index of the response line.</param>
/// <param name="Delta">The delta of the response line.</param> /// <param name="Delta">The delta of the response line.</param>
public readonly record struct ResponseStreamLine(string Type, int Index, Delta Delta); public readonly record struct ResponseStreamLine(string Type, int Index, Delta Delta) : IResponseStreamLine
{
/// <inheritdoc />
public bool ContainsContent() => this != default && !string.IsNullOrWhiteSpace(this.Delta.Text);
/// <inheritdoc />
public string GetContent() => this.Delta.Text;
}
/// <summary> /// <summary>
/// The delta object of a response line. /// The delta object of a response line.

View File

@ -1,4 +1,6 @@
using System.Net; using System.Net;
using System.Runtime.CompilerServices;
using System.Text.Json;
using AIStudio.Chat; using AIStudio.Chat;
using AIStudio.Settings; using AIStudio.Settings;
@ -32,6 +34,11 @@ public abstract class BaseProvider : IProvider, ISecretId
protected static readonly Encryption ENCRYPTION; protected static readonly Encryption ENCRYPTION;
protected static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
/// <summary> /// <summary>
/// Constructor for the base provider. /// Constructor for the base provider.
/// </summary> /// </summary>
@ -127,4 +134,105 @@ public abstract class BaseProvider : IProvider, ISecretId
return new HttpRateLimitedStreamResult(true, false, string.Empty, response); return new HttpRateLimitedStreamResult(true, false, string.Empty, response);
} }
protected async IAsyncEnumerable<string> StreamChatCompletionInternal<T>(string providerName, Func<Task<HttpRequestMessage>> requestBuilder, [EnumeratorCancellation] CancellationToken token = default) where T : struct, IResponseStreamLine
{
StreamReader? streamReader = null;
try
{
// Send the request using exponential backoff:
var responseData = await this.SendRequest(requestBuilder, token);
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"The {providerName} chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var providerStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(providerStream);
}
catch(Exception e)
{
this.logger.LogError($"Failed to stream chat completion from {providerName} '{this.InstanceName}': {e.Message}");
}
if (streamReader is null)
yield break;
// Read the stream, line by line:
while (true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from {providerName} '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if (token.IsCancellationRequested)
{
this.logger.LogWarning($"The user canceled the chat completion for {providerName} '{this.InstanceName}'.");
streamReader.Close();
yield break;
}
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from {providerName} '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if (string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if (!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
T providerResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
providerResponse = JsonSerializer.Deserialize<T>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if (!providerResponse.ContainsContent())
continue;
// Yield the response:
yield return providerResponse.GetContent();
}
streamReader.Dispose();
}
} }

View File

@ -10,11 +10,6 @@ namespace AIStudio.Provider.Fireworks;
public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.fireworks.ai/inference/v1/", logger) public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.fireworks.ai/inference/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider #region Implementation of IProvider
/// <inheritdoc /> /// <inheritdoc />
@ -69,9 +64,6 @@ public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.firew
Stream = true, Stream = true,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
StreamReader? streamReader = null;
try
{
async Task<HttpRequestMessage> RequestBuilder() async Task<HttpRequestMessage> RequestBuilder()
{ {
// Build the HTTP post request: // Build the HTTP post request:
@ -85,94 +77,8 @@ public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.firew
return request; return request;
} }
// Send the request using exponential backoff: await foreach (var content in this.StreamChatCompletionInternal<ResponseStreamLine>("Fireworks", RequestBuilder, token))
var responseData = await this.SendRequest(RequestBuilder, token); yield return content;
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"Fireworks chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var fireworksStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(fireworksStream);
}
catch (Exception e)
{
this.logger.LogError($"Failed to stream chat completion from Fireworks '{this.InstanceName}': {e.Message}");
}
if (streamReader is null)
yield break;
// Read the stream, line by line:
while(true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from Fireworks '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if(token.IsCancellationRequested)
yield break;
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from Fireworks '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if(string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if(!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine fireworksResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
fireworksResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(fireworksResponse == default || fireworksResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return fireworksResponse.Choices[0].Delta.Content;
}
} }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

View File

@ -8,7 +8,14 @@ namespace AIStudio.Provider.Fireworks;
/// <param name="Created">The timestamp of the response.</param> /// <param name="Created">The timestamp of the response.</param>
/// <param name="Model">The model used for the response.</param> /// <param name="Model">The model used for the response.</param>
/// <param name="Choices">The choices made by the AI.</param> /// <param name="Choices">The choices made by the AI.</param>
public readonly record struct ResponseStreamLine(string Id, string Object, uint Created, string Model, IList<Choice> Choices); public readonly record struct ResponseStreamLine(string Id, string Object, uint Created, string Model, IList<Choice> Choices) : IResponseStreamLine
{
/// <inheritdoc />
public bool ContainsContent() => this != default && this.Choices.Count > 0;
/// <inheritdoc />
public string GetContent() => this.Choices[0].Delta.Content;
}
/// <summary> /// <summary>
/// Data model for a choice made by the AI. /// Data model for a choice made by the AI.

View File

@ -11,11 +11,6 @@ namespace AIStudio.Provider.Google;
public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativelanguage.googleapis.com/v1beta/", logger) public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativelanguage.googleapis.com/v1beta/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider #region Implementation of IProvider
/// <inheritdoc /> /// <inheritdoc />
@ -70,9 +65,6 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
Stream = true, Stream = true,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
StreamReader? streamReader = null;
try
{
async Task<HttpRequestMessage> RequestBuilder() async Task<HttpRequestMessage> RequestBuilder()
{ {
// Build the HTTP post request: // Build the HTTP post request:
@ -86,94 +78,8 @@ public class ProviderGoogle(ILogger logger) : BaseProvider("https://generativela
return request; return request;
} }
// Send the request using exponential backoff: await foreach (var content in this.StreamChatCompletionInternal<ResponseStreamLine>("Google", RequestBuilder, token))
var responseData = await this.SendRequest(RequestBuilder, token); yield return content;
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"Google chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var geminiStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(geminiStream);
}
catch (Exception e)
{
this.logger.LogError($"Failed to stream chat completion from Google '{this.InstanceName}': {e.Message}");
}
if (streamReader is null)
yield break;
// Read the stream, line by line:
while(true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from Google '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if(token.IsCancellationRequested)
yield break;
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from Google '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if(string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if(!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine geminiResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
geminiResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(geminiResponse == default || geminiResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return geminiResponse.Choices[0].Delta.Content;
}
} }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

View File

@ -11,11 +11,6 @@ namespace AIStudio.Provider.Groq;
public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/openai/v1/", logger) public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/openai/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider #region Implementation of IProvider
/// <inheritdoc /> /// <inheritdoc />
@ -72,9 +67,6 @@ public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/o
Stream = true, Stream = true,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
StreamReader? streamReader = null;
try
{
async Task<HttpRequestMessage> RequestBuilder() async Task<HttpRequestMessage> RequestBuilder()
{ {
// Build the HTTP post request: // Build the HTTP post request:
@ -88,94 +80,8 @@ public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/o
return request; return request;
} }
// Send the request using exponential backoff: await foreach (var content in this.StreamChatCompletionInternal<ResponseStreamLine>("Groq", RequestBuilder, token))
var responseData = await this.SendRequest(RequestBuilder, token); yield return content;
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"Groq chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var groqStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(groqStream);
}
catch (Exception e)
{
this.logger.LogError($"Failed to stream chat completion from Groq '{this.InstanceName}': {e.Message}");
}
if (streamReader is null)
yield break;
// Read the stream, line by line:
while(true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from Groq '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if(token.IsCancellationRequested)
yield break;
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from Groq '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if(string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if(!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine groqResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
groqResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(groqResponse == default || groqResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return groqResponse.Choices[0].Delta.Content;
}
} }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

View File

@ -0,0 +1,16 @@
namespace AIStudio.Provider;
public interface IResponseStreamLine
{
/// <summary>
/// Checks if the response line contains any content.
/// </summary>
/// <returns>True when the response line contains content, false otherwise.</returns>
public bool ContainsContent();
/// <summary>
/// Gets the content of the response line.
/// </summary>
/// <returns>The content of the response line.</returns>
public string GetContent();
}

View File

@ -11,11 +11,6 @@ namespace AIStudio.Provider.Mistral;
public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.mistral.ai/v1/", logger) public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.mistral.ai/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider #region Implementation of IProvider
public override string Id => LLMProviders.MISTRAL.ToName(); public override string Id => LLMProviders.MISTRAL.ToName();
@ -71,9 +66,6 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
SafePrompt = false, SafePrompt = false,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
StreamReader? streamReader = null;
try
{
async Task<HttpRequestMessage> RequestBuilder() async Task<HttpRequestMessage> RequestBuilder()
{ {
// Build the HTTP post request: // Build the HTTP post request:
@ -87,94 +79,8 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
return request; return request;
} }
// Send the request using exponential backoff: await foreach (var content in this.StreamChatCompletionInternal<ResponseStreamLine>("Mistral", RequestBuilder, token))
var responseData = await this.SendRequest(RequestBuilder, token); yield return content;
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"Mistral chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var mistralStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(mistralStream);
}
catch (Exception e)
{
this.logger.LogError($"Failed to stream chat completion from Mistral '{this.InstanceName}': {e.Message}");
}
if (streamReader is null)
yield break;
// Read the stream, line by line:
while(true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from Mistral '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if(token.IsCancellationRequested)
yield break;
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from Mistral '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if(string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if(!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine mistralResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
mistralResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(mistralResponse == default || mistralResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return mistralResponse.Choices[0].Delta.Content;
}
} }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

View File

@ -13,11 +13,6 @@ namespace AIStudio.Provider.OpenAI;
/// </summary> /// </summary>
public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.openai.com/v1/", logger) public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.openai.com/v1/", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider #region Implementation of IProvider
/// <inheritdoc /> /// <inheritdoc />
@ -99,9 +94,6 @@ public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.o
Stream = true, Stream = true,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
StreamReader? streamReader = null;
try
{
async Task<HttpRequestMessage> RequestBuilder() async Task<HttpRequestMessage> RequestBuilder()
{ {
// Build the HTTP post request: // Build the HTTP post request:
@ -115,94 +107,8 @@ public sealed class ProviderOpenAI(ILogger logger) : BaseProvider("https://api.o
return request; return request;
} }
// Send the request using exponential backoff: await foreach (var content in this.StreamChatCompletionInternal<ResponseStreamLine>("OpenAI", RequestBuilder, token))
var responseData = await this.SendRequest(RequestBuilder, token); yield return content;
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"OpenAI chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var openAIStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(openAIStream);
}
catch (Exception e)
{
this.logger.LogError($"Failed to stream chat completion from OpenAI '{this.InstanceName}': {e.Message}");
}
if (streamReader is null)
yield break;
// Read the stream, line by line:
while(true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from OpenAI '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if(token.IsCancellationRequested)
yield break;
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from OpenAI '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if(string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if(!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine openAIResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
openAIResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if(openAIResponse == default || openAIResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return openAIResponse.Choices[0].Delta.Content;
}
} }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

View File

@ -9,7 +9,14 @@ namespace AIStudio.Provider.OpenAI;
/// <param name="Model">The model used for the response.</param> /// <param name="Model">The model used for the response.</param>
/// <param name="SystemFingerprint">The system fingerprint; together with the seed, this allows you to reproduce the response.</param> /// <param name="SystemFingerprint">The system fingerprint; together with the seed, this allows you to reproduce the response.</param>
/// <param name="Choices">The choices made by the AI.</param> /// <param name="Choices">The choices made by the AI.</param>
public readonly record struct ResponseStreamLine(string Id, string Object, uint Created, string Model, string SystemFingerprint, IList<Choice> Choices); public readonly record struct ResponseStreamLine(string Id, string Object, uint Created, string Model, string SystemFingerprint, IList<Choice> Choices) : IResponseStreamLine
{
/// <inheritdoc />
public bool ContainsContent() => this != default && this.Choices.Count > 0;
/// <inheritdoc />
public string GetContent() => this.Choices[0].Delta.Content;
}
/// <summary> /// <summary>
/// Data model for a choice made by the AI. /// Data model for a choice made by the AI.

View File

@ -11,11 +11,6 @@ namespace AIStudio.Provider.SelfHosted;
public sealed class ProviderSelfHosted(ILogger logger, Host host, string hostname) : BaseProvider($"{hostname}{host.BaseURL()}", logger) public sealed class ProviderSelfHosted(ILogger logger, Host host, string hostname) : BaseProvider($"{hostname}{host.BaseURL()}", logger)
{ {
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new()
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};
#region Implementation of IProvider #region Implementation of IProvider
public override string Id => LLMProviders.SELF_HOSTED.ToName(); public override string Id => LLMProviders.SELF_HOSTED.ToName();
@ -67,9 +62,6 @@ public sealed class ProviderSelfHosted(ILogger logger, Host host, string hostnam
MaxTokens = -1, MaxTokens = -1,
}, JSON_SERIALIZER_OPTIONS); }, JSON_SERIALIZER_OPTIONS);
StreamReader? streamReader = null;
try
{
async Task<HttpRequestMessage> RequestBuilder() async Task<HttpRequestMessage> RequestBuilder()
{ {
// Build the HTTP post request: // Build the HTTP post request:
@ -84,96 +76,8 @@ public sealed class ProviderSelfHosted(ILogger logger, Host host, string hostnam
return request; return request;
} }
// Send the request using exponential backoff: await foreach (var content in this.StreamChatCompletionInternal<ResponseStreamLine>("self-hosted provider", RequestBuilder, token))
var responseData = await this.SendRequest(RequestBuilder, token); yield return content;
if(responseData.IsFailedAfterAllRetries)
{
this.logger.LogError($"Self-hosted provider's chat completion failed: {responseData.ErrorMessage}");
yield break;
}
// Open the response stream:
var providerStream = await responseData.Response!.Content.ReadAsStreamAsync(token);
// Add a stream reader to read the stream, line by line:
streamReader = new StreamReader(providerStream);
}
catch(Exception e)
{
this.logger.LogError($"Failed to stream chat completion from self-hosted provider '{this.InstanceName}': {e.Message}");
}
if (streamReader is null)
yield break;
// Read the stream, line by line:
while (true)
{
try
{
if(streamReader.EndOfStream)
break;
}
catch (Exception e)
{
this.logger.LogWarning($"Failed to read the end-of-stream state from self-hosted provider '{this.InstanceName}': {e.Message}");
break;
}
// Check if the token is canceled:
if (token.IsCancellationRequested)
yield break;
// Read the next line:
string? line;
try
{
line = await streamReader.ReadLineAsync(token);
}
catch (Exception e)
{
this.logger.LogError($"Failed to read the stream from self-hosted provider '{this.InstanceName}': {e.Message}");
break;
}
// Skip empty lines:
if (string.IsNullOrWhiteSpace(line))
continue;
// Skip lines that do not start with "data: ". Regard
// to the specification, we only want to read the data lines:
if (!line.StartsWith("data: ", StringComparison.InvariantCulture))
continue;
// Check if the line is the end of the stream:
if (line.StartsWith("data: [DONE]", StringComparison.InvariantCulture))
yield break;
ResponseStreamLine providerResponse;
try
{
// We know that the line starts with "data: ". Hence, we can
// skip the first 6 characters to get the JSON data after that.
var jsonData = line[6..];
// Deserialize the JSON data:
providerResponse = JsonSerializer.Deserialize<ResponseStreamLine>(jsonData, JSON_SERIALIZER_OPTIONS);
}
catch
{
// Skip invalid JSON data:
continue;
}
// Skip empty responses:
if (providerResponse == default || providerResponse.Choices.Count == 0)
continue;
// Yield the response:
yield return providerResponse.Choices[0].Delta.Content;
}
streamReader.Dispose();
} }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

View File

@ -0,0 +1,3 @@
# v0.9.25, build 200 (2025-01-xx xx:xx UTC)
- Improved the stop generation button behavior to ensure that the AI stops generating content immediately (which will save compute time, energy and financial resources).
- Restructured the streaming network code to be centralized out of the individual providers. This will allow for easier maintenance and updates in the future.