mirror of
				https://github.com/MindWorkAI/AI-Studio.git
				synced 2025-11-04 15:20:20 +00:00 
			
		
		
		
	Added the retrieval context validation
This commit is contained in:
		
							parent
							
								
									3485d3cf5d
								
							
						
					
					
						commit
						b2e3ddf555
					
				@ -1,7 +1,9 @@
 | 
				
			|||||||
using System.Text;
 | 
					using System.Text;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using AIStudio.Agents;
 | 
				
			||||||
using AIStudio.Chat;
 | 
					using AIStudio.Chat;
 | 
				
			||||||
using AIStudio.Provider;
 | 
					using AIStudio.Provider;
 | 
				
			||||||
 | 
					using AIStudio.Settings;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace AIStudio.Tools.RAG.AugmentationProcesses;
 | 
					namespace AIStudio.Tools.RAG.AugmentationProcesses;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -22,6 +24,8 @@ public sealed class AugmentationOne : IAugmentationProcess
 | 
				
			|||||||
    public async Task<ChatThread> ProcessAsync(IProvider provider, IContent lastPrompt, ChatThread chatThread, IReadOnlyList<IRetrievalContext> retrievalContexts, CancellationToken token = default)
 | 
					    public async Task<ChatThread> ProcessAsync(IProvider provider, IContent lastPrompt, ChatThread chatThread, IReadOnlyList<IRetrievalContext> retrievalContexts, CancellationToken token = default)
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        var logger = Program.SERVICE_PROVIDER.GetService<ILogger<AugmentationOne>>()!;
 | 
					        var logger = Program.SERVICE_PROVIDER.GetService<ILogger<AugmentationOne>>()!;
 | 
				
			||||||
 | 
					        var settings = Program.SERVICE_PROVIDER.GetService<SettingsManager>()!;
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        if(retrievalContexts.Count == 0)
 | 
					        if(retrievalContexts.Count == 0)
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            logger.LogWarning("No retrieval contexts were issued. Skipping the augmentation process.");
 | 
					            logger.LogWarning("No retrieval contexts were issued. Skipping the augmentation process.");
 | 
				
			||||||
@ -29,6 +33,27 @@ public sealed class AugmentationOne : IAugmentationProcess
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        var numTotalRetrievalContexts = retrievalContexts.Count;
 | 
					        var numTotalRetrievalContexts = retrievalContexts.Count;
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        // Want the user to validate all retrieval contexts?
 | 
				
			||||||
 | 
					        if (settings.ConfigurationData.AgentRetrievalContextValidation.EnableRetrievalContextValidation)
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            // Let's get the validation agent & set up its provider:
 | 
				
			||||||
 | 
					            var validationAgent = Program.SERVICE_PROVIDER.GetService<AgentRetrievalContextValidation>()!;
 | 
				
			||||||
 | 
					            validationAgent.SetLLMProvider(provider);
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            // Let's validate all retrieval contexts:
 | 
				
			||||||
 | 
					            var validationResults = await validationAgent.ValidateRetrievalContextsAsync(lastPrompt, chatThread, retrievalContexts, token);
 | 
				
			||||||
 | 
					         
 | 
				
			||||||
 | 
					            //
 | 
				
			||||||
 | 
					            // Now, filter the retrieval contexts to the most relevant ones:
 | 
				
			||||||
 | 
					            //
 | 
				
			||||||
 | 
					            var targetWindow = validationResults.DetermineTargetWindow(TargetWindowStrategy.TOP10_BETTER_THAN_GUESSING);
 | 
				
			||||||
 | 
					            var threshold = validationResults.GetConfidenceThreshold(targetWindow);
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            // Filter the retrieval contexts:
 | 
				
			||||||
 | 
					            retrievalContexts = validationResults.Where(x => x.RetrievalContext is not null && x.Confidence >= threshold).Select(x => x.RetrievalContext!).ToList();
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        logger.LogInformation($"Starting the augmentation process over {numTotalRetrievalContexts:###,###,###,###} retrieval contexts.");
 | 
					        logger.LogInformation($"Starting the augmentation process over {numTotalRetrievalContexts:###,###,###,###} retrieval contexts.");
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        //
 | 
					        //
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user