mirror of
https://github.com/MindWorkAI/AI-Studio.git
synced 2025-04-28 16:19:48 +00:00
Implemented a parallel execution for the validation of many retrieval contexts
This commit is contained in:
parent
a9f821b67e
commit
3485d3cf5d
@ -149,20 +149,68 @@ public sealed class AgentRetrievalContextValidation (ILogger<AgentRetrievalConte
|
|||||||
this.ProviderSettings = agentProvider;
|
this.ProviderSettings = agentProvider;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Validate all retrieval contexts against the last user and the system prompt.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="lastPrompt">The last user prompt.</param>
|
||||||
|
/// <param name="chatThread">The chat thread.</param>
|
||||||
|
/// <param name="retrievalContexts">All retrieval contexts to validate.</param>
|
||||||
|
/// <param name="token">The cancellation token.</param>
|
||||||
|
/// <returns>The validation results.</returns>
|
||||||
|
public async Task<IReadOnlyList<RetrievalContextValidationResult>> ValidateRetrievalContextsAsync(IContent lastPrompt, ChatThread chatThread, IReadOnlyList<IRetrievalContext> retrievalContexts, CancellationToken token = default)
|
||||||
|
{
|
||||||
|
// Check if the retrieval context validation is enabled:
|
||||||
|
if (!this.SettingsManager.ConfigurationData.AgentRetrievalContextValidation.EnableRetrievalContextValidation)
|
||||||
|
return [];
|
||||||
|
|
||||||
|
// Prepare the list of validation tasks:
|
||||||
|
var validationTasks = new List<Task<RetrievalContextValidationResult>>(retrievalContexts.Count);
|
||||||
|
|
||||||
|
// Read the number of parallel validations:
|
||||||
|
var numParallelValidations = this.SettingsManager.ConfigurationData.AgentRetrievalContextValidation.NumParallelValidations;
|
||||||
|
numParallelValidations = Math.Max(1, numParallelValidations);
|
||||||
|
|
||||||
|
// Use a semaphore to limit the number of parallel validations:
|
||||||
|
using var semaphore = new SemaphoreSlim(numParallelValidations);
|
||||||
|
foreach (var retrievalContext in retrievalContexts)
|
||||||
|
{
|
||||||
|
// Wait for an available slot in the semaphore:
|
||||||
|
await semaphore.WaitAsync(token);
|
||||||
|
|
||||||
|
// Start the next validation task:
|
||||||
|
validationTasks.Add(this.ValidateRetrievalContextAsync(lastPrompt, chatThread, retrievalContext, token, semaphore));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all validation tasks to complete:
|
||||||
|
return await Task.WhenAll(validationTasks);
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Validates the retrieval context against the last user and the system prompt.
|
/// Validates the retrieval context against the last user and the system prompt.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <remarks>
|
/// <remarks>
|
||||||
/// Probably, you have a lot of retrieval contexts to validate. In this case, you
|
/// Probably, you have a lot of retrieval contexts to validate. In this case, you
|
||||||
/// can call this method in parallel for each retrieval context.
|
/// can call this method in parallel for each retrieval context. You might use
|
||||||
|
/// the ValidateRetrievalContextsAsync method to validate all retrieval contexts.
|
||||||
/// </remarks>
|
/// </remarks>
|
||||||
/// <param name="lastPrompt">The last user prompt.</param>
|
/// <param name="lastPrompt">The last user prompt.</param>
|
||||||
/// <param name="chatThread">The chat thread.</param>
|
/// <param name="chatThread">The chat thread.</param>
|
||||||
/// <param name="dataContext">The retrieval context to validate.</param>
|
/// <param name="retrievalContext">The retrieval context to validate.</param>
|
||||||
/// <param name="token">The cancellation token.</param>
|
/// <param name="token">The cancellation token.</param>
|
||||||
|
/// <param name="semaphore">The optional semaphore to limit the number of parallel validations.</param>
|
||||||
/// <returns>The validation result.</returns>
|
/// <returns>The validation result.</returns>
|
||||||
public async Task<RetrievalContextValidationResult> ValidateRetrievalContextAsync(IContent lastPrompt, ChatThread chatThread, IRetrievalContext dataContext, CancellationToken token = default)
|
public async Task<RetrievalContextValidationResult> ValidateRetrievalContextAsync(IContent lastPrompt, ChatThread chatThread, IRetrievalContext retrievalContext, CancellationToken token = default, SemaphoreSlim? semaphore = null)
|
||||||
{
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
//
|
||||||
|
// Check if the validation was canceled. This could happen when the user
|
||||||
|
// canceled the validation process or when the validation process took
|
||||||
|
// too long:
|
||||||
|
//
|
||||||
|
if(token.IsCancellationRequested)
|
||||||
|
return new(false, "The validation was canceled.", 1.0f, retrievalContext);
|
||||||
|
|
||||||
//
|
//
|
||||||
// 1. Prepare the current system and user prompts as input for the agent:
|
// 1. Prepare the current system and user prompts as input for the agent:
|
||||||
//
|
//
|
||||||
@ -180,14 +228,14 @@ public sealed class AgentRetrievalContextValidation (ILogger<AgentRetrievalConte
|
|||||||
if (string.IsNullOrWhiteSpace(lastPromptContent))
|
if (string.IsNullOrWhiteSpace(lastPromptContent))
|
||||||
{
|
{
|
||||||
logger.LogWarning("The last prompt is empty. The AI cannot validate the retrieval context.");
|
logger.LogWarning("The last prompt is empty. The AI cannot validate the retrieval context.");
|
||||||
return new(false, "The last prompt was empty.", 1.0f, dataContext);
|
return new(false, "The last prompt was empty.", 1.0f, retrievalContext);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// 2. Prepare the retrieval context for the agent:
|
// 2. Prepare the retrieval context for the agent:
|
||||||
//
|
//
|
||||||
var additionalData = new Dictionary<string, string>();
|
var additionalData = new Dictionary<string, string>();
|
||||||
var markdownRetrievalContext = await dataContext.AsMarkdown(token: token);
|
var markdownRetrievalContext = await retrievalContext.AsMarkdown(token: token);
|
||||||
additionalData.Add("retrievalContext", markdownRetrievalContext);
|
additionalData.Add("retrievalContext", markdownRetrievalContext);
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -219,10 +267,10 @@ public sealed class AgentRetrievalContextValidation (ILogger<AgentRetrievalConte
|
|||||||
},
|
},
|
||||||
}, additionalData);
|
}, additionalData);
|
||||||
|
|
||||||
if(aiResponse.Content is null)
|
if (aiResponse.Content is null)
|
||||||
{
|
{
|
||||||
logger.LogWarning("The agent did not return a response.");
|
logger.LogWarning("The agent did not return a response.");
|
||||||
return new(false, "The agent did not return a response.", 1.0f, dataContext);
|
return new(false, "The agent did not return a response.", 1.0f, retrievalContext);
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (aiResponse)
|
switch (aiResponse)
|
||||||
@ -247,26 +295,32 @@ public sealed class AgentRetrievalContextValidation (ILogger<AgentRetrievalConte
|
|||||||
try
|
try
|
||||||
{
|
{
|
||||||
var result = JsonSerializer.Deserialize<RetrievalContextValidationResult>(json, JSON_SERIALIZER_OPTIONS);
|
var result = JsonSerializer.Deserialize<RetrievalContextValidationResult>(json, JSON_SERIALIZER_OPTIONS);
|
||||||
return result with { RetrievalContext = dataContext };
|
return result with { RetrievalContext = retrievalContext };
|
||||||
}
|
}
|
||||||
catch
|
catch
|
||||||
{
|
{
|
||||||
logger.LogWarning("The agent answered with an invalid or unexpected JSON format.");
|
logger.LogWarning("The agent answered with an invalid or unexpected JSON format.");
|
||||||
return new(false, "The agent answered with an invalid or unexpected JSON format.", 1.0f, dataContext);
|
return new(false, "The agent answered with an invalid or unexpected JSON format.", 1.0f, retrievalContext);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case { ContentType: ContentType.TEXT }:
|
case { ContentType: ContentType.TEXT }:
|
||||||
logger.LogWarning("The agent answered with an unexpected inner content type.");
|
logger.LogWarning("The agent answered with an unexpected inner content type.");
|
||||||
return new(false, "The agent answered with an unexpected inner content type.", 1.0f, dataContext);
|
return new(false, "The agent answered with an unexpected inner content type.", 1.0f, retrievalContext);
|
||||||
|
|
||||||
case { ContentType: ContentType.NONE }:
|
case { ContentType: ContentType.NONE }:
|
||||||
logger.LogWarning("The agent did not return a response.");
|
logger.LogWarning("The agent did not return a response.");
|
||||||
return new(false, "The agent did not return a response.", 1.0f, dataContext);
|
return new(false, "The agent did not return a response.", 1.0f, retrievalContext);
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logger.LogWarning($"The agent answered with an unexpected content type '{aiResponse.ContentType}'.");
|
logger.LogWarning($"The agent answered with an unexpected content type '{aiResponse.ContentType}'.");
|
||||||
return new(false, $"The agent answered with an unexpected content type '{aiResponse.ContentType}'.", 1.0f, dataContext);
|
return new(false, $"The agent answered with an unexpected content type '{aiResponse.ContentType}'.", 1.0f, retrievalContext);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
// Release the semaphore slot:
|
||||||
|
semaphore?.Release();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user