mirror of
				https://github.com/MindWorkAI/AI-Studio.git
				synced 2025-11-04 11:40:21 +00:00 
			
		
		
		
	Implemented a parallel execution for the validation of many retrieval contexts
This commit is contained in:
		
							parent
							
								
									a9f821b67e
								
							
						
					
					
						commit
						3485d3cf5d
					
				@ -149,20 +149,68 @@ public sealed class AgentRetrievalContextValidation (ILogger<AgentRetrievalConte
 | 
				
			|||||||
        this.ProviderSettings = agentProvider;
 | 
					        this.ProviderSettings = agentProvider;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
					    /// <summary>
 | 
				
			||||||
 | 
					    /// Validate all retrieval contexts against the last user and the system prompt.
 | 
				
			||||||
 | 
					    /// </summary>
 | 
				
			||||||
 | 
					    /// <param name="lastPrompt">The last user prompt.</param>
 | 
				
			||||||
 | 
					    /// <param name="chatThread">The chat thread.</param>
 | 
				
			||||||
 | 
					    /// <param name="retrievalContexts">All retrieval contexts to validate.</param>
 | 
				
			||||||
 | 
					    /// <param name="token">The cancellation token.</param>
 | 
				
			||||||
 | 
					    /// <returns>The validation results.</returns>
 | 
				
			||||||
 | 
					    public async Task<IReadOnlyList<RetrievalContextValidationResult>> ValidateRetrievalContextsAsync(IContent lastPrompt, ChatThread chatThread, IReadOnlyList<IRetrievalContext> retrievalContexts, CancellationToken token = default)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        // Check if the retrieval context validation is enabled:
 | 
				
			||||||
 | 
					        if (!this.SettingsManager.ConfigurationData.AgentRetrievalContextValidation.EnableRetrievalContextValidation)
 | 
				
			||||||
 | 
					            return [];
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        // Prepare the list of validation tasks:
 | 
				
			||||||
 | 
					        var validationTasks = new List<Task<RetrievalContextValidationResult>>(retrievalContexts.Count);
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        // Read the number of parallel validations:
 | 
				
			||||||
 | 
					        var numParallelValidations = this.SettingsManager.ConfigurationData.AgentRetrievalContextValidation.NumParallelValidations;
 | 
				
			||||||
 | 
					        numParallelValidations = Math.Max(1, numParallelValidations);
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        // Use a semaphore to limit the number of parallel validations:
 | 
				
			||||||
 | 
					        using var semaphore = new SemaphoreSlim(numParallelValidations);
 | 
				
			||||||
 | 
					        foreach (var retrievalContext in retrievalContexts)
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            // Wait for an available slot in the semaphore:
 | 
				
			||||||
 | 
					            await semaphore.WaitAsync(token);
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            // Start the next validation task:
 | 
				
			||||||
 | 
					            validationTasks.Add(this.ValidateRetrievalContextAsync(lastPrompt, chatThread, retrievalContext, token, semaphore));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Wait for all validation tasks to complete:
 | 
				
			||||||
 | 
					        return await Task.WhenAll(validationTasks);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /// <summary>
 | 
					    /// <summary>
 | 
				
			||||||
    /// Validates the retrieval context against the last user and the system prompt.
 | 
					    /// Validates the retrieval context against the last user and the system prompt.
 | 
				
			||||||
    /// </summary>
 | 
					    /// </summary>
 | 
				
			||||||
    /// <remarks>
 | 
					    /// <remarks>
 | 
				
			||||||
    /// Probably, you have a lot of retrieval contexts to validate. In this case, you
 | 
					    /// Probably, you have a lot of retrieval contexts to validate. In this case, you
 | 
				
			||||||
    /// can call this method in parallel for each retrieval context.
 | 
					    /// can call this method in parallel for each retrieval context. You might use
 | 
				
			||||||
 | 
					    /// the ValidateRetrievalContextsAsync method to validate all retrieval contexts.
 | 
				
			||||||
    /// </remarks>
 | 
					    /// </remarks>
 | 
				
			||||||
    /// <param name="lastPrompt">The last user prompt.</param>
 | 
					    /// <param name="lastPrompt">The last user prompt.</param>
 | 
				
			||||||
    /// <param name="chatThread">The chat thread.</param>
 | 
					    /// <param name="chatThread">The chat thread.</param>
 | 
				
			||||||
    /// <param name="dataContext">The retrieval context to validate.</param>
 | 
					    /// <param name="retrievalContext">The retrieval context to validate.</param>
 | 
				
			||||||
    /// <param name="token">The cancellation token.</param>
 | 
					    /// <param name="token">The cancellation token.</param>
 | 
				
			||||||
 | 
					    /// <param name="semaphore">The optional semaphore to limit the number of parallel validations.</param>
 | 
				
			||||||
    /// <returns>The validation result.</returns>
 | 
					    /// <returns>The validation result.</returns>
 | 
				
			||||||
    public async Task<RetrievalContextValidationResult> ValidateRetrievalContextAsync(IContent lastPrompt, ChatThread chatThread, IRetrievalContext dataContext, CancellationToken token = default)
 | 
					    public async Task<RetrievalContextValidationResult> ValidateRetrievalContextAsync(IContent lastPrompt, ChatThread chatThread, IRetrievalContext retrievalContext, CancellationToken token = default, SemaphoreSlim? semaphore = null)
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
 | 
					        try
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            //
 | 
				
			||||||
 | 
					            // Check if the validation was canceled. This could happen when the user
 | 
				
			||||||
 | 
					            // canceled the validation process or when the validation process took
 | 
				
			||||||
 | 
					            // too long:
 | 
				
			||||||
 | 
					            //
 | 
				
			||||||
 | 
					            if(token.IsCancellationRequested)
 | 
				
			||||||
 | 
					                return new(false, "The validation was canceled.", 1.0f, retrievalContext);
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
            //
 | 
					            //
 | 
				
			||||||
            // 1. Prepare the current system and user prompts as input for the agent: 
 | 
					            // 1. Prepare the current system and user prompts as input for the agent: 
 | 
				
			||||||
            //
 | 
					            //
 | 
				
			||||||
@ -180,14 +228,14 @@ public sealed class AgentRetrievalContextValidation (ILogger<AgentRetrievalConte
 | 
				
			|||||||
            if (string.IsNullOrWhiteSpace(lastPromptContent))
 | 
					            if (string.IsNullOrWhiteSpace(lastPromptContent))
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                logger.LogWarning("The last prompt is empty. The AI cannot validate the retrieval context.");
 | 
					                logger.LogWarning("The last prompt is empty. The AI cannot validate the retrieval context.");
 | 
				
			||||||
            return new(false, "The last prompt was empty.", 1.0f, dataContext);
 | 
					                return new(false, "The last prompt was empty.", 1.0f, retrievalContext);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            //
 | 
					            //
 | 
				
			||||||
            // 2. Prepare the retrieval context for the agent:
 | 
					            // 2. Prepare the retrieval context for the agent:
 | 
				
			||||||
            //
 | 
					            //
 | 
				
			||||||
            var additionalData = new Dictionary<string, string>();
 | 
					            var additionalData = new Dictionary<string, string>();
 | 
				
			||||||
        var markdownRetrievalContext = await dataContext.AsMarkdown(token: token);
 | 
					            var markdownRetrievalContext = await retrievalContext.AsMarkdown(token: token);
 | 
				
			||||||
            additionalData.Add("retrievalContext", markdownRetrievalContext);
 | 
					            additionalData.Add("retrievalContext", markdownRetrievalContext);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            //
 | 
					            //
 | 
				
			||||||
@ -219,10 +267,10 @@ public sealed class AgentRetrievalContextValidation (ILogger<AgentRetrievalConte
 | 
				
			|||||||
                },
 | 
					                },
 | 
				
			||||||
            }, additionalData);
 | 
					            }, additionalData);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if(aiResponse.Content is null)
 | 
					            if (aiResponse.Content is null)
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                logger.LogWarning("The agent did not return a response.");
 | 
					                logger.LogWarning("The agent did not return a response.");
 | 
				
			||||||
            return new(false, "The agent did not return a response.", 1.0f, dataContext);
 | 
					                return new(false, "The agent did not return a response.", 1.0f, retrievalContext);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            switch (aiResponse)
 | 
					            switch (aiResponse)
 | 
				
			||||||
@ -247,26 +295,32 @@ public sealed class AgentRetrievalContextValidation (ILogger<AgentRetrievalConte
 | 
				
			|||||||
                    try
 | 
					                    try
 | 
				
			||||||
                    {
 | 
					                    {
 | 
				
			||||||
                        var result = JsonSerializer.Deserialize<RetrievalContextValidationResult>(json, JSON_SERIALIZER_OPTIONS);
 | 
					                        var result = JsonSerializer.Deserialize<RetrievalContextValidationResult>(json, JSON_SERIALIZER_OPTIONS);
 | 
				
			||||||
                    return result with { RetrievalContext = dataContext };
 | 
					                        return result with { RetrievalContext = retrievalContext };
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    catch
 | 
					                    catch
 | 
				
			||||||
                    {
 | 
					                    {
 | 
				
			||||||
                        logger.LogWarning("The agent answered with an invalid or unexpected JSON format.");
 | 
					                        logger.LogWarning("The agent answered with an invalid or unexpected JSON format.");
 | 
				
			||||||
                    return new(false, "The agent answered with an invalid or unexpected JSON format.", 1.0f, dataContext);
 | 
					                        return new(false, "The agent answered with an invalid or unexpected JSON format.", 1.0f, retrievalContext);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                case { ContentType: ContentType.TEXT }:
 | 
					                case { ContentType: ContentType.TEXT }:
 | 
				
			||||||
                    logger.LogWarning("The agent answered with an unexpected inner content type.");
 | 
					                    logger.LogWarning("The agent answered with an unexpected inner content type.");
 | 
				
			||||||
                return new(false, "The agent answered with an unexpected inner content type.", 1.0f, dataContext);
 | 
					                    return new(false, "The agent answered with an unexpected inner content type.", 1.0f, retrievalContext);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                case { ContentType: ContentType.NONE }:
 | 
					                case { ContentType: ContentType.NONE }:
 | 
				
			||||||
                    logger.LogWarning("The agent did not return a response.");
 | 
					                    logger.LogWarning("The agent did not return a response.");
 | 
				
			||||||
                return new(false, "The agent did not return a response.", 1.0f, dataContext);
 | 
					                    return new(false, "The agent did not return a response.", 1.0f, retrievalContext);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                default:
 | 
					                default:
 | 
				
			||||||
                    logger.LogWarning($"The agent answered with an unexpected content type '{aiResponse.ContentType}'.");
 | 
					                    logger.LogWarning($"The agent answered with an unexpected content type '{aiResponse.ContentType}'.");
 | 
				
			||||||
                return new(false, $"The agent answered with an unexpected content type '{aiResponse.ContentType}'.", 1.0f, dataContext);
 | 
					                    return new(false, $"The agent answered with an unexpected content type '{aiResponse.ContentType}'.", 1.0f, retrievalContext);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        finally
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            // Release the semaphore slot:
 | 
				
			||||||
 | 
					            semaphore?.Release();
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user