Implemented a parallel execution for the validation of many retrieval contexts

This commit is contained in:
Thorsten Sommer 2025-02-22 20:45:08 +01:00
parent a9f821b67e
commit 3485d3cf5d
Signed by: tsommer
GPG Key ID: 371BBA77A02C0108

View File

@ -149,124 +149,178 @@ public sealed class AgentRetrievalContextValidation (ILogger<AgentRetrievalConte
this.ProviderSettings = agentProvider;
}
/// <summary>
/// Validate all retrieval contexts against the last user and the system prompt.
/// </summary>
/// <param name="lastPrompt">The last user prompt.</param>
/// <param name="chatThread">The chat thread.</param>
/// <param name="retrievalContexts">All retrieval contexts to validate.</param>
/// <param name="token">The cancellation token.</param>
/// <returns>The validation results.</returns>
public async Task<IReadOnlyList<RetrievalContextValidationResult>> ValidateRetrievalContextsAsync(IContent lastPrompt, ChatThread chatThread, IReadOnlyList<IRetrievalContext> retrievalContexts, CancellationToken token = default)
{
// Check if the retrieval context validation is enabled:
if (!this.SettingsManager.ConfigurationData.AgentRetrievalContextValidation.EnableRetrievalContextValidation)
return [];
// Prepare the list of validation tasks:
var validationTasks = new List<Task<RetrievalContextValidationResult>>(retrievalContexts.Count);
// Read the number of parallel validations:
var numParallelValidations = this.SettingsManager.ConfigurationData.AgentRetrievalContextValidation.NumParallelValidations;
numParallelValidations = Math.Max(1, numParallelValidations);
// Use a semaphore to limit the number of parallel validations:
using var semaphore = new SemaphoreSlim(numParallelValidations);
foreach (var retrievalContext in retrievalContexts)
{
// Wait for an available slot in the semaphore:
await semaphore.WaitAsync(token);
// Start the next validation task:
validationTasks.Add(this.ValidateRetrievalContextAsync(lastPrompt, chatThread, retrievalContext, token, semaphore));
}
// Wait for all validation tasks to complete:
return await Task.WhenAll(validationTasks);
}
/// <summary>
/// Validates the retrieval context against the last user and the system prompt.
/// </summary>
/// <remarks>
/// Probably, you have a lot of retrieval contexts to validate. In this case, you
/// can call this method in parallel for each retrieval context.
/// can call this method in parallel for each retrieval context. You might use
/// the ValidateRetrievalContextsAsync method to validate all retrieval contexts.
/// </remarks>
/// <param name="lastPrompt">The last user prompt.</param>
/// <param name="chatThread">The chat thread.</param>
/// <param name="dataContext">The retrieval context to validate.</param>
/// <param name="retrievalContext">The retrieval context to validate.</param>
/// <param name="token">The cancellation token.</param>
/// <param name="semaphore">The optional semaphore to limit the number of parallel validations.</param>
/// <returns>The validation result.</returns>
public async Task<RetrievalContextValidationResult> ValidateRetrievalContextAsync(IContent lastPrompt, ChatThread chatThread, IRetrievalContext dataContext, CancellationToken token = default)
public async Task<RetrievalContextValidationResult> ValidateRetrievalContextAsync(IContent lastPrompt, ChatThread chatThread, IRetrievalContext retrievalContext, CancellationToken token = default, SemaphoreSlim? semaphore = null)
{
//
// 1. Prepare the current system and user prompts as input for the agent:
//
var lastPromptContent = lastPrompt switch
{
ContentText text => text.Text,
// Image prompts may be empty, e.g., when the image is too large:
ContentImage image => await image.AsBase64(token),
// Other content types are not supported yet:
_ => string.Empty,
};
if (string.IsNullOrWhiteSpace(lastPromptContent))
{
logger.LogWarning("The last prompt is empty. The AI cannot validate the retrieval context.");
return new(false, "The last prompt was empty.", 1.0f, dataContext);
}
//
// 2. Prepare the retrieval context for the agent:
//
var additionalData = new Dictionary<string, string>();
var markdownRetrievalContext = await dataContext.AsMarkdown(token: token);
additionalData.Add("retrievalContext", markdownRetrievalContext);
//
// 3. Let the agent validate the retrieval context:
//
var prompt = $"""
The system prompt is:
```
{chatThread.SystemPrompt}
```
The user prompt is:
```
{lastPromptContent}
```
""";
// Call the agent:
var aiResponse = await this.ProcessInput(new ContentBlock
{
Time = DateTimeOffset.UtcNow,
ContentType = ContentType.TEXT,
Role = ChatRole.USER,
Content = new ContentText
{
Text = prompt,
},
}, additionalData);
if(aiResponse.Content is null)
{
logger.LogWarning("The agent did not return a response.");
return new(false, "The agent did not return a response.", 1.0f, dataContext);
}
switch (aiResponse)
try
{
//
// Check if the validation was canceled. This could happen when the user
// canceled the validation process or when the validation process took
// too long:
//
if(token.IsCancellationRequested)
return new(false, "The validation was canceled.", 1.0f, retrievalContext);
//
// 4. Parse the agent response:
// 1. Prepare the current system and user prompts as input for the agent:
//
case { ContentType: ContentType.TEXT, Content: ContentText textContent }:
var lastPromptContent = lastPrompt switch
{
//
// What we expect is one JSON object:
//
var validationJson = textContent.Text;
//
// We know how bad LLM may be in generating JSON without surrounding text.
// Thus, we expect the worst and try to extract the JSON list from the text:
//
var json = ExtractJson(validationJson);
try
{
var result = JsonSerializer.Deserialize<RetrievalContextValidationResult>(json, JSON_SERIALIZER_OPTIONS);
return result with { RetrievalContext = dataContext };
}
catch
{
logger.LogWarning("The agent answered with an invalid or unexpected JSON format.");
return new(false, "The agent answered with an invalid or unexpected JSON format.", 1.0f, dataContext);
}
ContentText text => text.Text,
// Image prompts may be empty, e.g., when the image is too large:
ContentImage image => await image.AsBase64(token),
// Other content types are not supported yet:
_ => string.Empty,
};
if (string.IsNullOrWhiteSpace(lastPromptContent))
{
logger.LogWarning("The last prompt is empty. The AI cannot validate the retrieval context.");
return new(false, "The last prompt was empty.", 1.0f, retrievalContext);
}
case { ContentType: ContentType.TEXT }:
logger.LogWarning("The agent answered with an unexpected inner content type.");
return new(false, "The agent answered with an unexpected inner content type.", 1.0f, dataContext);
case { ContentType: ContentType.NONE }:
//
// 2. Prepare the retrieval context for the agent:
//
var additionalData = new Dictionary<string, string>();
var markdownRetrievalContext = await retrievalContext.AsMarkdown(token: token);
additionalData.Add("retrievalContext", markdownRetrievalContext);
//
// 3. Let the agent validate the retrieval context:
//
var prompt = $"""
The system prompt is:
```
{chatThread.SystemPrompt}
```
The user prompt is:
```
{lastPromptContent}
```
""";
// Call the agent:
var aiResponse = await this.ProcessInput(new ContentBlock
{
Time = DateTimeOffset.UtcNow,
ContentType = ContentType.TEXT,
Role = ChatRole.USER,
Content = new ContentText
{
Text = prompt,
},
}, additionalData);
if (aiResponse.Content is null)
{
logger.LogWarning("The agent did not return a response.");
return new(false, "The agent did not return a response.", 1.0f, dataContext);
default:
logger.LogWarning($"The agent answered with an unexpected content type '{aiResponse.ContentType}'.");
return new(false, $"The agent answered with an unexpected content type '{aiResponse.ContentType}'.", 1.0f, dataContext);
return new(false, "The agent did not return a response.", 1.0f, retrievalContext);
}
switch (aiResponse)
{
//
// 4. Parse the agent response:
//
case { ContentType: ContentType.TEXT, Content: ContentText textContent }:
{
//
// What we expect is one JSON object:
//
var validationJson = textContent.Text;
//
// We know how bad LLM may be in generating JSON without surrounding text.
// Thus, we expect the worst and try to extract the JSON list from the text:
//
var json = ExtractJson(validationJson);
try
{
var result = JsonSerializer.Deserialize<RetrievalContextValidationResult>(json, JSON_SERIALIZER_OPTIONS);
return result with { RetrievalContext = retrievalContext };
}
catch
{
logger.LogWarning("The agent answered with an invalid or unexpected JSON format.");
return new(false, "The agent answered with an invalid or unexpected JSON format.", 1.0f, retrievalContext);
}
}
case { ContentType: ContentType.TEXT }:
logger.LogWarning("The agent answered with an unexpected inner content type.");
return new(false, "The agent answered with an unexpected inner content type.", 1.0f, retrievalContext);
case { ContentType: ContentType.NONE }:
logger.LogWarning("The agent did not return a response.");
return new(false, "The agent did not return a response.", 1.0f, retrievalContext);
default:
logger.LogWarning($"The agent answered with an unexpected content type '{aiResponse.ContentType}'.");
return new(false, $"The agent answered with an unexpected content type '{aiResponse.ContentType}'.", 1.0f, retrievalContext);
}
}
finally
{
// Release the semaphore slot:
semaphore?.Release();
}
}