using System.Text.Json; using AIStudio.Chat; using AIStudio.Provider; using AIStudio.Settings; using AIStudio.Tools.RAG; using AIStudio.Tools.Services; namespace AIStudio.Agents; public sealed class AgentRetrievalContextValidation (ILogger logger, ILogger baseLogger, SettingsManager settingsManager, DataSourceService dataSourceService, ThreadSafeRandom rng) : AgentBase(baseLogger, settingsManager, dataSourceService, rng) { #region Overrides of AgentBase /// protected override Type Type => Type.WORKER; /// public override string Id => "Retrieval Context Validation"; /// protected override string JobDescription => """ You receive a system and user prompt as well as a retrieval context as input. Your task is to decide whether this retrieval context is helpful in processing the prompts or not. You respond with the decision (true or false), your reasoning, and your confidence in this decision. Your response is only one JSON object in the following format: ``` {"decision": true, "reason": "Why did you choose this source?", "confidence": 0.87} ``` You express your confidence as a floating-point number between 0.0 (maximum uncertainty) and 1.0 (you are absolutely certain that this retrieval context is needed). The JSON schema is: ``` { "$schema": "http://json-schema.org/draft-04/schema#", "type": "object", "properties": { "decision": { "type": "boolean" }, "reason": { "type": "string" }, "confidence": { "type": "number" } }, "required": [ "decision", "reason", "confidence" ] } ``` You do not ask any follow-up questions. You do not address the user. Your response consists solely of that one JSON object. """; /// protected override string SystemPrompt(string retrievalContext) => $""" {this.JobDescription} {retrievalContext} """; /// public override Settings.Provider? ProviderSettings { get; set; } /// /// The retrieval context validation agent does not work with context. Use /// the process input method instead. /// /// The chat thread without any changes. public override Task ProcessContext(ChatThread chatThread, IDictionary additionalData) => Task.FromResult(chatThread); /// public override async Task ProcessInput(ContentBlock input, IDictionary additionalData) { if (input.Content is not ContentText text) return EMPTY_BLOCK; if(text.InitialRemoteWait || text.IsStreaming) return EMPTY_BLOCK; if(string.IsNullOrWhiteSpace(text.Text)) return EMPTY_BLOCK; if(!additionalData.TryGetValue("retrievalContext", out var retrievalContext) || string.IsNullOrWhiteSpace(retrievalContext)) return EMPTY_BLOCK; var thread = this.CreateChatThread(this.SystemPrompt(retrievalContext)); var userRequest = this.AddUserRequest(thread, text.Text); await this.AddAIResponseAsync(thread, userRequest.UserPrompt, userRequest.Time); return thread.Blocks[^1]; } /// public override Task MadeDecision(ContentBlock input) => Task.FromResult(true); /// /// We do not provide any context. This agent will process many retrieval contexts. /// This would block a huge amount of memory. /// /// An empty list. public override IReadOnlyCollection GetContext() => []; /// /// We do not provide any answers. This agent will process many retrieval contexts. /// This would block a huge amount of memory. /// /// An empty list. public override IReadOnlyCollection GetAnswers() => []; #endregion /// /// Sets the LLM provider for the agent. /// /// /// When you have to call the validation in parallel for many retrieval contexts, /// you can set the provider once and then call the validation method in parallel. /// /// The current LLM provider. When the user doesn't preselect an agent provider, the agent uses this provider. public void SetLLMProvider(IProvider provider) { // We start with the provider currently selected by the user: var agentProvider = this.SettingsManager.GetPreselectedProvider(Tools.Components.AGENT_RETRIEVAL_CONTEXT_VALIDATION, provider.Id, true); // Assign the provider settings to the agent: logger.LogInformation($"The agent for the retrieval context validation uses the provider '{agentProvider.InstanceName}' ({agentProvider.UsedLLMProvider.ToName()}, confidence={agentProvider.UsedLLMProvider.GetConfidence(this.SettingsManager).Level.GetName()})."); this.ProviderSettings = agentProvider; } /// /// Validate all retrieval contexts against the last user and the system prompt. /// /// The last user prompt. /// The chat thread. /// All retrieval contexts to validate. /// The cancellation token. /// The validation results. public async Task> ValidateRetrievalContextsAsync(IContent lastPrompt, ChatThread chatThread, IReadOnlyList retrievalContexts, CancellationToken token = default) { // Check if the retrieval context validation is enabled: if (!this.SettingsManager.ConfigurationData.AgentRetrievalContextValidation.EnableRetrievalContextValidation) return []; logger.LogInformation($"Validating {retrievalContexts.Count:###,###,###,###} retrieval contexts."); // Prepare the list of validation tasks: var validationTasks = new List>(retrievalContexts.Count); // Read the number of parallel validations: var numParallelValidations = 3; if(this.SettingsManager.ConfigurationData.AgentRetrievalContextValidation.PreselectAgentOptions) numParallelValidations = this.SettingsManager.ConfigurationData.AgentRetrievalContextValidation.NumParallelValidations; numParallelValidations = Math.Max(1, numParallelValidations); // Use a semaphore to limit the number of parallel validations: using var semaphore = new SemaphoreSlim(numParallelValidations); foreach (var retrievalContext in retrievalContexts) { // Wait for an available slot in the semaphore: await semaphore.WaitAsync(token); // Start the next validation task: validationTasks.Add(this.ValidateRetrievalContextAsync(lastPrompt, chatThread, retrievalContext, token, semaphore)); } // Wait for all validation tasks to complete: return await Task.WhenAll(validationTasks); } /// /// Validates the retrieval context against the last user and the system prompt. /// /// /// Probably, you have a lot of retrieval contexts to validate. In this case, you /// can call this method in parallel for each retrieval context. You might use /// the ValidateRetrievalContextsAsync method to validate all retrieval contexts. /// /// The last user prompt. /// The chat thread. /// The retrieval context to validate. /// The cancellation token. /// The optional semaphore to limit the number of parallel validations. /// The validation result. public async Task ValidateRetrievalContextAsync(IContent lastPrompt, ChatThread chatThread, IRetrievalContext retrievalContext, CancellationToken token = default, SemaphoreSlim? semaphore = null) { try { // // Check if the validation was canceled. This could happen when the user // canceled the validation process or when the validation process took // too long: // if(token.IsCancellationRequested) return new(false, "The validation was canceled.", 1.0f, retrievalContext); // // 1. Prepare the current system and user prompts as input for the agent: // var lastPromptContent = lastPrompt switch { ContentText text => text.Text, // Image prompts may be empty, e.g., when the image is too large: ContentImage image => await image.AsBase64(token), // Other content types are not supported yet: _ => string.Empty, }; if (string.IsNullOrWhiteSpace(lastPromptContent)) { logger.LogWarning("The last prompt is empty. The AI cannot validate the retrieval context."); return new(false, "The last prompt was empty.", 1.0f, retrievalContext); } // // 2. Prepare the retrieval context for the agent: // var additionalData = new Dictionary(); var markdownRetrievalContext = await retrievalContext.AsMarkdown(token: token); additionalData.Add("retrievalContext", markdownRetrievalContext); // // 3. Let the agent validate the retrieval context: // var prompt = $""" The system prompt is: ``` {chatThread.SystemPrompt} ``` The user prompt is: ``` {lastPromptContent} ``` """; // Call the agent: var aiResponse = await this.ProcessInput(new ContentBlock { Time = DateTimeOffset.UtcNow, ContentType = ContentType.TEXT, Role = ChatRole.USER, Content = new ContentText { Text = prompt, }, }, additionalData); if (aiResponse.Content is null) { logger.LogWarning("The agent did not return a response."); return new(false, "The agent did not return a response.", 1.0f, retrievalContext); } switch (aiResponse) { // // 4. Parse the agent response: // case { ContentType: ContentType.TEXT, Content: ContentText textContent }: { // // What we expect is one JSON object: // var validationJson = textContent.Text; // // We know how bad LLM may be in generating JSON without surrounding text. // Thus, we expect the worst and try to extract the JSON list from the text: // var json = ExtractJson(validationJson); try { var result = JsonSerializer.Deserialize(json, JSON_SERIALIZER_OPTIONS); return result with { RetrievalContext = retrievalContext }; } catch { logger.LogWarning("The agent answered with an invalid or unexpected JSON format."); return new(false, "The agent answered with an invalid or unexpected JSON format.", 1.0f, retrievalContext); } } case { ContentType: ContentType.TEXT }: logger.LogWarning("The agent answered with an unexpected inner content type."); return new(false, "The agent answered with an unexpected inner content type.", 1.0f, retrievalContext); case { ContentType: ContentType.NONE }: logger.LogWarning("The agent did not return a response."); return new(false, "The agent did not return a response.", 1.0f, retrievalContext); default: logger.LogWarning($"The agent answered with an unexpected content type '{aiResponse.ContentType}'."); return new(false, $"The agent answered with an unexpected content type '{aiResponse.ContentType}'.", 1.0f, retrievalContext); } } finally { // Release the semaphore slot: semaphore?.Release(); } } // A wrapper around the span version, because we need to call this method from an async context. private static string ExtractJson(string text) => ExtractJson(text.AsSpan()).ToString(); private static ReadOnlySpan ExtractJson(ReadOnlySpan input) { // // 1. Expect the best case ;-) // if (CheckJsonObjectStart(input)) return ExtractJsonPart(input); // // 2. Okay, we have some garbage before the // JSON object. We expected that... // for (var index = 0; index < input.Length; index++) { if (input[index] is '{' && CheckJsonObjectStart(input[index..])) return ExtractJsonPart(input[index..]); } return []; } private static bool CheckJsonObjectStart(ReadOnlySpan area) { char[] expectedSymbols = ['{', '"', 'd']; var symbolIndex = 0; foreach (var c in area) { if (symbolIndex >= expectedSymbols.Length) return true; if (char.IsWhiteSpace(c)) continue; if (c == expectedSymbols[symbolIndex++]) continue; return false; } return true; } private static ReadOnlySpan ExtractJsonPart(ReadOnlySpan input) { var insideString = false; for (var index = 0; index < input.Length; index++) { if (input[index] is '"') { insideString = !insideString; continue; } if (insideString) continue; if (input[index] is '}') return input[..++index]; } return []; } }