diff --git a/app/MindWork AI Studio/Provider/Anthropic/ProviderAnthropic.cs b/app/MindWork AI Studio/Provider/Anthropic/ProviderAnthropic.cs index 5eb8fe2b..49a0e6ea 100644 --- a/app/MindWork AI Studio/Provider/Anthropic/ProviderAnthropic.cs +++ b/app/MindWork AI Studio/Provider/Anthropic/ProviderAnthropic.cs @@ -29,6 +29,9 @@ public sealed class ProviderAnthropic() : BaseProvider(LLMProviders.ANTHROPIC, " // Parse the API parameters: var apiParameters = this.ParseAdditionalApiParameters("system"); + var maxTokens = 4_096; + if (TryPopIntParameter(apiParameters, "max_tokens", out var parsedMaxTokens)) + maxTokens = parsedMaxTokens; // Build the list of messages: var messages = await chatThread.Blocks.BuildMessagesAsync( @@ -73,7 +76,7 @@ public sealed class ProviderAnthropic() : BaseProvider(LLMProviders.ANTHROPIC, " Messages = [..messages], System = chatThread.PrepareSystemPrompt(settingsManager), - MaxTokens = apiParameters.TryGetValue("max_tokens", out var value) && value is int intValue ? intValue : 4_096, + MaxTokens = maxTokens, // Right now, we only support streaming completions: Stream = true, @@ -188,4 +191,4 @@ public sealed class ProviderAnthropic() : BaseProvider(LLMProviders.ANTHROPIC, " var modelResponse = await response.Content.ReadFromJsonAsync(JSON_SERIALIZER_OPTIONS, token); return modelResponse.Data; } -} \ No newline at end of file +} diff --git a/app/MindWork AI Studio/Provider/BaseProvider.cs b/app/MindWork AI Studio/Provider/BaseProvider.cs index 4acefc62..9b729824 100644 --- a/app/MindWork AI Studio/Provider/BaseProvider.cs +++ b/app/MindWork AI Studio/Provider/BaseProvider.cs @@ -731,7 +731,7 @@ public abstract class BaseProvider : IProvider, ISecretId /// Optional list of keys to remove from the final dictionary /// (case-insensitive). The parameters stream, model, and messages are removed by default. protected IDictionary ParseAdditionalApiParameters( - params List keysToRemove) + params string[] keysToRemove) { if(string.IsNullOrWhiteSpace(this.AdditionalJsonApiParameters)) return new Dictionary(); @@ -744,14 +744,23 @@ public abstract class BaseProvider : IProvider, ISecretId var dict = ConvertToDictionary(jsonDoc); // Some keys are always removed because we set them: - keysToRemove.Add("stream"); - keysToRemove.Add("model"); - keysToRemove.Add("messages"); + var removeSet = new HashSet(StringComparer.OrdinalIgnoreCase); + if (keysToRemove.Length > 0) + removeSet.UnionWith(keysToRemove); + + removeSet.Add("stream"); + removeSet.Add("model"); + removeSet.Add("messages"); // Remove the specified keys (case-insensitive): - var removeSet = new HashSet(keysToRemove, StringComparer.OrdinalIgnoreCase); - foreach (var key in removeSet) - dict.Remove(key); + if (removeSet.Count > 0) + { + foreach (var key in dict.Keys.ToList()) + { + if (removeSet.Contains(key)) + dict.Remove(key); + } + } return dict; } @@ -761,6 +770,85 @@ public abstract class BaseProvider : IProvider, ISecretId return new Dictionary(); } } + + protected static bool TryPopIntParameter(IDictionary parameters, string key, out int value) + { + value = default; + if (!TryPopParameter(parameters, key, out var raw) || raw is null) + return false; + + switch (raw) + { + case int i: + value = i; + return true; + + case long l when l is >= int.MinValue and <= int.MaxValue: + value = (int)l; + return true; + + case double d when d is >= int.MinValue and <= int.MaxValue: + value = (int)d; + return true; + + case decimal m when m is >= int.MinValue and <= int.MaxValue: + value = (int)m; + return true; + } + + return false; + } + + protected static bool TryPopBoolParameter(IDictionary parameters, string key, out bool value) + { + value = default; + if (!TryPopParameter(parameters, key, out var raw) || raw is null) + return false; + + switch (raw) + { + case bool b: + value = b; + return true; + + case string s when bool.TryParse(s, out var parsed): + value = parsed; + return true; + + case int i: + value = i != 0; + return true; + + case long l: + value = l != 0; + return true; + + case double d: + value = Math.Abs(d) > double.Epsilon; + return true; + + case decimal m: + value = m != 0; + return true; + } + + return false; + } + + private static bool TryPopParameter(IDictionary parameters, string key, out object? value) + { + value = null; + if (parameters.Count == 0) + return false; + + var foundKey = parameters.Keys.FirstOrDefault(k => string.Equals(k, key, StringComparison.OrdinalIgnoreCase)); + if (foundKey is null) + return false; + + value = parameters[foundKey]; + parameters.Remove(foundKey); + return true; + } private static IDictionary ConvertToDictionary(JsonElement element) { @@ -785,4 +873,4 @@ public abstract class BaseProvider : IProvider, ISecretId _ => string.Empty, }; -} \ No newline at end of file +} diff --git a/app/MindWork AI Studio/Provider/Mistral/ChatRequest.cs b/app/MindWork AI Studio/Provider/Mistral/ChatRequest.cs index 01a45a89..1d42081f 100644 --- a/app/MindWork AI Studio/Provider/Mistral/ChatRequest.cs +++ b/app/MindWork AI Studio/Provider/Mistral/ChatRequest.cs @@ -14,11 +14,12 @@ public readonly record struct ChatRequest( string Model, IList Messages, bool Stream, - int RandomSeed, + [property: JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + int? RandomSeed, bool SafePrompt = false ) { // Attention: The "required" modifier is not supported for [JsonExtensionData]. [JsonExtensionData] public IDictionary AdditionalApiParameters { get; init; } = new Dictionary(); -} \ No newline at end of file +} diff --git a/app/MindWork AI Studio/Provider/Mistral/ProviderMistral.cs b/app/MindWork AI Studio/Provider/Mistral/ProviderMistral.cs index f4cb07f4..485729fb 100644 --- a/app/MindWork AI Studio/Provider/Mistral/ProviderMistral.cs +++ b/app/MindWork AI Studio/Provider/Mistral/ProviderMistral.cs @@ -36,6 +36,8 @@ public sealed class ProviderMistral() : BaseProvider(LLMProviders.MISTRAL, "http // Parse the API parameters: var apiParameters = this.ParseAdditionalApiParameters(); + var safePrompt = TryPopBoolParameter(apiParameters, "safe_prompt", out var parsedSafePrompt) && parsedSafePrompt; + var randomSeed = TryPopIntParameter(apiParameters, "random_seed", out var parsedRandomSeed) ? parsedRandomSeed : (int?)null; // Build the list of messages: var messages = await chatThread.Blocks.BuildMessagesUsingDirectImageUrlAsync(this.Provider, chatModel); @@ -52,7 +54,8 @@ public sealed class ProviderMistral() : BaseProvider(LLMProviders.MISTRAL, "http // Right now, we only support streaming completions: Stream = true, - SafePrompt = apiParameters.TryGetValue("safe_prompt", out var value) && value is true, + RandomSeed = randomSeed, + SafePrompt = safePrompt, AdditionalApiParameters = apiParameters }, JSON_SERIALIZER_OPTIONS); @@ -165,4 +168,4 @@ public sealed class ProviderMistral() : BaseProvider(LLMProviders.MISTRAL, "http var modelResponse = await response.Content.ReadFromJsonAsync(token); return modelResponse; } -} \ No newline at end of file +}