diff --git a/app/MindWork AI Studio/Tools/HTMLParser.cs b/app/MindWork AI Studio/Tools/HTMLParser.cs index 56aceee4..3e86e830 100644 --- a/app/MindWork AI Studio/Tools/HTMLParser.cs +++ b/app/MindWork AI Studio/Tools/HTMLParser.cs @@ -11,6 +11,7 @@ namespace AIStudio.Tools; public sealed class HTMLParser { private const string USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) MindWorkAIStudio/1.0"; + private const int MAX_REDIRECTS = 10; private static readonly Config MARKDOWN_PARSER_CONFIG = new() { @@ -43,11 +44,12 @@ public sealed class HTMLParser return innerHtml; } - public async Task LoadWebPageAsync(Uri url, CancellationToken token = default, int timeoutSeconds = 30) + public async Task LoadWebPageAsync(Uri url, CancellationToken token = default, int timeoutSeconds = 30, Func? validateUrlAsync = null) { using var handler = new HttpClientHandler { AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate | DecompressionMethods.Brotli, + AllowAutoRedirect = false, }; using var httpClient = new HttpClient(handler) { @@ -55,7 +57,53 @@ public sealed class HTMLParser }; using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(token); timeoutCts.CancelAfter(TimeSpan.FromSeconds(timeoutSeconds)); - using var request = new HttpRequestMessage(HttpMethod.Get, url); + + var currentUrl = url; + for (var redirectCount = 0; redirectCount <= MAX_REDIRECTS; redirectCount++) + { + if (validateUrlAsync is not null) + await validateUrlAsync(currentUrl, timeoutCts.Token); + + using var request = CreateRequest(currentUrl); + using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutCts.Token); + if (IsRedirect(response.StatusCode)) + { + if (response.Headers.Location is null) + throw new HttpRequestException($"The server returned a redirect without a Location header for '{currentUrl}'.", null, response.StatusCode); + + currentUrl = response.Headers.Location.IsAbsoluteUri + ? response.Headers.Location + : new Uri(currentUrl, response.Headers.Location); + + continue; + } + + if (!response.IsSuccessStatusCode) + { + var statusCode = (int)response.StatusCode; + var reasonPhrase = string.IsNullOrWhiteSpace(response.ReasonPhrase) ? "Unknown" : response.ReasonPhrase; + throw new HttpRequestException($"The server returned HTTP {statusCode} ({reasonPhrase}) for '{currentUrl}'.", null, response.StatusCode); + } + + var html = await response.Content.ReadAsStringAsync(timeoutCts.Token); + var document = new HtmlDocument(); + document.LoadHtml(html); + + return new HTMLParserWebPage + { + RequestedUrl = url, + FinalUrl = response.RequestMessage?.RequestUri ?? currentUrl, + ContentType = response.Content.Headers.ContentType?.MediaType ?? string.Empty, + Document = document, + }; + } + + throw new HttpRequestException($"The server returned more than {MAX_REDIRECTS} redirects for '{url}'."); + } + + private static HttpRequestMessage CreateRequest(Uri url) + { + var request = new HttpRequestMessage(HttpMethod.Get, url); request.Headers.TryAddWithoutValidation("User-Agent", USER_AGENT); request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/html")); request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/xhtml+xml")); @@ -69,28 +117,11 @@ public sealed class HTMLParser request.Headers.TryAddWithoutValidation("Sec-Fetch-Mode", "navigate"); request.Headers.TryAddWithoutValidation("Sec-Fetch-Dest", "document"); request.Headers.TryAddWithoutValidation("Sec-Fetch-User", "?1"); - - using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutCts.Token); - if (!response.IsSuccessStatusCode) - { - var statusCode = (int)response.StatusCode; - var reasonPhrase = string.IsNullOrWhiteSpace(response.ReasonPhrase) ? "Unknown" : response.ReasonPhrase; - throw new HttpRequestException($"The server returned HTTP {statusCode} ({reasonPhrase}) for '{url}'.", null, response.StatusCode); - } - - var html = await response.Content.ReadAsStringAsync(token); - var document = new HtmlDocument(); - document.LoadHtml(html); - - return new HTMLParserWebPage - { - RequestedUrl = url, - FinalUrl = response.RequestMessage?.RequestUri ?? url, - ContentType = response.Content.Headers.ContentType?.MediaType ?? string.Empty, - Document = document, - }; + return request; } + private static bool IsRedirect(HttpStatusCode statusCode) => (int)statusCode is >= 300 and <= 399; + public string ExtractTitle(HtmlDocument document) { var title = document.DocumentNode.SelectSingleNode("//title")?.InnerText?.Trim(); diff --git a/app/MindWork AI Studio/Tools/ToolCallingSystem/ToolCallingImplementations/ReadWebPageTool.cs b/app/MindWork AI Studio/Tools/ToolCallingSystem/ToolCallingImplementations/ReadWebPageTool.cs index 4a9c997f..0920789c 100644 --- a/app/MindWork AI Studio/Tools/ToolCallingSystem/ToolCallingImplementations/ReadWebPageTool.cs +++ b/app/MindWork AI Studio/Tools/ToolCallingSystem/ToolCallingImplementations/ReadWebPageTool.cs @@ -1,17 +1,21 @@ +using System.Net; +using System.Net.Sockets; using System.Text.Json; using System.Text.Json.Nodes; +using AIStudio.Provider; using AIStudio.Tools.PluginSystem; using HtmlAgilityPack; namespace AIStudio.Tools.ToolCallingSystem.ToolCallingImplementations; -public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation +public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger logger) : IToolImplementation { private static string TB(string fallbackEN) => I18N.I.T(fallbackEN, typeof(ReadWebPageTool).Namespace, nameof(ReadWebPageTool)); private const int DEFAULT_TIMEOUT_SECONDS = 30; private const int DEFAULT_MAX_CONTENT_CHARACTERS = 12000; private const int MAX_TRACE_LENGTH = 12000; + private const string ALLOWED_PRIVATE_HOSTS_SETTING = "allowedPrivateHosts"; private static readonly string[] REMOVED_NODE_XPATHS = [ @@ -42,6 +46,7 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation { "timeoutSeconds" => TB("Timeout Seconds"), "maxContentCharacters" => TB("Maximum Content Characters"), + ALLOWED_PRIVATE_HOSTS_SETTING => TB("Allowed Private Hosts"), _ => TB(fieldDefinition.Title), }; @@ -49,6 +54,7 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation { "timeoutSeconds" => TB("Optional HTTP timeout for loading a web page in seconds."), "maxContentCharacters" => TB("Optional global truncation limit for extracted Markdown returned to the model."), + ALLOWED_PRIVATE_HOSTS_SETTING => TB("Optional host allowlist for private or VPN web pages. Separate host patterns with commas, such as example.de, *.example.de. Allowed private hosts require a High-confidence provider."), _ => TB(fieldDefinition.Description), }; @@ -75,6 +81,15 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation }); } + if (!TryReadAllowedPrivateHostPatterns(settingsValues.GetValueOrDefault(ALLOWED_PRIVATE_HOSTS_SETTING), out _, out var allowlistError)) + { + return Task.FromResult(new ToolConfigurationState + { + IsConfigured = false, + Message = allowlistError, + }); + } + return Task.FromResult(null); } @@ -86,11 +101,17 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation var timeoutSeconds = ReadOptionalPositiveIntSetting(context.SettingsValues, "timeoutSeconds") ?? DEFAULT_TIMEOUT_SECONDS; var maxContentCharacters = ReadOptionalPositiveIntSetting(context.SettingsValues, "maxContentCharacters") ?? DEFAULT_MAX_CONTENT_CHARACTERS; + if (!TryReadAllowedPrivateHostPatterns(context.SettingsValues.GetValueOrDefault(ALLOWED_PRIVATE_HOSTS_SETTING), out var allowedPrivateHosts, out var allowlistError)) + throw new InvalidOperationException(allowlistError); HTMLParserWebPage page; try { - page = await htmlParser.LoadWebPageAsync(url, token, timeoutSeconds); + page = await htmlParser.LoadWebPageAsync( + url, + token, + timeoutSeconds, + async (candidateUrl, validationToken) => await this.ValidateUrlAccessAsync(candidateUrl, allowedPrivateHosts, context.ProviderConfidence, validationToken)); } catch (OperationCanceledException) when (!token.IsCancellationRequested) { @@ -162,6 +183,178 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation return $"{rawResult[..MAX_TRACE_LENGTH]}..."; } + private async Task ValidateUrlAccessAsync( + Uri url, + IReadOnlyList allowedPrivateHosts, + ConfidenceLevel providerConfidence, + CancellationToken token) + { + if (url is not { Scheme: "http" or "https" }) + throw new ToolExecutionBlockedException("Only HTTP and HTTPS URLs are supported."); + + if (IsBlockedHostName(url.Host)) + throw new ToolExecutionBlockedException("Local web page URLs are not supported."); + + var addresses = await ResolveHostAddressesAsync(url, token); + if (addresses.Count == 0) + throw new InvalidOperationException($"The host '{url.Host}' did not resolve to an IP address."); + + if (addresses.Any(IsNeverAllowedAddress)) + throw new ToolExecutionBlockedException("Local, link-local, multicast, and unspecified network addresses are not supported."); + + if (!addresses.Any(IsNonPublicAddress)) + return; + + if (!IsAllowedPrivateHost(url.Host, allowedPrivateHosts)) + throw new ToolExecutionBlockedException("Private or local-network web page URLs are not supported unless their host is explicitly allowed."); + + if (providerConfidence >= ConfidenceLevel.HIGH) + return; + + await this.ReportPrivateHostProviderBlockAsync(url, providerConfidence); + throw new ToolExecutionBlockedException("This private or VPN web page requires a High-confidence provider."); + } + + private async Task ReportPrivateHostProviderBlockAsync(Uri url, ConfidenceLevel providerConfidence) + { + logger.LogWarning( + "Blocked read_web_page access to allowed private host '{Host}' because provider confidence '{ProviderConfidence}' is below HIGH.", + url.Host, + providerConfidence); + + await MessageBus.INSTANCE.SendError(new DataErrorMessage( + Icons.Material.Filled.Security, + TB("The web page was not loaded because private or VPN web pages require a High-confidence provider."))); + } + + private static async Task> ResolveHostAddressesAsync(Uri url, CancellationToken token) + { + if (IPAddress.TryParse(url.Host, out var parsedAddress)) + return [NormalizeAddress(parsedAddress)]; + + try + { + return (await Dns.GetHostAddressesAsync(url.DnsSafeHost, token)) + .Select(NormalizeAddress) + .ToList(); + } + catch (SocketException exception) + { + throw new InvalidOperationException($"The host '{url.Host}' could not be resolved: {exception.Message}", exception); + } + } + + private static IPAddress NormalizeAddress(IPAddress address) => address.IsIPv4MappedToIPv6 ? address.MapToIPv4() : address; + + private static bool IsBlockedHostName(string host) + { + var normalizedHost = NormalizeHost(host); + return normalizedHost is "localhost" || + normalizedHost.EndsWith(".localhost", StringComparison.Ordinal); + } + + private static bool IsAllowedPrivateHost(string host, IReadOnlyList allowedPrivateHosts) + { + var normalizedHost = NormalizeHost(host); + return allowedPrivateHosts.Any(pattern => pattern.IsMatch(normalizedHost)); + } + + private static string NormalizeHost(string host) => host.Trim().TrimEnd('.').ToLowerInvariant(); + + private static bool IsNeverAllowedAddress(IPAddress address) + { + address = NormalizeAddress(address); + if (IPAddress.IsLoopback(address)) + return true; + + if (address.AddressFamily is AddressFamily.InterNetwork) + { + var bytes = address.GetAddressBytes(); + return address.Equals(IPAddress.Any) || + bytes[0] is 0 or 127 or >= 224 || + (bytes[0] == 169 && bytes[1] == 254); + } + + if (address.AddressFamily is AddressFamily.InterNetworkV6) + { + return address.Equals(IPAddress.IPv6Any) || + address.Equals(IPAddress.IPv6None) || + address.Equals(IPAddress.IPv6Loopback) || + address.IsIPv6LinkLocal || + address.IsIPv6Multicast; + } + + return true; + } + + private static bool IsNonPublicAddress(IPAddress address) + { + address = NormalizeAddress(address); + if (IsNeverAllowedAddress(address)) + return true; + + if (address.AddressFamily is AddressFamily.InterNetwork) + { + var bytes = address.GetAddressBytes(); + return bytes[0] == 10 || // Private network: 10.0.0.0/8 + (bytes[0] == 100 && bytes[1] is >= 64 and <= 127) || // Carrier-grade NAT: 100.64.0.0/10 + (bytes[0] == 172 && bytes[1] is >= 16 and <= 31) || // Private network: 172.16.0.0/12 + (bytes[0] == 192 && bytes[1] == 168) || // Private network: 192.168.0.0/16 + (bytes[0] == 192 && bytes[1] == 0 && bytes[2] == 0) || // IETF protocol assignments: 192.0.0.0/24 + (bytes[0] == 192 && bytes[1] == 0 && bytes[2] == 2) || // Documentation range: 192.0.2.0/24 + (bytes[0] == 198 && bytes[1] is 18 or 19) || // Benchmark testing range: 198.18.0.0/15 + (bytes[0] == 198 && bytes[1] == 51 && bytes[2] == 100) || // Documentation range: 198.51.100.0/24 + (bytes[0] == 203 && bytes[1] == 0 && bytes[2] == 113); // Documentation range: 203.0.113.0/24 + } + + if (address.AddressFamily is AddressFamily.InterNetworkV6) + { + var bytes = address.GetAddressBytes(); + return (bytes[0] & 0xfe) == 0xfc || // Unique local addresses: fc00::/7 + address.IsIPv6SiteLocal; // Deprecated site-local addresses: fec0::/10 + } + + return true; + } + + private static bool TryReadAllowedPrivateHostPatterns( + string? rawValue, + out List patterns, + out string error) + { + patterns = []; + error = string.Empty; + + foreach (var rawPattern in SplitAllowedPrivateHostPatterns(rawValue)) + { + var pattern = NormalizeHost(rawPattern); + if (pattern.Contains("://", StringComparison.Ordinal) || pattern.Contains('/')) + { + error = TB("Allowed private hosts must be host names only, without scheme or path."); + return false; + } + + var isWildcard = pattern.StartsWith("*.", StringComparison.Ordinal); + var host = isWildcard ? pattern[2..] : pattern; + if (string.IsNullOrWhiteSpace(host) || Uri.CheckHostName(host) is UriHostNameType.Unknown) + { + error = string.Format(TB("Allowed private host '{0}' is not valid."), rawPattern); + return false; + } + + patterns.Add(new AllowedPrivateHostPattern(host, isWildcard)); + } + + patterns = patterns + .Distinct() + .ToList(); + return true; + } + + private static IEnumerable SplitAllowedPrivateHostPatterns(string? rawValue) => rawValue? + .Split(['\r', '\n', ',', ';'], StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) + .Where(x => !string.IsNullOrWhiteSpace(x)) ?? []; + private static void RemoveNoiseNodes(HtmlNode rootNode) { foreach (var xpath in REMOVED_NODE_XPATHS) @@ -221,4 +414,12 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation error = I18N.I.T($"The setting '{key}' must be a positive integer.", typeof(ReadWebPageTool).Namespace, nameof(ReadWebPageTool)); return false; } + + private readonly record struct AllowedPrivateHostPattern(string Host, bool IsWildcard) + { + public bool IsMatch(string normalizedHost) => + this.IsWildcard + ? normalizedHost.EndsWith($".{this.Host}", StringComparison.Ordinal) && normalizedHost.Length > this.Host.Length + 1 + : normalizedHost.Equals(this.Host, StringComparison.Ordinal); + } }