diff --git a/app/MindWork AI Studio/Provider/HuggingFace/ProviderHuggingFace.cs b/app/MindWork AI Studio/Provider/HuggingFace/ProviderHuggingFace.cs new file mode 100644 index 00000000..d29e5b32 --- /dev/null +++ b/app/MindWork AI Studio/Provider/HuggingFace/ProviderHuggingFace.cs @@ -0,0 +1,110 @@ +using System.Net.Http.Headers; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; + +using AIStudio.Chat; +using AIStudio.Provider.OpenAI; +using AIStudio.Settings; + +namespace AIStudio.Provider.HuggingFace; + +public sealed class ProviderHuggingFace(ILogger logger) : BaseProvider("https://router.huggingface.co/cerebras/v1/", logger) +{ + #region Implementation of IProvider + + /// + public override string Id => LLMProviders.HUGGINGFACE.ToName(); + + /// + public override string InstanceName { get; set; } = "HuggingFace"; + + /// + public override async IAsyncEnumerable StreamChatCompletion(Model chatModel, ChatThread chatThread, SettingsManager settingsManager, [EnumeratorCancellation] CancellationToken token = default) + { + // Get the API key: + var requestedSecret = await RUST_SERVICE.GetAPIKey(this); + if(!requestedSecret.Success) + yield break; + + // Prepare the system prompt: + var systemPrompt = new Message + { + Role = "system", + Content = chatThread.PrepareSystemPrompt(settingsManager, chatThread, this.logger), + }; + + // Prepare the HuggingFace HTTP chat request: + var huggingfaceChatRequest = JsonSerializer.Serialize(new ChatRequest + { + Model = chatModel.Id, + + // Build the messages: + // - First of all the system prompt + // - Then none-empty user and AI messages + Messages = [systemPrompt, ..chatThread.Blocks.Where(n => n.ContentType is ContentType.TEXT && !string.IsNullOrWhiteSpace((n.Content as ContentText)?.Text)).Select(n => new Message + { + Role = n.Role switch + { + ChatRole.USER => "user", + ChatRole.AI => "assistant", + ChatRole.AGENT => "assistant", + ChatRole.SYSTEM => "system", + + _ => "user", + }, + + Content = n.Content switch + { + ContentText text => text.Text, + _ => string.Empty, + } + }).ToList()], + Stream = true, + }, JSON_SERIALIZER_OPTIONS); + + async Task RequestBuilder() + { + // Build the HTTP post request: + var request = new HttpRequestMessage(HttpMethod.Post, "chat/completions"); + + // Set the authorization header: + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await requestedSecret.Secret.Decrypt(ENCRYPTION)); + + // Set the content: + request.Content = new StringContent(huggingfaceChatRequest, Encoding.UTF8, "application/json"); + return request; + } + + await foreach (var content in this.StreamChatCompletionInternal("HuggingFace", RequestBuilder, token)) + yield return content; + } + + #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + /// + public override async IAsyncEnumerable StreamImageCompletion(Model imageModel, string promptPositive, string promptNegative = FilterOperator.String.Empty, ImageURL referenceImageURL = default, [EnumeratorCancellation] CancellationToken token = default) + { + yield break; + } + #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + + /// + public override Task> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) + { + return Task.FromResult(Enumerable.Empty()); + } + + /// + public override Task> GetImageModels(string? apiKeyProvisional = null, CancellationToken token = default) + { + return Task.FromResult(Enumerable.Empty()); + } + + /// + public override Task> GetEmbeddingModels(string? apiKeyProvisional = null, CancellationToken token = default) + { + return Task.FromResult(Enumerable.Empty()); + } + + #endregion +} \ No newline at end of file diff --git a/app/MindWork AI Studio/Provider/LLMProviders.cs b/app/MindWork AI Studio/Provider/LLMProviders.cs index 1c65835f..118d68aa 100644 --- a/app/MindWork AI Studio/Provider/LLMProviders.cs +++ b/app/MindWork AI Studio/Provider/LLMProviders.cs @@ -17,6 +17,7 @@ public enum LLMProviders FIREWORKS = 5, GROQ = 6, + HUGGINGFACE = 13, SELF_HOSTED = 4, diff --git a/app/MindWork AI Studio/Provider/LLMProvidersExtensions.cs b/app/MindWork AI Studio/Provider/LLMProvidersExtensions.cs index c516ce7d..3a2a69d8 100644 --- a/app/MindWork AI Studio/Provider/LLMProvidersExtensions.cs +++ b/app/MindWork AI Studio/Provider/LLMProvidersExtensions.cs @@ -6,6 +6,7 @@ using AIStudio.Provider.Google; using AIStudio.Provider.Groq; using AIStudio.Provider.GWDG; using AIStudio.Provider.Helmholtz; +using AIStudio.Provider.HuggingFace; using AIStudio.Provider.Mistral; using AIStudio.Provider.OpenAI; using AIStudio.Provider.SelfHosted; @@ -37,6 +38,7 @@ public static class LLMProvidersExtensions LLMProviders.GROQ => "Groq", LLMProviders.FIREWORKS => "Fireworks.ai", + LLMProviders.HUGGINGFACE => "Hugging Face", LLMProviders.SELF_HOSTED => "Self-hosted", @@ -58,6 +60,9 @@ public static class LLMProvidersExtensions LLMProviders.FIREWORKS => Confidence.USA_NOT_TRUSTED.WithRegion("America, U.S.").WithSources("https://fireworks.ai/terms-of-service").WithLevel(settingsManager.GetConfiguredConfidenceLevel(llmProvider)), + // Not trusted, because huggingface only routes you to a third-party-provider and we can't make sure they do not use your data + LLMProviders.HUGGINGFACE => Confidence.USA_NOT_TRUSTED.WithRegion("America, U.S.").WithSources("https://huggingface.co/terms-of-service").WithLevel(settingsManager.GetConfiguredConfidenceLevel(llmProvider)), + LLMProviders.OPEN_AI => Confidence.USA_NO_TRAINING.WithRegion("America, U.S.").WithSources( "https://platform.openai.com/docs/models/default-usage-policies-by-endpoint", "https://openai.com/policies/terms-of-use/", @@ -112,6 +117,7 @@ public static class LLMProvidersExtensions LLMProviders.X => false, LLMProviders.GWDG => false, LLMProviders.DEEP_SEEK => false, + LLMProviders.HUGGINGFACE => false, // // Self-hosted providers are treated as a special case anyway. @@ -159,6 +165,7 @@ public static class LLMProvidersExtensions LLMProviders.GROQ => new ProviderGroq(logger) { InstanceName = instanceName }, LLMProviders.FIREWORKS => new ProviderFireworks(logger) { InstanceName = instanceName }, + LLMProviders.HUGGINGFACE => new ProviderHuggingFace(logger) { InstanceName = instanceName }, LLMProviders.SELF_HOSTED => new ProviderSelfHosted(logger, host, hostname) { InstanceName = instanceName }, @@ -187,6 +194,7 @@ public static class LLMProvidersExtensions LLMProviders.GROQ => "https://console.groq.com/", LLMProviders.FIREWORKS => "https://fireworks.ai/login", + LLMProviders.HUGGINGFACE => "https://huggingface.co/login", LLMProviders.HELMHOLTZ => "https://sdlaml.pages.jsc.fz-juelich.de/ai/guides/blablador_api_access/#step-1-register-on-gitlab", LLMProviders.GWDG => "https://docs.hpc.gwdg.de/services/saia/index.html#api-request", @@ -205,6 +213,7 @@ public static class LLMProvidersExtensions LLMProviders.FIREWORKS => "https://fireworks.ai/account/billing", LLMProviders.DEEP_SEEK => "https://platform.deepseek.com/usage", LLMProviders.ALIBABA_CLOUD => "https://usercenter2-intl.aliyun.com/billing", + LLMProviders.HUGGINGFACE => "https://huggingface.co/settings/billing", _ => string.Empty, }; @@ -220,6 +229,7 @@ public static class LLMProvidersExtensions LLMProviders.GOOGLE => true, LLMProviders.DEEP_SEEK => true, LLMProviders.ALIBABA_CLOUD => true, + LLMProviders.HUGGINGFACE => true, _ => false, }; @@ -227,12 +237,14 @@ public static class LLMProvidersExtensions public static string GetModelsOverviewURL(this LLMProviders provider) => provider switch { LLMProviders.FIREWORKS => "https://fireworks.ai/models?show=Serverless", + LLMProviders.HUGGINGFACE => "https://huggingface.co/models?inference_provider=all", _ => string.Empty, }; public static bool IsLLMModelProvidedManually(this LLMProviders provider) => provider switch { LLMProviders.FIREWORKS => true, + LLMProviders.HUGGINGFACE => true, _ => false, }; @@ -268,6 +280,7 @@ public static class LLMProvidersExtensions LLMProviders.FIREWORKS => true, LLMProviders.HELMHOLTZ => true, LLMProviders.GWDG => true, + LLMProviders.HUGGINGFACE => true, LLMProviders.SELF_HOSTED => host is Host.OLLAMA, @@ -288,6 +301,7 @@ public static class LLMProvidersExtensions LLMProviders.FIREWORKS => true, LLMProviders.HELMHOLTZ => true, LLMProviders.GWDG => true, + LLMProviders.HUGGINGFACE => true, _ => false, };