// ReSharper disable InconsistentNaming using AIStudio.Assistants.ERI; using AIStudio.Chat; using AIStudio.Tools.ERIClient; using AIStudio.Tools.ERIClient.DataModel; using AIStudio.Tools.PluginSystem; using AIStudio.Tools.RAG; using AIStudio.Tools.Services; using Lua; using ChatThread = AIStudio.Chat.ChatThread; using ContentType = AIStudio.Tools.ERIClient.DataModel.ContentType; namespace AIStudio.Settings.DataModel; /// /// An external data source, accessed via an ERI server, cf. https://github.com/MindWorkAI/ERI. /// public readonly record struct DataSourceERI_V1 : IERIDataSource { private static readonly ILogger LOGGER = Program.LOGGER_FACTORY.CreateLogger(); public DataSourceERI_V1() { } /// public uint Num { get; init; } /// public string Id { get; init; } = Guid.Empty.ToString(); /// public string Name { get; init; } = string.Empty; /// public DataSourceType Type { get; init; } = DataSourceType.NONE; /// public string Hostname { get; init; } = string.Empty; /// public int Port { get; init; } /// public AuthMethod AuthMethod { get; init; } = AuthMethod.NONE; /// public string Username { get; init; } = string.Empty; /// public DataSourceERIUsernamePasswordMode UsernamePasswordMode { get; init; } = DataSourceERIUsernamePasswordMode.USER_MANAGED; /// public DataSourceSecurity SecurityPolicy { get; init; } = DataSourceSecurity.NOT_SPECIFIED; /// public bool IsEnterpriseConfiguration { get; init; } /// public Guid EnterpriseConfigurationPluginId { get; init; } = Guid.Empty; /// public ERIVersion Version { get; init; } = ERIVersion.V1; /// public string SelectedRetrievalId { get; init; } = string.Empty; /// public ushort MaxMatches { get; init; } = 10; /// public async Task> RetrieveDataAsync(IContent lastUserPrompt, ChatThread thread, CancellationToken token = default) { // Important: Do not dispose the RustService here, as it is a singleton. var rustService = Program.SERVICE_PROVIDER.GetRequiredService(); var logger = Program.SERVICE_PROVIDER.GetRequiredService>(); using var eriClient = ERIClientFactory.Get(this.Version, this)!; var authResponse = await eriClient.AuthenticateAsync(rustService, cancellationToken: token); if (authResponse.Successful) { var retrievalRequest = new RetrievalRequest { LatestUserPromptType = lastUserPrompt.ToERIContentType, LatestUserPrompt = lastUserPrompt switch { ContentText text => text.Text, ContentImage image => await image.TryAsBase64(token) is (success: true, { } base64Image) ? base64Image : string.Empty, _ => string.Empty }, Thread = await thread.ToERIChatThread(token), MaxMatches = this.MaxMatches, RetrievalProcessId = this.SelectedRetrievalId, Parameters = null, // The ERI server selects useful default parameters }; var retrievalResponse = await eriClient.ExecuteRetrievalAsync(retrievalRequest, token); if(retrievalResponse is { Successful: true, Data: not null }) { // // Next, we have to transform the ERI context back to our generic retrieval context: // var genericRetrievalContexts = new List(retrievalResponse.Data.Count); foreach (var eriContext in retrievalResponse.Data) { switch (eriContext.Type) { case ContentType.TEXT: genericRetrievalContexts.Add(new RetrievalTextContext { Path = eriContext.Path ?? string.Empty, Type = eriContext.ToRetrievalContentType(), Links = eriContext.Links, Category = eriContext.Type.ToRetrievalContentCategory(), MatchedText = eriContext.MatchedContent, DataSourceName = eriContext.Name, SurroundingContent = eriContext.SurroundingContent, }); break; case ContentType.IMAGE: genericRetrievalContexts.Add(new RetrievalImageContext { Path = eriContext.Path ?? string.Empty, Type = eriContext.ToRetrievalContentType(), Links = eriContext.Links, Source = eriContext.MatchedContent, Category = eriContext.Type.ToRetrievalContentCategory(), SourceType = ContentImageSource.BASE64, DataSourceName = eriContext.Name, }); break; default: logger.LogWarning($"The ERI context type '{eriContext.Type}' is not supported yet."); break; } } return genericRetrievalContexts; } logger.LogWarning($"Was not able to retrieve data from the ERI data source '{this.Name}'. Message: {retrievalResponse.Message}"); return []; } logger.LogWarning($"Was not able to authenticate with the ERI data source '{this.Name}'. Message: {authResponse.Message}"); return []; } public static bool TryParseConfiguration(int idx, LuaTable table, Guid configPluginId, out DataSourceERI_V1 dataSource) { dataSource = default; if (!table.TryGetValue("Id", out var idValue) || !idValue.TryRead(out var idText) || !Guid.TryParse(idText, out var id)) { LOGGER.LogWarning($"The configured data source {idx} does not contain a valid ID. The ID must be a valid GUID. (Plugin ID: {configPluginId})"); return false; } if (!table.TryGetValue("Name", out var nameValue) || !nameValue.TryRead(out var name) || string.IsNullOrWhiteSpace(name)) { LOGGER.LogWarning($"The configured data source {idx} does not contain a valid name. (Plugin ID: {configPluginId})"); return false; } if (!table.TryGetValue("Type", out var typeValue) || !typeValue.TryRead(out var typeText) || !Enum.TryParse(typeText, true, out var type) || type is not DataSourceType.ERI_V1) { LOGGER.LogWarning($"The configured data source {idx} does not contain a supported data source type. Only ERI_V1 is supported. (Plugin ID: {configPluginId})"); return false; } if (!table.TryGetValue("Hostname", out var hostnameValue) || !hostnameValue.TryRead(out var hostname) || string.IsNullOrWhiteSpace(hostname)) { LOGGER.LogWarning($"The configured data source {idx} does not contain a valid hostname. (Plugin ID: {configPluginId})"); return false; } if (!table.TryGetValue("Port", out var portValue) || !portValue.TryRead(out var port) || port is < 1 or > 65535) { LOGGER.LogWarning($"The configured data source {idx} does not contain a valid port. (Plugin ID: {configPluginId})"); return false; } if (!table.TryGetValue("AuthMethod", out var authMethodValue) || !authMethodValue.TryRead(out var authMethodText) || !Enum.TryParse(authMethodText, true, out var authMethod)) { LOGGER.LogWarning($"The configured data source {idx} does not contain a valid auth method. (Plugin ID: {configPluginId})"); return false; } if (!table.TryGetValue("SecurityPolicy", out var securityPolicyValue) || !securityPolicyValue.TryRead(out var securityPolicyText) || !Enum.TryParse(securityPolicyText, true, out var securityPolicy)) { LOGGER.LogWarning($"The configured data source {idx} does not contain a valid security policy. (Plugin ID: {configPluginId})"); return false; } if (securityPolicy is DataSourceSecurity.NOT_SPECIFIED) { LOGGER.LogWarning($"The configured data source {idx} must specify a security policy. (Plugin ID: {configPluginId})"); return false; } if (!table.TryGetValue("SelectedRetrievalId", out var selectedRetrievalIdValue) || !selectedRetrievalIdValue.TryRead(out var selectedRetrievalId) || string.IsNullOrWhiteSpace(selectedRetrievalId)) { LOGGER.LogWarning($"The configured data source {idx} must specify a selected retrieval ID. (Plugin ID: {configPluginId})"); return false; } if (!table.TryGetValue("MaxMatches", out var maxMatchesValue) || !maxMatchesValue.TryRead(out var maxMatches) || maxMatches is < 1 or > ushort.MaxValue) { LOGGER.LogWarning($"The configured data source {idx} does not contain a valid maximum number of matches. (Plugin ID: {configPluginId})"); return false; } var username = string.Empty; var usernamePasswordMode = DataSourceERIUsernamePasswordMode.USER_MANAGED; if (table.TryGetValue("UsernamePasswordMode", out var usernamePasswordModeValue) && usernamePasswordModeValue.TryRead(out var usernamePasswordModeText)) { if (!Enum.TryParse(usernamePasswordModeText, true, out usernamePasswordMode)) { LOGGER.LogWarning($"The configured data source {idx} does not contain a valid username/password mode. (Plugin ID: {configPluginId})"); return false; } if (usernamePasswordMode is DataSourceERIUsernamePasswordMode.USER_MANAGED) { LOGGER.LogWarning($"The configured data source {idx} uses the user-managed username/password mode. This mode is not allowed in configuration plugins. (Plugin ID: {configPluginId})"); return false; } } if (authMethod is AuthMethod.USERNAME_PASSWORD) { if (!table.TryGetValue("UsernamePasswordMode", out _) || usernamePasswordMode is DataSourceERIUsernamePasswordMode.USER_MANAGED) { LOGGER.LogWarning($"The configured data source {idx} must specify an organization-managed username/password mode. (Plugin ID: {configPluginId})"); return false; } if (usernamePasswordMode is DataSourceERIUsernamePasswordMode.SHARED_USERNAME_AND_PASSWORD && (!table.TryGetValue("Username", out var usernameValue) || !usernameValue.TryRead(out username) || string.IsNullOrWhiteSpace(username))) { LOGGER.LogWarning($"The configured data source {idx} must specify a username. (Plugin ID: {configPluginId})"); return false; } } dataSource = new DataSourceERI_V1 { Num = 0, Id = id.ToString(), Name = name, Type = DataSourceType.ERI_V1, Hostname = CleanHostname(hostname), Port = port, AuthMethod = authMethod, Username = username, UsernamePasswordMode = usernamePasswordMode, SecurityPolicy = securityPolicy, Version = ERIVersion.V1, SelectedRetrievalId = selectedRetrievalId, MaxMatches = (ushort)maxMatches, IsEnterpriseConfiguration = true, EnterpriseConfigurationPluginId = configPluginId, }; return TryQueueEnterpriseSecret(idx, table, configPluginId, dataSource); } /// /// Exports the ERI v1 data source configuration as a Lua configuration section. /// /// Optional encrypted token or password to include in the export. /// The organization-managed username/password mode to export. /// A Lua configuration section string. public string ExportAsConfigurationSection(string? encryptedSecret = null, DataSourceERIUsernamePasswordMode usernamePasswordMode = DataSourceERIUsernamePasswordMode.USER_MANAGED) { var secretLine = string.Empty; var usernamePasswordModeLine = string.Empty; var usernameLine = string.Empty; switch (this.AuthMethod) { case AuthMethod.TOKEN: secretLine = CreateSecretLine("Token", encryptedSecret); break; case AuthMethod.USERNAME_PASSWORD: if (usernamePasswordMode is DataSourceERIUsernamePasswordMode.USER_MANAGED) usernamePasswordMode = DataSourceERIUsernamePasswordMode.OS_USERNAME_SHARED_PASSWORD; usernamePasswordModeLine = $""" ["UsernamePasswordMode"] = "{usernamePasswordMode}", """; if (usernamePasswordMode is DataSourceERIUsernamePasswordMode.SHARED_USERNAME_AND_PASSWORD) { var username = string.IsNullOrWhiteSpace(this.Username) ? "" : this.Username; usernameLine = $""" ["Username"] = "{LuaTools.EscapeLuaString(username)}", """; } secretLine = CreateSecretLine("Password", encryptedSecret); break; } return $$""" CONFIG["DATA_SOURCES"][#CONFIG["DATA_SOURCES"]+1] = { ["Id"] = "{{Guid.NewGuid().ToString()}}", ["Name"] = "{{LuaTools.EscapeLuaString(this.Name)}}", ["Type"] = "ERI_V1", ["Hostname"] = "{{LuaTools.EscapeLuaString(this.Hostname)}}", ["Port"] = {{this.Port}}, ["AuthMethod"] = "{{this.AuthMethod}}", {{usernamePasswordModeLine}} {{usernameLine}} {{secretLine}} ["SecurityPolicy"] = "{{this.SecurityPolicy}}", ["SelectedRetrievalId"] = "{{LuaTools.EscapeLuaString(this.SelectedRetrievalId)}}", ["MaxMatches"] = {{this.MaxMatches}}, } """; } private static bool TryQueueEnterpriseSecret(int idx, LuaTable table, Guid configPluginId, DataSourceERI_V1 dataSource) { var secretFieldName = dataSource.AuthMethod switch { AuthMethod.TOKEN => "Token", AuthMethod.USERNAME_PASSWORD => "Password", _ => string.Empty, }; if (string.IsNullOrWhiteSpace(secretFieldName)) return true; if (!table.TryGetValue(secretFieldName, out var secretValue) || !secretValue.TryRead(out var encryptedSecret) || string.IsNullOrWhiteSpace(encryptedSecret)) { LOGGER.LogWarning($"The configured data source {idx} does not contain a valid encrypted {secretFieldName}. (Plugin ID: {configPluginId})"); return false; } if (!EnterpriseEncryption.IsEncrypted(encryptedSecret)) { LOGGER.LogWarning($"The configured data source {idx} contains a plaintext {secretFieldName}. Only encrypted secrets (starting with 'ENC:v1:') are supported. (Plugin ID: {configPluginId})"); return false; } var encryption = PluginFactory.EnterpriseEncryption; if (encryption?.IsAvailable != true) { LOGGER.LogWarning($"The configured data source {idx} contains an encrypted {secretFieldName}, but no encryption secret is configured. (Plugin ID: {configPluginId})"); return false; } if (!encryption.TryDecrypt(encryptedSecret, out var decryptedSecret)) { LOGGER.LogWarning($"Failed to decrypt the {secretFieldName} for data source {idx}. The encryption secret may be incorrect. (Plugin ID: {configPluginId})"); return false; } PendingEnterpriseSecrets.Add(new( $"{ISecretId.ENTERPRISE_KEY_PREFIX}::{dataSource.Id}", dataSource.Name, decryptedSecret, SecretStoreType.DATA_SOURCE)); LOGGER.LogDebug($"Successfully decrypted the {secretFieldName} for data source {idx}. It will be stored in the OS keyring. (Plugin ID: {configPluginId})"); return true; } private static string CreateSecretLine(string fieldName, string? encryptedSecret) { if (string.IsNullOrWhiteSpace(encryptedSecret)) return string.Empty; return $""" ["{fieldName}"] = "{LuaTools.EscapeLuaString(encryptedSecret)}", """; } private static string CleanHostname(string hostname) { var cleanedHostname = hostname.Trim(); return cleanedHostname.EndsWith('/') ? cleanedHostname[..^1] : cleanedHostname; } }