mirror of
https://github.com/MindWorkAI/AI-Studio.git
synced 2026-06-27 14:36:27 +00:00
Resolve target host addresses before connecting, then bind HTTP connection to those validated IPs.
Prevents request from re-resolving the host after validation
This commit is contained in:
parent
e50c67182c
commit
cf6256c215
@ -1,9 +1,7 @@
|
||||
using System.Net;
|
||||
using System.Net.Http;
|
||||
using System.Net.Http.Headers;
|
||||
|
||||
using System.Net.Sockets;
|
||||
using HtmlAgilityPack;
|
||||
|
||||
using ReverseMarkdown;
|
||||
|
||||
namespace AIStudio.Tools;
|
||||
@ -44,13 +42,20 @@ public sealed class HTMLParser
|
||||
return innerHtml;
|
||||
}
|
||||
|
||||
public async Task<HTMLParserWebPage> LoadWebPageAsync(Uri url, CancellationToken token = default, int timeoutSeconds = 30, Func<Uri, CancellationToken, Task>? validateUrlAsync = null)
|
||||
public async Task<HTMLParserWebPage> LoadWebPageAsync(Uri url, CancellationToken token = default, int timeoutSeconds = 30, Func<Uri, CancellationToken, Task<IReadOnlyList<IPAddress>>>? resolveUrlAddressesAsync = null)
|
||||
{
|
||||
using var handler = new HttpClientHandler
|
||||
using var handler = new SocketsHttpHandler
|
||||
{
|
||||
AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate | DecompressionMethods.Brotli,
|
||||
AllowAutoRedirect = false,
|
||||
};
|
||||
if (resolveUrlAddressesAsync is not null)
|
||||
{
|
||||
// The callback binds the request to a vetted target IP; a proxy would change the endpoint being connected to.
|
||||
handler.UseProxy = false;
|
||||
handler.ConnectCallback = async (context, connectionToken) => await ConnectToResolvedAddressAsync(context, resolveUrlAddressesAsync, connectionToken);
|
||||
}
|
||||
|
||||
using var httpClient = new HttpClient(handler)
|
||||
{
|
||||
Timeout = Timeout.InfiniteTimeSpan,
|
||||
@ -61,8 +66,7 @@ public sealed class HTMLParser
|
||||
var currentUrl = url;
|
||||
for (var redirectCount = 0; redirectCount <= MAX_REDIRECTS; redirectCount++)
|
||||
{
|
||||
if (validateUrlAsync is not null)
|
||||
await validateUrlAsync(currentUrl, timeoutCts.Token);
|
||||
ValidateHttpOrHttpsUrl(currentUrl);
|
||||
|
||||
using var request = CreateRequest(currentUrl);
|
||||
using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutCts.Token);
|
||||
@ -101,6 +105,58 @@ public sealed class HTMLParser
|
||||
throw new HttpRequestException($"The server returned more than {MAX_REDIRECTS} redirects for '{url}'.");
|
||||
}
|
||||
|
||||
private static void ValidateHttpOrHttpsUrl(Uri url)
|
||||
{
|
||||
if (url.Scheme.Equals(Uri.UriSchemeHttp, StringComparison.OrdinalIgnoreCase) ||
|
||||
url.Scheme.Equals(Uri.UriSchemeHttps, StringComparison.OrdinalIgnoreCase))
|
||||
return;
|
||||
|
||||
throw new HttpRequestException($"Unsupported URL scheme '{url.Scheme}' for '{url}'.");
|
||||
}
|
||||
|
||||
private static async ValueTask<Stream> ConnectToResolvedAddressAsync(
|
||||
SocketsHttpConnectionContext context,
|
||||
Func<Uri, CancellationToken, Task<IReadOnlyList<IPAddress>>> resolveUrlAddressesAsync,
|
||||
CancellationToken token)
|
||||
{
|
||||
var requestUri = context.InitialRequestMessage.RequestUri ??
|
||||
throw new HttpRequestException("The HTTP request did not contain a target URL.");
|
||||
|
||||
var addresses = await resolveUrlAddressesAsync(requestUri, token);
|
||||
if (addresses.Count == 0)
|
||||
throw new HttpRequestException($"The host '{requestUri.Host}' did not resolve to an IP address.");
|
||||
|
||||
List<SocketException> connectionErrors = [];
|
||||
foreach (var address in addresses.Distinct())
|
||||
{
|
||||
var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp)
|
||||
{
|
||||
NoDelay = true,
|
||||
};
|
||||
|
||||
try
|
||||
{
|
||||
await socket.ConnectAsync(new IPEndPoint(address, context.DnsEndPoint.Port), token);
|
||||
return new NetworkStream(socket, ownsSocket: true);
|
||||
}
|
||||
catch (SocketException exception)
|
||||
{
|
||||
connectionErrors.Add(exception);
|
||||
socket.Dispose();
|
||||
}
|
||||
catch
|
||||
{
|
||||
socket.Dispose();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
Exception innerException = connectionErrors.Count == 1
|
||||
? connectionErrors[0]
|
||||
: new AggregateException(connectionErrors);
|
||||
throw new HttpRequestException($"Could not connect to a validated address for '{requestUri.Host}'.", innerException);
|
||||
}
|
||||
|
||||
private static HttpRequestMessage CreateRequest(Uri url)
|
||||
{
|
||||
var request = new HttpRequestMessage(HttpMethod.Get, url);
|
||||
|
||||
@ -111,7 +111,7 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger<ReadWebPageTo
|
||||
url,
|
||||
token,
|
||||
timeoutSeconds,
|
||||
async (candidateUrl, validationToken) => await this.ValidateUrlAccessAsync(candidateUrl, allowedPrivateHosts, context.ProviderConfidence, validationToken));
|
||||
async (candidateUrl, validationToken) => await this.ResolveValidatedUrlAddressesAsync(candidateUrl, allowedPrivateHosts, context.ProviderConfidence, validationToken));
|
||||
}
|
||||
catch (OperationCanceledException) when (!token.IsCancellationRequested)
|
||||
{
|
||||
@ -119,6 +119,9 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger<ReadWebPageTo
|
||||
}
|
||||
catch (HttpRequestException exception)
|
||||
{
|
||||
if (FindBlockedException(exception) is { } blockedException)
|
||||
throw blockedException;
|
||||
|
||||
throw new InvalidOperationException($"Loading the web page failed: {exception.Message}", exception);
|
||||
}
|
||||
|
||||
@ -183,7 +186,24 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger<ReadWebPageTo
|
||||
return $"{rawResult[..MAX_TRACE_LENGTH]}...";
|
||||
}
|
||||
|
||||
private async Task ValidateUrlAccessAsync(
|
||||
private static ToolExecutionBlockedException? FindBlockedException(Exception exception)
|
||||
{
|
||||
if (exception is ToolExecutionBlockedException blockedException)
|
||||
return blockedException;
|
||||
|
||||
if (exception is AggregateException aggregateException)
|
||||
{
|
||||
foreach (var innerException in aggregateException.InnerExceptions)
|
||||
{
|
||||
if (FindBlockedException(innerException) is { } innerBlockedException)
|
||||
return innerBlockedException;
|
||||
}
|
||||
}
|
||||
|
||||
return exception.InnerException is null ? null : FindBlockedException(exception.InnerException);
|
||||
}
|
||||
|
||||
private async Task<IReadOnlyList<IPAddress>> ResolveValidatedUrlAddressesAsync(
|
||||
Uri url,
|
||||
IReadOnlyList<AllowedPrivateHostPattern> allowedPrivateHosts,
|
||||
ConfidenceLevel providerConfidence,
|
||||
@ -203,13 +223,13 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger<ReadWebPageTo
|
||||
throw new ToolExecutionBlockedException("Local, link-local, multicast, and unspecified network addresses are not supported.");
|
||||
|
||||
if (!addresses.Any(IsNonPublicAddress))
|
||||
return;
|
||||
return addresses;
|
||||
|
||||
if (!IsAllowedPrivateHost(url.Host, allowedPrivateHosts))
|
||||
throw new ToolExecutionBlockedException("Private or local-network web page URLs are not supported unless their host is explicitly allowed.");
|
||||
|
||||
if (providerConfidence >= ConfidenceLevel.HIGH)
|
||||
return;
|
||||
return addresses;
|
||||
|
||||
await this.ReportPrivateHostProviderBlockAsync(url, providerConfidence);
|
||||
throw new ToolExecutionBlockedException("This private or VPN web page requires a High-confidence provider.");
|
||||
|
||||
Loading…
Reference in New Issue
Block a user