mirror of
https://github.com/MindWorkAI/AI-Studio.git
synced 2025-04-28 17:19:47 +00:00
Implemented the generic retrieval process
This commit is contained in:
parent
0c029177c0
commit
b05d44c2f4
@ -4,6 +4,7 @@ using AIStudio.Agents;
|
|||||||
using AIStudio.Components;
|
using AIStudio.Components;
|
||||||
using AIStudio.Provider;
|
using AIStudio.Provider;
|
||||||
using AIStudio.Settings;
|
using AIStudio.Settings;
|
||||||
|
using AIStudio.Tools.RAG;
|
||||||
using AIStudio.Tools.Services;
|
using AIStudio.Tools.Services;
|
||||||
|
|
||||||
namespace AIStudio.Chat;
|
namespace AIStudio.Chat;
|
||||||
@ -199,9 +200,30 @@ public sealed class ContentText : IContent
|
|||||||
//
|
//
|
||||||
// Trigger the retrieval part of the (R)AG process:
|
// Trigger the retrieval part of the (R)AG process:
|
||||||
//
|
//
|
||||||
|
var dataContexts = new List<IRetrievalContext>();
|
||||||
if (proceedWithRAG)
|
if (proceedWithRAG)
|
||||||
{
|
{
|
||||||
|
//
|
||||||
|
// We kick off the retrieval process for each data source in parallel:
|
||||||
|
//
|
||||||
|
var retrievalTasks = new List<Task<IReadOnlyList<IRetrievalContext>>>(selectedDataSources.Count);
|
||||||
|
foreach (var dataSource in selectedDataSources)
|
||||||
|
retrievalTasks.Add(dataSource.RetrieveDataAsync(lastPrompt, chatThread, token));
|
||||||
|
|
||||||
|
//
|
||||||
|
// Wait for all retrieval tasks to finish:
|
||||||
|
//
|
||||||
|
foreach (var retrievalTask in retrievalTasks)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
dataContexts.AddRange(await retrievalTask);
|
||||||
|
}
|
||||||
|
catch (Exception e)
|
||||||
|
{
|
||||||
|
logger.LogError(e, "An error occurred during the retrieval process.");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -42,4 +42,15 @@ public interface IContent
|
|||||||
/// Uses the provider to create the content.
|
/// Uses the provider to create the content.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public Task CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default);
|
public Task CreateFromProviderAsync(IProvider provider, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Returns the corresponding ERI content type.
|
||||||
|
/// </summary>
|
||||||
|
public Tools.ERIClient.DataModel.ContentType ToERIContentType => this switch
|
||||||
|
{
|
||||||
|
ContentText => Tools.ERIClient.DataModel.ContentType.TEXT,
|
||||||
|
ContentImage => Tools.ERIClient.DataModel.ContentType.IMAGE,
|
||||||
|
|
||||||
|
_ => Tools.ERIClient.DataModel.ContentType.UNKNOWN,
|
||||||
|
};
|
||||||
}
|
}
|
@ -1,7 +1,14 @@
|
|||||||
// ReSharper disable InconsistentNaming
|
// ReSharper disable InconsistentNaming
|
||||||
|
|
||||||
using AIStudio.Assistants.ERI;
|
using AIStudio.Assistants.ERI;
|
||||||
|
using AIStudio.Chat;
|
||||||
|
using AIStudio.Tools.ERIClient;
|
||||||
using AIStudio.Tools.ERIClient.DataModel;
|
using AIStudio.Tools.ERIClient.DataModel;
|
||||||
|
using AIStudio.Tools.RAG;
|
||||||
|
using AIStudio.Tools.Services;
|
||||||
|
|
||||||
|
using ChatThread = AIStudio.Chat.ChatThread;
|
||||||
|
using ContentType = AIStudio.Tools.ERIClient.DataModel.ContentType;
|
||||||
|
|
||||||
namespace AIStudio.Settings.DataModel;
|
namespace AIStudio.Settings.DataModel;
|
||||||
|
|
||||||
@ -43,4 +50,85 @@ public readonly record struct DataSourceERI_V1 : IERIDataSource
|
|||||||
|
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public ERIVersion Version { get; init; } = ERIVersion.V1;
|
public ERIVersion Version { get; init; } = ERIVersion.V1;
|
||||||
|
|
||||||
|
/// <inheritdoc />
|
||||||
|
public async Task<IReadOnlyList<IRetrievalContext>> RetrieveDataAsync(IContent lastPrompt, ChatThread thread, CancellationToken token = default)
|
||||||
|
{
|
||||||
|
// Important: Do not dispose the RustService here, as it is a singleton.
|
||||||
|
var rustService = Program.SERVICE_PROVIDER.GetRequiredService<RustService>();
|
||||||
|
var logger = Program.SERVICE_PROVIDER.GetRequiredService<ILogger<DataSourceERI_V1>>();
|
||||||
|
|
||||||
|
using var eriClient = ERIClientFactory.Get(this.Version, this)!;
|
||||||
|
var authResponse = await eriClient.AuthenticateAsync(this, rustService, token);
|
||||||
|
if (authResponse.Successful)
|
||||||
|
{
|
||||||
|
var retrievalRequest = new RetrievalRequest
|
||||||
|
{
|
||||||
|
LatestUserPromptType = lastPrompt.ToERIContentType,
|
||||||
|
LatestUserPrompt = lastPrompt switch
|
||||||
|
{
|
||||||
|
ContentText text => text.Text,
|
||||||
|
ContentImage image => await image.AsBase64(token),
|
||||||
|
_ => string.Empty
|
||||||
|
},
|
||||||
|
|
||||||
|
Thread = await thread.ToERIChatThread(token),
|
||||||
|
MaxMatches = 10,
|
||||||
|
RetrievalProcessId = null, // The ERI server selects the retrieval process when multiple processes are available
|
||||||
|
Parameters = null, // The ERI server selects useful default parameters
|
||||||
|
};
|
||||||
|
|
||||||
|
var retrievalResponse = await eriClient.ExecuteRetrievalAsync(retrievalRequest, token);
|
||||||
|
if(retrievalResponse is { Successful: true, Data: not null })
|
||||||
|
{
|
||||||
|
//
|
||||||
|
// Next, we have to transform the ERI context back to our generic retrieval context:
|
||||||
|
//
|
||||||
|
var genericRetrievalContexts = new List<IRetrievalContext>(retrievalResponse.Data.Count);
|
||||||
|
foreach (var eriContext in retrievalResponse.Data)
|
||||||
|
{
|
||||||
|
switch (eriContext.Type)
|
||||||
|
{
|
||||||
|
case ContentType.TEXT:
|
||||||
|
genericRetrievalContexts.Add(new RetrievalTextContext
|
||||||
|
{
|
||||||
|
Path = eriContext.Path ?? string.Empty,
|
||||||
|
Type = eriContext.ToRetrievalContentType(),
|
||||||
|
Links = eriContext.Links,
|
||||||
|
Category = RetrievalContentCategory.TEXT,
|
||||||
|
MatchedText = eriContext.MatchedContent,
|
||||||
|
DataSourceName = eriContext.Name,
|
||||||
|
SurroundingContent = eriContext.SurroundingContent,
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
|
||||||
|
case ContentType.IMAGE:
|
||||||
|
genericRetrievalContexts.Add(new RetrievalImageContext
|
||||||
|
{
|
||||||
|
Path = eriContext.Path ?? string.Empty,
|
||||||
|
Type = eriContext.ToRetrievalContentType(),
|
||||||
|
Links = eriContext.Links,
|
||||||
|
Source = eriContext.MatchedContent,
|
||||||
|
Category = RetrievalContentCategory.IMAGE,
|
||||||
|
SourceType = ContentImageSource.BASE64,
|
||||||
|
DataSourceName = eriContext.Name,
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
logger.LogWarning($"The ERI context type '{eriContext.Type}' is not supported yet.");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return genericRetrievalContexts;
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.LogWarning($"Was not able to retrieve data from the ERI data source '{this.Name}'. Message: {retrievalResponse.Message}");
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.LogWarning($"Was not able to authenticate with the ERI data source '{this.Name}'. Message: {authResponse.Message}");
|
||||||
|
return [];
|
||||||
|
}
|
||||||
}
|
}
|
@ -1,3 +1,6 @@
|
|||||||
|
using AIStudio.Chat;
|
||||||
|
using AIStudio.Tools.RAG;
|
||||||
|
|
||||||
namespace AIStudio.Settings.DataModel;
|
namespace AIStudio.Settings.DataModel;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
@ -27,6 +30,13 @@ public readonly record struct DataSourceLocalDirectory : IInternalDataSource
|
|||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public DataSourceSecurity SecurityPolicy { get; init; } = DataSourceSecurity.NOT_SPECIFIED;
|
public DataSourceSecurity SecurityPolicy { get; init; } = DataSourceSecurity.NOT_SPECIFIED;
|
||||||
|
|
||||||
|
/// <inheritdoc />
|
||||||
|
public Task<IReadOnlyList<IRetrievalContext>> RetrieveDataAsync(IContent lastPrompt, ChatThread thread, CancellationToken token = default)
|
||||||
|
{
|
||||||
|
IReadOnlyList<IRetrievalContext> retrievalContext = new List<IRetrievalContext>();
|
||||||
|
return Task.FromResult(retrievalContext);
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// The path to the directory.
|
/// The path to the directory.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
using AIStudio.Chat;
|
||||||
|
using AIStudio.Tools.RAG;
|
||||||
|
|
||||||
namespace AIStudio.Settings.DataModel;
|
namespace AIStudio.Settings.DataModel;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
@ -27,6 +30,13 @@ public readonly record struct DataSourceLocalFile : IInternalDataSource
|
|||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public DataSourceSecurity SecurityPolicy { get; init; } = DataSourceSecurity.NOT_SPECIFIED;
|
public DataSourceSecurity SecurityPolicy { get; init; } = DataSourceSecurity.NOT_SPECIFIED;
|
||||||
|
|
||||||
|
/// <inheritdoc />
|
||||||
|
public Task<IReadOnlyList<IRetrievalContext>> RetrieveDataAsync(IContent lastPrompt, ChatThread thread, CancellationToken token = default)
|
||||||
|
{
|
||||||
|
IReadOnlyList<IRetrievalContext> retrievalContext = new List<IRetrievalContext>();
|
||||||
|
return Task.FromResult(retrievalContext);
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// The path to the file.
|
/// The path to the file.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
using System.Text.Json.Serialization;
|
using System.Text.Json.Serialization;
|
||||||
|
|
||||||
|
using AIStudio.Chat;
|
||||||
using AIStudio.Settings.DataModel;
|
using AIStudio.Settings.DataModel;
|
||||||
|
using AIStudio.Tools.RAG;
|
||||||
|
|
||||||
namespace AIStudio.Settings;
|
namespace AIStudio.Settings;
|
||||||
|
|
||||||
@ -37,4 +39,13 @@ public interface IDataSource
|
|||||||
/// Which data security policy is applied to this data source?
|
/// Which data security policy is applied to this data source?
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public DataSourceSecurity SecurityPolicy { get; init; }
|
public DataSourceSecurity SecurityPolicy { get; init; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Perform the data retrieval process.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="lastPrompt">The last prompt from the chat.</param>
|
||||||
|
/// <param name="thread">The chat thread.</param>
|
||||||
|
/// <param name="token">The cancellation token.</param>
|
||||||
|
/// <returns>The retrieved data context.</returns>
|
||||||
|
public Task<IReadOnlyList<IRetrievalContext>> RetrieveDataAsync(IContent lastPrompt, ChatThread thread, CancellationToken token = default);
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user