added the actual filtering and blocking security logic so that only public hosts will be allowed and only high confidence providers are used if request goes to allowed internal hosts

This commit is contained in:
Nils Kruthoff 2026-05-18 15:33:36 +02:00
parent 948d5dec27
commit b1f50b7b5c
No known key found for this signature in database
GPG Key ID: A5C0151B4DDB172C
2 changed files with 256 additions and 24 deletions

View File

@ -11,6 +11,7 @@ namespace AIStudio.Tools;
public sealed class HTMLParser public sealed class HTMLParser
{ {
private const string USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) MindWorkAIStudio/1.0"; private const string USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) MindWorkAIStudio/1.0";
private const int MAX_REDIRECTS = 10;
private static readonly Config MARKDOWN_PARSER_CONFIG = new() private static readonly Config MARKDOWN_PARSER_CONFIG = new()
{ {
@ -43,11 +44,12 @@ public sealed class HTMLParser
return innerHtml; return innerHtml;
} }
public async Task<HTMLParserWebPage> LoadWebPageAsync(Uri url, CancellationToken token = default, int timeoutSeconds = 30) public async Task<HTMLParserWebPage> LoadWebPageAsync(Uri url, CancellationToken token = default, int timeoutSeconds = 30, Func<Uri, CancellationToken, Task>? validateUrlAsync = null)
{ {
using var handler = new HttpClientHandler using var handler = new HttpClientHandler
{ {
AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate | DecompressionMethods.Brotli, AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate | DecompressionMethods.Brotli,
AllowAutoRedirect = false,
}; };
using var httpClient = new HttpClient(handler) using var httpClient = new HttpClient(handler)
{ {
@ -55,7 +57,53 @@ public sealed class HTMLParser
}; };
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(token); using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(token);
timeoutCts.CancelAfter(TimeSpan.FromSeconds(timeoutSeconds)); timeoutCts.CancelAfter(TimeSpan.FromSeconds(timeoutSeconds));
using var request = new HttpRequestMessage(HttpMethod.Get, url);
var currentUrl = url;
for (var redirectCount = 0; redirectCount <= MAX_REDIRECTS; redirectCount++)
{
if (validateUrlAsync is not null)
await validateUrlAsync(currentUrl, timeoutCts.Token);
using var request = CreateRequest(currentUrl);
using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutCts.Token);
if (IsRedirect(response.StatusCode))
{
if (response.Headers.Location is null)
throw new HttpRequestException($"The server returned a redirect without a Location header for '{currentUrl}'.", null, response.StatusCode);
currentUrl = response.Headers.Location.IsAbsoluteUri
? response.Headers.Location
: new Uri(currentUrl, response.Headers.Location);
continue;
}
if (!response.IsSuccessStatusCode)
{
var statusCode = (int)response.StatusCode;
var reasonPhrase = string.IsNullOrWhiteSpace(response.ReasonPhrase) ? "Unknown" : response.ReasonPhrase;
throw new HttpRequestException($"The server returned HTTP {statusCode} ({reasonPhrase}) for '{currentUrl}'.", null, response.StatusCode);
}
var html = await response.Content.ReadAsStringAsync(timeoutCts.Token);
var document = new HtmlDocument();
document.LoadHtml(html);
return new HTMLParserWebPage
{
RequestedUrl = url,
FinalUrl = response.RequestMessage?.RequestUri ?? currentUrl,
ContentType = response.Content.Headers.ContentType?.MediaType ?? string.Empty,
Document = document,
};
}
throw new HttpRequestException($"The server returned more than {MAX_REDIRECTS} redirects for '{url}'.");
}
private static HttpRequestMessage CreateRequest(Uri url)
{
var request = new HttpRequestMessage(HttpMethod.Get, url);
request.Headers.TryAddWithoutValidation("User-Agent", USER_AGENT); request.Headers.TryAddWithoutValidation("User-Agent", USER_AGENT);
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/html")); request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/html"));
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/xhtml+xml")); request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/xhtml+xml"));
@ -69,27 +117,10 @@ public sealed class HTMLParser
request.Headers.TryAddWithoutValidation("Sec-Fetch-Mode", "navigate"); request.Headers.TryAddWithoutValidation("Sec-Fetch-Mode", "navigate");
request.Headers.TryAddWithoutValidation("Sec-Fetch-Dest", "document"); request.Headers.TryAddWithoutValidation("Sec-Fetch-Dest", "document");
request.Headers.TryAddWithoutValidation("Sec-Fetch-User", "?1"); request.Headers.TryAddWithoutValidation("Sec-Fetch-User", "?1");
return request;
using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutCts.Token);
if (!response.IsSuccessStatusCode)
{
var statusCode = (int)response.StatusCode;
var reasonPhrase = string.IsNullOrWhiteSpace(response.ReasonPhrase) ? "Unknown" : response.ReasonPhrase;
throw new HttpRequestException($"The server returned HTTP {statusCode} ({reasonPhrase}) for '{url}'.", null, response.StatusCode);
} }
var html = await response.Content.ReadAsStringAsync(token); private static bool IsRedirect(HttpStatusCode statusCode) => (int)statusCode is >= 300 and <= 399;
var document = new HtmlDocument();
document.LoadHtml(html);
return new HTMLParserWebPage
{
RequestedUrl = url,
FinalUrl = response.RequestMessage?.RequestUri ?? url,
ContentType = response.Content.Headers.ContentType?.MediaType ?? string.Empty,
Document = document,
};
}
public string ExtractTitle(HtmlDocument document) public string ExtractTitle(HtmlDocument document)
{ {

View File

@ -1,17 +1,21 @@
using System.Net;
using System.Net.Sockets;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Nodes; using System.Text.Json.Nodes;
using AIStudio.Provider;
using AIStudio.Tools.PluginSystem; using AIStudio.Tools.PluginSystem;
using HtmlAgilityPack; using HtmlAgilityPack;
namespace AIStudio.Tools.ToolCallingSystem.ToolCallingImplementations; namespace AIStudio.Tools.ToolCallingSystem.ToolCallingImplementations;
public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger<ReadWebPageTool> logger) : IToolImplementation
{ {
private static string TB(string fallbackEN) => I18N.I.T(fallbackEN, typeof(ReadWebPageTool).Namespace, nameof(ReadWebPageTool)); private static string TB(string fallbackEN) => I18N.I.T(fallbackEN, typeof(ReadWebPageTool).Namespace, nameof(ReadWebPageTool));
private const int DEFAULT_TIMEOUT_SECONDS = 30; private const int DEFAULT_TIMEOUT_SECONDS = 30;
private const int DEFAULT_MAX_CONTENT_CHARACTERS = 12000; private const int DEFAULT_MAX_CONTENT_CHARACTERS = 12000;
private const int MAX_TRACE_LENGTH = 12000; private const int MAX_TRACE_LENGTH = 12000;
private const string ALLOWED_PRIVATE_HOSTS_SETTING = "allowedPrivateHosts";
private static readonly string[] REMOVED_NODE_XPATHS = private static readonly string[] REMOVED_NODE_XPATHS =
[ [
@ -42,6 +46,7 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation
{ {
"timeoutSeconds" => TB("Timeout Seconds"), "timeoutSeconds" => TB("Timeout Seconds"),
"maxContentCharacters" => TB("Maximum Content Characters"), "maxContentCharacters" => TB("Maximum Content Characters"),
ALLOWED_PRIVATE_HOSTS_SETTING => TB("Allowed Private Hosts"),
_ => TB(fieldDefinition.Title), _ => TB(fieldDefinition.Title),
}; };
@ -49,6 +54,7 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation
{ {
"timeoutSeconds" => TB("Optional HTTP timeout for loading a web page in seconds."), "timeoutSeconds" => TB("Optional HTTP timeout for loading a web page in seconds."),
"maxContentCharacters" => TB("Optional global truncation limit for extracted Markdown returned to the model."), "maxContentCharacters" => TB("Optional global truncation limit for extracted Markdown returned to the model."),
ALLOWED_PRIVATE_HOSTS_SETTING => TB("Optional host allowlist for private or VPN web pages. Separate host patterns with commas, such as example.de, *.example.de. Allowed private hosts require a High-confidence provider."),
_ => TB(fieldDefinition.Description), _ => TB(fieldDefinition.Description),
}; };
@ -75,6 +81,15 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation
}); });
} }
if (!TryReadAllowedPrivateHostPatterns(settingsValues.GetValueOrDefault(ALLOWED_PRIVATE_HOSTS_SETTING), out _, out var allowlistError))
{
return Task.FromResult<ToolConfigurationState?>(new ToolConfigurationState
{
IsConfigured = false,
Message = allowlistError,
});
}
return Task.FromResult<ToolConfigurationState?>(null); return Task.FromResult<ToolConfigurationState?>(null);
} }
@ -86,11 +101,17 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation
var timeoutSeconds = ReadOptionalPositiveIntSetting(context.SettingsValues, "timeoutSeconds") ?? DEFAULT_TIMEOUT_SECONDS; var timeoutSeconds = ReadOptionalPositiveIntSetting(context.SettingsValues, "timeoutSeconds") ?? DEFAULT_TIMEOUT_SECONDS;
var maxContentCharacters = ReadOptionalPositiveIntSetting(context.SettingsValues, "maxContentCharacters") ?? DEFAULT_MAX_CONTENT_CHARACTERS; var maxContentCharacters = ReadOptionalPositiveIntSetting(context.SettingsValues, "maxContentCharacters") ?? DEFAULT_MAX_CONTENT_CHARACTERS;
if (!TryReadAllowedPrivateHostPatterns(context.SettingsValues.GetValueOrDefault(ALLOWED_PRIVATE_HOSTS_SETTING), out var allowedPrivateHosts, out var allowlistError))
throw new InvalidOperationException(allowlistError);
HTMLParserWebPage page; HTMLParserWebPage page;
try try
{ {
page = await htmlParser.LoadWebPageAsync(url, token, timeoutSeconds); page = await htmlParser.LoadWebPageAsync(
url,
token,
timeoutSeconds,
async (candidateUrl, validationToken) => await this.ValidateUrlAccessAsync(candidateUrl, allowedPrivateHosts, context.ProviderConfidence, validationToken));
} }
catch (OperationCanceledException) when (!token.IsCancellationRequested) catch (OperationCanceledException) when (!token.IsCancellationRequested)
{ {
@ -162,6 +183,178 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation
return $"{rawResult[..MAX_TRACE_LENGTH]}..."; return $"{rawResult[..MAX_TRACE_LENGTH]}...";
} }
private async Task ValidateUrlAccessAsync(
Uri url,
IReadOnlyList<AllowedPrivateHostPattern> allowedPrivateHosts,
ConfidenceLevel providerConfidence,
CancellationToken token)
{
if (url is not { Scheme: "http" or "https" })
throw new ToolExecutionBlockedException("Only HTTP and HTTPS URLs are supported.");
if (IsBlockedHostName(url.Host))
throw new ToolExecutionBlockedException("Local web page URLs are not supported.");
var addresses = await ResolveHostAddressesAsync(url, token);
if (addresses.Count == 0)
throw new InvalidOperationException($"The host '{url.Host}' did not resolve to an IP address.");
if (addresses.Any(IsNeverAllowedAddress))
throw new ToolExecutionBlockedException("Local, link-local, multicast, and unspecified network addresses are not supported.");
if (!addresses.Any(IsNonPublicAddress))
return;
if (!IsAllowedPrivateHost(url.Host, allowedPrivateHosts))
throw new ToolExecutionBlockedException("Private or local-network web page URLs are not supported unless their host is explicitly allowed.");
if (providerConfidence >= ConfidenceLevel.HIGH)
return;
await this.ReportPrivateHostProviderBlockAsync(url, providerConfidence);
throw new ToolExecutionBlockedException("This private or VPN web page requires a High-confidence provider.");
}
private async Task ReportPrivateHostProviderBlockAsync(Uri url, ConfidenceLevel providerConfidence)
{
logger.LogWarning(
"Blocked read_web_page access to allowed private host '{Host}' because provider confidence '{ProviderConfidence}' is below HIGH.",
url.Host,
providerConfidence);
await MessageBus.INSTANCE.SendError(new DataErrorMessage(
Icons.Material.Filled.Security,
TB("The web page was not loaded because private or VPN web pages require a High-confidence provider.")));
}
private static async Task<IReadOnlyList<IPAddress>> ResolveHostAddressesAsync(Uri url, CancellationToken token)
{
if (IPAddress.TryParse(url.Host, out var parsedAddress))
return [NormalizeAddress(parsedAddress)];
try
{
return (await Dns.GetHostAddressesAsync(url.DnsSafeHost, token))
.Select(NormalizeAddress)
.ToList();
}
catch (SocketException exception)
{
throw new InvalidOperationException($"The host '{url.Host}' could not be resolved: {exception.Message}", exception);
}
}
private static IPAddress NormalizeAddress(IPAddress address) => address.IsIPv4MappedToIPv6 ? address.MapToIPv4() : address;
private static bool IsBlockedHostName(string host)
{
var normalizedHost = NormalizeHost(host);
return normalizedHost is "localhost" ||
normalizedHost.EndsWith(".localhost", StringComparison.Ordinal);
}
private static bool IsAllowedPrivateHost(string host, IReadOnlyList<AllowedPrivateHostPattern> allowedPrivateHosts)
{
var normalizedHost = NormalizeHost(host);
return allowedPrivateHosts.Any(pattern => pattern.IsMatch(normalizedHost));
}
private static string NormalizeHost(string host) => host.Trim().TrimEnd('.').ToLowerInvariant();
private static bool IsNeverAllowedAddress(IPAddress address)
{
address = NormalizeAddress(address);
if (IPAddress.IsLoopback(address))
return true;
if (address.AddressFamily is AddressFamily.InterNetwork)
{
var bytes = address.GetAddressBytes();
return address.Equals(IPAddress.Any) ||
bytes[0] is 0 or 127 or >= 224 ||
(bytes[0] == 169 && bytes[1] == 254);
}
if (address.AddressFamily is AddressFamily.InterNetworkV6)
{
return address.Equals(IPAddress.IPv6Any) ||
address.Equals(IPAddress.IPv6None) ||
address.Equals(IPAddress.IPv6Loopback) ||
address.IsIPv6LinkLocal ||
address.IsIPv6Multicast;
}
return true;
}
private static bool IsNonPublicAddress(IPAddress address)
{
address = NormalizeAddress(address);
if (IsNeverAllowedAddress(address))
return true;
if (address.AddressFamily is AddressFamily.InterNetwork)
{
var bytes = address.GetAddressBytes();
return bytes[0] == 10 || // Private network: 10.0.0.0/8
(bytes[0] == 100 && bytes[1] is >= 64 and <= 127) || // Carrier-grade NAT: 100.64.0.0/10
(bytes[0] == 172 && bytes[1] is >= 16 and <= 31) || // Private network: 172.16.0.0/12
(bytes[0] == 192 && bytes[1] == 168) || // Private network: 192.168.0.0/16
(bytes[0] == 192 && bytes[1] == 0 && bytes[2] == 0) || // IETF protocol assignments: 192.0.0.0/24
(bytes[0] == 192 && bytes[1] == 0 && bytes[2] == 2) || // Documentation range: 192.0.2.0/24
(bytes[0] == 198 && bytes[1] is 18 or 19) || // Benchmark testing range: 198.18.0.0/15
(bytes[0] == 198 && bytes[1] == 51 && bytes[2] == 100) || // Documentation range: 198.51.100.0/24
(bytes[0] == 203 && bytes[1] == 0 && bytes[2] == 113); // Documentation range: 203.0.113.0/24
}
if (address.AddressFamily is AddressFamily.InterNetworkV6)
{
var bytes = address.GetAddressBytes();
return (bytes[0] & 0xfe) == 0xfc || // Unique local addresses: fc00::/7
address.IsIPv6SiteLocal; // Deprecated site-local addresses: fec0::/10
}
return true;
}
private static bool TryReadAllowedPrivateHostPatterns(
string? rawValue,
out List<AllowedPrivateHostPattern> patterns,
out string error)
{
patterns = [];
error = string.Empty;
foreach (var rawPattern in SplitAllowedPrivateHostPatterns(rawValue))
{
var pattern = NormalizeHost(rawPattern);
if (pattern.Contains("://", StringComparison.Ordinal) || pattern.Contains('/'))
{
error = TB("Allowed private hosts must be host names only, without scheme or path.");
return false;
}
var isWildcard = pattern.StartsWith("*.", StringComparison.Ordinal);
var host = isWildcard ? pattern[2..] : pattern;
if (string.IsNullOrWhiteSpace(host) || Uri.CheckHostName(host) is UriHostNameType.Unknown)
{
error = string.Format(TB("Allowed private host '{0}' is not valid."), rawPattern);
return false;
}
patterns.Add(new AllowedPrivateHostPattern(host, isWildcard));
}
patterns = patterns
.Distinct()
.ToList();
return true;
}
private static IEnumerable<string> SplitAllowedPrivateHostPatterns(string? rawValue) => rawValue?
.Split(['\r', '\n', ',', ';'], StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)
.Where(x => !string.IsNullOrWhiteSpace(x)) ?? [];
private static void RemoveNoiseNodes(HtmlNode rootNode) private static void RemoveNoiseNodes(HtmlNode rootNode)
{ {
foreach (var xpath in REMOVED_NODE_XPATHS) foreach (var xpath in REMOVED_NODE_XPATHS)
@ -221,4 +414,12 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation
error = I18N.I.T($"The setting '{key}' must be a positive integer.", typeof(ReadWebPageTool).Namespace, nameof(ReadWebPageTool)); error = I18N.I.T($"The setting '{key}' must be a positive integer.", typeof(ReadWebPageTool).Namespace, nameof(ReadWebPageTool));
return false; return false;
} }
private readonly record struct AllowedPrivateHostPattern(string Host, bool IsWildcard)
{
public bool IsMatch(string normalizedHost) =>
this.IsWildcard
? normalizedHost.EndsWith($".{this.Host}", StringComparison.Ordinal) && normalizedHost.Length > this.Host.Length + 1
: normalizedHost.Equals(this.Host, StringComparison.Ordinal);
}
} }