From 05e1c6633056ef3de080d58ceb93faccd21c0def Mon Sep 17 00:00:00 2001 From: Thorsten Sommer Date: Fri, 14 Feb 2025 12:55:15 +0100 Subject: [PATCH] Refactored data service to handle different provider types --- .../Tools/Services/DataSourceService.cs | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/app/MindWork AI Studio/Tools/Services/DataSourceService.cs b/app/MindWork AI Studio/Tools/Services/DataSourceService.cs index 43c44b02..de593088 100644 --- a/app/MindWork AI Studio/Tools/Services/DataSourceService.cs +++ b/app/MindWork AI Studio/Tools/Services/DataSourceService.cs @@ -1,4 +1,6 @@ using AIStudio.Assistants.ERI; +using AIStudio.Provider; +using AIStudio.Provider.SelfHosted; using AIStudio.Settings; using AIStudio.Settings.DataModel; using AIStudio.Tools.ERIClient; @@ -41,6 +43,34 @@ public sealed class DataSourceService return new([], []); } + return await this.GetDataSources(selectedLLMProvider.IsSelfHosted, previousSelectedDataSources); + } + + /// + /// Returns a list of data sources that are allowed for the selected LLM provider. + /// It also returns the data sources selected before when they are still allowed. + /// + /// The selected LLM provider. + /// The data sources selected before. + /// The allowed data sources and the data sources selected before -- when they are still allowed. + public async Task GetDataSources(IProvider selectedLLMProvider, IReadOnlyCollection? previousSelectedDataSources = null) + { + // + // Case: Somehow the selected LLM provider was not set. The default provider + // does not mean anything. We cannot filter the data sources by any means. + // We return an empty list. Better safe than sorry. + // + if (selectedLLMProvider is NoProvider) + { + this.logger.LogWarning("The selected LLM provider is the default provider. We cannot filter the data sources by any means."); + return new([], []); + } + + return await this.GetDataSources(selectedLLMProvider is ProviderSelfHosted, previousSelectedDataSources); + } + + private async Task GetDataSources(bool usingSelfHostedProvider, IReadOnlyCollection? previousSelectedDataSources = null) + { var allDataSources = this.settingsManager.ConfigurationData.DataSources; var filteredDataSources = new List(allDataSources.Count); var filteredSelectedDataSources = new List(previousSelectedDataSources?.Count ?? 0); @@ -48,7 +78,7 @@ public sealed class DataSourceService // Start all checks in parallel: foreach (var source in allDataSources) - tasks.Add(this.CheckOneDataSource(source, selectedLLMProvider)); + tasks.Add(this.CheckOneDataSource(source, usingSelfHostedProvider)); // Wait for all checks and collect the results: foreach (var task in tasks) @@ -65,7 +95,7 @@ public sealed class DataSourceService return new(filteredDataSources, filteredSelectedDataSources); } - private async Task CheckOneDataSource(IDataSource source, AIStudio.Settings.Provider selectedLLMProvider) + private async Task CheckOneDataSource(IDataSource source, bool usingSelfHostedProvider) { // // Unfortunately, we have to live-check any ERI source for its security requirements. @@ -110,7 +140,7 @@ public sealed class DataSourceService // Case: The data source allows any provider type. We want to use a self-hosted provider. // There is no issue with this source. Accept it. // - if(selectedLLMProvider.IsSelfHosted) + if(usingSelfHostedProvider) return source; // @@ -148,14 +178,14 @@ public sealed class DataSourceService // Case: The data source requires a self-hosted provider. We want to use a self-hosted provider. // There is no issue with this source. Accept it. // - case DataSourceSecurity.SELF_HOSTED when selectedLLMProvider.IsSelfHosted: + case DataSourceSecurity.SELF_HOSTED when usingSelfHostedProvider: return source; // // Case: The data source requires a self-hosted provider. We want to use a cloud provider. // We skip this source. // - case DataSourceSecurity.SELF_HOSTED when !selectedLLMProvider.IsSelfHosted: + case DataSourceSecurity.SELF_HOSTED when !usingSelfHostedProvider: this.logger.LogWarning($"The data source '{source.Name}' (id={source.Id}) requires a self-hosted provider. We skip this source."); return null;