// ReSharper disable InconsistentNaming
using AIStudio.Assistants.ERI;
using AIStudio.Chat;
using AIStudio.Tools.ERIClient;
using AIStudio.Tools.ERIClient.DataModel;
using AIStudio.Tools.RAG;
using AIStudio.Tools.Services;
using ChatThread = AIStudio.Chat.ChatThread;
using ContentType = AIStudio.Tools.ERIClient.DataModel.ContentType;
namespace AIStudio.Settings.DataModel;
///
/// An external data source, accessed via an ERI server, cf. https://github.com/MindWorkAI/ERI.
///
public readonly record struct DataSourceERI_V1 : IERIDataSource
{
public DataSourceERI_V1()
{
}
///
public uint Num { get; init; }
///
public string Id { get; init; } = Guid.Empty.ToString();
///
public string Name { get; init; } = string.Empty;
///
public DataSourceType Type { get; init; } = DataSourceType.NONE;
///
public string Hostname { get; init; } = string.Empty;
///
public int Port { get; init; }
///
public AuthMethod AuthMethod { get; init; } = AuthMethod.NONE;
///
public string Username { get; init; } = string.Empty;
///
public DataSourceSecurity SecurityPolicy { get; init; } = DataSourceSecurity.NOT_SPECIFIED;
///
public ERIVersion Version { get; init; } = ERIVersion.V1;
///
public async Task> RetrieveDataAsync(IContent lastPrompt, ChatThread thread, CancellationToken token = default)
{
// Important: Do not dispose the RustService here, as it is a singleton.
var rustService = Program.SERVICE_PROVIDER.GetRequiredService();
var logger = Program.SERVICE_PROVIDER.GetRequiredService>();
using var eriClient = ERIClientFactory.Get(this.Version, this)!;
var authResponse = await eriClient.AuthenticateAsync(this, rustService, token);
if (authResponse.Successful)
{
var retrievalRequest = new RetrievalRequest
{
LatestUserPromptType = lastPrompt.ToERIContentType,
LatestUserPrompt = lastPrompt switch
{
ContentText text => text.Text,
ContentImage image => await image.AsBase64(token),
_ => string.Empty
},
Thread = await thread.ToERIChatThread(token),
MaxMatches = 10,
RetrievalProcessId = null, // The ERI server selects the retrieval process when multiple processes are available
Parameters = null, // The ERI server selects useful default parameters
};
var retrievalResponse = await eriClient.ExecuteRetrievalAsync(retrievalRequest, token);
if(retrievalResponse is { Successful: true, Data: not null })
{
//
// Next, we have to transform the ERI context back to our generic retrieval context:
//
var genericRetrievalContexts = new List(retrievalResponse.Data.Count);
foreach (var eriContext in retrievalResponse.Data)
{
switch (eriContext.Type)
{
case ContentType.TEXT:
genericRetrievalContexts.Add(new RetrievalTextContext
{
Path = eriContext.Path ?? string.Empty,
Type = eriContext.ToRetrievalContentType(),
Links = eriContext.Links,
Category = eriContext.Type.ToRetrievalContentCategory(),
MatchedText = eriContext.MatchedContent,
DataSourceName = eriContext.Name,
SurroundingContent = eriContext.SurroundingContent,
});
break;
case ContentType.IMAGE:
genericRetrievalContexts.Add(new RetrievalImageContext
{
Path = eriContext.Path ?? string.Empty,
Type = eriContext.ToRetrievalContentType(),
Links = eriContext.Links,
Source = eriContext.MatchedContent,
Category = eriContext.Type.ToRetrievalContentCategory(),
SourceType = ContentImageSource.BASE64,
DataSourceName = eriContext.Name,
});
break;
default:
logger.LogWarning($"The ERI context type '{eriContext.Type}' is not supported yet.");
break;
}
}
return genericRetrievalContexts;
}
logger.LogWarning($"Was not able to retrieve data from the ERI data source '{this.Name}'. Message: {retrievalResponse.Message}");
return [];
}
logger.LogWarning($"Was not able to authenticate with the ERI data source '{this.Name}'. Message: {authResponse.Message}");
return [];
}
}