Resolve target host addresses before connecting, then bind HTTP connection to those validated IPs.

Prevents request from re-resolving the host after validation
This commit is contained in:
Nils Kruthoff 2026-05-18 17:00:05 +02:00
parent e50c67182c
commit cf6256c215
No known key found for this signature in database
GPG Key ID: A5C0151B4DDB172C
2 changed files with 87 additions and 11 deletions

View File

@ -1,9 +1,7 @@
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Net.Sockets;
using HtmlAgilityPack;
using ReverseMarkdown;
namespace AIStudio.Tools;
@ -44,13 +42,20 @@ public sealed class HTMLParser
return innerHtml;
}
public async Task<HTMLParserWebPage> LoadWebPageAsync(Uri url, CancellationToken token = default, int timeoutSeconds = 30, Func<Uri, CancellationToken, Task>? validateUrlAsync = null)
public async Task<HTMLParserWebPage> LoadWebPageAsync(Uri url, CancellationToken token = default, int timeoutSeconds = 30, Func<Uri, CancellationToken, Task<IReadOnlyList<IPAddress>>>? resolveUrlAddressesAsync = null)
{
using var handler = new HttpClientHandler
using var handler = new SocketsHttpHandler
{
AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate | DecompressionMethods.Brotli,
AllowAutoRedirect = false,
};
if (resolveUrlAddressesAsync is not null)
{
// The callback binds the request to a vetted target IP; a proxy would change the endpoint being connected to.
handler.UseProxy = false;
handler.ConnectCallback = async (context, connectionToken) => await ConnectToResolvedAddressAsync(context, resolveUrlAddressesAsync, connectionToken);
}
using var httpClient = new HttpClient(handler)
{
Timeout = Timeout.InfiniteTimeSpan,
@ -61,8 +66,7 @@ public sealed class HTMLParser
var currentUrl = url;
for (var redirectCount = 0; redirectCount <= MAX_REDIRECTS; redirectCount++)
{
if (validateUrlAsync is not null)
await validateUrlAsync(currentUrl, timeoutCts.Token);
ValidateHttpOrHttpsUrl(currentUrl);
using var request = CreateRequest(currentUrl);
using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutCts.Token);
@ -101,6 +105,58 @@ public sealed class HTMLParser
throw new HttpRequestException($"The server returned more than {MAX_REDIRECTS} redirects for '{url}'.");
}
private static void ValidateHttpOrHttpsUrl(Uri url)
{
if (url.Scheme.Equals(Uri.UriSchemeHttp, StringComparison.OrdinalIgnoreCase) ||
url.Scheme.Equals(Uri.UriSchemeHttps, StringComparison.OrdinalIgnoreCase))
return;
throw new HttpRequestException($"Unsupported URL scheme '{url.Scheme}' for '{url}'.");
}
private static async ValueTask<Stream> ConnectToResolvedAddressAsync(
SocketsHttpConnectionContext context,
Func<Uri, CancellationToken, Task<IReadOnlyList<IPAddress>>> resolveUrlAddressesAsync,
CancellationToken token)
{
var requestUri = context.InitialRequestMessage.RequestUri ??
throw new HttpRequestException("The HTTP request did not contain a target URL.");
var addresses = await resolveUrlAddressesAsync(requestUri, token);
if (addresses.Count == 0)
throw new HttpRequestException($"The host '{requestUri.Host}' did not resolve to an IP address.");
List<SocketException> connectionErrors = [];
foreach (var address in addresses.Distinct())
{
var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp)
{
NoDelay = true,
};
try
{
await socket.ConnectAsync(new IPEndPoint(address, context.DnsEndPoint.Port), token);
return new NetworkStream(socket, ownsSocket: true);
}
catch (SocketException exception)
{
connectionErrors.Add(exception);
socket.Dispose();
}
catch
{
socket.Dispose();
throw;
}
}
Exception innerException = connectionErrors.Count == 1
? connectionErrors[0]
: new AggregateException(connectionErrors);
throw new HttpRequestException($"Could not connect to a validated address for '{requestUri.Host}'.", innerException);
}
private static HttpRequestMessage CreateRequest(Uri url)
{
var request = new HttpRequestMessage(HttpMethod.Get, url);

View File

@ -111,7 +111,7 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger<ReadWebPageTo
url,
token,
timeoutSeconds,
async (candidateUrl, validationToken) => await this.ValidateUrlAccessAsync(candidateUrl, allowedPrivateHosts, context.ProviderConfidence, validationToken));
async (candidateUrl, validationToken) => await this.ResolveValidatedUrlAddressesAsync(candidateUrl, allowedPrivateHosts, context.ProviderConfidence, validationToken));
}
catch (OperationCanceledException) when (!token.IsCancellationRequested)
{
@ -119,6 +119,9 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger<ReadWebPageTo
}
catch (HttpRequestException exception)
{
if (FindBlockedException(exception) is { } blockedException)
throw blockedException;
throw new InvalidOperationException($"Loading the web page failed: {exception.Message}", exception);
}
@ -183,7 +186,24 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger<ReadWebPageTo
return $"{rawResult[..MAX_TRACE_LENGTH]}...";
}
private async Task ValidateUrlAccessAsync(
private static ToolExecutionBlockedException? FindBlockedException(Exception exception)
{
if (exception is ToolExecutionBlockedException blockedException)
return blockedException;
if (exception is AggregateException aggregateException)
{
foreach (var innerException in aggregateException.InnerExceptions)
{
if (FindBlockedException(innerException) is { } innerBlockedException)
return innerBlockedException;
}
}
return exception.InnerException is null ? null : FindBlockedException(exception.InnerException);
}
private async Task<IReadOnlyList<IPAddress>> ResolveValidatedUrlAddressesAsync(
Uri url,
IReadOnlyList<AllowedPrivateHostPattern> allowedPrivateHosts,
ConfidenceLevel providerConfidence,
@ -203,13 +223,13 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger<ReadWebPageTo
throw new ToolExecutionBlockedException("Local, link-local, multicast, and unspecified network addresses are not supported.");
if (!addresses.Any(IsNonPublicAddress))
return;
return addresses;
if (!IsAllowedPrivateHost(url.Host, allowedPrivateHosts))
throw new ToolExecutionBlockedException("Private or local-network web page URLs are not supported unless their host is explicitly allowed.");
if (providerConfidence >= ConfidenceLevel.HIGH)
return;
return addresses;
await this.ReportPrivateHostProviderBlockAsync(url, providerConfidence);
throw new ToolExecutionBlockedException("This private or VPN web page requires a High-confidence provider.");