using System.Net.Http.Headers;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using AIStudio.Chat;
using AIStudio.Settings;
using AIStudio.Tools.PluginSystem;
using AIStudio.Tools.Rust;
using AIStudio.Tools.ToolCallingSystem;
using AIStudio.Tools.Services;
using Microsoft.Extensions.DependencyInjection;
namespace AIStudio.Provider.OpenAI;
///
/// The OpenAI provider.
///
public sealed class ProviderOpenAI() : BaseProvider(LLMProviders.OPEN_AI, "https://api.openai.com/v1/", LOGGER)
{
private static readonly ILogger LOGGER = Program.LOGGER_FACTORY.CreateLogger();
private static string TB(string fallbackEN) => I18N.I.T(fallbackEN, typeof(ProviderOpenAI).Namespace, nameof(ProviderOpenAI));
#region Implementation of IProvider
///
public override string Id => LLMProviders.OPEN_AI.ToName();
///
public override string InstanceName { get; set; } = "OpenAI";
///
public override async IAsyncEnumerable StreamChatCompletion(Model chatModel, ChatThread chatThread, SettingsManager settingsManager, [EnumeratorCancellation] CancellationToken token = default)
{
// Get the API key:
var requestedSecret = await RUST_SERVICE.GetAPIKey(this, SecretStoreType.LLM_PROVIDER);
if(!requestedSecret.Success)
yield break;
// Unfortunately, OpenAI changed the name of the system prompt based on the model.
// All models that start with "o" (the omni aka reasoning models), all GPT4o models,
// and all newer models have the system prompt named "developer". All other models
// have the system prompt named "system". We need to check this to get the correct
// system prompt.
//
// To complicate it even more: The early versions of reasoning models, which are released
// before the 17th of December 2024, have no system prompt at all. We need to check this
// as well.
// Apply the basic rule first:
var systemPromptRole =
chatModel.Id.StartsWith('o') ||
chatModel.Id.StartsWith("gpt-5", StringComparison.Ordinal) ||
chatModel.Id.Contains("4o") ? "developer" : "system";
// Check if the model is an early version of the reasoning models:
systemPromptRole = chatModel.Id switch
{
"o1-mini" => "user",
"o1-mini-2024-09-12" => "user",
"o1-preview" => "user",
"o1-preview-2024-09-12" => "user",
_ => systemPromptRole,
};
// Read the model capabilities:
var modelCapabilities = this.Provider.GetModelCapabilities(chatModel);
// Check if we are using the Responses API or the Chat Completion API:
var usingResponsesAPI = modelCapabilities.Contains(Capability.RESPONSES_API);
// Prepare the request path based on the API we are using:
var requestPath = usingResponsesAPI ? "responses" : "chat/completions";
LOGGER.LogInformation("Using the system prompt role '{SystemPromptRole}' and the '{RequestPath}' API for model '{ChatModelId}'.", systemPromptRole, requestPath, chatModel.Id);
//
// Prepare the tools we want to use:
//
var providerConfidence = this.Provider.GetConfidence(settingsManager).Level;
var minimumWebSearchConfidence = settingsManager.GetMinimumProviderConfidenceForTool(ToolSelectionRules.WEB_SEARCH_TOOL_ID);
var isWebSearchAllowed = ToolSelectionRules.IsProviderConfidenceAllowed(providerConfidence, minimumWebSearchConfidence);
IList providerTools = modelCapabilities.Contains(Capability.WEB_SEARCH) && isWebSearchAllowed
? [ ProviderTools.WEB_SEARCH ]
: [];
// Parse the API parameters:
var apiParameters = this.ParseAdditionalApiParameters("input", "store", "tools");
if (!usingResponsesAPI)
{
await foreach (var content in this.StreamOpenAICompatibleChatCompletion(
"OpenAI",
chatModel,
chatThread,
settingsManager,
() => chatThread.Blocks.BuildMessagesAsync(
this.Provider,
chatModel,
role => role switch
{
ChatRole.USER => "user",
ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant",
ChatRole.SYSTEM => systemPromptRole,
_ => "user",
},
text => new SubContentText
{
Text = text,
},
async attachment => new SubContentImageUrlNested
{
ImageUrl = new SubContentImageUrlData
{
Url = await attachment.TryAsBase64(token: token) is (true, var base64Content)
? $"data:{attachment.DetermineMimeType()};base64,{base64Content}"
: string.Empty,
},
}),
(systemPrompt, messages, apiParameters, stream, tools) => Task.FromResult(new ChatCompletionAPIRequest
{
Model = chatModel.Id,
Messages = [systemPrompt, ..messages],
Stream = stream,
Tools = tools,
ParallelToolCalls = tools is null ? null : true,
AdditionalApiParameters = apiParameters,
}),
systemPromptRole: systemPromptRole,
requestPath: "chat/completions",
token: token))
yield return content;
yield break;
}
// Prepare the system prompt:
var systemPrompt = new TextMessage
{
Role = systemPromptRole,
Content = chatThread.PrepareSystemPrompt(settingsManager),
};
// Build the list of messages:
var messages = await chatThread.Blocks.BuildMessagesAsync(
this.Provider, chatModel,
role => role switch
{
ChatRole.USER => "user",
ChatRole.AI => "assistant",
ChatRole.AGENT => "assistant",
ChatRole.SYSTEM => systemPromptRole,
_ => "user",
},
text => new SubContentInputText
{
Text = text,
},
async attachment => new SubContentInputImage
{
ImageUrl = await attachment.TryAsBase64(token: token) is (true, var base64Content)
? $"data:{attachment.DetermineMimeType()};base64,{base64Content}"
: string.Empty,
});
var baseInput = new List { systemPrompt };
baseInput.AddRange(messages.Cast());
var toolRegistry = Program.SERVICE_PROVIDER.GetService();
var toolExecutor = Program.SERVICE_PROVIDER.GetService();
var currentAssistantContent = chatThread.Blocks.LastOrDefault(x => x.Role is ChatRole.AI)?.Content as ContentText;
currentAssistantContent?.ToolInvocations.Clear();
IReadOnlyList<(ToolDefinition Definition, IToolImplementation Implementation)> runnableTools = toolRegistry is null
? []
: await toolRegistry.GetRunnableToolsAsync(
chatThread.RuntimeComponent,
chatThread.RuntimeSelectedToolIds,
modelCapabilities,
providerConfidence,
settingsManager.IsToolSelectionVisible(chatThread.RuntimeComponent));
if (usingResponsesAPI && toolExecutor is not null && runnableTools.Count > 0)
{
await foreach (var content in this.StreamResponsesWithLocalTools(
chatModel,
baseInput,
apiParameters,
runnableTools,
toolExecutor,
currentAssistantContent,
requestedSecret,
providerConfidence,
token))
yield return content;
yield break;
}
if (runnableTools.Count > 0)
providerTools = [];
//
// Create the request: either for the Responses API or the Chat Completion API
//
var openAIChatRequest = usingResponsesAPI switch
{
// Chat Completion API request:
false => JsonSerializer.Serialize(new ChatCompletionAPIRequest
{
Model = chatModel.Id,
// All messages go into the messages field:
Messages = [systemPrompt, ..messages],
// Right now, we only support streaming completions:
Stream = true,
AdditionalApiParameters = apiParameters
}, JSON_SERIALIZER_OPTIONS),
// Responses API request:
true => JsonSerializer.Serialize(new ResponsesAPIRequest
{
Model = chatModel.Id,
// All messages go into the input field:
Input = baseInput,
// Right now, we only support streaming completions:
Stream = true,
// We do not want to store any data on OpenAI's servers:
Store = false,
// Tools we want to use:
Tools = providerTools,
// Additional API parameters:
AdditionalApiParameters = apiParameters
}, JSON_SERIALIZER_OPTIONS),
};
async Task RequestBuilder()
{
// Build the HTTP post request:
var request = new HttpRequestMessage(HttpMethod.Post, requestPath);
// Set the authorization header:
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION));
// Set the content:
request.Content = new StringContent(openAIChatRequest, Encoding.UTF8, "application/json");
return request;
}
if (usingResponsesAPI)
await foreach (var content in this.StreamResponsesInternal("OpenAI", RequestBuilder, token))
yield return content;
else
await foreach (var content in this.StreamChatCompletionInternal("OpenAI", RequestBuilder, token))
yield return content;
}
private async IAsyncEnumerable StreamResponsesWithLocalTools(
Model chatModel,
IList baseInput,
IDictionary apiParameters,
IReadOnlyList<(ToolDefinition Definition, IToolImplementation Implementation)> runnableTools,
ToolExecutor toolExecutor,
ContentText? currentAssistantContent,
RequestedSecret requestedSecret,
ConfidenceLevel providerConfidence,
[EnumeratorCancellation] CancellationToken token)
{
var providerTools = runnableTools
.Select(x => (object)ProviderToolAdapters.ToResponsesTool(x.Definition))
.ToList();
var internalItems = new List();
var toolCallCount = 0;
while (true)
{
var requestDto = new ResponsesAPIRequest
{
Model = chatModel.Id,
Input = [..baseInput, ..internalItems],
Stream = false,
Store = false,
Tools = providerTools,
AdditionalApiParameters = apiParameters,
};
var response = await this.ExecuteResponsesRequest(requestDto, requestedSecret, token);
if (response is null)
{
if (currentAssistantContent is not null)
{
currentAssistantContent.ToolRuntimeStatus = new();
await currentAssistantContent.StreamingEvent();
}
yield break;
}
var functionCalls = response.GetFunctionCalls();
if (functionCalls.Count == 0)
{
if (currentAssistantContent is not null)
{
currentAssistantContent.ToolRuntimeStatus = new();
await currentAssistantContent.StreamingEvent();
}
var textOutput = response.GetTextOutput();
if (!string.IsNullOrWhiteSpace(textOutput))
yield return new ContentStreamChunk(textOutput, []);
else if (toolCallCount > 0)
yield return new ContentStreamChunk("The model completed the tool call but did not return a final answer.", []);
yield break;
}
if (currentAssistantContent is not null)
{
currentAssistantContent.ToolRuntimeStatus = new ToolRuntimeStatus
{
IsRunning = true,
ToolNames = functionCalls
.Select(x => runnableTools.FirstOrDefault(tool => tool.Definition.Function.Name.Equals(x.Name, StringComparison.Ordinal)).Implementation?.GetDisplayName() ?? x.Name)
.ToList(),
};
await currentAssistantContent.StreamingEvent();
}
foreach (var outputItem in response.Output)
internalItems.Add(outputItem);
foreach (var functionCall in functionCalls)
{
toolCallCount++;
if (toolCallCount > 10)
{
var limitMessage = "Tool calling stopped because the maximum of 10 tool calls was reached.";
currentAssistantContent?.ToolInvocations.Add(new ToolInvocationTrace
{
Order = toolCallCount,
ToolId = functionCall.Name,
ToolName = functionCall.Name,
ToolCallId = functionCall.CallId,
Status = ToolInvocationTraceStatus.BLOCKED,
StatusMessage = limitMessage,
Result = limitMessage,
});
if (currentAssistantContent is not null)
{
currentAssistantContent.ToolRuntimeStatus = new();
await currentAssistantContent.StreamingEvent();
}
yield return new ContentStreamChunk(limitMessage, []);
yield break;
}
var (toolContent, trace) = await toolExecutor.ExecuteAsync(
functionCall.CallId,
functionCall.Name,
functionCall.Arguments,
runnableTools,
providerConfidence,
toolCallCount,
token);
currentAssistantContent?.ToolInvocations.Add(trace);
internalItems.Add(new ResponsesFunctionCallOutputItem
{
CallId = functionCall.CallId,
Output = toolContent,
});
}
if (currentAssistantContent is not null)
await currentAssistantContent.StreamingEvent();
}
}
private async Task ExecuteResponsesRequest(ResponsesAPIRequest requestDto, RequestedSecret requestedSecret, CancellationToken token)
{
using var request = new HttpRequestMessage(HttpMethod.Post, "responses");
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION));
request.Content = new StringContent(JsonSerializer.Serialize(requestDto, JSON_SERIALIZER_OPTIONS), Encoding.UTF8, "application/json");
using var response = await this.httpClient.SendAsync(request, token);
if (!response.IsSuccessStatusCode)
{
var responseBody = await response.Content.ReadAsStringAsync(token);
LOGGER.LogError("Tool calling Responses API request failed with status code {ResponseStatusCode} and body: '{ResponseBody}'.", response.StatusCode, responseBody);
await MessageBus.INSTANCE.SendError(new(
Icons.Material.Filled.Build,
string.Format(TB("The tool calling request failed with status code {0}. See the logs for details."), (int)response.StatusCode)));
return null;
}
return await response.Content.ReadFromJsonAsync(JSON_SERIALIZER_OPTIONS, token);
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
///
public override async IAsyncEnumerable StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default)
{
yield break;
}
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
///
public override async Task TranscribeAudioAsync(Model transcriptionModel, string audioFilePath, SettingsManager settingsManager, CancellationToken token = default)
{
var requestedSecret = await RUST_SERVICE.GetAPIKey(this, SecretStoreType.TRANSCRIPTION_PROVIDER);
return await this.PerformStandardTranscriptionRequest(requestedSecret, transcriptionModel, audioFilePath, token: token);
}
///
public override async Task>> EmbedTextAsync(Model embeddingModel, SettingsManager settingsManager, CancellationToken token = default, params List texts)
{
var requestedSecret = await RUST_SERVICE.GetAPIKey(this, SecretStoreType.EMBEDDING_PROVIDER);
return await this.PerformStandardTextEmbeddingRequest(requestedSecret, embeddingModel, token: token, texts: texts);
}
///
public override async Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
var models = await this.LoadModels(SecretStoreType.LLM_PROVIDER, ["chatgpt-", "gpt-", "o1-", "o3-", "o4-"], token, apiKeyProvisional);
return models.Where(model => !model.Id.Contains("image", StringComparison.OrdinalIgnoreCase) &&
!model.Id.Contains("realtime", StringComparison.OrdinalIgnoreCase) &&
!model.Id.Contains("audio", StringComparison.OrdinalIgnoreCase) &&
!model.Id.Contains("tts", StringComparison.OrdinalIgnoreCase) &&
!model.Id.Contains("transcribe", StringComparison.OrdinalIgnoreCase));
}
///
public override Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
return this.LoadModels(SecretStoreType.IMAGE_PROVIDER, ["dall-e-", "gpt-image"], token, apiKeyProvisional);
}
///
public override Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
return this.LoadModels(SecretStoreType.EMBEDDING_PROVIDER, ["text-embedding-"], token, apiKeyProvisional);
}
///
public override async Task> GetTranscriptionModels(string? apiKeyProvisional = null, CancellationToken token = default)
{
var models = await this.LoadModels(SecretStoreType.TRANSCRIPTION_PROVIDER, ["whisper-", "gpt-"], token, apiKeyProvisional);
return models.Where(model => model.Id.StartsWith("whisper-", StringComparison.InvariantCultureIgnoreCase) ||
model.Id.Contains("-transcribe", StringComparison.InvariantCultureIgnoreCase));
}
#endregion
private async Task> LoadModels(SecretStoreType storeType, string[] prefixes, CancellationToken token, string? apiKeyProvisional = null)
{
var secretKey = apiKeyProvisional switch
{
not null => apiKeyProvisional,
_ => await RUST_SERVICE.GetAPIKey(this, storeType) switch
{
{ Success: true } result => await result.Secret.Decrypt(ENCRYPTION),
_ => null,
}
};
if (secretKey is null)
return [];
using var request = new HttpRequestMessage(HttpMethod.Get, "models");
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", secretKey);
using var response = await this.httpClient.SendAsync(request, token);
if(!response.IsSuccessStatusCode)
return [];
var modelResponse = await response.Content.ReadFromJsonAsync(token);
return modelResponse.Data.Where(model => prefixes.Any(prefix => model.Id.StartsWith(prefix, StringComparison.InvariantCulture)));
}
}