AI-Studio/app/MindWork AI Studio/Provider/OpenAI/ProviderOpenAI.cs

576 lines
25 KiB
C#
Raw Normal View History

using System.Net;
2024-05-04 09:11:23 +00:00
using System.Net.Http.Headers;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using AIStudio.Chat;
using AIStudio.Settings;
using AIStudio.Tools.PluginSystem;
using AIStudio.Tools.Rust;
using AIStudio.Tools.ToolCallingSystem;
using AIStudio.Tools.Services;
using Microsoft.Extensions.DependencyInjection;
2024-05-04 09:11:23 +00:00
namespace AIStudio.Provider.OpenAI;
/// <summary>
/// The OpenAI provider.
/// </summary>
public sealed class ProviderOpenAI() : BaseProvider(LLMProviders.OPEN_AI, new Uri("https://api.openai.com/v1/"), ExternalHttpTrustPolicy.SYSTEM_TRUST_ONLY, LOGGER)
2024-05-04 09:11:23 +00:00
{
2025-09-03 19:25:17 +00:00
private static readonly ILogger<ProviderOpenAI> LOGGER = Program.LOGGER_FACTORY.CreateLogger<ProviderOpenAI>();
private static string TB(string fallbackEN) => I18N.I.T(fallbackEN, typeof(ProviderOpenAI).Namespace, nameof(ProviderOpenAI));
2025-09-03 19:25:17 +00:00
2024-05-04 09:11:23 +00:00
#region Implementation of IProvider
/// <inheritdoc />
2024-12-03 14:24:40 +00:00
public override string Id => LLMProviders.OPEN_AI.ToName();
2024-05-04 09:11:23 +00:00
/// <inheritdoc />
2024-12-03 14:24:40 +00:00
public override string InstanceName { get; set; } = "OpenAI";
2024-05-04 09:11:23 +00:00
/// <inheritdoc />
public override bool HasModelLoadingCapability => true;
protected override ProviderRequestFailureReason ClassifyProviderRequestFailure(HttpStatusCode statusCode, string responseBody)
{
if (statusCode is HttpStatusCode.TooManyRequests && HasInsufficientQuotaError(responseBody))
return ProviderRequestFailureReason.INSUFFICIENT_QUOTA;
return base.ClassifyProviderRequestFailure(statusCode, responseBody);
}
protected override ProviderRequestFailureReason ClassifyProviderRequestFailure(string? errorCode, string? errorType, string? errorMessage, string responseBody)
{
if (IsInsufficientQuota(errorCode) || IsInsufficientQuota(errorType) || HasInsufficientQuotaError(responseBody))
return ProviderRequestFailureReason.INSUFFICIENT_QUOTA;
return base.ClassifyProviderRequestFailure(errorCode, errorType, errorMessage, responseBody);
}
protected override string GetProviderRequestFailureUserMessage(ProviderRequestFailureReason failureReason) => failureReason switch
{
ProviderRequestFailureReason.INSUFFICIENT_QUOTA => TB("It looks like you do not have any API credits left with OpenAI. Please add credits to your account and try again."),
_ => base.GetProviderRequestFailureUserMessage(failureReason),
};
2024-05-04 09:11:23 +00:00
/// <inheritdoc />
public override async IAsyncEnumerable<ContentStreamChunk> StreamChatCompletion(Model chatModel, ChatThread chatThread, SettingsManager settingsManager, [EnumeratorCancellation] CancellationToken token = default)
2024-05-04 09:11:23 +00:00
{
// Get the API key:
var requestedSecret = await RUST_SERVICE.GetAPIKey(this, SecretStoreType.LLM_PROVIDER);
2024-05-04 09:11:23 +00:00
if(!requestedSecret.Success)
yield break;
2025-01-01 19:11:42 +00:00
// Unfortunately, OpenAI changed the name of the system prompt based on the model.
2025-09-03 08:08:04 +00:00
// All models that start with "o" (the omni aka reasoning models), all GPT4o models,
// and all newer models have the system prompt named "developer". All other models
// have the system prompt named "system". We need to check this to get the correct
// system prompt.
2025-01-01 19:11:42 +00:00
//
// To complicate it even more: The early versions of reasoning models, which are released
// before the 17th of December 2024, have no system prompt at all. We need to check this
// as well.
// Apply the basic rule first:
2025-09-03 08:08:04 +00:00
var systemPromptRole =
chatModel.Id.StartsWith('o') ||
chatModel.Id.StartsWith("gpt-5", StringComparison.Ordinal) ||
chatModel.Id.Contains("4o") ? "developer" : "system";
2025-01-01 19:11:42 +00:00
// Check if the model is an early version of the reasoning models:
systemPromptRole = chatModel.Id switch
{
"o1-mini" => "user",
"o1-mini-2024-09-12" => "user",
"o1-preview" => "user",
"o1-preview-2024-09-12" => "user",
_ => systemPromptRole,
};
2024-05-04 09:11:23 +00:00
2025-09-03 08:08:04 +00:00
// Read the model capabilities:
2025-12-30 17:30:32 +00:00
var modelCapabilities = this.Provider.GetModelCapabilities(chatModel);
2025-09-03 08:08:04 +00:00
// Check if we are using the Responses API or the Chat Completion API:
var usingResponsesAPI = modelCapabilities.Contains(Capability.RESPONSES_API);
// Prepare the request path based on the API we are using:
var requestPath = usingResponsesAPI ? "responses" : "chat/completions";
2025-09-03 19:25:17 +00:00
LOGGER.LogInformation("Using the system prompt role '{SystemPromptRole}' and the '{RequestPath}' API for model '{ChatModelId}'.", systemPromptRole, requestPath, chatModel.Id);
2025-09-03 08:08:04 +00:00
//
// Prepare the tools we want to use:
//
var providerConfidence = this.Provider.GetConfidence(settingsManager).Level;
var minimumWebSearchConfidence = settingsManager.GetMinimumProviderConfidenceForTool(ToolSelectionRules.WEB_SEARCH_TOOL_ID);
var isWebSearchAllowed = ToolSelectionRules.IsProviderConfidenceAllowed(providerConfidence, minimumWebSearchConfidence);
IList<object> providerTools = modelCapabilities.Contains(Capability.WEB_SEARCH) && isWebSearchAllowed
? [ ProviderTools.WEB_SEARCH ]
: [];
2024-05-04 09:11:23 +00:00
// Parse the API parameters:
var apiParameters = this.ParseAdditionalApiParameters("input", "store", "tools");
if (!usingResponsesAPI)
{
2026-06-03 14:43:06 +00:00
await foreach (var content in this.StreamOpenAICompatibleChatCompletion<ChatCompletionAPIRequest, ChatCompletionDeltaStreamLine, ChatCompletionAnnotationStreamLine>(
"OpenAI",
chatModel,
chatThread,
settingsManager,
2026-06-03 14:43:06 +00:00
async (systemPrompt, apiParameters, tools) =>
{
var messages = await chatThread.Blocks.BuildMessagesAsync(
this.Provider,
chatModel,
role => role switch
{
2026-06-03 14:43:06 +00:00
ChatRole.USER => "user",
ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant",
ChatRole.SYSTEM => systemPromptRole,
_ => "user",
},
2026-06-03 14:43:06 +00:00
text => new SubContentText
{
Text = text,
},
async attachment => new SubContentImageUrlNested
{
ImageUrl = new SubContentImageUrlData
{
Url = await attachment.TryAsBase64(token: token) is (true, var base64Content)
? $"data:{attachment.DetermineMimeType()};base64,{base64Content}"
: string.Empty,
},
});
return new ChatCompletionAPIRequest
{
Model = chatModel.Id,
Messages = [systemPrompt, ..messages],
Stream = true,
Tools = tools,
ParallelToolCalls = tools is null ? null : true,
AdditionalApiParameters = apiParameters,
};
},
systemPromptRole: systemPromptRole,
requestPath: "chat/completions",
token: token))
yield return content;
yield break;
}
// Prepare the system prompt:
var systemPrompt = new TextMessage
{
Role = systemPromptRole,
Content = chatThread.PrepareSystemPrompt(settingsManager),
};
// Build the list of messages:
2025-12-30 17:30:32 +00:00
var messages = await chatThread.Blocks.BuildMessagesAsync(
this.Provider, chatModel,
role => role switch
{
ChatRole.USER => "user",
ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant",
ChatRole.SYSTEM => systemPromptRole,
_ => "user",
},
text => new SubContentInputText
{
Text = text,
2025-12-30 17:30:32 +00:00
},
async attachment => new SubContentInputImage
2025-12-30 17:30:32 +00:00
{
ImageUrl = await attachment.TryAsBase64(token: token) is (true, var base64Content)
? $"data:{attachment.DetermineMimeType()};base64,{base64Content}"
: string.Empty,
2025-12-30 17:30:32 +00:00
});
var baseInput = new List<object> { systemPrompt };
baseInput.AddRange(messages.Cast<object>());
var toolRegistry = Program.SERVICE_PROVIDER.GetService<ToolRegistry>();
var toolExecutor = Program.SERVICE_PROVIDER.GetService<ToolExecutor>();
var currentAssistantContent = chatThread.Blocks.LastOrDefault(x => x.Role is ChatRole.AI)?.Content as ContentText;
currentAssistantContent?.ToolInvocations.Clear();
IReadOnlyList<(ToolDefinition Definition, IToolImplementation Implementation)> runnableTools = toolRegistry is null
? []
: await toolRegistry.GetRunnableToolsAsync(
chatThread.RuntimeComponent,
chatThread.RuntimeSelectedToolIds,
modelCapabilities,
providerConfidence,
settingsManager.IsToolSelectionVisible(chatThread.RuntimeComponent));
if (usingResponsesAPI && toolExecutor is not null && runnableTools.Count > 0)
{
await foreach (var content in this.StreamResponsesWithLocalTools(
chatModel,
baseInput,
apiParameters,
runnableTools,
toolExecutor,
currentAssistantContent,
requestedSecret,
providerConfidence,
token))
yield return content;
yield break;
}
if (runnableTools.Count > 0)
providerTools = [];
2025-09-03 08:08:04 +00:00
//
// Create the request: either for the Responses API or the Chat Completion API
//
var openAIChatRequest = usingResponsesAPI switch
2024-05-04 09:11:23 +00:00
{
2025-09-03 08:08:04 +00:00
// Chat Completion API request:
false => JsonSerializer.Serialize(new ChatCompletionAPIRequest
2024-05-04 09:11:23 +00:00
{
2025-09-03 08:08:04 +00:00
Model = chatModel.Id,
2025-12-28 13:10:20 +00:00
// All messages go into the messages field:
Messages = [systemPrompt, ..messages],
2025-09-03 08:08:04 +00:00
// Right now, we only support streaming completions:
Stream = true,
AdditionalApiParameters = apiParameters
2025-09-03 08:08:04 +00:00
}, JSON_SERIALIZER_OPTIONS),
// Responses API request:
true => JsonSerializer.Serialize(new ResponsesAPIRequest
{
Model = chatModel.Id,
2025-12-28 13:10:20 +00:00
// All messages go into the input field:
Input = baseInput,
2024-05-04 09:11:23 +00:00
2025-09-03 08:08:04 +00:00
// Right now, we only support streaming completions:
Stream = true,
// We do not want to store any data on OpenAI's servers:
Store = false,
// Tools we want to use:
Tools = providerTools,
2025-09-03 08:08:04 +00:00
// Additional API parameters:
AdditionalApiParameters = apiParameters
2025-09-03 08:08:04 +00:00
}, JSON_SERIALIZER_OPTIONS),
};
2025-12-30 17:30:32 +00:00
2025-01-04 13:11:32 +00:00
async Task<HttpRequestMessage> RequestBuilder()
2025-01-01 14:49:27 +00:00
{
2025-01-04 13:11:32 +00:00
// Build the HTTP post request:
2025-09-03 08:08:04 +00:00
var request = new HttpRequestMessage(HttpMethod.Post, requestPath);
2025-01-01 14:49:27 +00:00
2025-01-04 13:11:32 +00:00
// Set the authorization header:
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION));
2025-01-01 14:49:27 +00:00
2025-01-04 13:11:32 +00:00
// Set the content:
request.Content = new StringContent(openAIChatRequest, Encoding.UTF8, "application/json");
return request;
}
2025-09-03 08:08:04 +00:00
if (usingResponsesAPI)
await foreach (var content in this.StreamResponsesInternal<ResponsesDeltaStreamLine, ResponsesAnnotationStreamLine>("OpenAI", RequestBuilder, token))
yield return content;
2024-05-04 09:11:23 +00:00
2025-09-03 08:08:04 +00:00
else
await foreach (var content in this.StreamChatCompletionInternal<ChatCompletionDeltaStreamLine, ChatCompletionAnnotationStreamLine>("OpenAI", RequestBuilder, token))
yield return content;
2024-05-04 09:11:23 +00:00
}
private async IAsyncEnumerable<ContentStreamChunk> StreamResponsesWithLocalTools(
Model chatModel,
IList<object> baseInput,
IDictionary<string, object> apiParameters,
IReadOnlyList<(ToolDefinition Definition, IToolImplementation Implementation)> runnableTools,
ToolExecutor toolExecutor,
ContentText? currentAssistantContent,
RequestedSecret requestedSecret,
ConfidenceLevel providerConfidence,
[EnumeratorCancellation] CancellationToken token)
{
var providerTools = runnableTools
.Select(x => (object)ProviderToolAdapters.ToResponsesTool(x.Definition))
.ToList();
var internalItems = new List<object>();
var toolCallCount = 0;
while (true)
{
var requestDto = new ResponsesAPIRequest
{
Model = chatModel.Id,
Input = [..baseInput, ..internalItems],
Stream = false,
Store = false,
Tools = providerTools,
AdditionalApiParameters = apiParameters,
};
var response = await this.ExecuteResponsesRequest(requestDto, requestedSecret, token);
if (response is null)
{
if (currentAssistantContent is not null)
{
currentAssistantContent.ToolRuntimeStatus = new();
await currentAssistantContent.StreamingEvent();
}
yield break;
}
var functionCalls = response.GetFunctionCalls();
if (functionCalls.Count == 0)
{
if (currentAssistantContent is not null)
{
currentAssistantContent.ToolRuntimeStatus = new();
await currentAssistantContent.StreamingEvent();
}
var textOutput = response.GetTextOutput();
if (!string.IsNullOrWhiteSpace(textOutput))
yield return new ContentStreamChunk(textOutput, []);
else if (toolCallCount > 0)
yield return new ContentStreamChunk("The model completed the tool call but did not return a final answer.", []);
yield break;
}
if (currentAssistantContent is not null)
{
currentAssistantContent.ToolRuntimeStatus = new ToolRuntimeStatus
{
IsRunning = true,
ToolNames = functionCalls
.Select(x => runnableTools.FirstOrDefault(tool => tool.Definition.Function.Name.Equals(x.Name, StringComparison.Ordinal)).Implementation?.GetDisplayName() ?? x.Name)
.ToList(),
};
await currentAssistantContent.StreamingEvent();
}
foreach (var outputItem in response.Output)
internalItems.Add(outputItem);
foreach (var functionCall in functionCalls)
{
toolCallCount++;
if (toolCallCount > 10)
{
var limitMessage = "Tool calling stopped because the maximum of 10 tool calls was reached.";
currentAssistantContent?.ToolInvocations.Add(new ToolInvocationTrace
{
Order = toolCallCount,
ToolId = functionCall.Name,
ToolName = functionCall.Name,
ToolCallId = functionCall.CallId,
Status = ToolInvocationTraceStatus.BLOCKED,
StatusMessage = limitMessage,
Result = limitMessage,
});
if (currentAssistantContent is not null)
{
currentAssistantContent.ToolRuntimeStatus = new();
await currentAssistantContent.StreamingEvent();
}
yield return new ContentStreamChunk(limitMessage, []);
yield break;
}
var (toolContent, trace) = await toolExecutor.ExecuteAsync(
functionCall.CallId,
functionCall.Name,
functionCall.Arguments,
runnableTools,
providerConfidence,
toolCallCount,
token);
currentAssistantContent?.ToolInvocations.Add(trace);
internalItems.Add(new ResponsesFunctionCallOutputItem
{
CallId = functionCall.CallId,
Output = toolContent,
});
}
if (currentAssistantContent is not null)
await currentAssistantContent.StreamingEvent();
}
}
private async Task<ResponsesResponse?> ExecuteResponsesRequest(ResponsesAPIRequest requestDto, RequestedSecret requestedSecret, CancellationToken token)
{
using var request = new HttpRequestMessage(HttpMethod.Post, "responses");
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION));
request.Content = new StringContent(JsonSerializer.Serialize(requestDto, JSON_SERIALIZER_OPTIONS), Encoding.UTF8, "application/json");
2026-06-03 14:43:06 +00:00
using var response = await this.HttpClient.SendAsync(request, token);
if (!response.IsSuccessStatusCode)
{
var responseBody = await response.Content.ReadAsStringAsync(token);
LOGGER.LogError("Tool calling Responses API request failed with status code {ResponseStatusCode} and body: '{ResponseBody}'.", response.StatusCode, responseBody);
await MessageBus.INSTANCE.SendError(new(
Icons.Material.Filled.Build,
string.Format(TB("The tool calling request failed with status code {0}. See the logs for details."), (int)response.StatusCode)));
return null;
}
return await response.Content.ReadFromJsonAsync<ResponsesResponse>(JSON_SERIALIZER_OPTIONS, token);
}
2024-05-04 09:11:23 +00:00
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
2025-09-03 08:08:04 +00:00
2024-05-04 09:11:23 +00:00
/// <inheritdoc />
2024-12-03 14:24:40 +00:00
public override async IAsyncEnumerable<ImageURL> StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default)
2024-05-04 09:11:23 +00:00
{
yield break;
}
2025-09-03 08:08:04 +00:00
2024-05-04 09:11:23 +00:00
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc />
public override async Task<TranscriptionResult> TranscribeAudioAsync(Model transcriptionModel, string audioFilePath, SettingsManager settingsManager, CancellationToken token = default)
{
var requestedSecret = await RUST_SERVICE.GetAPIKey(this, SecretStoreType.TRANSCRIPTION_PROVIDER);
return await this.PerformStandardTranscriptionRequest(requestedSecret, transcriptionModel, audioFilePath, token: token);
}
/// <inhertidoc />
public override async Task<IReadOnlyList<IReadOnlyList<float>>> EmbedTextAsync(Model embeddingModel, SettingsManager settingsManager, CancellationToken token = default, params List<string> texts)
{
var requestedSecret = await RUST_SERVICE.GetAPIKey(this, SecretStoreType.EMBEDDING_PROVIDER);
return await this.PerformStandardTextEmbeddingRequest(requestedSecret, embeddingModel, token: token, texts: texts);
}
2024-05-04 09:11:23 +00:00
/// <inheritdoc />
public override async Task<ModelLoadResult> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
2024-05-04 09:11:23 +00:00
{
var result = await this.LoadModels(SecretStoreType.LLM_PROVIDER, ["chatgpt-", "gpt-", "o1-", "o3-", "o4-"], token, apiKeyProvisional);
return result with
{
Models =
[
..result.Models.Where(model => !model.Id.Contains("image", StringComparison.OrdinalIgnoreCase) &&
!model.Id.Contains("realtime", StringComparison.OrdinalIgnoreCase) &&
!model.Id.Contains("audio", StringComparison.OrdinalIgnoreCase) &&
!model.Id.Contains("tts", StringComparison.OrdinalIgnoreCase) &&
!model.Id.Contains("transcribe", StringComparison.OrdinalIgnoreCase))
]
};
2024-05-04 09:11:23 +00:00
}
/// <inheritdoc />
public override Task<ModelLoadResult> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
2024-05-04 09:11:23 +00:00
{
return this.LoadModels(SecretStoreType.IMAGE_PROVIDER, ["dall-e-", "gpt-image"], token, apiKeyProvisional);
2024-05-04 09:11:23 +00:00
}
2024-12-03 14:24:40 +00:00
/// <inheritdoc />
public override Task<ModelLoadResult> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
2024-12-03 14:24:40 +00:00
{
return this.LoadModels(SecretStoreType.EMBEDDING_PROVIDER, ["text-embedding-"], token, apiKeyProvisional);
2024-12-03 14:24:40 +00:00
}
2025-05-11 10:51:35 +00:00
2026-01-09 11:45:21 +00:00
/// <inheritdoc />
public override async Task<ModelLoadResult> GetTranscriptionModels(string? apiKeyProvisional = null, CancellationToken token = default)
2026-01-09 11:45:21 +00:00
{
var result = await this.LoadModels(SecretStoreType.TRANSCRIPTION_PROVIDER, ["whisper-", "gpt-"], token, apiKeyProvisional);
return result with
{
Models =
[
..result.Models.Where(model => model.Id.StartsWith("whisper-", StringComparison.InvariantCultureIgnoreCase) ||
model.Id.Contains("-transcribe", StringComparison.InvariantCultureIgnoreCase))
]
};
2026-01-09 11:45:21 +00:00
}
2024-05-04 09:11:23 +00:00
#endregion
private Task<ModelLoadResult> LoadModels(SecretStoreType storeType, string[] prefixes, CancellationToken token, string? apiKeyProvisional = null)
2024-05-04 09:11:23 +00:00
{
return this.LoadModelsResponse<ModelsResponse>(
storeType,
"models",
modelResponse => modelResponse.Data.Where(model => prefixes.Any(prefix => model.Id.StartsWith(prefix, StringComparison.InvariantCulture))),
token,
apiKeyProvisional);
2024-05-04 09:11:23 +00:00
}
private static bool HasInsufficientQuotaError(string responseBody)
{
if (string.IsNullOrWhiteSpace(responseBody))
return false;
try
2024-06-03 17:42:53 +00:00
{
using var document = JsonDocument.Parse(responseBody);
return HasInsufficientQuotaError(document.RootElement);
}
catch (JsonException)
{
return false;
}
}
2024-06-03 17:42:53 +00:00
private static bool HasInsufficientQuotaError(JsonElement element)
{
switch (element.ValueKind)
{
case JsonValueKind.Object:
if (HasJsonStringValue(element, "type", "insufficient_quota") ||
HasJsonStringValue(element, "code", "insufficient_quota"))
return true;
2024-05-19 14:14:49 +00:00
foreach (var property in element.EnumerateObject())
if (HasInsufficientQuotaError(property.Value))
return true;
return false;
2024-05-04 09:11:23 +00:00
case JsonValueKind.Array:
foreach (var item in element.EnumerateArray())
if (HasInsufficientQuotaError(item))
return true;
return false;
default:
return false;
}
}
private static bool IsInsufficientQuota(string? value)
{
return value is not null && value.Equals("insufficient_quota", StringComparison.OrdinalIgnoreCase);
}
private static bool HasJsonStringValue(JsonElement element, string propertyName, string expectedValue)
{
return element.TryGetProperty(propertyName, out var propertyElement) &&
propertyElement.ValueKind is JsonValueKind.String &&
string.Equals(propertyElement.GetString(), expectedValue, StringComparison.OrdinalIgnoreCase);
2024-05-04 09:11:23 +00:00
}
}