diff --git a/README.md b/README.md index aa101ab..257852d 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Things we are currently working on: - [ ] App: Implement the process to vectorize one local file using embeddings - [ ] Runtime: Integration of the vector database [LanceDB](https://github.com/lancedb/lancedb) - [ ] App: Implement the continuous process of vectorizing data - - [x] ~~App: Define a common retrieval context interface for the integration of RAG processes in chats (PR [#281](https://github.com/MindWorkAI/AI-Studio/pull/281), [#284](https://github.com/MindWorkAI/AI-Studio/pull/284), [#286](https://github.com/MindWorkAI/AI-Studio/pull/286))~~ + - [x] ~~App: Define a common retrieval context interface for the integration of RAG processes in chats (PR [#281](https://github.com/MindWorkAI/AI-Studio/pull/281), [#284](https://github.com/MindWorkAI/AI-Studio/pull/284), [#286](https://github.com/MindWorkAI/AI-Studio/pull/286), [#287](https://github.com/MindWorkAI/AI-Studio/pull/287))~~ - [ ] App: Define a common augmentation interface for the integration of RAG processes in chats - [x] ~~App: Integrate data sources in chats (PR [#282](https://github.com/MindWorkAI/AI-Studio/pull/282))~~ diff --git a/app/MindWork AI Studio/Chat/ContentText.cs b/app/MindWork AI Studio/Chat/ContentText.cs index 75912e3..9f51175 100644 --- a/app/MindWork AI Studio/Chat/ContentText.cs +++ b/app/MindWork AI Studio/Chat/ContentText.cs @@ -1,11 +1,8 @@ using System.Text.Json.Serialization; -using AIStudio.Agents; -using AIStudio.Components; using AIStudio.Provider; using AIStudio.Settings; -using AIStudio.Tools.RAG; -using AIStudio.Tools.Services; +using AIStudio.Tools.RAG.RAGProcesses; namespace AIStudio.Chat; @@ -43,204 +40,23 @@ public sealed class ContentText : IContent { if(chatThread is null) return; - - var logger = Program.SERVICE_PROVIDER.GetService>()!; - var settings = Program.SERVICE_PROVIDER.GetService()!; - var dataSourceService = Program.SERVICE_PROVIDER.GetService()!; - // - // 1. Check if the user wants to bind any data sources to the chat: - // - if (chatThread.DataSourceOptions.IsEnabled() && lastPrompt is not null) + // Call the RAG process. Right now, we only have one RAG process: + if (lastPrompt is not null) { - logger.LogInformation("Data sources are enabled for this chat."); - - // Across the different code-branches, we keep track of whether it - // makes sense to proceed with the RAG process: - var proceedWithRAG = true; - - // - // When the user wants to bind data sources to the chat, we - // have to check if the data sources are available for the - // selected provider. Also, we have to check if any ERI - // data sources changed its security requirements. - // - List preselectedDataSources = chatThread.DataSourceOptions.PreselectedDataSourceIds.Select(id => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == id)).Where(ds => ds is not null).ToList()!; - var dataSources = await dataSourceService.GetDataSources(provider, preselectedDataSources); - var selectedDataSources = dataSources.SelectedDataSources; - - // - // Should the AI select the data sources? - // - if (chatThread.DataSourceOptions.AutomaticDataSourceSelection) - { - // Get the agent for the data source selection: - var selectionAgent = Program.SERVICE_PROVIDER.GetService()!; - - // Let the AI agent do its work: - IReadOnlyList finalAISelection = []; - var aiSelectedDataSources = await selectionAgent.PerformSelectionAsync(provider, lastPrompt, chatThread, dataSources, token); - - // Check if the AI selected any data sources: - if(aiSelectedDataSources.Count is 0) - { - logger.LogWarning("The AI did not select any data sources. The RAG process is skipped."); - proceedWithRAG = false; - - // Send the selected data sources to the data source selection component. - // Then, the user can see which data sources were selected by the AI. - await MessageBus.INSTANCE.SendMessage(null, Event.RAG_AUTO_DATA_SOURCES_SELECTED, finalAISelection); - chatThread.AISelectedDataSources = finalAISelection; - } - else - { - // Log the selected data sources: - var selectedDataSourceInfo = aiSelectedDataSources.Select(ds => $"[Id={ds.Id}, reason={ds.Reason}, confidence={ds.Confidence}]").Aggregate((a, b) => $"'{a}', '{b}'"); - logger.LogInformation($"The AI selected the data sources automatically. {aiSelectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}."); - - // - // Check how many data sources were hallucinated by the AI: - // - var totalAISelectedDataSources = aiSelectedDataSources.Count; - - // Filter out the data sources that are not available: - aiSelectedDataSources = aiSelectedDataSources.Where(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id) is not null).ToList(); - - // Store the real AI-selected data sources: - finalAISelection = aiSelectedDataSources.Select(x => new DataSourceAgentSelected { DataSource = settings.ConfigurationData.DataSources.First(ds => ds.Id == x.Id), AIDecision = x, Selected = false }).ToList(); - - var numHallucinatedSources = totalAISelectedDataSources - aiSelectedDataSources.Count; - if(numHallucinatedSources > 0) - logger.LogWarning($"The AI hallucinated {numHallucinatedSources} data source(s). We ignore them."); - - if (aiSelectedDataSources.Count > 3) - { - // - // We have more than 3 data sources. Let's filter by confidence. - // In order to do that, we must identify the lower and upper - // bounds of the confidence interval: - // - var confidenceValues = aiSelectedDataSources.Select(x => x.Confidence).ToList(); - var lowerBound = confidenceValues.Min(); - var upperBound = confidenceValues.Max(); - - // - // Next, we search for a threshold so that we have between 2 and 3 - // data sources. When not possible, we take all data sources. - // - var threshold = 0.0f; - - // Check the case where the confidence values are too close: - if (upperBound - lowerBound >= 0.01) - { - var previousThreshold = 0.0f; - for (var i = 0; i < 10; i++) - { - threshold = lowerBound + (upperBound - lowerBound) * i / 10; - var numMatches = aiSelectedDataSources.Count(x => x.Confidence >= threshold); - if (numMatches <= 1) - { - threshold = previousThreshold; - break; - } - - if (numMatches is <= 3 and >= 2) - break; - - previousThreshold = threshold; - } - } - - // - // Filter the data sources by the threshold: - // - aiSelectedDataSources = aiSelectedDataSources.Where(x => x.Confidence >= threshold).ToList(); - foreach (var dataSource in finalAISelection) - if(aiSelectedDataSources.Any(x => x.Id == dataSource.DataSource.Id)) - dataSource.Selected = true; - - logger.LogInformation($"The AI selected {aiSelectedDataSources.Count} data source(s) with a confidence of at least {threshold}."); - - // Transform the final data sources to the actual data sources: - selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!; - } - - // We have max. 3 data sources. We take all of them: - else - { - // Transform the selected data sources to the actual data sources: - selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!; - - // Mark the data sources as selected: - foreach (var dataSource in finalAISelection) - dataSource.Selected = true; - } - - // Send the selected data sources to the data source selection component. - // Then, the user can see which data sources were selected by the AI. - await MessageBus.INSTANCE.SendMessage(null, Event.RAG_AUTO_DATA_SOURCES_SELECTED, finalAISelection); - chatThread.AISelectedDataSources = finalAISelection; - } - } - else - { - // - // No, the user made the choice manually: - // - var selectedDataSourceInfo = selectedDataSources.Select(ds => ds.Name).Aggregate((a, b) => $"'{a}', '{b}'"); - logger.LogInformation($"The user selected the data sources manually. {selectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}."); - } - - if(selectedDataSources.Count == 0) - { - logger.LogWarning("No data sources are selected. The RAG process is skipped."); - proceedWithRAG = false; - } - - // - // Trigger the retrieval part of the (R)AG process: - // - var dataContexts = new List(); - if (proceedWithRAG) - { - // - // We kick off the retrieval process for each data source in parallel: - // - var retrievalTasks = new List>>(selectedDataSources.Count); - foreach (var dataSource in selectedDataSources) - retrievalTasks.Add(dataSource.RetrieveDataAsync(lastPrompt, chatThread, token)); - - // - // Wait for all retrieval tasks to finish: - // - foreach (var retrievalTask in retrievalTasks) - { - try - { - dataContexts.AddRange(await retrievalTask); - } - catch (Exception e) - { - logger.LogError(e, "An error occurred during the retrieval process."); - } - } - } - - // - // Perform the augmentation of the R(A)G process: - // - if (proceedWithRAG) - { - - } + var rag = new AISrcSelWithRetCtxVal(); + chatThread = await rag.ProcessAsync(provider, lastPrompt, chatThread, token); } - + // Store the last time we got a response. We use this later // to determine whether we should notify the UI about the // new content or not. Depends on the energy saving mode // the user chose. var last = DateTimeOffset.Now; + // Get the settings manager: + var settings = Program.SERVICE_PROVIDER.GetService()!; + // Start another thread by using a task to uncouple // the UI thread from the AI processing: await Task.Run(async () => diff --git a/app/MindWork AI Studio/Tools/RAG/DataSelectionResult.cs b/app/MindWork AI Studio/Tools/RAG/DataSelectionResult.cs new file mode 100644 index 0000000..6508b76 --- /dev/null +++ b/app/MindWork AI Studio/Tools/RAG/DataSelectionResult.cs @@ -0,0 +1,10 @@ +using AIStudio.Settings; + +namespace AIStudio.Tools.RAG; + +/// +/// Result of any data selection process. +/// +/// Makes it sense to proceed with the RAG process? +/// The selected data sources. +public readonly record struct DataSelectionResult(bool ProceedWithRAG, IReadOnlyList SelectedDataSources); \ No newline at end of file diff --git a/app/MindWork AI Studio/Tools/RAG/DataSourceSelectionProcesses/AgenticSrcSelWithDynHeur.cs b/app/MindWork AI Studio/Tools/RAG/DataSourceSelectionProcesses/AgenticSrcSelWithDynHeur.cs new file mode 100644 index 0000000..6409978 --- /dev/null +++ b/app/MindWork AI Studio/Tools/RAG/DataSourceSelectionProcesses/AgenticSrcSelWithDynHeur.cs @@ -0,0 +1,147 @@ +using AIStudio.Agents; +using AIStudio.Chat; +using AIStudio.Components; +using AIStudio.Provider; +using AIStudio.Settings; + +namespace AIStudio.Tools.RAG.DataSourceSelectionProcesses; + +public class AgenticSrcSelWithDynHeur : IDataSourceSelectionProcess +{ + #region Implementation of IDataSourceSelectionProcess + + /// + public string TechnicalName => "AgenticSrcSelWithDynHeur"; + + /// + public string UIName => "Automatic AI data source selection with heuristik source reduction"; + + /// + public string Description => "Automatically selects the appropriate data sources based on the last prompt. Applies a heuristic reduction at the end to reduce the number of data sources."; + + /// + public async Task SelectDataSourcesAsync(IProvider provider, IContent lastPrompt, ChatThread chatThread, AllowedSelectedDataSources dataSources, CancellationToken token = default) + { + var proceedWithRAG = true; + IReadOnlyList selectedDataSources = []; + IReadOnlyList finalAISelection = []; + + // Get the logger: + var logger = Program.SERVICE_PROVIDER.GetService>()!; + + // Get the settings manager: + var settings = Program.SERVICE_PROVIDER.GetService()!; + + // Get the agent for the data source selection: + var selectionAgent = Program.SERVICE_PROVIDER.GetService()!; + + try + { + // Let the AI agent do its work: + var aiSelectedDataSources = await selectionAgent.PerformSelectionAsync(provider, lastPrompt, chatThread, dataSources, token); + + // Check if the AI selected any data sources: + if (aiSelectedDataSources.Count is 0) + { + logger.LogWarning("The AI did not select any data sources. The RAG process is skipped."); + proceedWithRAG = false; + + return new(proceedWithRAG, selectedDataSources); + } + + // Log the selected data sources: + var selectedDataSourceInfo = aiSelectedDataSources.Select(ds => $"[Id={ds.Id}, reason={ds.Reason}, confidence={ds.Confidence}]").Aggregate((a, b) => $"'{a}', '{b}'"); + logger.LogInformation($"The AI selected the data sources automatically. {aiSelectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}."); + + // + // Check how many data sources were hallucinated by the AI: + // + var totalAISelectedDataSources = aiSelectedDataSources.Count; + + // Filter out the data sources that are not available: + aiSelectedDataSources = aiSelectedDataSources.Where(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id) is not null).ToList(); + + // Store the real AI-selected data sources: + finalAISelection = aiSelectedDataSources.Select(x => new DataSourceAgentSelected { DataSource = settings.ConfigurationData.DataSources.First(ds => ds.Id == x.Id), AIDecision = x, Selected = false }).ToList(); + + var numHallucinatedSources = totalAISelectedDataSources - aiSelectedDataSources.Count; + if (numHallucinatedSources > 0) + logger.LogWarning($"The AI hallucinated {numHallucinatedSources} data source(s). We ignore them."); + + if (aiSelectedDataSources.Count > 3) + { + // + // We have more than 3 data sources. Let's filter by confidence. + // In order to do that, we must identify the lower and upper + // bounds of the confidence interval: + // + var confidenceValues = aiSelectedDataSources.Select(x => x.Confidence).ToList(); + var lowerBound = confidenceValues.Min(); + var upperBound = confidenceValues.Max(); + + // + // Next, we search for a threshold so that we have between 2 and 3 + // data sources. When not possible, we take all data sources. + // + var threshold = 0.0f; + + // Check the case where the confidence values are too close: + if (upperBound - lowerBound >= 0.01) + { + var previousThreshold = 0.0f; + for (var i = 0; i < 10; i++) + { + threshold = lowerBound + (upperBound - lowerBound) * i / 10; + var numMatches = aiSelectedDataSources.Count(x => x.Confidence >= threshold); + if (numMatches <= 1) + { + threshold = previousThreshold; + break; + } + + if (numMatches is <= 3 and >= 2) + break; + + previousThreshold = threshold; + } + } + + // + // Filter the data sources by the threshold: + // + aiSelectedDataSources = aiSelectedDataSources.Where(x => x.Confidence >= threshold).ToList(); + foreach (var dataSource in finalAISelection) + if (aiSelectedDataSources.Any(x => x.Id == dataSource.DataSource.Id)) + dataSource.Selected = true; + + logger.LogInformation($"The AI selected {aiSelectedDataSources.Count} data source(s) with a confidence of at least {threshold}."); + + // Transform the final data sources to the actual data sources: + selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!; + return new(proceedWithRAG, selectedDataSources); + } + + // + // Case: we have max. 3 data sources. We take all of them: + // + + // Transform the selected data sources to the actual data sources: + selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!; + + // Mark the data sources as selected: + foreach (var dataSource in finalAISelection) + dataSource.Selected = true; + + return new(proceedWithRAG, selectedDataSources); + } + finally + { + // Send the selected data sources to the data source selection component. + // Then, the user can see which data sources were selected by the AI. + await MessageBus.INSTANCE.SendMessage(null, Event.RAG_AUTO_DATA_SOURCES_SELECTED, finalAISelection); + chatThread.AISelectedDataSources = finalAISelection; + } + } + + #endregion +} \ No newline at end of file diff --git a/app/MindWork AI Studio/Tools/RAG/IDataSourceSelectionProcess.cs b/app/MindWork AI Studio/Tools/RAG/IDataSourceSelectionProcess.cs new file mode 100644 index 0000000..8213ec2 --- /dev/null +++ b/app/MindWork AI Studio/Tools/RAG/IDataSourceSelectionProcess.cs @@ -0,0 +1,33 @@ +using AIStudio.Chat; +using AIStudio.Provider; + +namespace AIStudio.Tools.RAG; + +public interface IDataSourceSelectionProcess +{ + /// + /// How is the RAG process called? + /// + public string TechnicalName { get; } + + /// + /// How is the RAG process called in the UI? + /// + public string UIName { get; } + + /// + /// How works the RAG process? + /// + public string Description { get; } + + /// + /// Starts the data source selection process. + /// + /// The LLM provider. Used as default for data selection agents. + /// The last prompt that was issued by the user. + /// The chat thread. + /// The allowed data sources yielded by the data source service. + /// The cancellation token. + /// + public Task SelectDataSourcesAsync(IProvider provider, IContent lastPrompt, ChatThread chatThread, AllowedSelectedDataSources dataSources, CancellationToken token = default); +} \ No newline at end of file diff --git a/app/MindWork AI Studio/Tools/RAG/IRagProcess.cs b/app/MindWork AI Studio/Tools/RAG/IRagProcess.cs new file mode 100644 index 0000000..d38e323 --- /dev/null +++ b/app/MindWork AI Studio/Tools/RAG/IRagProcess.cs @@ -0,0 +1,32 @@ +using AIStudio.Chat; +using AIStudio.Provider; + +namespace AIStudio.Tools.RAG; + +public interface IRagProcess +{ + /// + /// How is the RAG process called? + /// + public string TechnicalName { get; } + + /// + /// How is the RAG process called in the UI? + /// + public string UIName { get; } + + /// + /// How works the RAG process? + /// + public string Description { get; } + + /// + /// Starts the RAG process. + /// + /// The LLM provider. Used to check whether the data sources are allowed to be used by this LLM. + /// The last prompt that was issued by the user. + /// The chat thread. + /// The cancellation token. + /// The altered chat thread. + public Task ProcessAsync(IProvider provider, IContent lastPrompt, ChatThread chatThread, CancellationToken token = default); +} \ No newline at end of file diff --git a/app/MindWork AI Studio/Tools/RAG/RAGProcesses/AISrcSelWithRetCtxVal.cs b/app/MindWork AI Studio/Tools/RAG/RAGProcesses/AISrcSelWithRetCtxVal.cs new file mode 100644 index 0000000..8581a3e --- /dev/null +++ b/app/MindWork AI Studio/Tools/RAG/RAGProcesses/AISrcSelWithRetCtxVal.cs @@ -0,0 +1,117 @@ +using AIStudio.Chat; +using AIStudio.Provider; +using AIStudio.Settings; +using AIStudio.Tools.RAG.DataSourceSelectionProcesses; +using AIStudio.Tools.Services; + +namespace AIStudio.Tools.RAG.RAGProcesses; + +public sealed class AISrcSelWithRetCtxVal : IRagProcess +{ + #region Implementation of IRagProcess + + /// + public string TechnicalName => "AISrcSelWithRetCtxVal"; + + /// + public string UIName => "AI source selection with AI retrieval context validation"; + + /// + public string Description => "This RAG process filters data sources, automatically selects appropriate sources, optionally allows manual source selection, retrieves data, and automatically validates the retrieval context."; + + /// + public async Task ProcessAsync(IProvider provider, IContent lastPrompt, ChatThread chatThread, CancellationToken token = default) + { + var logger = Program.SERVICE_PROVIDER.GetService>()!; + var settings = Program.SERVICE_PROVIDER.GetService()!; + var dataSourceService = Program.SERVICE_PROVIDER.GetService()!; + + // + // 1. Check if the user wants to bind any data sources to the chat: + // + if (chatThread.DataSourceOptions.IsEnabled()) + { + logger.LogInformation("Data sources are enabled for this chat."); + + // Across the different code-branches, we keep track of whether it + // makes sense to proceed with the RAG process: + var proceedWithRAG = true; + + // + // When the user wants to bind data sources to the chat, we + // have to check if the data sources are available for the + // selected provider. Also, we have to check if any ERI + // data sources changed its security requirements. + // + List preselectedDataSources = chatThread.DataSourceOptions.PreselectedDataSourceIds.Select(id => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == id)).Where(ds => ds is not null).ToList()!; + var dataSources = await dataSourceService.GetDataSources(provider, preselectedDataSources); + var selectedDataSources = dataSources.SelectedDataSources; + + // + // Should the AI select the data sources? + // + if (chatThread.DataSourceOptions.AutomaticDataSourceSelection) + { + var dataSourceSelectionProcess = new AgenticSrcSelWithDynHeur(); + var result = await dataSourceSelectionProcess.SelectDataSourcesAsync(provider, lastPrompt, chatThread, dataSources, token); + proceedWithRAG = result.ProceedWithRAG; + selectedDataSources = result.SelectedDataSources; + } + else + { + // + // No, the user made the choice manually: + // + var selectedDataSourceInfo = selectedDataSources.Select(ds => ds.Name).Aggregate((a, b) => $"'{a}', '{b}'"); + logger.LogInformation($"The user selected the data sources manually. {selectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}."); + } + + if(selectedDataSources.Count == 0) + { + logger.LogWarning("No data sources are selected. The RAG process is skipped."); + proceedWithRAG = false; + } + + // + // Trigger the retrieval part of the (R)AG process: + // + var dataContexts = new List(); + if (proceedWithRAG) + { + // + // We kick off the retrieval process for each data source in parallel: + // + var retrievalTasks = new List>>(selectedDataSources.Count); + foreach (var dataSource in selectedDataSources) + retrievalTasks.Add(dataSource.RetrieveDataAsync(lastPrompt, chatThread, token)); + + // + // Wait for all retrieval tasks to finish: + // + foreach (var retrievalTask in retrievalTasks) + { + try + { + dataContexts.AddRange(await retrievalTask); + } + catch (Exception e) + { + logger.LogError(e, "An error occurred during the retrieval process."); + } + } + } + + // + // Perform the augmentation of the R(A)G process: + // + if (proceedWithRAG) + { + + } + } + + return chatThread; + } + + #endregion +} \ No newline at end of file