mirror of
https://github.com/MindWorkAI/AI-Studio.git
synced 2025-03-12 18:09:06 +00:00
Refactored the RAG process (#287)
This commit is contained in:
parent
77d427610b
commit
f01cf498e2
@ -19,7 +19,7 @@ Things we are currently working on:
|
||||
- [ ] App: Implement the process to vectorize one local file using embeddings
|
||||
- [ ] Runtime: Integration of the vector database [LanceDB](https://github.com/lancedb/lancedb)
|
||||
- [ ] App: Implement the continuous process of vectorizing data
|
||||
- [x] ~~App: Define a common retrieval context interface for the integration of RAG processes in chats (PR [#281](https://github.com/MindWorkAI/AI-Studio/pull/281), [#284](https://github.com/MindWorkAI/AI-Studio/pull/284), [#286](https://github.com/MindWorkAI/AI-Studio/pull/286))~~
|
||||
- [x] ~~App: Define a common retrieval context interface for the integration of RAG processes in chats (PR [#281](https://github.com/MindWorkAI/AI-Studio/pull/281), [#284](https://github.com/MindWorkAI/AI-Studio/pull/284), [#286](https://github.com/MindWorkAI/AI-Studio/pull/286), [#287](https://github.com/MindWorkAI/AI-Studio/pull/287))~~
|
||||
- [ ] App: Define a common augmentation interface for the integration of RAG processes in chats
|
||||
- [x] ~~App: Integrate data sources in chats (PR [#282](https://github.com/MindWorkAI/AI-Studio/pull/282))~~
|
||||
|
||||
|
@ -1,11 +1,8 @@
|
||||
using System.Text.Json.Serialization;
|
||||
|
||||
using AIStudio.Agents;
|
||||
using AIStudio.Components;
|
||||
using AIStudio.Provider;
|
||||
using AIStudio.Settings;
|
||||
using AIStudio.Tools.RAG;
|
||||
using AIStudio.Tools.Services;
|
||||
using AIStudio.Tools.RAG.RAGProcesses;
|
||||
|
||||
namespace AIStudio.Chat;
|
||||
|
||||
@ -44,195 +41,11 @@ public sealed class ContentText : IContent
|
||||
if(chatThread is null)
|
||||
return;
|
||||
|
||||
var logger = Program.SERVICE_PROVIDER.GetService<ILogger<ContentText>>()!;
|
||||
var settings = Program.SERVICE_PROVIDER.GetService<SettingsManager>()!;
|
||||
var dataSourceService = Program.SERVICE_PROVIDER.GetService<DataSourceService>()!;
|
||||
|
||||
//
|
||||
// 1. Check if the user wants to bind any data sources to the chat:
|
||||
//
|
||||
if (chatThread.DataSourceOptions.IsEnabled() && lastPrompt is not null)
|
||||
// Call the RAG process. Right now, we only have one RAG process:
|
||||
if (lastPrompt is not null)
|
||||
{
|
||||
logger.LogInformation("Data sources are enabled for this chat.");
|
||||
|
||||
// Across the different code-branches, we keep track of whether it
|
||||
// makes sense to proceed with the RAG process:
|
||||
var proceedWithRAG = true;
|
||||
|
||||
//
|
||||
// When the user wants to bind data sources to the chat, we
|
||||
// have to check if the data sources are available for the
|
||||
// selected provider. Also, we have to check if any ERI
|
||||
// data sources changed its security requirements.
|
||||
//
|
||||
List<IDataSource> preselectedDataSources = chatThread.DataSourceOptions.PreselectedDataSourceIds.Select(id => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == id)).Where(ds => ds is not null).ToList()!;
|
||||
var dataSources = await dataSourceService.GetDataSources(provider, preselectedDataSources);
|
||||
var selectedDataSources = dataSources.SelectedDataSources;
|
||||
|
||||
//
|
||||
// Should the AI select the data sources?
|
||||
//
|
||||
if (chatThread.DataSourceOptions.AutomaticDataSourceSelection)
|
||||
{
|
||||
// Get the agent for the data source selection:
|
||||
var selectionAgent = Program.SERVICE_PROVIDER.GetService<AgentDataSourceSelection>()!;
|
||||
|
||||
// Let the AI agent do its work:
|
||||
IReadOnlyList<DataSourceAgentSelected> finalAISelection = [];
|
||||
var aiSelectedDataSources = await selectionAgent.PerformSelectionAsync(provider, lastPrompt, chatThread, dataSources, token);
|
||||
|
||||
// Check if the AI selected any data sources:
|
||||
if(aiSelectedDataSources.Count is 0)
|
||||
{
|
||||
logger.LogWarning("The AI did not select any data sources. The RAG process is skipped.");
|
||||
proceedWithRAG = false;
|
||||
|
||||
// Send the selected data sources to the data source selection component.
|
||||
// Then, the user can see which data sources were selected by the AI.
|
||||
await MessageBus.INSTANCE.SendMessage(null, Event.RAG_AUTO_DATA_SOURCES_SELECTED, finalAISelection);
|
||||
chatThread.AISelectedDataSources = finalAISelection;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Log the selected data sources:
|
||||
var selectedDataSourceInfo = aiSelectedDataSources.Select(ds => $"[Id={ds.Id}, reason={ds.Reason}, confidence={ds.Confidence}]").Aggregate((a, b) => $"'{a}', '{b}'");
|
||||
logger.LogInformation($"The AI selected the data sources automatically. {aiSelectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}.");
|
||||
|
||||
//
|
||||
// Check how many data sources were hallucinated by the AI:
|
||||
//
|
||||
var totalAISelectedDataSources = aiSelectedDataSources.Count;
|
||||
|
||||
// Filter out the data sources that are not available:
|
||||
aiSelectedDataSources = aiSelectedDataSources.Where(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id) is not null).ToList();
|
||||
|
||||
// Store the real AI-selected data sources:
|
||||
finalAISelection = aiSelectedDataSources.Select(x => new DataSourceAgentSelected { DataSource = settings.ConfigurationData.DataSources.First(ds => ds.Id == x.Id), AIDecision = x, Selected = false }).ToList();
|
||||
|
||||
var numHallucinatedSources = totalAISelectedDataSources - aiSelectedDataSources.Count;
|
||||
if(numHallucinatedSources > 0)
|
||||
logger.LogWarning($"The AI hallucinated {numHallucinatedSources} data source(s). We ignore them.");
|
||||
|
||||
if (aiSelectedDataSources.Count > 3)
|
||||
{
|
||||
//
|
||||
// We have more than 3 data sources. Let's filter by confidence.
|
||||
// In order to do that, we must identify the lower and upper
|
||||
// bounds of the confidence interval:
|
||||
//
|
||||
var confidenceValues = aiSelectedDataSources.Select(x => x.Confidence).ToList();
|
||||
var lowerBound = confidenceValues.Min();
|
||||
var upperBound = confidenceValues.Max();
|
||||
|
||||
//
|
||||
// Next, we search for a threshold so that we have between 2 and 3
|
||||
// data sources. When not possible, we take all data sources.
|
||||
//
|
||||
var threshold = 0.0f;
|
||||
|
||||
// Check the case where the confidence values are too close:
|
||||
if (upperBound - lowerBound >= 0.01)
|
||||
{
|
||||
var previousThreshold = 0.0f;
|
||||
for (var i = 0; i < 10; i++)
|
||||
{
|
||||
threshold = lowerBound + (upperBound - lowerBound) * i / 10;
|
||||
var numMatches = aiSelectedDataSources.Count(x => x.Confidence >= threshold);
|
||||
if (numMatches <= 1)
|
||||
{
|
||||
threshold = previousThreshold;
|
||||
break;
|
||||
}
|
||||
|
||||
if (numMatches is <= 3 and >= 2)
|
||||
break;
|
||||
|
||||
previousThreshold = threshold;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Filter the data sources by the threshold:
|
||||
//
|
||||
aiSelectedDataSources = aiSelectedDataSources.Where(x => x.Confidence >= threshold).ToList();
|
||||
foreach (var dataSource in finalAISelection)
|
||||
if(aiSelectedDataSources.Any(x => x.Id == dataSource.DataSource.Id))
|
||||
dataSource.Selected = true;
|
||||
|
||||
logger.LogInformation($"The AI selected {aiSelectedDataSources.Count} data source(s) with a confidence of at least {threshold}.");
|
||||
|
||||
// Transform the final data sources to the actual data sources:
|
||||
selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!;
|
||||
}
|
||||
|
||||
// We have max. 3 data sources. We take all of them:
|
||||
else
|
||||
{
|
||||
// Transform the selected data sources to the actual data sources:
|
||||
selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!;
|
||||
|
||||
// Mark the data sources as selected:
|
||||
foreach (var dataSource in finalAISelection)
|
||||
dataSource.Selected = true;
|
||||
}
|
||||
|
||||
// Send the selected data sources to the data source selection component.
|
||||
// Then, the user can see which data sources were selected by the AI.
|
||||
await MessageBus.INSTANCE.SendMessage(null, Event.RAG_AUTO_DATA_SOURCES_SELECTED, finalAISelection);
|
||||
chatThread.AISelectedDataSources = finalAISelection;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
//
|
||||
// No, the user made the choice manually:
|
||||
//
|
||||
var selectedDataSourceInfo = selectedDataSources.Select(ds => ds.Name).Aggregate((a, b) => $"'{a}', '{b}'");
|
||||
logger.LogInformation($"The user selected the data sources manually. {selectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}.");
|
||||
}
|
||||
|
||||
if(selectedDataSources.Count == 0)
|
||||
{
|
||||
logger.LogWarning("No data sources are selected. The RAG process is skipped.");
|
||||
proceedWithRAG = false;
|
||||
}
|
||||
|
||||
//
|
||||
// Trigger the retrieval part of the (R)AG process:
|
||||
//
|
||||
var dataContexts = new List<IRetrievalContext>();
|
||||
if (proceedWithRAG)
|
||||
{
|
||||
//
|
||||
// We kick off the retrieval process for each data source in parallel:
|
||||
//
|
||||
var retrievalTasks = new List<Task<IReadOnlyList<IRetrievalContext>>>(selectedDataSources.Count);
|
||||
foreach (var dataSource in selectedDataSources)
|
||||
retrievalTasks.Add(dataSource.RetrieveDataAsync(lastPrompt, chatThread, token));
|
||||
|
||||
//
|
||||
// Wait for all retrieval tasks to finish:
|
||||
//
|
||||
foreach (var retrievalTask in retrievalTasks)
|
||||
{
|
||||
try
|
||||
{
|
||||
dataContexts.AddRange(await retrievalTask);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
logger.LogError(e, "An error occurred during the retrieval process.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Perform the augmentation of the R(A)G process:
|
||||
//
|
||||
if (proceedWithRAG)
|
||||
{
|
||||
|
||||
}
|
||||
var rag = new AISrcSelWithRetCtxVal();
|
||||
chatThread = await rag.ProcessAsync(provider, lastPrompt, chatThread, token);
|
||||
}
|
||||
|
||||
// Store the last time we got a response. We use this later
|
||||
@ -241,6 +54,9 @@ public sealed class ContentText : IContent
|
||||
// the user chose.
|
||||
var last = DateTimeOffset.Now;
|
||||
|
||||
// Get the settings manager:
|
||||
var settings = Program.SERVICE_PROVIDER.GetService<SettingsManager>()!;
|
||||
|
||||
// Start another thread by using a task to uncouple
|
||||
// the UI thread from the AI processing:
|
||||
await Task.Run(async () =>
|
||||
|
10
app/MindWork AI Studio/Tools/RAG/DataSelectionResult.cs
Normal file
10
app/MindWork AI Studio/Tools/RAG/DataSelectionResult.cs
Normal file
@ -0,0 +1,10 @@
|
||||
using AIStudio.Settings;
|
||||
|
||||
namespace AIStudio.Tools.RAG;
|
||||
|
||||
/// <summary>
|
||||
/// Result of any data selection process.
|
||||
/// </summary>
|
||||
/// <param name="ProceedWithRAG">Makes it sense to proceed with the RAG process?</param>
|
||||
/// <param name="SelectedDataSources">The selected data sources.</param>
|
||||
public readonly record struct DataSelectionResult(bool ProceedWithRAG, IReadOnlyList<IDataSource> SelectedDataSources);
|
@ -0,0 +1,147 @@
|
||||
using AIStudio.Agents;
|
||||
using AIStudio.Chat;
|
||||
using AIStudio.Components;
|
||||
using AIStudio.Provider;
|
||||
using AIStudio.Settings;
|
||||
|
||||
namespace AIStudio.Tools.RAG.DataSourceSelectionProcesses;
|
||||
|
||||
public class AgenticSrcSelWithDynHeur : IDataSourceSelectionProcess
|
||||
{
|
||||
#region Implementation of IDataSourceSelectionProcess
|
||||
|
||||
/// <inheritdoc />
|
||||
public string TechnicalName => "AgenticSrcSelWithDynHeur";
|
||||
|
||||
/// <inheritdoc />
|
||||
public string UIName => "Automatic AI data source selection with heuristik source reduction";
|
||||
|
||||
/// <inheritdoc />
|
||||
public string Description => "Automatically selects the appropriate data sources based on the last prompt. Applies a heuristic reduction at the end to reduce the number of data sources.";
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<DataSelectionResult> SelectDataSourcesAsync(IProvider provider, IContent lastPrompt, ChatThread chatThread, AllowedSelectedDataSources dataSources, CancellationToken token = default)
|
||||
{
|
||||
var proceedWithRAG = true;
|
||||
IReadOnlyList<IDataSource> selectedDataSources = [];
|
||||
IReadOnlyList<DataSourceAgentSelected> finalAISelection = [];
|
||||
|
||||
// Get the logger:
|
||||
var logger = Program.SERVICE_PROVIDER.GetService<ILogger<AgenticSrcSelWithDynHeur>>()!;
|
||||
|
||||
// Get the settings manager:
|
||||
var settings = Program.SERVICE_PROVIDER.GetService<SettingsManager>()!;
|
||||
|
||||
// Get the agent for the data source selection:
|
||||
var selectionAgent = Program.SERVICE_PROVIDER.GetService<AgentDataSourceSelection>()!;
|
||||
|
||||
try
|
||||
{
|
||||
// Let the AI agent do its work:
|
||||
var aiSelectedDataSources = await selectionAgent.PerformSelectionAsync(provider, lastPrompt, chatThread, dataSources, token);
|
||||
|
||||
// Check if the AI selected any data sources:
|
||||
if (aiSelectedDataSources.Count is 0)
|
||||
{
|
||||
logger.LogWarning("The AI did not select any data sources. The RAG process is skipped.");
|
||||
proceedWithRAG = false;
|
||||
|
||||
return new(proceedWithRAG, selectedDataSources);
|
||||
}
|
||||
|
||||
// Log the selected data sources:
|
||||
var selectedDataSourceInfo = aiSelectedDataSources.Select(ds => $"[Id={ds.Id}, reason={ds.Reason}, confidence={ds.Confidence}]").Aggregate((a, b) => $"'{a}', '{b}'");
|
||||
logger.LogInformation($"The AI selected the data sources automatically. {aiSelectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}.");
|
||||
|
||||
//
|
||||
// Check how many data sources were hallucinated by the AI:
|
||||
//
|
||||
var totalAISelectedDataSources = aiSelectedDataSources.Count;
|
||||
|
||||
// Filter out the data sources that are not available:
|
||||
aiSelectedDataSources = aiSelectedDataSources.Where(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id) is not null).ToList();
|
||||
|
||||
// Store the real AI-selected data sources:
|
||||
finalAISelection = aiSelectedDataSources.Select(x => new DataSourceAgentSelected { DataSource = settings.ConfigurationData.DataSources.First(ds => ds.Id == x.Id), AIDecision = x, Selected = false }).ToList();
|
||||
|
||||
var numHallucinatedSources = totalAISelectedDataSources - aiSelectedDataSources.Count;
|
||||
if (numHallucinatedSources > 0)
|
||||
logger.LogWarning($"The AI hallucinated {numHallucinatedSources} data source(s). We ignore them.");
|
||||
|
||||
if (aiSelectedDataSources.Count > 3)
|
||||
{
|
||||
//
|
||||
// We have more than 3 data sources. Let's filter by confidence.
|
||||
// In order to do that, we must identify the lower and upper
|
||||
// bounds of the confidence interval:
|
||||
//
|
||||
var confidenceValues = aiSelectedDataSources.Select(x => x.Confidence).ToList();
|
||||
var lowerBound = confidenceValues.Min();
|
||||
var upperBound = confidenceValues.Max();
|
||||
|
||||
//
|
||||
// Next, we search for a threshold so that we have between 2 and 3
|
||||
// data sources. When not possible, we take all data sources.
|
||||
//
|
||||
var threshold = 0.0f;
|
||||
|
||||
// Check the case where the confidence values are too close:
|
||||
if (upperBound - lowerBound >= 0.01)
|
||||
{
|
||||
var previousThreshold = 0.0f;
|
||||
for (var i = 0; i < 10; i++)
|
||||
{
|
||||
threshold = lowerBound + (upperBound - lowerBound) * i / 10;
|
||||
var numMatches = aiSelectedDataSources.Count(x => x.Confidence >= threshold);
|
||||
if (numMatches <= 1)
|
||||
{
|
||||
threshold = previousThreshold;
|
||||
break;
|
||||
}
|
||||
|
||||
if (numMatches is <= 3 and >= 2)
|
||||
break;
|
||||
|
||||
previousThreshold = threshold;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Filter the data sources by the threshold:
|
||||
//
|
||||
aiSelectedDataSources = aiSelectedDataSources.Where(x => x.Confidence >= threshold).ToList();
|
||||
foreach (var dataSource in finalAISelection)
|
||||
if (aiSelectedDataSources.Any(x => x.Id == dataSource.DataSource.Id))
|
||||
dataSource.Selected = true;
|
||||
|
||||
logger.LogInformation($"The AI selected {aiSelectedDataSources.Count} data source(s) with a confidence of at least {threshold}.");
|
||||
|
||||
// Transform the final data sources to the actual data sources:
|
||||
selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!;
|
||||
return new(proceedWithRAG, selectedDataSources);
|
||||
}
|
||||
|
||||
//
|
||||
// Case: we have max. 3 data sources. We take all of them:
|
||||
//
|
||||
|
||||
// Transform the selected data sources to the actual data sources:
|
||||
selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!;
|
||||
|
||||
// Mark the data sources as selected:
|
||||
foreach (var dataSource in finalAISelection)
|
||||
dataSource.Selected = true;
|
||||
|
||||
return new(proceedWithRAG, selectedDataSources);
|
||||
}
|
||||
finally
|
||||
{
|
||||
// Send the selected data sources to the data source selection component.
|
||||
// Then, the user can see which data sources were selected by the AI.
|
||||
await MessageBus.INSTANCE.SendMessage(null, Event.RAG_AUTO_DATA_SOURCES_SELECTED, finalAISelection);
|
||||
chatThread.AISelectedDataSources = finalAISelection;
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
using AIStudio.Chat;
|
||||
using AIStudio.Provider;
|
||||
|
||||
namespace AIStudio.Tools.RAG;
|
||||
|
||||
public interface IDataSourceSelectionProcess
|
||||
{
|
||||
/// <summary>
|
||||
/// How is the RAG process called?
|
||||
/// </summary>
|
||||
public string TechnicalName { get; }
|
||||
|
||||
/// <summary>
|
||||
/// How is the RAG process called in the UI?
|
||||
/// </summary>
|
||||
public string UIName { get; }
|
||||
|
||||
/// <summary>
|
||||
/// How works the RAG process?
|
||||
/// </summary>
|
||||
public string Description { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Starts the data source selection process.
|
||||
/// </summary>
|
||||
/// <param name="provider">The LLM provider. Used as default for data selection agents.</param>
|
||||
/// <param name="lastPrompt">The last prompt that was issued by the user.</param>
|
||||
/// <param name="chatThread">The chat thread.</param>
|
||||
/// <param name="dataSources">The allowed data sources yielded by the data source service.</param>
|
||||
/// <param name="token">The cancellation token.</param>
|
||||
/// <returns></returns>
|
||||
public Task<DataSelectionResult> SelectDataSourcesAsync(IProvider provider, IContent lastPrompt, ChatThread chatThread, AllowedSelectedDataSources dataSources, CancellationToken token = default);
|
||||
}
|
32
app/MindWork AI Studio/Tools/RAG/IRagProcess.cs
Normal file
32
app/MindWork AI Studio/Tools/RAG/IRagProcess.cs
Normal file
@ -0,0 +1,32 @@
|
||||
using AIStudio.Chat;
|
||||
using AIStudio.Provider;
|
||||
|
||||
namespace AIStudio.Tools.RAG;
|
||||
|
||||
public interface IRagProcess
|
||||
{
|
||||
/// <summary>
|
||||
/// How is the RAG process called?
|
||||
/// </summary>
|
||||
public string TechnicalName { get; }
|
||||
|
||||
/// <summary>
|
||||
/// How is the RAG process called in the UI?
|
||||
/// </summary>
|
||||
public string UIName { get; }
|
||||
|
||||
/// <summary>
|
||||
/// How works the RAG process?
|
||||
/// </summary>
|
||||
public string Description { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Starts the RAG process.
|
||||
/// </summary>
|
||||
/// <param name="provider">The LLM provider. Used to check whether the data sources are allowed to be used by this LLM.</param>
|
||||
/// <param name="lastPrompt">The last prompt that was issued by the user.</param>
|
||||
/// <param name="chatThread">The chat thread.</param>
|
||||
/// <param name="token">The cancellation token.</param>
|
||||
/// <returns>The altered chat thread.</returns>
|
||||
public Task<ChatThread> ProcessAsync(IProvider provider, IContent lastPrompt, ChatThread chatThread, CancellationToken token = default);
|
||||
}
|
@ -0,0 +1,117 @@
|
||||
using AIStudio.Chat;
|
||||
using AIStudio.Provider;
|
||||
using AIStudio.Settings;
|
||||
using AIStudio.Tools.RAG.DataSourceSelectionProcesses;
|
||||
using AIStudio.Tools.Services;
|
||||
|
||||
namespace AIStudio.Tools.RAG.RAGProcesses;
|
||||
|
||||
public sealed class AISrcSelWithRetCtxVal : IRagProcess
|
||||
{
|
||||
#region Implementation of IRagProcess
|
||||
|
||||
/// <inheritdoc />
|
||||
public string TechnicalName => "AISrcSelWithRetCtxVal";
|
||||
|
||||
/// <inheritdoc />
|
||||
public string UIName => "AI source selection with AI retrieval context validation";
|
||||
|
||||
/// <inheritdoc />
|
||||
public string Description => "This RAG process filters data sources, automatically selects appropriate sources, optionally allows manual source selection, retrieves data, and automatically validates the retrieval context.";
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<ChatThread> ProcessAsync(IProvider provider, IContent lastPrompt, ChatThread chatThread, CancellationToken token = default)
|
||||
{
|
||||
var logger = Program.SERVICE_PROVIDER.GetService<ILogger<AISrcSelWithRetCtxVal>>()!;
|
||||
var settings = Program.SERVICE_PROVIDER.GetService<SettingsManager>()!;
|
||||
var dataSourceService = Program.SERVICE_PROVIDER.GetService<DataSourceService>()!;
|
||||
|
||||
//
|
||||
// 1. Check if the user wants to bind any data sources to the chat:
|
||||
//
|
||||
if (chatThread.DataSourceOptions.IsEnabled())
|
||||
{
|
||||
logger.LogInformation("Data sources are enabled for this chat.");
|
||||
|
||||
// Across the different code-branches, we keep track of whether it
|
||||
// makes sense to proceed with the RAG process:
|
||||
var proceedWithRAG = true;
|
||||
|
||||
//
|
||||
// When the user wants to bind data sources to the chat, we
|
||||
// have to check if the data sources are available for the
|
||||
// selected provider. Also, we have to check if any ERI
|
||||
// data sources changed its security requirements.
|
||||
//
|
||||
List<IDataSource> preselectedDataSources = chatThread.DataSourceOptions.PreselectedDataSourceIds.Select(id => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == id)).Where(ds => ds is not null).ToList()!;
|
||||
var dataSources = await dataSourceService.GetDataSources(provider, preselectedDataSources);
|
||||
var selectedDataSources = dataSources.SelectedDataSources;
|
||||
|
||||
//
|
||||
// Should the AI select the data sources?
|
||||
//
|
||||
if (chatThread.DataSourceOptions.AutomaticDataSourceSelection)
|
||||
{
|
||||
var dataSourceSelectionProcess = new AgenticSrcSelWithDynHeur();
|
||||
var result = await dataSourceSelectionProcess.SelectDataSourcesAsync(provider, lastPrompt, chatThread, dataSources, token);
|
||||
proceedWithRAG = result.ProceedWithRAG;
|
||||
selectedDataSources = result.SelectedDataSources;
|
||||
}
|
||||
else
|
||||
{
|
||||
//
|
||||
// No, the user made the choice manually:
|
||||
//
|
||||
var selectedDataSourceInfo = selectedDataSources.Select(ds => ds.Name).Aggregate((a, b) => $"'{a}', '{b}'");
|
||||
logger.LogInformation($"The user selected the data sources manually. {selectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}.");
|
||||
}
|
||||
|
||||
if(selectedDataSources.Count == 0)
|
||||
{
|
||||
logger.LogWarning("No data sources are selected. The RAG process is skipped.");
|
||||
proceedWithRAG = false;
|
||||
}
|
||||
|
||||
//
|
||||
// Trigger the retrieval part of the (R)AG process:
|
||||
//
|
||||
var dataContexts = new List<IRetrievalContext>();
|
||||
if (proceedWithRAG)
|
||||
{
|
||||
//
|
||||
// We kick off the retrieval process for each data source in parallel:
|
||||
//
|
||||
var retrievalTasks = new List<Task<IReadOnlyList<IRetrievalContext>>>(selectedDataSources.Count);
|
||||
foreach (var dataSource in selectedDataSources)
|
||||
retrievalTasks.Add(dataSource.RetrieveDataAsync(lastPrompt, chatThread, token));
|
||||
|
||||
//
|
||||
// Wait for all retrieval tasks to finish:
|
||||
//
|
||||
foreach (var retrievalTask in retrievalTasks)
|
||||
{
|
||||
try
|
||||
{
|
||||
dataContexts.AddRange(await retrievalTask);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
logger.LogError(e, "An error occurred during the retrieval process.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Perform the augmentation of the R(A)G process:
|
||||
//
|
||||
if (proceedWithRAG)
|
||||
{
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return chatThread;
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
Loading…
Reference in New Issue
Block a user