Added capabilities for some open source models

This commit is contained in:
Thorsten Sommer 2025-05-11 12:50:47 +02:00
parent df6757f396
commit b2b0b11009
Signed by: tsommer
GPG Key ID: 371BBA77A02C0108
9 changed files with 167 additions and 47 deletions

View File

@ -0,0 +1,150 @@
namespace AIStudio.Provider;
public static class CapabilitiesOpenSource
{
public static IReadOnlyCollection<Capability> GetCapabilities(Model model)
{
var modelName = model.Id.ToLowerInvariant().AsSpan();
//
// Checking for names in the case of open source models is a hard task.
// Let's assume we want to check for the llama 3.1 405b model.
//
// Here is a not complete list of how providers name this model:
// - Fireworks: accounts/fireworks/models/llama-v3p1-405b-instruct
// - Hugging Face -> Nebius AI Studio: meta-llama/Meta-Llama-3.1-405B-Instruct
// - Groq: llama-3.1-405b-instruct
// - LM Studio: llama-3.1-405b-instruct
// - Helmholtz Blablador: 1 - Llama3 405 the best general model
// - GWDG: Llama 3.1 405B Instruct
//
//
// Meta llama models:
//
if (modelName.IndexOf("llama") is not -1)
{
if (modelName.IndexOf("llama4") is not -1 ||
modelName.IndexOf("llama 4") is not -1 ||
modelName.IndexOf("llama-4") is not -1 ||
modelName.IndexOf("llama-v4") is not -1)
return
[
Capability.TEXT_INPUT, Capability.MULTIPLE_IMAGE_INPUT,
Capability.TEXT_OUTPUT,
Capability.FUNCTION_CALLING,
];
// The old vision models cannot do function calling:
if (modelName.IndexOf("vision") is not -1)
return [Capability.TEXT_INPUT, Capability.MULTIPLE_IMAGE_INPUT, Capability.TEXT_OUTPUT];
//
// All models >= 3.1 are able to do function calling:
//
if (modelName.IndexOf("llama3.") is not -1 ||
modelName.IndexOf("llama 3.") is not -1 ||
modelName.IndexOf("llama-3.") is not -1 ||
modelName.IndexOf("llama-v3p") is not -1)
return
[
Capability.TEXT_INPUT,
Capability.TEXT_OUTPUT,
Capability.FUNCTION_CALLING,
];
// All other llama models can only do text input and output:
return [Capability.TEXT_INPUT, Capability.TEXT_OUTPUT];
}
//
// DeepSeek models:
//
if (modelName.IndexOf("deepseek") is not -1)
{
if(modelName.IndexOf("deepseek-r1") is not -1 ||
modelName.IndexOf("deepseek r1") is not -1)
return [Capability.TEXT_INPUT, Capability.TEXT_OUTPUT, Capability.ALWAYS_REASONING];
return [Capability.TEXT_INPUT, Capability.TEXT_OUTPUT];
}
//
// Qwen models:
//
if (modelName.IndexOf("qwen") is not -1 || modelName.IndexOf("qwq") is not -1)
{
if (modelName.IndexOf("qwq") is not -1)
return [Capability.TEXT_INPUT, Capability.TEXT_OUTPUT, Capability.ALWAYS_REASONING];
return [Capability.TEXT_INPUT, Capability.TEXT_OUTPUT];
}
//
// Mistral models:
//
if (modelName.IndexOf("mistral") is not -1 ||
modelName.IndexOf("pixtral") is not -1)
{
if(modelName.IndexOf("pixtral") is not -1)
return
[
Capability.TEXT_INPUT, Capability.MULTIPLE_IMAGE_INPUT,
Capability.TEXT_OUTPUT,
Capability.FUNCTION_CALLING
];
if (modelName.IndexOf("3.1") is not -1)
return
[
Capability.TEXT_INPUT, Capability.MULTIPLE_IMAGE_INPUT,
Capability.TEXT_OUTPUT,
Capability.FUNCTION_CALLING
];
// Default:
return
[
Capability.TEXT_INPUT,
Capability.TEXT_OUTPUT,
Capability.FUNCTION_CALLING
];
}
//
// Grok models:
//
if (modelName.IndexOf("grok") is not -1)
{
if(modelName.IndexOf("-vision-") is not -1)
return
[
Capability.TEXT_INPUT, Capability.MULTIPLE_IMAGE_INPUT,
Capability.TEXT_OUTPUT,
];
if(modelName.StartsWith("grok-3-mini"))
return
[
Capability.TEXT_INPUT,
Capability.TEXT_OUTPUT,
Capability.ALWAYS_REASONING, Capability.FUNCTION_CALLING,
];
if(modelName.StartsWith("grok-3"))
return
[
Capability.TEXT_INPUT,
Capability.TEXT_OUTPUT,
Capability.FUNCTION_CALLING,
];
}
// Default:
return [Capability.TEXT_INPUT, Capability.TEXT_OUTPUT];
}
}

View File

@ -107,5 +107,7 @@ public class ProviderFireworks(ILogger logger) : BaseProvider("https://api.firew
return Task.FromResult(Enumerable.Empty<Model>()); return Task.FromResult(Enumerable.Empty<Model>());
} }
public override IReadOnlyCollection<Capability> GetModelCapabilities(Model model) => CapabilitiesOpenSource.GetCapabilities(model);
#endregion #endregion
} }

View File

@ -108,6 +108,8 @@ public sealed class ProviderGWDG(ILogger logger) : BaseProvider("https://chat-ai
return models.Where(model => model.Id.StartsWith("e5-", StringComparison.InvariantCultureIgnoreCase)); return models.Where(model => model.Id.StartsWith("e5-", StringComparison.InvariantCultureIgnoreCase));
} }
public override IReadOnlyCollection<Capability> GetModelCapabilities(Model model) => CapabilitiesOpenSource.GetCapabilities(model);
#endregion #endregion
private async Task<IEnumerable<Model>> LoadModels(CancellationToken token, string? apiKeyProvisional = null) private async Task<IEnumerable<Model>> LoadModels(CancellationToken token, string? apiKeyProvisional = null)

View File

@ -110,6 +110,8 @@ public class ProviderGroq(ILogger logger) : BaseProvider("https://api.groq.com/o
return Task.FromResult(Enumerable.Empty<Model>()); return Task.FromResult(Enumerable.Empty<Model>());
} }
public override IReadOnlyCollection<Capability> GetModelCapabilities(Model model) => CapabilitiesOpenSource.GetCapabilities(model);
#endregion #endregion
private async Task<IEnumerable<Model>> LoadModels(CancellationToken token, string? apiKeyProvisional = null) private async Task<IEnumerable<Model>> LoadModels(CancellationToken token, string? apiKeyProvisional = null)

View File

@ -112,6 +112,8 @@ public sealed class ProviderHelmholtz(ILogger logger) : BaseProvider("https://ap
model.Id.Contains("gritlm", StringComparison.InvariantCultureIgnoreCase)); model.Id.Contains("gritlm", StringComparison.InvariantCultureIgnoreCase));
} }
public override IReadOnlyCollection<Capability> GetModelCapabilities(Model model) => CapabilitiesOpenSource.GetCapabilities(model);
#endregion #endregion
private async Task<IEnumerable<Model>> LoadModels(CancellationToken token, string? apiKeyProvisional = null) private async Task<IEnumerable<Model>> LoadModels(CancellationToken token, string? apiKeyProvisional = null)

View File

@ -111,5 +111,7 @@ public sealed class ProviderHuggingFace : BaseProvider
return Task.FromResult(Enumerable.Empty<Model>()); return Task.FromResult(Enumerable.Empty<Model>());
} }
public override IReadOnlyCollection<Capability> GetModelCapabilities(Model model) => CapabilitiesOpenSource.GetCapabilities(model);
#endregion #endregion
} }

View File

@ -165,13 +165,7 @@ public sealed class ProviderMistral(ILogger logger) : BaseProvider("https://api.
]; ];
// Default: // Default:
return return CapabilitiesOpenSource.GetCapabilities(model);
[
Capability.TEXT_INPUT,
Capability.TEXT_OUTPUT,
Capability.FUNCTION_CALLING,
];
} }
#endregion #endregion

View File

@ -88,7 +88,6 @@ public sealed class ProviderSelfHosted(ILogger logger, Host host, string hostnam
} }
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
public override async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default) public override async Task<IEnumerable<Provider.Model>> GetTextModels(string? apiKeyProvisional = null, CancellationToken token = default)
{ {
try try
@ -140,6 +139,8 @@ public sealed class ProviderSelfHosted(ILogger logger, Host host, string hostnam
} }
} }
public override IReadOnlyCollection<Capability> GetModelCapabilities(Provider.Model model) => CapabilitiesOpenSource.GetCapabilities(model);
#endregion #endregion
private async Task<IEnumerable<Provider.Model>> LoadModels(string[] ignorePhrases, string[] filterPhrases, CancellationToken token, string? apiKeyProvisional = null) private async Task<IEnumerable<Provider.Model>> LoadModels(string[] ignorePhrases, string[] filterPhrases, CancellationToken token, string? apiKeyProvisional = null)

View File

@ -111,42 +111,7 @@ public sealed class ProviderX(ILogger logger) : BaseProvider("https://api.x.ai/v
return Task.FromResult<IEnumerable<Model>>([]); return Task.FromResult<IEnumerable<Model>>([]);
} }
public override IReadOnlyCollection<Capability> GetModelCapabilities(Model model) public override IReadOnlyCollection<Capability> GetModelCapabilities(Model model) => CapabilitiesOpenSource.GetCapabilities(model);
{
var modelName = model.Id.ToLowerInvariant().AsSpan();
if(modelName.IndexOf("-vision-") is not -1)
return
[
Capability.TEXT_INPUT, Capability.MULTIPLE_IMAGE_INPUT,
Capability.TEXT_OUTPUT,
];
if(modelName.StartsWith("grok-3-mini"))
return
[
Capability.TEXT_INPUT,
Capability.TEXT_OUTPUT,
Capability.ALWAYS_REASONING, Capability.FUNCTION_CALLING,
];
if(modelName.StartsWith("grok-3"))
return
[
Capability.TEXT_INPUT,
Capability.TEXT_OUTPUT,
Capability.FUNCTION_CALLING,
];
// Default capabilities:
return
[
Capability.TEXT_INPUT,
Capability.TEXT_OUTPUT,
];
}
#endregion #endregion