Merge branch '34-choosing-a-model' into 'main'

Resolve "Choosing a model"

Closes #34

See merge request products/mindwork-ai-studio!4
This commit is contained in:
Thorsten 2024-05-19 18:30:46 +00:00
commit 329b05e84a
15 changed files with 135 additions and 46 deletions

View File

@ -102,7 +102,7 @@ public partial class Chat : ComponentBase
// Use the selected provider to get the AI response.
// By awaiting this line, we wait for the entire
// content to be streamed.
await aiText.CreateFromProviderAsync(this.selectedProvider.UsedProvider.CreateProvider(), this.JsRuntime, this.SettingsManager, new Model("gpt-4o"), this.chatThread);
await aiText.CreateFromProviderAsync(this.selectedProvider.UsedProvider.CreateProvider(this.selectedProvider.InstanceName), this.JsRuntime, this.SettingsManager, this.selectedProvider.Model, this.chatThread);
// Disable the stream state:
this.isStreaming = false;

View File

@ -7,20 +7,23 @@
<MudTable Items="@this.SettingsManager.ConfigurationData.Providers">
<ColGroup>
<col style="width: 3em;"/>
<col style="width: 6em;"/>
<col style="width: 12em;"/>
<col style="width: 12em;"/>
<col/>
<col style="width: 20em;"/>
</ColGroup>
<HeaderContent>
<MudTh>#</MudTh>
<MudTh>Instance Name</MudTh>
<MudTh>Provider</MudTh>
<MudTh>Name</MudTh>
<MudTh>Model</MudTh>
<MudTh Style="text-align: left;">Actions</MudTh>
</HeaderContent>
<RowTemplate>
<MudTd></MudTd>
<MudTd>@context.UsedProvider</MudTd>
<MudTd>@context.Num</MudTd>
<MudTd>@context.InstanceName</MudTd>
<MudTd>@context.UsedProvider</MudTd>
<MudTd>@context.Model</MudTd>
<MudTd Style="text-align: left;">
<MudButton Variant="Variant.Filled" Color="Color.Info" StartIcon="@Icons.Material.Filled.Edit" Class="mr-2" OnClick="() => this.EditProvider(context)">
Edit

View File

@ -52,6 +52,8 @@ public partial class Settings : ComponentBase
return;
var addedProvider = (AIStudio.Settings.Provider)dialogResult.Data;
addedProvider = addedProvider with { Num = this.SettingsManager.ConfigurationData.NextProviderNum++ };
this.SettingsManager.ConfigurationData.Providers.Add(addedProvider);
await this.SettingsManager.StoreSettings();
}
@ -60,9 +62,11 @@ public partial class Settings : ComponentBase
{
var dialogParameters = new DialogParameters<ProviderDialog>
{
{ x => x.DataNum, provider.Num },
{ x => x.DataId, provider.Id },
{ x => x.DataInstanceName, provider.InstanceName },
{ x => x.DataProvider, provider.UsedProvider },
{ x => x.DataModel, provider.Model },
{ x => x.IsEditing, true },
};
@ -72,6 +76,12 @@ public partial class Settings : ComponentBase
return;
var editedProvider = (AIStudio.Settings.Provider)dialogResult.Data;
// Set the provider number if it's not set. This is important for providers
// added before we started saving the provider number.
if(editedProvider.Num == 0)
editedProvider = editedProvider with { Num = this.SettingsManager.ConfigurationData.NextProviderNum++ };
this.SettingsManager.ConfigurationData.Providers[this.SettingsManager.ConfigurationData.Providers.IndexOf(provider)] = editedProvider;
await this.SettingsManager.StoreSettings();
}
@ -88,9 +98,7 @@ public partial class Settings : ComponentBase
if (dialogResult.Canceled)
return;
var providerInstance = provider.UsedProvider.CreateProvider();
providerInstance.InstanceName = provider.InstanceName;
var providerInstance = provider.UsedProvider.CreateProvider(provider.InstanceName);
var deleteSecretResponse = await this.SettingsManager.DeleteAPIKey(this.JsRuntime, providerInstance);
if(deleteSecretResponse.Success)
{

View File

@ -20,12 +20,18 @@
<EnableUnsafeUTF7Encoding>false</EnableUnsafeUTF7Encoding> <!-- Remove unsafe UTF7 encoding -->
<JsonSerializerIsReflectionEnabledByDefault>true</JsonSerializerIsReflectionEnabledByDefault> <!-- Enable reflection for JSON serialization -->
<SuppressTrimAnalysisWarnings>true</SuppressTrimAnalysisWarnings> <!-- Suppress trim analysis warnings -->
<NoWarn>IL2026</NoWarn> <!-- Suppress warning IL2026: Usage of methods marked as RequiresUnreferencedCode. None issue here, since we use trim mode partial, though. -->
<!--
IL2026: Usage of methods marked as RequiresUnreferencedCode. None issue here, since we use partial trim mode, though.
CS8974: Converting method group to non-delegate type; Did you intend to invoke the method? We have this issue with MudBlazor validation methods.
-->
<NoWarn>IL2026, CS8974</NoWarn>
</PropertyGroup>
<ItemGroup>
<!-- Remove launchSettings.json from the project -->
<None Remove="Properties\launchSettings.json" />
<None Remove="build.nu" />
</ItemGroup>
<ItemGroup>

View File

@ -1,15 +1,16 @@
using System.Reflection;
using AIStudio;
using AIStudio.Components;
using AIStudio.Settings;
using AIStudio.Tools;
using Microsoft.Extensions.FileProviders;
using MudBlazor;
using MudBlazor.Services;
#if !DEBUG
using System.Reflection;
using Microsoft.Extensions.FileProviders;
#endif
var builder = WebApplication.CreateBuilder();
builder.Services.AddMudServices(config =>
{

View File

@ -53,7 +53,7 @@ public interface IProvider
/// <param name="settings">The settings manager to access the API key.</param>
/// <param name="token">The cancellation token.</param>
/// <returns>The list of text models.</returns>
public Task<IList<Model>> GetTextModels(IJSRuntime jsRuntime, SettingsManager settings, CancellationToken token = default);
public Task<IEnumerable<Model>> GetTextModels(IJSRuntime jsRuntime, SettingsManager settings, CancellationToken token = default);
/// <summary>
/// Load all possible image models that can be used with this provider.
@ -62,5 +62,5 @@ public interface IProvider
/// <param name="settings">The settings manager to access the API key.</param>
/// <param name="token">The cancellation token.</param>
/// <returns>The list of image models.</returns>
public Task<IList<Model>> GetImageModels(IJSRuntime jsRuntime, SettingsManager settings, CancellationToken token = default);
public Task<IEnumerable<Model>> GetImageModels(IJSRuntime jsRuntime, SettingsManager settings, CancellationToken token = default);
}

View File

@ -3,5 +3,5 @@ namespace AIStudio.Provider;
/// <summary>
/// An image URL.
/// </summary>
/// <param name="url">The image URL.</param>
public readonly record struct ImageURL(string url);
/// <param name="URL">The image URL.</param>
public readonly record struct ImageURL(string URL);

View File

@ -4,4 +4,11 @@ namespace AIStudio.Provider;
/// The data model for the model to use.
/// </summary>
/// <param name="Id">The model's ID.</param>
public readonly record struct Model(string Id);
public readonly record struct Model(string Id)
{
#region Overrides of ValueType
public override string ToString() => string.IsNullOrWhiteSpace(this.Id) ? "no model selected" : this.Id;
#endregion
}

View File

@ -17,9 +17,9 @@ public class NoProvider : IProvider
public string InstanceName { get; set; } = "None";
public Task<IList<Model>> GetTextModels(IJSRuntime jsRuntime, SettingsManager settings, CancellationToken token = default) => Task.FromResult<IList<Model>>(new List<Model>());
public Task<IEnumerable<Model>> GetTextModels(IJSRuntime jsRuntime, SettingsManager settings, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
public Task<IList<Model>> GetImageModels(IJSRuntime jsRuntime, SettingsManager settings, CancellationToken token = default) => Task.FromResult<IList<Model>>(new List<Model>());
public Task<IEnumerable<Model>> GetImageModels(IJSRuntime jsRuntime, SettingsManager settings, CancellationToken token = default) => Task.FromResult<IEnumerable<Model>>([]);
public async IAsyncEnumerable<string> StreamChatCompletion(IJSRuntime jsRuntime, SettingsManager settings, Model chatModel, ChatThread chatChatThread, [EnumeratorCancellation] CancellationToken token = default)
{

View File

@ -155,34 +155,33 @@ public sealed class ProviderOpenAI() : BaseProvider("https://api.openai.com/v1/"
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc />
public async Task<IList<Model>> GetTextModels(IJSRuntime jsRuntime, SettingsManager settings, CancellationToken token = default)
public Task<IEnumerable<Model>> GetTextModels(IJSRuntime jsRuntime, SettingsManager settings, CancellationToken token = default)
{
return await this.LoadModels(jsRuntime, settings, "gpt-", token);
return this.LoadModels(jsRuntime, settings, "gpt-", token);
}
/// <inheritdoc />
public async Task<IList<Model>> GetImageModels(IJSRuntime jsRuntime, SettingsManager settings, CancellationToken token = default)
public Task<IEnumerable<Model>> GetImageModels(IJSRuntime jsRuntime, SettingsManager settings, CancellationToken token = default)
{
return await this.LoadModels(jsRuntime, settings, "dall-e-", token);
return this.LoadModels(jsRuntime, settings, "dall-e-", token);
}
#endregion
private async Task<IList<Model>> LoadModels(IJSRuntime jsRuntime, SettingsManager settings, string prefix, CancellationToken token)
private async Task<IEnumerable<Model>> LoadModels(IJSRuntime jsRuntime, SettingsManager settings, string prefix, CancellationToken token)
{
var requestedSecret = await settings.GetAPIKey(jsRuntime, this);
if(!requestedSecret.Success)
return new List<Model>();
if (!requestedSecret.Success)
return [];
var request = new HttpRequestMessage(HttpMethod.Get, "models");
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", requestedSecret.Secret);
var emptyList = new List<Model>();
var response = await this.httpClient.SendAsync(request, token);
if(!response.IsSuccessStatusCode)
return emptyList;
return [];
var modelResponse = await response.Content.ReadFromJsonAsync<ModelsResponse>(token);
return modelResponse.Data.Where(n => n.Id.StartsWith(prefix, StringComparison.InvariantCulture)).ToList();
return modelResponse.Data.Where(n => n.Id.StartsWith(prefix, StringComparison.InvariantCulture));
}
}

View File

@ -25,17 +25,19 @@ public static class ExtensionsProvider
{
Providers.OPEN_AI => "OpenAI",
Providers.NONE => "No provider selected",
_ => "Unknown",
};
/// <summary>
/// Creates a new provider instance based on the provider value.
/// </summary>
/// <param name="provider">The provider value.</param>
/// <param name="instanceName">The used instance name.</param>
/// <returns>The provider instance.</returns>
public static IProvider CreateProvider(this Providers provider) => provider switch
public static IProvider CreateProvider(this Providers provider, string instanceName) => provider switch
{
Providers.OPEN_AI => new ProviderOpenAI(),
Providers.OPEN_AI => new ProviderOpenAI { InstanceName = instanceName },
_ => new NoProvider(),
};

View File

@ -14,7 +14,12 @@ public sealed class Data
/// <summary>
/// List of configured providers.
/// </summary>
public List<Provider> Providers { get; init; } = new();
public List<Provider> Providers { get; init; } = [];
/// <summary>
/// The next provider number to use.
/// </summary>
public uint NextProviderNum { get; set; } = 1;
/// <summary>
/// Should we save energy? When true, we will update content streamed

View File

@ -5,10 +5,12 @@ namespace AIStudio.Settings;
/// <summary>
/// Data model for configured providers.
/// </summary>
/// <param name="Num">The provider's number.</param>
/// <param name="Id">The provider's ID.</param>
/// <param name="InstanceName">The provider's instance name. Useful for multiple instances of the same provider, e.g., to distinguish between different OpenAI API keys.</param>
/// <param name="UsedProvider">The provider used.</param>
public readonly record struct Provider(string Id, string InstanceName, Providers UsedProvider)
/// <param name="Model">The LLM model to use for chat.</param>
public readonly record struct Provider(uint Num, string Id, string InstanceName, Providers UsedProvider, Model Model)
{
#region Overrides of ValueType
@ -19,7 +21,7 @@ public readonly record struct Provider(string Id, string InstanceName, Providers
/// <returns>A string that represents the current provider in a human-readable format.</returns>
public override string ToString()
{
return $"{this.InstanceName} ({this.UsedProvider.ToName()})";
return $"{this.InstanceName} ({this.UsedProvider.ToName()}, {this.Model})";
}
#endregion

View File

@ -9,6 +9,7 @@
T="string"
@bind-Text="@this.DataInstanceName"
Label="Instance Name"
Class="mb-3"
Adornment="Adornment.Start"
AdornmentIcon="@Icons.Material.Filled.Lightbulb"
AdornmentColor="Color.Info"
@ -16,7 +17,7 @@
/>
@* ReSharper disable once CSharpWarnings::CS8974 *@
<MudSelect @bind-Value="@this.DataProvider" Label="Provider" OpenIcon="@Icons.Material.Filled.AccountBalance" AdornmentColor="Color.Info" Adornment="Adornment.Start" Validation="@this.ValidatingProvider">
<MudSelect @bind-Value="@this.DataProvider" Label="Provider" Class="mb-3" OpenIcon="@Icons.Material.Filled.AccountBalance" AdornmentColor="Color.Info" Adornment="Adornment.Start" Validation="@this.ValidatingProvider">
@foreach (Providers provider in Enum.GetValues(typeof(Providers)))
{
<MudSelectItem Value="@provider">@provider</MudSelectItem>
@ -28,12 +29,24 @@
T="string"
@bind-Text="@this.dataAPIKey"
Label="API Key"
Class="mb-3"
Adornment="Adornment.Start"
AdornmentIcon="@Icons.Material.Filled.VpnKey"
AdornmentColor="Color.Info"
InputType="InputType.Password"
Validation="@this.ValidatingAPIKey"
/>
<MudStack Row="@true" AlignItems="AlignItems.Center">
<MudButton Disabled="@(!this.CanLoadModels)" Variant="Variant.Filled" Size="Size.Small" StartIcon="@Icons.Material.Filled.Refresh" OnClick="this.ReloadModels">Reload</MudButton>
<MudSelect @bind-Value="@this.DataModel" Label="Model" Class="mb-3" OpenIcon="@Icons.Material.Filled.FaceRetouchingNatural" AdornmentColor="Color.Info" Adornment="Adornment.Start" Validation="@this.ValidatingModel">
@foreach (var model in this.availableModels)
{
<MudSelectItem Value="@model">@model</MudSelectItem>
}
</MudSelect>
</MudStack>
</MudForm>
@if (this.dataIssues.Any())

View File

@ -16,6 +16,12 @@ public partial class ProviderDialog : ComponentBase
{
[CascadingParameter]
private MudDialogInstance MudDialog { get; set; } = null!;
/// <summary>
/// The provider's number in the list.
/// </summary>
[Parameter]
public uint DataNum { get; set; }
/// <summary>
/// The provider's ID.
@ -35,6 +41,12 @@ public partial class ProviderDialog : ComponentBase
[Parameter]
public Providers DataProvider { get; set; } = Providers.NONE;
/// <summary>
/// The LLM model to use, e.g., GPT-4o.
/// </summary>
[Parameter]
public Model DataModel { get; set; }
/// <summary>
/// Should the dialog be in editing mode?
/// </summary>
@ -50,7 +62,7 @@ public partial class ProviderDialog : ComponentBase
/// <summary>
/// The list of used instance names. We need this to check for uniqueness.
/// </summary>
private List<string> usedInstanceNames { get; set; } = [];
private List<string> UsedInstanceNames { get; set; } = [];
private bool dataIsValid;
private string[] dataIssues = [];
@ -60,28 +72,33 @@ public partial class ProviderDialog : ComponentBase
// We get the form reference from Blazor code to validate it manually:
private MudForm form = null!;
private readonly List<Model> availableModels = new();
#region Overrides of ComponentBase
protected override async Task OnInitializedAsync()
{
// Load the used instance names:
this.usedInstanceNames = this.SettingsManager.ConfigurationData.Providers.Select(x => x.InstanceName.ToLowerInvariant()).ToList();
this.UsedInstanceNames = this.SettingsManager.ConfigurationData.Providers.Select(x => x.InstanceName.ToLowerInvariant()).ToList();
// When editing, we need to load the data:
if(this.IsEditing)
{
this.dataEditingPreviousInstanceName = this.DataInstanceName.ToLowerInvariant();
var provider = this.DataProvider.CreateProvider();
var provider = this.DataProvider.CreateProvider(this.DataInstanceName);
if(provider is NoProvider)
return;
provider.InstanceName = this.DataInstanceName;
// Load the API key:
var requestedSecret = await this.SettingsManager.GetAPIKey(this.JsRuntime, provider);
if(requestedSecret.Success)
{
this.dataAPIKey = requestedSecret.Secret;
// Now, we try to load the list of available models:
await this.ReloadModels();
}
else
{
this.dataAPIKeyStorageIssue = $"Failed to load the API key from the operating system. The message was: {requestedSecret.Issue}. You might ignore this message and provide the API key again.";
@ -118,14 +135,15 @@ public partial class ProviderDialog : ComponentBase
// We just return this data to the parent component:
var addedProvider = new Provider
{
Num = this.DataNum,
Id = this.DataId,
InstanceName = this.DataInstanceName,
UsedProvider = this.DataProvider,
Model = this.DataModel,
};
// We need to instantiate the provider to store the API key:
var provider = this.DataProvider.CreateProvider();
provider.InstanceName = this.DataInstanceName;
var provider = this.DataProvider.CreateProvider(this.DataInstanceName);
// Store the API key in the OS secure storage:
var storeResponse = await this.SettingsManager.SetAPIKey(this.JsRuntime, provider, this.dataAPIKey);
@ -147,6 +165,14 @@ public partial class ProviderDialog : ComponentBase
return null;
}
private string? ValidatingModel(Model model)
{
if (model == default)
return "Please select a model.";
return null;
}
[GeneratedRegex("^[a-zA-Z0-9 ]+$")]
private static partial Regex InstanceNameRegex();
@ -170,7 +196,7 @@ public partial class ProviderDialog : ComponentBase
// The instance name must be unique:
var lowerInstanceName = instanceName.ToLowerInvariant();
if (lowerInstanceName != this.dataEditingPreviousInstanceName && this.usedInstanceNames.Contains(lowerInstanceName))
if (lowerInstanceName != this.dataEditingPreviousInstanceName && this.UsedInstanceNames.Contains(lowerInstanceName))
return "The instance name must be unique; the chosen name is already in use.";
return null;
@ -188,4 +214,21 @@ public partial class ProviderDialog : ComponentBase
}
private void Cancel() => this.MudDialog.Cancel();
private bool CanLoadModels => !string.IsNullOrWhiteSpace(this.dataAPIKey) && this.DataProvider != Providers.NONE && !string.IsNullOrWhiteSpace(this.DataInstanceName);
private async Task ReloadModels()
{
var provider = this.DataProvider.CreateProvider(this.DataInstanceName);
if(provider is NoProvider)
return;
var models = await provider.GetTextModels(this.JsRuntime, this.SettingsManager);
// Order descending by ID means that the newest models probably come first:
var orderedModels = models.OrderByDescending(n => n.Id);
this.availableModels.Clear();
this.availableModels.AddRange(orderedModels);
}
}