Execute retrieval processes (#286)

This commit is contained in:
Thorsten Sommer 2025-02-17 14:12:46 +01:00 committed by GitHub
parent 7c59aa11fe
commit 77d427610b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 198 additions and 3 deletions

View File

@ -19,7 +19,7 @@ Things we are currently working on:
- [ ] App: Implement the process to vectorize one local file using embeddings
- [ ] Runtime: Integration of the vector database [LanceDB](https://github.com/lancedb/lancedb)
- [ ] App: Implement the continuous process of vectorizing data
- [x] ~~App: Define a common retrieval context interface for the integration of RAG processes in chats (PR [#281](https://github.com/MindWorkAI/AI-Studio/pull/281), [#284](https://github.com/MindWorkAI/AI-Studio/pull/284))~~
- [x] ~~App: Define a common retrieval context interface for the integration of RAG processes in chats (PR [#281](https://github.com/MindWorkAI/AI-Studio/pull/281), [#284](https://github.com/MindWorkAI/AI-Studio/pull/284), [#286](https://github.com/MindWorkAI/AI-Studio/pull/286))~~
- [ ] App: Define a common augmentation interface for the integration of RAG processes in chats
- [x] ~~App: Integrate data sources in chats (PR [#282](https://github.com/MindWorkAI/AI-Studio/pull/282))~~

View File

@ -1,6 +1,7 @@
using AIStudio.Components;
using AIStudio.Settings;
using AIStudio.Settings.DataModel;
using AIStudio.Tools.ERIClient.DataModel;
namespace AIStudio.Chat;
@ -150,4 +151,46 @@ public sealed record ChatThread
// Remove the block from the chat thread:
this.Blocks.Remove(block);
}
/// <summary>
/// Transforms this chat thread to an ERI chat thread.
/// </summary>
/// <param name="token">The cancellation token.</param>
/// <returns>The ERI chat thread.</returns>
public async Task<Tools.ERIClient.DataModel.ChatThread> ToERIChatThread(CancellationToken token = default)
{
//
// Transform the content blocks:
//
var contentBlocks = new List<Tools.ERIClient.DataModel.ContentBlock>(this.Blocks.Count);
foreach (var block in this.Blocks)
{
var (contentData, contentType) = block.Content switch
{
ContentImage image => (await image.AsBase64(token), Tools.ERIClient.DataModel.ContentType.IMAGE),
ContentText text => (text.Text, Tools.ERIClient.DataModel.ContentType.TEXT),
_ => (string.Empty, Tools.ERIClient.DataModel.ContentType.UNKNOWN),
};
contentBlocks.Add(new Tools.ERIClient.DataModel.ContentBlock
{
Role = block.Role switch
{
ChatRole.AI => Role.AI,
ChatRole.USER => Role.USER,
ChatRole.AGENT => Role.AGENT,
ChatRole.SYSTEM => Role.SYSTEM,
ChatRole.NONE => Role.NONE,
_ => Role.UNKNOW,
},
Content = contentData,
Type = contentType,
});
}
return new Tools.ERIClient.DataModel.ChatThread { ContentBlocks = contentBlocks };
}
}

View File

@ -4,6 +4,7 @@ using AIStudio.Agents;
using AIStudio.Components;
using AIStudio.Provider;
using AIStudio.Settings;
using AIStudio.Tools.RAG;
using AIStudio.Tools.Services;
namespace AIStudio.Chat;
@ -199,9 +200,30 @@ public sealed class ContentText : IContent
//
// Trigger the retrieval part of the (R)AG process:
//
var dataContexts = new List<IRetrievalContext>();
if (proceedWithRAG)
{
//
// We kick off the retrieval process for each data source in parallel:
//
var retrievalTasks = new List<Task<IReadOnlyList<IRetrievalContext>>>(selectedDataSources.Count);
foreach (var dataSource in selectedDataSources)
retrievalTasks.Add(dataSource.RetrieveDataAsync(lastPrompt, chatThread, token));
//
// Wait for all retrieval tasks to finish:
//
foreach (var retrievalTask in retrievalTasks)
{
try
{
dataContexts.AddRange(await retrievalTask);
}
catch (Exception e)
{
logger.LogError(e, "An error occurred during the retrieval process.");
}
}
}
//

View File

@ -42,4 +42,15 @@ public interface IContent
/// Uses the provider to create the content.
/// </summary>
public Task CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default);
/// <summary>
/// Returns the corresponding ERI content type.
/// </summary>
public Tools.ERIClient.DataModel.ContentType ToERIContentType => this switch
{
ContentText => Tools.ERIClient.DataModel.ContentType.TEXT,
ContentImage => Tools.ERIClient.DataModel.ContentType.IMAGE,
_ => Tools.ERIClient.DataModel.ContentType.UNKNOWN,
};
}

View File

@ -1,7 +1,14 @@
// ReSharper disable InconsistentNaming
using AIStudio.Assistants.ERI;
using AIStudio.Chat;
using AIStudio.Tools.ERIClient;
using AIStudio.Tools.ERIClient.DataModel;
using AIStudio.Tools.RAG;
using AIStudio.Tools.Services;
using ChatThread = AIStudio.Chat.ChatThread;
using ContentType = AIStudio.Tools.ERIClient.DataModel.ContentType;
namespace AIStudio.Settings.DataModel;
@ -43,4 +50,85 @@ public readonly record struct DataSourceERI_V1 : IERIDataSource
/// <inheritdoc />
public ERIVersion Version { get; init; } = ERIVersion.V1;
/// <inheritdoc />
public async Task<IReadOnlyList<IRetrievalContext>> RetrieveDataAsync(IContent lastPrompt, ChatThread thread, CancellationToken token = default)
{
// Important: Do not dispose the RustService here, as it is a singleton.
var rustService = Program.SERVICE_PROVIDER.GetRequiredService<RustService>();
var logger = Program.SERVICE_PROVIDER.GetRequiredService<ILogger<DataSourceERI_V1>>();
using var eriClient = ERIClientFactory.Get(this.Version, this)!;
var authResponse = await eriClient.AuthenticateAsync(this, rustService, token);
if (authResponse.Successful)
{
var retrievalRequest = new RetrievalRequest
{
LatestUserPromptType = lastPrompt.ToERIContentType,
LatestUserPrompt = lastPrompt switch
{
ContentText text => text.Text,
ContentImage image => await image.AsBase64(token),
_ => string.Empty
},
Thread = await thread.ToERIChatThread(token),
MaxMatches = 10,
RetrievalProcessId = null, // The ERI server selects the retrieval process when multiple processes are available
Parameters = null, // The ERI server selects useful default parameters
};
var retrievalResponse = await eriClient.ExecuteRetrievalAsync(retrievalRequest, token);
if(retrievalResponse is { Successful: true, Data: not null })
{
//
// Next, we have to transform the ERI context back to our generic retrieval context:
//
var genericRetrievalContexts = new List<IRetrievalContext>(retrievalResponse.Data.Count);
foreach (var eriContext in retrievalResponse.Data)
{
switch (eriContext.Type)
{
case ContentType.TEXT:
genericRetrievalContexts.Add(new RetrievalTextContext
{
Path = eriContext.Path ?? string.Empty,
Type = eriContext.ToRetrievalContentType(),
Links = eriContext.Links,
Category = RetrievalContentCategory.TEXT,
MatchedText = eriContext.MatchedContent,
DataSourceName = eriContext.Name,
SurroundingContent = eriContext.SurroundingContent,
});
break;
case ContentType.IMAGE:
genericRetrievalContexts.Add(new RetrievalImageContext
{
Path = eriContext.Path ?? string.Empty,
Type = eriContext.ToRetrievalContentType(),
Links = eriContext.Links,
Source = eriContext.MatchedContent,
Category = RetrievalContentCategory.IMAGE,
SourceType = ContentImageSource.BASE64,
DataSourceName = eriContext.Name,
});
break;
default:
logger.LogWarning($"The ERI context type '{eriContext.Type}' is not supported yet.");
break;
}
}
return genericRetrievalContexts;
}
logger.LogWarning($"Was not able to retrieve data from the ERI data source '{this.Name}'. Message: {retrievalResponse.Message}");
return [];
}
logger.LogWarning($"Was not able to authenticate with the ERI data source '{this.Name}'. Message: {authResponse.Message}");
return [];
}
}

View File

@ -1,3 +1,6 @@
using AIStudio.Chat;
using AIStudio.Tools.RAG;
namespace AIStudio.Settings.DataModel;
/// <summary>
@ -27,6 +30,13 @@ public readonly record struct DataSourceLocalDirectory : IInternalDataSource
/// <inheritdoc />
public DataSourceSecurity SecurityPolicy { get; init; } = DataSourceSecurity.NOT_SPECIFIED;
/// <inheritdoc />
public Task<IReadOnlyList<IRetrievalContext>> RetrieveDataAsync(IContent lastPrompt, ChatThread thread, CancellationToken token = default)
{
IReadOnlyList<IRetrievalContext> retrievalContext = new List<IRetrievalContext>();
return Task.FromResult(retrievalContext);
}
/// <summary>
/// The path to the directory.
/// </summary>

View File

@ -1,3 +1,6 @@
using AIStudio.Chat;
using AIStudio.Tools.RAG;
namespace AIStudio.Settings.DataModel;
/// <summary>
@ -27,6 +30,13 @@ public readonly record struct DataSourceLocalFile : IInternalDataSource
/// <inheritdoc />
public DataSourceSecurity SecurityPolicy { get; init; } = DataSourceSecurity.NOT_SPECIFIED;
/// <inheritdoc />
public Task<IReadOnlyList<IRetrievalContext>> RetrieveDataAsync(IContent lastPrompt, ChatThread thread, CancellationToken token = default)
{
IReadOnlyList<IRetrievalContext> retrievalContext = new List<IRetrievalContext>();
return Task.FromResult(retrievalContext);
}
/// <summary>
/// The path to the file.
/// </summary>

View File

@ -1,6 +1,8 @@
using System.Text.Json.Serialization;
using AIStudio.Chat;
using AIStudio.Settings.DataModel;
using AIStudio.Tools.RAG;
namespace AIStudio.Settings;
@ -37,4 +39,13 @@ public interface IDataSource
/// Which data security policy is applied to this data source?
/// </summary>
public DataSourceSecurity SecurityPolicy { get; init; }
/// <summary>
/// Perform the data retrieval process.
/// </summary>
/// <param name="lastPrompt">The last prompt from the chat.</param>
/// <param name="thread">The chat thread.</param>
/// <param name="token">The cancellation token.</param>
/// <returns>The retrieved data context.</returns>
public Task<IReadOnlyList<IRetrievalContext>> RetrieveDataAsync(IContent lastPrompt, ChatThread thread, CancellationToken token = default);
}

View File

@ -4,7 +4,7 @@ namespace AIStudio.Tools.RAG;
public static class RetrievalContentTypeExtensions
{
public static RetrievalContentType ToRetrievalContentType(Context eriContext)
public static RetrievalContentType ToRetrievalContentType(this Context eriContext)
{
//
// Right now, we have to parse the category string along the type enum to