mirror of
https://github.com/MindWorkAI/AI-Studio.git
synced 2025-04-28 15:39:46 +00:00
Added the retrieval context validation
This commit is contained in:
parent
3485d3cf5d
commit
b2e3ddf555
@ -1,7 +1,9 @@
|
|||||||
using System.Text;
|
using System.Text;
|
||||||
|
|
||||||
|
using AIStudio.Agents;
|
||||||
using AIStudio.Chat;
|
using AIStudio.Chat;
|
||||||
using AIStudio.Provider;
|
using AIStudio.Provider;
|
||||||
|
using AIStudio.Settings;
|
||||||
|
|
||||||
namespace AIStudio.Tools.RAG.AugmentationProcesses;
|
namespace AIStudio.Tools.RAG.AugmentationProcesses;
|
||||||
|
|
||||||
@ -22,6 +24,8 @@ public sealed class AugmentationOne : IAugmentationProcess
|
|||||||
public async Task<ChatThread> ProcessAsync(IProvider provider, IContent lastPrompt, ChatThread chatThread, IReadOnlyList<IRetrievalContext> retrievalContexts, CancellationToken token = default)
|
public async Task<ChatThread> ProcessAsync(IProvider provider, IContent lastPrompt, ChatThread chatThread, IReadOnlyList<IRetrievalContext> retrievalContexts, CancellationToken token = default)
|
||||||
{
|
{
|
||||||
var logger = Program.SERVICE_PROVIDER.GetService<ILogger<AugmentationOne>>()!;
|
var logger = Program.SERVICE_PROVIDER.GetService<ILogger<AugmentationOne>>()!;
|
||||||
|
var settings = Program.SERVICE_PROVIDER.GetService<SettingsManager>()!;
|
||||||
|
|
||||||
if(retrievalContexts.Count == 0)
|
if(retrievalContexts.Count == 0)
|
||||||
{
|
{
|
||||||
logger.LogWarning("No retrieval contexts were issued. Skipping the augmentation process.");
|
logger.LogWarning("No retrieval contexts were issued. Skipping the augmentation process.");
|
||||||
@ -29,6 +33,27 @@ public sealed class AugmentationOne : IAugmentationProcess
|
|||||||
}
|
}
|
||||||
|
|
||||||
var numTotalRetrievalContexts = retrievalContexts.Count;
|
var numTotalRetrievalContexts = retrievalContexts.Count;
|
||||||
|
|
||||||
|
// Want the user to validate all retrieval contexts?
|
||||||
|
if (settings.ConfigurationData.AgentRetrievalContextValidation.EnableRetrievalContextValidation)
|
||||||
|
{
|
||||||
|
// Let's get the validation agent & set up its provider:
|
||||||
|
var validationAgent = Program.SERVICE_PROVIDER.GetService<AgentRetrievalContextValidation>()!;
|
||||||
|
validationAgent.SetLLMProvider(provider);
|
||||||
|
|
||||||
|
// Let's validate all retrieval contexts:
|
||||||
|
var validationResults = await validationAgent.ValidateRetrievalContextsAsync(lastPrompt, chatThread, retrievalContexts, token);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Now, filter the retrieval contexts to the most relevant ones:
|
||||||
|
//
|
||||||
|
var targetWindow = validationResults.DetermineTargetWindow(TargetWindowStrategy.TOP10_BETTER_THAN_GUESSING);
|
||||||
|
var threshold = validationResults.GetConfidenceThreshold(targetWindow);
|
||||||
|
|
||||||
|
// Filter the retrieval contexts:
|
||||||
|
retrievalContexts = validationResults.Where(x => x.RetrievalContext is not null && x.Confidence >= threshold).Select(x => x.RetrievalContext!).ToList();
|
||||||
|
}
|
||||||
|
|
||||||
logger.LogInformation($"Starting the augmentation process over {numTotalRetrievalContexts:###,###,###,###} retrieval contexts.");
|
logger.LogInformation($"Starting the augmentation process over {numTotalRetrievalContexts:###,###,###,###} retrieval contexts.");
|
||||||
|
|
||||||
//
|
//
|
||||||
|
Loading…
Reference in New Issue
Block a user