mirror of
https://github.com/MindWorkAI/AI-Studio.git
synced 2025-04-28 10:39:47 +00:00
Implemented a parallel execution for the validation of many retrieval contexts
This commit is contained in:
parent
a9f821b67e
commit
3485d3cf5d
@ -149,20 +149,68 @@ public sealed class AgentRetrievalContextValidation (ILogger<AgentRetrievalConte
|
||||
this.ProviderSettings = agentProvider;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Validate all retrieval contexts against the last user and the system prompt.
|
||||
/// </summary>
|
||||
/// <param name="lastPrompt">The last user prompt.</param>
|
||||
/// <param name="chatThread">The chat thread.</param>
|
||||
/// <param name="retrievalContexts">All retrieval contexts to validate.</param>
|
||||
/// <param name="token">The cancellation token.</param>
|
||||
/// <returns>The validation results.</returns>
|
||||
public async Task<IReadOnlyList<RetrievalContextValidationResult>> ValidateRetrievalContextsAsync(IContent lastPrompt, ChatThread chatThread, IReadOnlyList<IRetrievalContext> retrievalContexts, CancellationToken token = default)
|
||||
{
|
||||
// Check if the retrieval context validation is enabled:
|
||||
if (!this.SettingsManager.ConfigurationData.AgentRetrievalContextValidation.EnableRetrievalContextValidation)
|
||||
return [];
|
||||
|
||||
// Prepare the list of validation tasks:
|
||||
var validationTasks = new List<Task<RetrievalContextValidationResult>>(retrievalContexts.Count);
|
||||
|
||||
// Read the number of parallel validations:
|
||||
var numParallelValidations = this.SettingsManager.ConfigurationData.AgentRetrievalContextValidation.NumParallelValidations;
|
||||
numParallelValidations = Math.Max(1, numParallelValidations);
|
||||
|
||||
// Use a semaphore to limit the number of parallel validations:
|
||||
using var semaphore = new SemaphoreSlim(numParallelValidations);
|
||||
foreach (var retrievalContext in retrievalContexts)
|
||||
{
|
||||
// Wait for an available slot in the semaphore:
|
||||
await semaphore.WaitAsync(token);
|
||||
|
||||
// Start the next validation task:
|
||||
validationTasks.Add(this.ValidateRetrievalContextAsync(lastPrompt, chatThread, retrievalContext, token, semaphore));
|
||||
}
|
||||
|
||||
// Wait for all validation tasks to complete:
|
||||
return await Task.WhenAll(validationTasks);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Validates the retrieval context against the last user and the system prompt.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// Probably, you have a lot of retrieval contexts to validate. In this case, you
|
||||
/// can call this method in parallel for each retrieval context.
|
||||
/// can call this method in parallel for each retrieval context. You might use
|
||||
/// the ValidateRetrievalContextsAsync method to validate all retrieval contexts.
|
||||
/// </remarks>
|
||||
/// <param name="lastPrompt">The last user prompt.</param>
|
||||
/// <param name="chatThread">The chat thread.</param>
|
||||
/// <param name="dataContext">The retrieval context to validate.</param>
|
||||
/// <param name="retrievalContext">The retrieval context to validate.</param>
|
||||
/// <param name="token">The cancellation token.</param>
|
||||
/// <param name="semaphore">The optional semaphore to limit the number of parallel validations.</param>
|
||||
/// <returns>The validation result.</returns>
|
||||
public async Task<RetrievalContextValidationResult> ValidateRetrievalContextAsync(IContent lastPrompt, ChatThread chatThread, IRetrievalContext dataContext, CancellationToken token = default)
|
||||
public async Task<RetrievalContextValidationResult> ValidateRetrievalContextAsync(IContent lastPrompt, ChatThread chatThread, IRetrievalContext retrievalContext, CancellationToken token = default, SemaphoreSlim? semaphore = null)
|
||||
{
|
||||
try
|
||||
{
|
||||
//
|
||||
// Check if the validation was canceled. This could happen when the user
|
||||
// canceled the validation process or when the validation process took
|
||||
// too long:
|
||||
//
|
||||
if(token.IsCancellationRequested)
|
||||
return new(false, "The validation was canceled.", 1.0f, retrievalContext);
|
||||
|
||||
//
|
||||
// 1. Prepare the current system and user prompts as input for the agent:
|
||||
//
|
||||
@ -180,14 +228,14 @@ public sealed class AgentRetrievalContextValidation (ILogger<AgentRetrievalConte
|
||||
if (string.IsNullOrWhiteSpace(lastPromptContent))
|
||||
{
|
||||
logger.LogWarning("The last prompt is empty. The AI cannot validate the retrieval context.");
|
||||
return new(false, "The last prompt was empty.", 1.0f, dataContext);
|
||||
return new(false, "The last prompt was empty.", 1.0f, retrievalContext);
|
||||
}
|
||||
|
||||
//
|
||||
// 2. Prepare the retrieval context for the agent:
|
||||
//
|
||||
var additionalData = new Dictionary<string, string>();
|
||||
var markdownRetrievalContext = await dataContext.AsMarkdown(token: token);
|
||||
var markdownRetrievalContext = await retrievalContext.AsMarkdown(token: token);
|
||||
additionalData.Add("retrievalContext", markdownRetrievalContext);
|
||||
|
||||
//
|
||||
@ -222,7 +270,7 @@ public sealed class AgentRetrievalContextValidation (ILogger<AgentRetrievalConte
|
||||
if (aiResponse.Content is null)
|
||||
{
|
||||
logger.LogWarning("The agent did not return a response.");
|
||||
return new(false, "The agent did not return a response.", 1.0f, dataContext);
|
||||
return new(false, "The agent did not return a response.", 1.0f, retrievalContext);
|
||||
}
|
||||
|
||||
switch (aiResponse)
|
||||
@ -247,26 +295,32 @@ public sealed class AgentRetrievalContextValidation (ILogger<AgentRetrievalConte
|
||||
try
|
||||
{
|
||||
var result = JsonSerializer.Deserialize<RetrievalContextValidationResult>(json, JSON_SERIALIZER_OPTIONS);
|
||||
return result with { RetrievalContext = dataContext };
|
||||
return result with { RetrievalContext = retrievalContext };
|
||||
}
|
||||
catch
|
||||
{
|
||||
logger.LogWarning("The agent answered with an invalid or unexpected JSON format.");
|
||||
return new(false, "The agent answered with an invalid or unexpected JSON format.", 1.0f, dataContext);
|
||||
return new(false, "The agent answered with an invalid or unexpected JSON format.", 1.0f, retrievalContext);
|
||||
}
|
||||
}
|
||||
|
||||
case { ContentType: ContentType.TEXT }:
|
||||
logger.LogWarning("The agent answered with an unexpected inner content type.");
|
||||
return new(false, "The agent answered with an unexpected inner content type.", 1.0f, dataContext);
|
||||
return new(false, "The agent answered with an unexpected inner content type.", 1.0f, retrievalContext);
|
||||
|
||||
case { ContentType: ContentType.NONE }:
|
||||
logger.LogWarning("The agent did not return a response.");
|
||||
return new(false, "The agent did not return a response.", 1.0f, dataContext);
|
||||
return new(false, "The agent did not return a response.", 1.0f, retrievalContext);
|
||||
|
||||
default:
|
||||
logger.LogWarning($"The agent answered with an unexpected content type '{aiResponse.ContentType}'.");
|
||||
return new(false, $"The agent answered with an unexpected content type '{aiResponse.ContentType}'.", 1.0f, dataContext);
|
||||
return new(false, $"The agent answered with an unexpected content type '{aiResponse.ContentType}'.", 1.0f, retrievalContext);
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
// Release the semaphore slot:
|
||||
semaphore?.Release();
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user