From b05d44c2f42b7d7e11daacba9787b2b024ff1759 Mon Sep 17 00:00:00 2001 From: Thorsten Sommer Date: Mon, 17 Feb 2025 14:09:54 +0100 Subject: [PATCH] Implemented the generic retrieval process --- app/MindWork AI Studio/Chat/ContentText.cs | 24 ++++- app/MindWork AI Studio/Chat/IContent.cs | 11 +++ .../Settings/DataModel/DataSourceERI_V1.cs | 88 +++++++++++++++++++ .../DataModel/DataSourceLocalDirectory.cs | 10 +++ .../Settings/DataModel/DataSourceLocalFile.cs | 10 +++ .../Settings/IDataSource.cs | 11 +++ 6 files changed, 153 insertions(+), 1 deletion(-) diff --git a/app/MindWork AI Studio/Chat/ContentText.cs b/app/MindWork AI Studio/Chat/ContentText.cs index e9f51e54..75912e3e 100644 --- a/app/MindWork AI Studio/Chat/ContentText.cs +++ b/app/MindWork AI Studio/Chat/ContentText.cs @@ -4,6 +4,7 @@ using AIStudio.Agents; using AIStudio.Components; using AIStudio.Provider; using AIStudio.Settings; +using AIStudio.Tools.RAG; using AIStudio.Tools.Services; namespace AIStudio.Chat; @@ -199,9 +200,30 @@ public sealed class ContentText : IContent // // Trigger the retrieval part of the (R)AG process: // + var dataContexts = new List(); if (proceedWithRAG) { - + // + // We kick off the retrieval process for each data source in parallel: + // + var retrievalTasks = new List>>(selectedDataSources.Count); + foreach (var dataSource in selectedDataSources) + retrievalTasks.Add(dataSource.RetrieveDataAsync(lastPrompt, chatThread, token)); + + // + // Wait for all retrieval tasks to finish: + // + foreach (var retrievalTask in retrievalTasks) + { + try + { + dataContexts.AddRange(await retrievalTask); + } + catch (Exception e) + { + logger.LogError(e, "An error occurred during the retrieval process."); + } + } } // diff --git a/app/MindWork AI Studio/Chat/IContent.cs b/app/MindWork AI Studio/Chat/IContent.cs index 08e8817b..8ca94025 100644 --- a/app/MindWork AI Studio/Chat/IContent.cs +++ b/app/MindWork AI Studio/Chat/IContent.cs @@ -42,4 +42,15 @@ public interface IContent /// Uses the provider to create the content. /// public Task CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default); + + /// + /// Returns the corresponding ERI content type. + /// + public Tools.ERIClient.DataModel.ContentType ToERIContentType => this switch + { + ContentText => Tools.ERIClient.DataModel.ContentType.TEXT, + ContentImage => Tools.ERIClient.DataModel.ContentType.IMAGE, + + _ => Tools.ERIClient.DataModel.ContentType.UNKNOWN, + }; } \ No newline at end of file diff --git a/app/MindWork AI Studio/Settings/DataModel/DataSourceERI_V1.cs b/app/MindWork AI Studio/Settings/DataModel/DataSourceERI_V1.cs index 328161e1..b591de00 100644 --- a/app/MindWork AI Studio/Settings/DataModel/DataSourceERI_V1.cs +++ b/app/MindWork AI Studio/Settings/DataModel/DataSourceERI_V1.cs @@ -1,7 +1,14 @@ // ReSharper disable InconsistentNaming using AIStudio.Assistants.ERI; +using AIStudio.Chat; +using AIStudio.Tools.ERIClient; using AIStudio.Tools.ERIClient.DataModel; +using AIStudio.Tools.RAG; +using AIStudio.Tools.Services; + +using ChatThread = AIStudio.Chat.ChatThread; +using ContentType = AIStudio.Tools.ERIClient.DataModel.ContentType; namespace AIStudio.Settings.DataModel; @@ -43,4 +50,85 @@ public readonly record struct DataSourceERI_V1 : IERIDataSource /// public ERIVersion Version { get; init; } = ERIVersion.V1; + + /// + public async Task> RetrieveDataAsync(IContent lastPrompt, ChatThread thread, CancellationToken token = default) + { + // Important: Do not dispose the RustService here, as it is a singleton. + var rustService = Program.SERVICE_PROVIDER.GetRequiredService(); + var logger = Program.SERVICE_PROVIDER.GetRequiredService>(); + + using var eriClient = ERIClientFactory.Get(this.Version, this)!; + var authResponse = await eriClient.AuthenticateAsync(this, rustService, token); + if (authResponse.Successful) + { + var retrievalRequest = new RetrievalRequest + { + LatestUserPromptType = lastPrompt.ToERIContentType, + LatestUserPrompt = lastPrompt switch + { + ContentText text => text.Text, + ContentImage image => await image.AsBase64(token), + _ => string.Empty + }, + + Thread = await thread.ToERIChatThread(token), + MaxMatches = 10, + RetrievalProcessId = null, // The ERI server selects the retrieval process when multiple processes are available + Parameters = null, // The ERI server selects useful default parameters + }; + + var retrievalResponse = await eriClient.ExecuteRetrievalAsync(retrievalRequest, token); + if(retrievalResponse is { Successful: true, Data: not null }) + { + // + // Next, we have to transform the ERI context back to our generic retrieval context: + // + var genericRetrievalContexts = new List(retrievalResponse.Data.Count); + foreach (var eriContext in retrievalResponse.Data) + { + switch (eriContext.Type) + { + case ContentType.TEXT: + genericRetrievalContexts.Add(new RetrievalTextContext + { + Path = eriContext.Path ?? string.Empty, + Type = eriContext.ToRetrievalContentType(), + Links = eriContext.Links, + Category = RetrievalContentCategory.TEXT, + MatchedText = eriContext.MatchedContent, + DataSourceName = eriContext.Name, + SurroundingContent = eriContext.SurroundingContent, + }); + break; + + case ContentType.IMAGE: + genericRetrievalContexts.Add(new RetrievalImageContext + { + Path = eriContext.Path ?? string.Empty, + Type = eriContext.ToRetrievalContentType(), + Links = eriContext.Links, + Source = eriContext.MatchedContent, + Category = RetrievalContentCategory.IMAGE, + SourceType = ContentImageSource.BASE64, + DataSourceName = eriContext.Name, + }); + break; + + default: + logger.LogWarning($"The ERI context type '{eriContext.Type}' is not supported yet."); + break; + } + } + + return genericRetrievalContexts; + } + + logger.LogWarning($"Was not able to retrieve data from the ERI data source '{this.Name}'. Message: {retrievalResponse.Message}"); + return []; + } + + logger.LogWarning($"Was not able to authenticate with the ERI data source '{this.Name}'. Message: {authResponse.Message}"); + return []; + } } \ No newline at end of file diff --git a/app/MindWork AI Studio/Settings/DataModel/DataSourceLocalDirectory.cs b/app/MindWork AI Studio/Settings/DataModel/DataSourceLocalDirectory.cs index 61c30d93..d81e30db 100644 --- a/app/MindWork AI Studio/Settings/DataModel/DataSourceLocalDirectory.cs +++ b/app/MindWork AI Studio/Settings/DataModel/DataSourceLocalDirectory.cs @@ -1,3 +1,6 @@ +using AIStudio.Chat; +using AIStudio.Tools.RAG; + namespace AIStudio.Settings.DataModel; /// @@ -27,6 +30,13 @@ public readonly record struct DataSourceLocalDirectory : IInternalDataSource /// public DataSourceSecurity SecurityPolicy { get; init; } = DataSourceSecurity.NOT_SPECIFIED; + /// + public Task> RetrieveDataAsync(IContent lastPrompt, ChatThread thread, CancellationToken token = default) + { + IReadOnlyList retrievalContext = new List(); + return Task.FromResult(retrievalContext); + } + /// /// The path to the directory. /// diff --git a/app/MindWork AI Studio/Settings/DataModel/DataSourceLocalFile.cs b/app/MindWork AI Studio/Settings/DataModel/DataSourceLocalFile.cs index 571fb0a8..5788a2a6 100644 --- a/app/MindWork AI Studio/Settings/DataModel/DataSourceLocalFile.cs +++ b/app/MindWork AI Studio/Settings/DataModel/DataSourceLocalFile.cs @@ -1,3 +1,6 @@ +using AIStudio.Chat; +using AIStudio.Tools.RAG; + namespace AIStudio.Settings.DataModel; /// @@ -27,6 +30,13 @@ public readonly record struct DataSourceLocalFile : IInternalDataSource /// public DataSourceSecurity SecurityPolicy { get; init; } = DataSourceSecurity.NOT_SPECIFIED; + /// + public Task> RetrieveDataAsync(IContent lastPrompt, ChatThread thread, CancellationToken token = default) + { + IReadOnlyList retrievalContext = new List(); + return Task.FromResult(retrievalContext); + } + /// /// The path to the file. /// diff --git a/app/MindWork AI Studio/Settings/IDataSource.cs b/app/MindWork AI Studio/Settings/IDataSource.cs index 72f4ad3c..7ee47e1c 100644 --- a/app/MindWork AI Studio/Settings/IDataSource.cs +++ b/app/MindWork AI Studio/Settings/IDataSource.cs @@ -1,6 +1,8 @@ using System.Text.Json.Serialization; +using AIStudio.Chat; using AIStudio.Settings.DataModel; +using AIStudio.Tools.RAG; namespace AIStudio.Settings; @@ -37,4 +39,13 @@ public interface IDataSource /// Which data security policy is applied to this data source? /// public DataSourceSecurity SecurityPolicy { get; init; } + + /// + /// Perform the data retrieval process. + /// + /// The last prompt from the chat. + /// The chat thread. + /// The cancellation token. + /// The retrieved data context. + public Task> RetrieveDataAsync(IContent lastPrompt, ChatThread thread, CancellationToken token = default); } \ No newline at end of file