mirror of
https://github.com/MindWorkAI/AI-Studio.git
synced 2026-06-27 17:16:28 +00:00
added the actual filtering and blocking security logic so that only public hosts will be allowed and only high confidence providers are used if request goes to allowed internal hosts
This commit is contained in:
parent
948d5dec27
commit
b1f50b7b5c
@ -11,6 +11,7 @@ namespace AIStudio.Tools;
|
||||
public sealed class HTMLParser
|
||||
{
|
||||
private const string USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) MindWorkAIStudio/1.0";
|
||||
private const int MAX_REDIRECTS = 10;
|
||||
|
||||
private static readonly Config MARKDOWN_PARSER_CONFIG = new()
|
||||
{
|
||||
@ -43,11 +44,12 @@ public sealed class HTMLParser
|
||||
return innerHtml;
|
||||
}
|
||||
|
||||
public async Task<HTMLParserWebPage> LoadWebPageAsync(Uri url, CancellationToken token = default, int timeoutSeconds = 30)
|
||||
public async Task<HTMLParserWebPage> LoadWebPageAsync(Uri url, CancellationToken token = default, int timeoutSeconds = 30, Func<Uri, CancellationToken, Task>? validateUrlAsync = null)
|
||||
{
|
||||
using var handler = new HttpClientHandler
|
||||
{
|
||||
AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate | DecompressionMethods.Brotli,
|
||||
AllowAutoRedirect = false,
|
||||
};
|
||||
using var httpClient = new HttpClient(handler)
|
||||
{
|
||||
@ -55,7 +57,53 @@ public sealed class HTMLParser
|
||||
};
|
||||
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(token);
|
||||
timeoutCts.CancelAfter(TimeSpan.FromSeconds(timeoutSeconds));
|
||||
using var request = new HttpRequestMessage(HttpMethod.Get, url);
|
||||
|
||||
var currentUrl = url;
|
||||
for (var redirectCount = 0; redirectCount <= MAX_REDIRECTS; redirectCount++)
|
||||
{
|
||||
if (validateUrlAsync is not null)
|
||||
await validateUrlAsync(currentUrl, timeoutCts.Token);
|
||||
|
||||
using var request = CreateRequest(currentUrl);
|
||||
using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutCts.Token);
|
||||
if (IsRedirect(response.StatusCode))
|
||||
{
|
||||
if (response.Headers.Location is null)
|
||||
throw new HttpRequestException($"The server returned a redirect without a Location header for '{currentUrl}'.", null, response.StatusCode);
|
||||
|
||||
currentUrl = response.Headers.Location.IsAbsoluteUri
|
||||
? response.Headers.Location
|
||||
: new Uri(currentUrl, response.Headers.Location);
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!response.IsSuccessStatusCode)
|
||||
{
|
||||
var statusCode = (int)response.StatusCode;
|
||||
var reasonPhrase = string.IsNullOrWhiteSpace(response.ReasonPhrase) ? "Unknown" : response.ReasonPhrase;
|
||||
throw new HttpRequestException($"The server returned HTTP {statusCode} ({reasonPhrase}) for '{currentUrl}'.", null, response.StatusCode);
|
||||
}
|
||||
|
||||
var html = await response.Content.ReadAsStringAsync(timeoutCts.Token);
|
||||
var document = new HtmlDocument();
|
||||
document.LoadHtml(html);
|
||||
|
||||
return new HTMLParserWebPage
|
||||
{
|
||||
RequestedUrl = url,
|
||||
FinalUrl = response.RequestMessage?.RequestUri ?? currentUrl,
|
||||
ContentType = response.Content.Headers.ContentType?.MediaType ?? string.Empty,
|
||||
Document = document,
|
||||
};
|
||||
}
|
||||
|
||||
throw new HttpRequestException($"The server returned more than {MAX_REDIRECTS} redirects for '{url}'.");
|
||||
}
|
||||
|
||||
private static HttpRequestMessage CreateRequest(Uri url)
|
||||
{
|
||||
var request = new HttpRequestMessage(HttpMethod.Get, url);
|
||||
request.Headers.TryAddWithoutValidation("User-Agent", USER_AGENT);
|
||||
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/html"));
|
||||
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/xhtml+xml"));
|
||||
@ -69,28 +117,11 @@ public sealed class HTMLParser
|
||||
request.Headers.TryAddWithoutValidation("Sec-Fetch-Mode", "navigate");
|
||||
request.Headers.TryAddWithoutValidation("Sec-Fetch-Dest", "document");
|
||||
request.Headers.TryAddWithoutValidation("Sec-Fetch-User", "?1");
|
||||
|
||||
using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutCts.Token);
|
||||
if (!response.IsSuccessStatusCode)
|
||||
{
|
||||
var statusCode = (int)response.StatusCode;
|
||||
var reasonPhrase = string.IsNullOrWhiteSpace(response.ReasonPhrase) ? "Unknown" : response.ReasonPhrase;
|
||||
throw new HttpRequestException($"The server returned HTTP {statusCode} ({reasonPhrase}) for '{url}'.", null, response.StatusCode);
|
||||
}
|
||||
|
||||
var html = await response.Content.ReadAsStringAsync(token);
|
||||
var document = new HtmlDocument();
|
||||
document.LoadHtml(html);
|
||||
|
||||
return new HTMLParserWebPage
|
||||
{
|
||||
RequestedUrl = url,
|
||||
FinalUrl = response.RequestMessage?.RequestUri ?? url,
|
||||
ContentType = response.Content.Headers.ContentType?.MediaType ?? string.Empty,
|
||||
Document = document,
|
||||
};
|
||||
return request;
|
||||
}
|
||||
|
||||
private static bool IsRedirect(HttpStatusCode statusCode) => (int)statusCode is >= 300 and <= 399;
|
||||
|
||||
public string ExtractTitle(HtmlDocument document)
|
||||
{
|
||||
var title = document.DocumentNode.SelectSingleNode("//title")?.InnerText?.Trim();
|
||||
|
||||
@ -1,17 +1,21 @@
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Nodes;
|
||||
using AIStudio.Provider;
|
||||
using AIStudio.Tools.PluginSystem;
|
||||
using HtmlAgilityPack;
|
||||
|
||||
namespace AIStudio.Tools.ToolCallingSystem.ToolCallingImplementations;
|
||||
|
||||
public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation
|
||||
public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger<ReadWebPageTool> logger) : IToolImplementation
|
||||
{
|
||||
private static string TB(string fallbackEN) => I18N.I.T(fallbackEN, typeof(ReadWebPageTool).Namespace, nameof(ReadWebPageTool));
|
||||
|
||||
private const int DEFAULT_TIMEOUT_SECONDS = 30;
|
||||
private const int DEFAULT_MAX_CONTENT_CHARACTERS = 12000;
|
||||
private const int MAX_TRACE_LENGTH = 12000;
|
||||
private const string ALLOWED_PRIVATE_HOSTS_SETTING = "allowedPrivateHosts";
|
||||
|
||||
private static readonly string[] REMOVED_NODE_XPATHS =
|
||||
[
|
||||
@ -42,6 +46,7 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation
|
||||
{
|
||||
"timeoutSeconds" => TB("Timeout Seconds"),
|
||||
"maxContentCharacters" => TB("Maximum Content Characters"),
|
||||
ALLOWED_PRIVATE_HOSTS_SETTING => TB("Allowed Private Hosts"),
|
||||
_ => TB(fieldDefinition.Title),
|
||||
};
|
||||
|
||||
@ -49,6 +54,7 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation
|
||||
{
|
||||
"timeoutSeconds" => TB("Optional HTTP timeout for loading a web page in seconds."),
|
||||
"maxContentCharacters" => TB("Optional global truncation limit for extracted Markdown returned to the model."),
|
||||
ALLOWED_PRIVATE_HOSTS_SETTING => TB("Optional host allowlist for private or VPN web pages. Separate host patterns with commas, such as example.de, *.example.de. Allowed private hosts require a High-confidence provider."),
|
||||
_ => TB(fieldDefinition.Description),
|
||||
};
|
||||
|
||||
@ -75,6 +81,15 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation
|
||||
});
|
||||
}
|
||||
|
||||
if (!TryReadAllowedPrivateHostPatterns(settingsValues.GetValueOrDefault(ALLOWED_PRIVATE_HOSTS_SETTING), out _, out var allowlistError))
|
||||
{
|
||||
return Task.FromResult<ToolConfigurationState?>(new ToolConfigurationState
|
||||
{
|
||||
IsConfigured = false,
|
||||
Message = allowlistError,
|
||||
});
|
||||
}
|
||||
|
||||
return Task.FromResult<ToolConfigurationState?>(null);
|
||||
}
|
||||
|
||||
@ -86,11 +101,17 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation
|
||||
|
||||
var timeoutSeconds = ReadOptionalPositiveIntSetting(context.SettingsValues, "timeoutSeconds") ?? DEFAULT_TIMEOUT_SECONDS;
|
||||
var maxContentCharacters = ReadOptionalPositiveIntSetting(context.SettingsValues, "maxContentCharacters") ?? DEFAULT_MAX_CONTENT_CHARACTERS;
|
||||
if (!TryReadAllowedPrivateHostPatterns(context.SettingsValues.GetValueOrDefault(ALLOWED_PRIVATE_HOSTS_SETTING), out var allowedPrivateHosts, out var allowlistError))
|
||||
throw new InvalidOperationException(allowlistError);
|
||||
|
||||
HTMLParserWebPage page;
|
||||
try
|
||||
{
|
||||
page = await htmlParser.LoadWebPageAsync(url, token, timeoutSeconds);
|
||||
page = await htmlParser.LoadWebPageAsync(
|
||||
url,
|
||||
token,
|
||||
timeoutSeconds,
|
||||
async (candidateUrl, validationToken) => await this.ValidateUrlAccessAsync(candidateUrl, allowedPrivateHosts, context.ProviderConfidence, validationToken));
|
||||
}
|
||||
catch (OperationCanceledException) when (!token.IsCancellationRequested)
|
||||
{
|
||||
@ -162,6 +183,178 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation
|
||||
return $"{rawResult[..MAX_TRACE_LENGTH]}...";
|
||||
}
|
||||
|
||||
private async Task ValidateUrlAccessAsync(
|
||||
Uri url,
|
||||
IReadOnlyList<AllowedPrivateHostPattern> allowedPrivateHosts,
|
||||
ConfidenceLevel providerConfidence,
|
||||
CancellationToken token)
|
||||
{
|
||||
if (url is not { Scheme: "http" or "https" })
|
||||
throw new ToolExecutionBlockedException("Only HTTP and HTTPS URLs are supported.");
|
||||
|
||||
if (IsBlockedHostName(url.Host))
|
||||
throw new ToolExecutionBlockedException("Local web page URLs are not supported.");
|
||||
|
||||
var addresses = await ResolveHostAddressesAsync(url, token);
|
||||
if (addresses.Count == 0)
|
||||
throw new InvalidOperationException($"The host '{url.Host}' did not resolve to an IP address.");
|
||||
|
||||
if (addresses.Any(IsNeverAllowedAddress))
|
||||
throw new ToolExecutionBlockedException("Local, link-local, multicast, and unspecified network addresses are not supported.");
|
||||
|
||||
if (!addresses.Any(IsNonPublicAddress))
|
||||
return;
|
||||
|
||||
if (!IsAllowedPrivateHost(url.Host, allowedPrivateHosts))
|
||||
throw new ToolExecutionBlockedException("Private or local-network web page URLs are not supported unless their host is explicitly allowed.");
|
||||
|
||||
if (providerConfidence >= ConfidenceLevel.HIGH)
|
||||
return;
|
||||
|
||||
await this.ReportPrivateHostProviderBlockAsync(url, providerConfidence);
|
||||
throw new ToolExecutionBlockedException("This private or VPN web page requires a High-confidence provider.");
|
||||
}
|
||||
|
||||
private async Task ReportPrivateHostProviderBlockAsync(Uri url, ConfidenceLevel providerConfidence)
|
||||
{
|
||||
logger.LogWarning(
|
||||
"Blocked read_web_page access to allowed private host '{Host}' because provider confidence '{ProviderConfidence}' is below HIGH.",
|
||||
url.Host,
|
||||
providerConfidence);
|
||||
|
||||
await MessageBus.INSTANCE.SendError(new DataErrorMessage(
|
||||
Icons.Material.Filled.Security,
|
||||
TB("The web page was not loaded because private or VPN web pages require a High-confidence provider.")));
|
||||
}
|
||||
|
||||
private static async Task<IReadOnlyList<IPAddress>> ResolveHostAddressesAsync(Uri url, CancellationToken token)
|
||||
{
|
||||
if (IPAddress.TryParse(url.Host, out var parsedAddress))
|
||||
return [NormalizeAddress(parsedAddress)];
|
||||
|
||||
try
|
||||
{
|
||||
return (await Dns.GetHostAddressesAsync(url.DnsSafeHost, token))
|
||||
.Select(NormalizeAddress)
|
||||
.ToList();
|
||||
}
|
||||
catch (SocketException exception)
|
||||
{
|
||||
throw new InvalidOperationException($"The host '{url.Host}' could not be resolved: {exception.Message}", exception);
|
||||
}
|
||||
}
|
||||
|
||||
private static IPAddress NormalizeAddress(IPAddress address) => address.IsIPv4MappedToIPv6 ? address.MapToIPv4() : address;
|
||||
|
||||
private static bool IsBlockedHostName(string host)
|
||||
{
|
||||
var normalizedHost = NormalizeHost(host);
|
||||
return normalizedHost is "localhost" ||
|
||||
normalizedHost.EndsWith(".localhost", StringComparison.Ordinal);
|
||||
}
|
||||
|
||||
private static bool IsAllowedPrivateHost(string host, IReadOnlyList<AllowedPrivateHostPattern> allowedPrivateHosts)
|
||||
{
|
||||
var normalizedHost = NormalizeHost(host);
|
||||
return allowedPrivateHosts.Any(pattern => pattern.IsMatch(normalizedHost));
|
||||
}
|
||||
|
||||
private static string NormalizeHost(string host) => host.Trim().TrimEnd('.').ToLowerInvariant();
|
||||
|
||||
private static bool IsNeverAllowedAddress(IPAddress address)
|
||||
{
|
||||
address = NormalizeAddress(address);
|
||||
if (IPAddress.IsLoopback(address))
|
||||
return true;
|
||||
|
||||
if (address.AddressFamily is AddressFamily.InterNetwork)
|
||||
{
|
||||
var bytes = address.GetAddressBytes();
|
||||
return address.Equals(IPAddress.Any) ||
|
||||
bytes[0] is 0 or 127 or >= 224 ||
|
||||
(bytes[0] == 169 && bytes[1] == 254);
|
||||
}
|
||||
|
||||
if (address.AddressFamily is AddressFamily.InterNetworkV6)
|
||||
{
|
||||
return address.Equals(IPAddress.IPv6Any) ||
|
||||
address.Equals(IPAddress.IPv6None) ||
|
||||
address.Equals(IPAddress.IPv6Loopback) ||
|
||||
address.IsIPv6LinkLocal ||
|
||||
address.IsIPv6Multicast;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private static bool IsNonPublicAddress(IPAddress address)
|
||||
{
|
||||
address = NormalizeAddress(address);
|
||||
if (IsNeverAllowedAddress(address))
|
||||
return true;
|
||||
|
||||
if (address.AddressFamily is AddressFamily.InterNetwork)
|
||||
{
|
||||
var bytes = address.GetAddressBytes();
|
||||
return bytes[0] == 10 || // Private network: 10.0.0.0/8
|
||||
(bytes[0] == 100 && bytes[1] is >= 64 and <= 127) || // Carrier-grade NAT: 100.64.0.0/10
|
||||
(bytes[0] == 172 && bytes[1] is >= 16 and <= 31) || // Private network: 172.16.0.0/12
|
||||
(bytes[0] == 192 && bytes[1] == 168) || // Private network: 192.168.0.0/16
|
||||
(bytes[0] == 192 && bytes[1] == 0 && bytes[2] == 0) || // IETF protocol assignments: 192.0.0.0/24
|
||||
(bytes[0] == 192 && bytes[1] == 0 && bytes[2] == 2) || // Documentation range: 192.0.2.0/24
|
||||
(bytes[0] == 198 && bytes[1] is 18 or 19) || // Benchmark testing range: 198.18.0.0/15
|
||||
(bytes[0] == 198 && bytes[1] == 51 && bytes[2] == 100) || // Documentation range: 198.51.100.0/24
|
||||
(bytes[0] == 203 && bytes[1] == 0 && bytes[2] == 113); // Documentation range: 203.0.113.0/24
|
||||
}
|
||||
|
||||
if (address.AddressFamily is AddressFamily.InterNetworkV6)
|
||||
{
|
||||
var bytes = address.GetAddressBytes();
|
||||
return (bytes[0] & 0xfe) == 0xfc || // Unique local addresses: fc00::/7
|
||||
address.IsIPv6SiteLocal; // Deprecated site-local addresses: fec0::/10
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private static bool TryReadAllowedPrivateHostPatterns(
|
||||
string? rawValue,
|
||||
out List<AllowedPrivateHostPattern> patterns,
|
||||
out string error)
|
||||
{
|
||||
patterns = [];
|
||||
error = string.Empty;
|
||||
|
||||
foreach (var rawPattern in SplitAllowedPrivateHostPatterns(rawValue))
|
||||
{
|
||||
var pattern = NormalizeHost(rawPattern);
|
||||
if (pattern.Contains("://", StringComparison.Ordinal) || pattern.Contains('/'))
|
||||
{
|
||||
error = TB("Allowed private hosts must be host names only, without scheme or path.");
|
||||
return false;
|
||||
}
|
||||
|
||||
var isWildcard = pattern.StartsWith("*.", StringComparison.Ordinal);
|
||||
var host = isWildcard ? pattern[2..] : pattern;
|
||||
if (string.IsNullOrWhiteSpace(host) || Uri.CheckHostName(host) is UriHostNameType.Unknown)
|
||||
{
|
||||
error = string.Format(TB("Allowed private host '{0}' is not valid."), rawPattern);
|
||||
return false;
|
||||
}
|
||||
|
||||
patterns.Add(new AllowedPrivateHostPattern(host, isWildcard));
|
||||
}
|
||||
|
||||
patterns = patterns
|
||||
.Distinct()
|
||||
.ToList();
|
||||
return true;
|
||||
}
|
||||
|
||||
private static IEnumerable<string> SplitAllowedPrivateHostPatterns(string? rawValue) => rawValue?
|
||||
.Split(['\r', '\n', ',', ';'], StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)
|
||||
.Where(x => !string.IsNullOrWhiteSpace(x)) ?? [];
|
||||
|
||||
private static void RemoveNoiseNodes(HtmlNode rootNode)
|
||||
{
|
||||
foreach (var xpath in REMOVED_NODE_XPATHS)
|
||||
@ -221,4 +414,12 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser) : IToolImplementation
|
||||
error = I18N.I.T($"The setting '{key}' must be a positive integer.", typeof(ReadWebPageTool).Namespace, nameof(ReadWebPageTool));
|
||||
return false;
|
||||
}
|
||||
|
||||
private readonly record struct AllowedPrivateHostPattern(string Host, bool IsWildcard)
|
||||
{
|
||||
public bool IsMatch(string normalizedHost) =>
|
||||
this.IsWildcard
|
||||
? normalizedHost.EndsWith($".{this.Host}", StringComparison.Ordinal) && normalizedHost.Length > this.Host.Length + 1
|
||||
: normalizedHost.Equals(this.Host, StringComparison.Ordinal);
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user