Added an agent to select the appropriate data sources for any prompt

This commit is contained in:
Thorsten Sommer 2025-02-16 23:18:24 +01:00
parent 1ccf0872bb
commit 171a83ba58
Signed by: tsommer
GPG Key ID: 371BBA77A02C0108
10 changed files with 577 additions and 4 deletions

View File

@ -0,0 +1,416 @@
using System.Text;
using System.Text.Json;
using AIStudio.Chat;
using AIStudio.Provider;
using AIStudio.Settings;
using AIStudio.Settings.DataModel;
using AIStudio.Tools.ERIClient;
using AIStudio.Tools.Services;
namespace AIStudio.Agents;
public sealed class AgentDataSourceSelection (ILogger<AgentDataSourceSelection> logger, ILogger<AgentBase> baseLogger, SettingsManager settingsManager, DataSourceService dataSourceService, ThreadSafeRandom rng) : AgentBase(baseLogger, settingsManager, dataSourceService, rng)
{
private static readonly ContentBlock EMPTY_BLOCK = new()
{
Content = null,
ContentType = ContentType.NONE,
Role = ChatRole.AGENT,
Time = DateTimeOffset.UtcNow,
};
private readonly List<ContentBlock> answers = new();
#region Overrides of AgentBase
/// <inheritdoc />
protected override Type Type => Type.SYSTEM;
/// <inheritdoc />
public override string Id => "Data Source Selection";
/// <inheritdoc />
protected override string JobDescription =>
"""
You receive a system and a user prompt, as well as a list of possible data sources as input.
Your task is to select the appropriate data sources for the given task. You may choose none,
one, or multiple sources, depending on what best fits the system and user prompt. You need
to estimate and assess which source, based on its description, might be helpful in
processing the prompts.
Your response is a JSON list in the following format:
```
[
{"id": "The data source ID", "reason": "Why did you choose this source?", "confidence": 0.87},
{"id": "The data source ID", "reason": "Why did you choose this source?", "confidence": 0.54}
]
```
You express your confidence as a floating-point number between 0.0 (maximum uncertainty) and
1.0 (you are absolutely certain that this source is needed).
The JSON schema is:
```
{
"$schema": "http://json-schema.org/draft-04/schema#",
"type": "array",
"items": [
{
"type": "object",
"properties": {
"id": {
"type": "string"
},
"reason": {
"type": "string"
},
"confidence": {
"type": "number"
}
},
"required": [
"id",
"reason",
"confidence"
]
}
]
}
```
When no data source is needed, you return an empty JSON list `[]`. You do not ask any
follow-up questions. You do not address the user. Your response consists solely of
the JSON list.
""";
/// <inheritdoc />
protected override string SystemPrompt(string availableDataSources) => $"""
{this.JobDescription}
{availableDataSources}
""";
/// <inheritdoc />
public override Settings.Provider? ProviderSettings { get; set; }
/// <summary>
/// The data source selection agent does not work with context. Use
/// the process input method instead.
/// </summary>
/// <returns>The chat thread without any changes.</returns>
public override Task<ChatThread> ProcessContext(ChatThread chatThread, IDictionary<string, string> additionalData) => Task.FromResult(chatThread);
/// <inheritdoc />
public override async Task<ContentBlock> ProcessInput(ContentBlock input, IDictionary<string, string> additionalData)
{
if (input.Content is not ContentText text)
return EMPTY_BLOCK;
if(text.InitialRemoteWait || text.IsStreaming)
return EMPTY_BLOCK;
if(string.IsNullOrWhiteSpace(text.Text))
return EMPTY_BLOCK;
if(!additionalData.TryGetValue("availableDataSources", out var availableDataSources) || string.IsNullOrWhiteSpace(availableDataSources))
return EMPTY_BLOCK;
var thread = this.CreateChatThread(this.SystemPrompt(availableDataSources));
var time = this.AddUserRequest(thread, text.Text);
await this.AddAIResponseAsync(thread, time);
var answer = thread.Blocks[^1];
this.answers.Add(answer);
return answer;
}
// <inheritdoc />
public override Task<bool> MadeDecision(ContentBlock input) => Task.FromResult(true);
// <inheritdoc />
public override IReadOnlyCollection<ContentBlock> GetContext() => [];
// <inheritdoc />
public override IReadOnlyCollection<ContentBlock> GetAnswers() => this.answers;
#endregion
public async Task<List<SelectedDataSource>> PerformSelectionAsync(IProvider provider, IContent lastPrompt, ChatThread chatThread, AllowedSelectedDataSources dataSources, CancellationToken token = default)
{
logger.LogInformation("The AI should select the appropriate data sources.");
//
// 1. Which LLM provider should the agent use?
//
// We start with the provider currently selected by the user:
var agentProvider = this.SettingsManager.ConfigurationData.Providers.FirstOrDefault(x => x.Id == provider.Id);
// If the user preselected an agent provider, we try to use this one:
if (this.SettingsManager.ConfigurationData.AgentDataSourceSelection.PreselectAgentOptions)
{
var configuredAgentProvider = this.SettingsManager.ConfigurationData.Providers.FirstOrDefault(x => x.Id == this.SettingsManager.ConfigurationData.AgentDataSourceSelection.PreselectedAgentProvider);
// If the configured agent provider is available, we use it:
if (configuredAgentProvider != default)
agentProvider = configuredAgentProvider;
}
// Assign the provider settings to the agent:
logger.LogInformation($"The agent for the data source selection uses the provider '{agentProvider.InstanceName}' ({agentProvider.UsedLLMProvider.ToName()}, confidence={agentProvider.UsedLLMProvider.GetConfidence(this.SettingsManager).Level.GetName()}).");
this.ProviderSettings = agentProvider;
//
// 2. Prepare the current system and user prompts as input for the agent:
//
var lastPromptContent = lastPrompt switch
{
ContentText text => text.Text,
// Image prompts may be empty, e.g., when the image is too large:
ContentImage image => await image.AsBase64(token),
// Other content types are not supported yet:
_ => string.Empty,
};
if (string.IsNullOrWhiteSpace(lastPromptContent))
{
logger.LogWarning("The last prompt is empty. The AI cannot select data sources.");
return [];
}
//
// 3. Prepare the allowed data sources as input for the agent:
//
var additionalData = new Dictionary<string, string>();
logger.LogInformation("Preparing the list of allowed data sources for the agent to choose from.");
// Notice: We do not dispose the Rust service here. The Rust service is a singleton
// and will be disposed when the application shuts down:
var rustService = Program.SERVICE_PROVIDER.GetService<RustService>()!;
var sb = new StringBuilder();
sb.AppendLine("The following data sources are available for selection:");
foreach (var ds in dataSources.AllowedDataSources)
{
switch (ds)
{
case DataSourceLocalDirectory localDirectory:
sb.AppendLine($"- Id={ds.Id}, name='{localDirectory.Name}', type=local directory, path='{localDirectory.Path}'");
break;
case DataSourceLocalFile localFile:
sb.AppendLine($"- Id={ds.Id}, name='{localFile.Name}', type=local file, path='{localFile.FilePath}'");
break;
case IERIDataSource eriDataSource:
var eriServerDescription = string.Empty;
try
{
//
// Call the ERI server to get the server description:
//
using var eriClient = ERIClientFactory.Get(eriDataSource.Version, eriDataSource)!;
var authResponse = await eriClient.AuthenticateAsync(eriDataSource, rustService, token);
if (authResponse.Successful)
{
var serverDescriptionResponse = await eriClient.GetDataSourceInfoAsync(token);
if (serverDescriptionResponse.Successful)
{
eriServerDescription = serverDescriptionResponse.Data.Description;
// Remove all line breaks from the description:
eriServerDescription = eriServerDescription.Replace("\n", " ").Replace("\r", " ");
}
else
logger.LogWarning($"Was not able to retrieve the server description from the ERI data source '{eriDataSource.Name}'. Message: {serverDescriptionResponse.Message}");
}
else
logger.LogWarning($"Was not able to authenticate with the ERI data source '{eriDataSource.Name}'. Message: {authResponse.Message}");
}
catch (Exception e)
{
logger.LogWarning($"The ERI data source '{eriDataSource.Name}' is not available. Thus, we cannot retrieve the server description. Error: {e.Message}");
}
//
// Append the ERI data source to the list. Use the server description if available:
//
if (string.IsNullOrWhiteSpace(eriServerDescription))
sb.AppendLine($"- Id={ds.Id}, name='{eriDataSource.Name}', type=external data source");
else
sb.AppendLine($"- Id={ds.Id}, name='{eriDataSource.Name}', type=external data source, description='{eriServerDescription}'");
break;
}
}
logger.LogInformation("Prepared the list of allowed data sources for the agent.");
additionalData.Add("availableDataSources", sb.ToString());
//
// 4. Let the agent select the data sources:
//
var prompt = $"""
The system prompt is:
```
{chatThread.SystemPrompt}
```
The user prompt is:
```
{lastPromptContent}
```
""";
// Call the agent:
var aiResponse = await this.ProcessInput(new ContentBlock
{
Time = DateTimeOffset.UtcNow,
ContentType = ContentType.TEXT,
Role = ChatRole.USER,
Content = new ContentText
{
Text = prompt,
},
}, additionalData);
if(aiResponse.Content is null)
{
logger.LogWarning("The agent did not return a response.");
return [];
}
switch (aiResponse)
{
//
// 5. Parse the agent response:
//
case { ContentType: ContentType.TEXT, Content: ContentText textContent }:
{
//
// What we expect is a JSON list of SelectedDataSource objects:
//
var selectedDataSourcesJson = textContent.Text;
//
// We know how bad LLM may be in generating JSON without surrounding text.
// Thus, we expect the worst and try to extract the JSON list from the text:
//
var json = this.ExtractJson(selectedDataSourcesJson);
try
{
var aiSelectedDataSources = JsonSerializer.Deserialize<List<SelectedDataSource>>(json, JSON_SERIALIZER_OPTIONS);
return aiSelectedDataSources ?? [];
}
catch
{
logger.LogWarning("The agent answered with an invalid or unexpected JSON format.");
return [];
}
}
case { ContentType: ContentType.TEXT }:
logger.LogWarning("The agent answered with an unexpected inner content type.");
return [];
case { ContentType: ContentType.NONE }:
logger.LogWarning("The agent did not return a response.");
return [];
default:
logger.LogWarning($"The agent answered with an unexpected content type '{aiResponse.ContentType}'.");
return [];
}
}
/// <summary>
/// Extracts the JSON list from the given text. The text may contain additional
/// information around the JSON list. The method tries to extract the JSON list
/// from the text.
/// </summary>
/// <remarks>
/// Algorithm: The method searches for the first line that contains only a '[' character.
/// Then, it searches for the first line that contains only a ']' character. The method
/// returns the text between these two lines (including the brackets). When the method
/// cannot find the JSON list, it returns an empty string.
/// <br/><br/>
/// This overload is using strings instead of spans. We can use this overload in any
/// async method. Thus, it is a wrapper around the span-based method. Yes, we are losing
/// the memory efficiency of the span-based method, but we still gain the performance
/// of the span-based method: the entire search algorithm is span-based.
/// </remarks>
/// <param name="text">The text that may contain the JSON list.</param>
/// <returns>The extracted JSON list.</returns>
private string ExtractJson(string text) => ExtractJson(text.AsSpan()).ToString();
/// <summary>
/// Extracts the JSON list from the given text. The text may contain additional
/// information around the JSON list. The method tries to extract the JSON list
/// from the text.
/// </summary>
/// <remarks>
/// Algorithm: The method searches for the first line that contains only a '[' character.
/// Then, it searches for the first line that contains only a ']' character. The method
/// returns the text between these two lines (including the brackets). When the method
/// cannot find the JSON list, it returns an empty string.
/// </remarks>
/// <param name="text">The text that may contain the JSON list.</param>
/// <returns>The extracted JSON list.</returns>
private static ReadOnlySpan<char> ExtractJson(ReadOnlySpan<char> text)
{
var startIndex = -1;
var endIndex = -1;
var foundStart = false;
var foundEnd = false;
var lineStart = 0;
for (var i = 0; i <= text.Length; i++)
{
// Handle the end of the line or the end of the text:
if (i == text.Length || text[i] == '\n')
{
if (IsCharacterAloneInLine(text, lineStart, i, '[') && !foundStart)
{
startIndex = lineStart;
foundStart = true;
}
else if (IsCharacterAloneInLine(text, lineStart, i, ']') && foundStart && !foundEnd)
{
endIndex = i;
foundEnd = true;
break;
}
lineStart = i + 1;
}
}
if (foundStart && foundEnd)
{
// Adjust endIndex for slicing, ensuring it's within bounds:
return text.Slice(startIndex, Math.Min(text.Length, endIndex + 1) - startIndex);
}
return ReadOnlySpan<char>.Empty;
}
private static bool IsCharacterAloneInLine(ReadOnlySpan<char> text, int lineStart, int lineEnd, char character)
{
for (var i = lineStart; i < lineEnd; i++)
if (!char.IsWhiteSpace(text[i]) && text[i] != character)
return false;
return true;
}
}

View File

@ -0,0 +1,9 @@
namespace AIStudio.Agents;
/// <summary>
/// Represents a selected data source, chosen by the agent.
/// </summary>
/// <param name="Id">The data source ID.</param>
/// <param name="Reason">The reason for selecting the data source.</param>
/// <param name="Confidence">The confidence of the agent in the selection.</param>
public readonly record struct SelectedDataSource(string Id, string Reason, float Confidence);

View File

@ -1,5 +1,6 @@
using System.Text.Json.Serialization;
using AIStudio.Agents;
using AIStudio.Provider;
using AIStudio.Settings;
using AIStudio.Tools.Services;
@ -41,11 +42,19 @@ public sealed class ContentText : IContent
if(chatThread is null)
return;
var logger = Program.SERVICE_PROVIDER.GetService<ILogger<ContentText>>()!;
//
// Check if the user wants to bind any data sources to the chat:
// 1. Check if the user wants to bind any data sources to the chat:
//
if (chatThread.DataSourceOptions.IsEnabled())
if (chatThread.DataSourceOptions.IsEnabled() && lastPrompt is not null)
{
logger.LogInformation("Data sources are enabled for this chat.");
// Across the different code-branches, we keep track of whether it
// makes sense to proceed with the RAG process:
var proceedWithRAG = true;
//
// When the user wants to bind data sources to the chat, we
// have to check if the data sources are available for the
@ -61,16 +70,122 @@ public sealed class ContentText : IContent
//
if (chatThread.DataSourceOptions.AutomaticDataSourceSelection)
{
// TODO: Start agent based on allowed data sources.
// Get the agent for the data source selection:
var selectionAgent = Program.SERVICE_PROVIDER.GetService<AgentDataSourceSelection>()!;
// Let the AI agent do its work:
var aiSelectedDataSources = await selectionAgent.PerformSelectionAsync(provider, lastPrompt, chatThread, dataSources, token);
// Check if the AI selected any data sources:
if(aiSelectedDataSources.Count is 0)
{
logger.LogWarning("The AI did not select any data sources. The RAG process is skipped.");
proceedWithRAG = false;
}
else
{
// Log the selected data sources:
var selectedDataSourceInfo = aiSelectedDataSources.Select(ds => $"[Id={ds.Id}, reason={ds.Reason}, confidence={ds.Confidence}]").Aggregate((a, b) => $"'{a}', '{b}'");
logger.LogInformation($"The AI selected the data sources automatically. {aiSelectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}.");
//
// Check how many data sources were hallucinated by the AI:
//
var totalAISelectedDataSources = aiSelectedDataSources.Count;
// Filter out the data sources that are not available:
aiSelectedDataSources = aiSelectedDataSources.Where(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id) is not null).ToList();
var numHallucinatedSources = totalAISelectedDataSources - aiSelectedDataSources.Count;
if(numHallucinatedSources > 0)
logger.LogWarning($"The AI hallucinated {numHallucinatedSources} data source(s). We ignore them.");
if (aiSelectedDataSources.Count > 3)
{
//
// We have more than 3 data sources. Let's filter by confidence.
// In order to do that, we must identify the lower and upper
// bounds of the confidence interval:
//
var confidenceValues = aiSelectedDataSources.Select(x => x.Confidence).ToList();
var lowerBound = confidenceValues.Min();
var upperBound = confidenceValues.Max();
//
// Next, we search for a threshold so that we have between 2 and 3
// data sources. When not possible, we take all data sources.
//
var threshold = 0.0f;
// Check the case where the confidence values are too close:
if (upperBound - lowerBound >= 0.01)
{
var previousThreshold = 0.0f;
for (var i = 0; i < 10; i++)
{
threshold = lowerBound + (upperBound - lowerBound) * i / 10;
var numMatches = aiSelectedDataSources.Count(x => x.Confidence >= threshold);
if (numMatches <= 1)
{
threshold = previousThreshold;
break;
}
if (numMatches is <= 3 and >= 2)
break;
previousThreshold = threshold;
}
}
//
// Filter the data sources by the threshold:
//
aiSelectedDataSources = aiSelectedDataSources.Where(x => x.Confidence >= threshold).ToList();
logger.LogInformation($"The AI selected {aiSelectedDataSources.Count} data source(s) with a confidence of at least {threshold}.");
// Transform the final data sources to the actual data sources:
selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!;
}
// We have max. 3 data sources. We take all of them:
else
{
// Transform the selected data sources to the actual data sources.
selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!;
}
}
}
else
{
//
// No, the user made the choice manually:
//
var selectedDataSourceInfo = selectedDataSources.Select(ds => ds.Name).Aggregate((a, b) => $"'{a}', '{b}'");
logger.LogInformation($"The user selected the data sources manually. {selectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}.");
}
if(selectedDataSources.Count == 0)
{
logger.LogWarning("No data sources are selected. The RAG process is skipped.");
proceedWithRAG = false;
}
//
// Trigger the retrieval part of the (R)AG process:
//
if (proceedWithRAG)
{
}
//
// Perform the augmentation of the R(A)G process:
//
if (proceedWithRAG)
{
}
}
// Store the last time we got a response. We use this later

View File

@ -0,0 +1,11 @@
@inherits SettingsPanelBase
<ExpansionPanel HeaderIcon="@Icons.Material.Filled.SelectAll" HeaderText="Agent: Data Source Selection Options">
<MudPaper Class="pa-3 mb-8 border-dashed border rounded-lg">
<MudText Typo="Typo.body1" Class="mb-3">
Use Case: this agent is used to select the appropriate data sources for the current prompt.
</MudText>
<ConfigurationOption OptionDescription="Preselect data source selection options?" LabelOn="Options are preselected" LabelOff="No options are preselected" State="@(() => this.SettingsManager.ConfigurationData.AgentDataSourceSelection.PreselectAgentOptions)" StateUpdate="@(updatedState => this.SettingsManager.ConfigurationData.AgentDataSourceSelection.PreselectAgentOptions = updatedState)" OptionHelp="When enabled, you can preselect some agent options. This is might be useful when you prefer a LLM."/>
<ConfigurationProviderSelection Data="@this.AvailableLLMProvidersFunc()" Disabled="@(() => !this.SettingsManager.ConfigurationData.AgentDataSourceSelection.PreselectAgentOptions)" SelectedValue="@(() => this.SettingsManager.ConfigurationData.AgentDataSourceSelection.PreselectedAgentProvider)" SelectionUpdate="@(selectedValue => this.SettingsManager.ConfigurationData.AgentDataSourceSelection.PreselectedAgentProvider = selectedValue)"/>
</MudPaper>
</ExpansionPanel>

View File

@ -0,0 +1,3 @@
namespace AIStudio.Components.Settings;
public partial class SettingsPanelAgentDataSourceSelection : SettingsPanelBase;

View File

@ -27,6 +27,7 @@
<SettingsPanelSynonyms AvailableLLMProvidersFunc="() => this.availableLLMProviders" />
<SettingsPanelMyTasks AvailableLLMProvidersFunc="() => this.availableLLMProviders" />
<SettingsPanelAssistantBias AvailableLLMProvidersFunc="() => this.availableLLMProviders" />
<SettingsPanelAgentDataSourceSelection AvailableLLMProvidersFunc="() => this.availableLLMProviders" />
<SettingsPanelAgentContentCleaner AvailableLLMProvidersFunc="() => this.availableLLMProviders" />
</MudExpansionPanels>
</InnerScrolling>

View File

@ -118,6 +118,7 @@ internal sealed class Program
builder.Services.AddSingleton<ThreadSafeRandom>();
builder.Services.AddSingleton<DataSourceService>();
builder.Services.AddTransient<HTMLParser>();
builder.Services.AddTransient<AgentDataSourceSelection>();
builder.Services.AddTransient<AgentTextContentCleaner>();
builder.Services.AddHostedService<UpdateService>();
builder.Services.AddHostedService<TemporaryChatService>();

View File

@ -74,6 +74,8 @@ public sealed class Data
public DataTextContentCleaner TextContentCleaner { get; init; } = new();
public DataAgentDataSourceSelection AgentDataSourceSelection { get; init; } = new();
public DataAgenda Agenda { get; init; } = new();
public DataGrammarSpelling GrammarSpelling { get; init; } = new();

View File

@ -0,0 +1,14 @@
namespace AIStudio.Settings.DataModel;
public sealed class DataAgentDataSourceSelection
{
/// <summary>
/// Preselect any text content cleaner options?
/// </summary>
public bool PreselectAgentOptions { get; set; }
/// <summary>
/// Preselect a text content cleaner provider?
/// </summary>
public string PreselectedAgentProvider { get; set; } = string.Empty;
}

View File

@ -2,5 +2,6 @@
- Added the possibility to select data sources for chats. This preview feature is hidden behind the RAG feature flag, check your app options in case you want to enable it.
- Added an option to all data sources to select a local security policy. This preview feature is hidden behind the RAG feature flag.
- Added an option to preselect data sources and options for new chats. This preview feature is hidden behind the RAG feature flag.
- Added an agent to select the appropriate data sources for any prompt. This preview feature is hidden behind the RAG feature flag.
- Improved confidence card for small spaces.
- Fixed a bug in which 'APP_SETTINGS' appeared as a valid destination in the "send to" menu.