diff --git a/app/MindWork AI Studio/Tools/RAG/AugmentationProcesses/AugmentationOne.cs b/app/MindWork AI Studio/Tools/RAG/AugmentationProcesses/AugmentationOne.cs index fda8cf37..e5a2b931 100644 --- a/app/MindWork AI Studio/Tools/RAG/AugmentationProcesses/AugmentationOne.cs +++ b/app/MindWork AI Studio/Tools/RAG/AugmentationProcesses/AugmentationOne.cs @@ -1,7 +1,9 @@ using System.Text; +using AIStudio.Agents; using AIStudio.Chat; using AIStudio.Provider; +using AIStudio.Settings; namespace AIStudio.Tools.RAG.AugmentationProcesses; @@ -22,13 +24,36 @@ public sealed class AugmentationOne : IAugmentationProcess public async Task ProcessAsync(IProvider provider, IContent lastPrompt, ChatThread chatThread, IReadOnlyList retrievalContexts, CancellationToken token = default) { var logger = Program.SERVICE_PROVIDER.GetService>()!; + var settings = Program.SERVICE_PROVIDER.GetService()!; + if(retrievalContexts.Count == 0) { logger.LogWarning("No retrieval contexts were issued. Skipping the augmentation process."); return chatThread; } - + var numTotalRetrievalContexts = retrievalContexts.Count; + + // Want the user to validate all retrieval contexts? + if (settings.ConfigurationData.AgentRetrievalContextValidation.EnableRetrievalContextValidation) + { + // Let's get the validation agent & set up its provider: + var validationAgent = Program.SERVICE_PROVIDER.GetService()!; + validationAgent.SetLLMProvider(provider); + + // Let's validate all retrieval contexts: + var validationResults = await validationAgent.ValidateRetrievalContextsAsync(lastPrompt, chatThread, retrievalContexts, token); + + // + // Now, filter the retrieval contexts to the most relevant ones: + // + var targetWindow = validationResults.DetermineTargetWindow(TargetWindowStrategy.TOP10_BETTER_THAN_GUESSING); + var threshold = validationResults.GetConfidenceThreshold(targetWindow); + + // Filter the retrieval contexts: + retrievalContexts = validationResults.Where(x => x.RetrievalContext is not null && x.Confidence >= threshold).Select(x => x.RetrievalContext!).ToList(); + } + logger.LogInformation($"Starting the augmentation process over {numTotalRetrievalContexts:###,###,###,###} retrieval contexts."); //