mirror of
https://github.com/MindWorkAI/AI-Studio.git
synced 2026-06-27 17:16:28 +00:00
Resolve target host addresses before connecting, then bind HTTP connection to those validated IPs.
Prevents request from re-resolving the host after validation
This commit is contained in:
parent
e50c67182c
commit
cf6256c215
@ -1,9 +1,7 @@
|
|||||||
using System.Net;
|
using System.Net;
|
||||||
using System.Net.Http;
|
|
||||||
using System.Net.Http.Headers;
|
using System.Net.Http.Headers;
|
||||||
|
using System.Net.Sockets;
|
||||||
using HtmlAgilityPack;
|
using HtmlAgilityPack;
|
||||||
|
|
||||||
using ReverseMarkdown;
|
using ReverseMarkdown;
|
||||||
|
|
||||||
namespace AIStudio.Tools;
|
namespace AIStudio.Tools;
|
||||||
@ -44,13 +42,20 @@ public sealed class HTMLParser
|
|||||||
return innerHtml;
|
return innerHtml;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<HTMLParserWebPage> LoadWebPageAsync(Uri url, CancellationToken token = default, int timeoutSeconds = 30, Func<Uri, CancellationToken, Task>? validateUrlAsync = null)
|
public async Task<HTMLParserWebPage> LoadWebPageAsync(Uri url, CancellationToken token = default, int timeoutSeconds = 30, Func<Uri, CancellationToken, Task<IReadOnlyList<IPAddress>>>? resolveUrlAddressesAsync = null)
|
||||||
{
|
{
|
||||||
using var handler = new HttpClientHandler
|
using var handler = new SocketsHttpHandler
|
||||||
{
|
{
|
||||||
AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate | DecompressionMethods.Brotli,
|
AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate | DecompressionMethods.Brotli,
|
||||||
AllowAutoRedirect = false,
|
AllowAutoRedirect = false,
|
||||||
};
|
};
|
||||||
|
if (resolveUrlAddressesAsync is not null)
|
||||||
|
{
|
||||||
|
// The callback binds the request to a vetted target IP; a proxy would change the endpoint being connected to.
|
||||||
|
handler.UseProxy = false;
|
||||||
|
handler.ConnectCallback = async (context, connectionToken) => await ConnectToResolvedAddressAsync(context, resolveUrlAddressesAsync, connectionToken);
|
||||||
|
}
|
||||||
|
|
||||||
using var httpClient = new HttpClient(handler)
|
using var httpClient = new HttpClient(handler)
|
||||||
{
|
{
|
||||||
Timeout = Timeout.InfiniteTimeSpan,
|
Timeout = Timeout.InfiniteTimeSpan,
|
||||||
@ -61,8 +66,7 @@ public sealed class HTMLParser
|
|||||||
var currentUrl = url;
|
var currentUrl = url;
|
||||||
for (var redirectCount = 0; redirectCount <= MAX_REDIRECTS; redirectCount++)
|
for (var redirectCount = 0; redirectCount <= MAX_REDIRECTS; redirectCount++)
|
||||||
{
|
{
|
||||||
if (validateUrlAsync is not null)
|
ValidateHttpOrHttpsUrl(currentUrl);
|
||||||
await validateUrlAsync(currentUrl, timeoutCts.Token);
|
|
||||||
|
|
||||||
using var request = CreateRequest(currentUrl);
|
using var request = CreateRequest(currentUrl);
|
||||||
using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutCts.Token);
|
using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutCts.Token);
|
||||||
@ -101,6 +105,58 @@ public sealed class HTMLParser
|
|||||||
throw new HttpRequestException($"The server returned more than {MAX_REDIRECTS} redirects for '{url}'.");
|
throw new HttpRequestException($"The server returned more than {MAX_REDIRECTS} redirects for '{url}'.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static void ValidateHttpOrHttpsUrl(Uri url)
|
||||||
|
{
|
||||||
|
if (url.Scheme.Equals(Uri.UriSchemeHttp, StringComparison.OrdinalIgnoreCase) ||
|
||||||
|
url.Scheme.Equals(Uri.UriSchemeHttps, StringComparison.OrdinalIgnoreCase))
|
||||||
|
return;
|
||||||
|
|
||||||
|
throw new HttpRequestException($"Unsupported URL scheme '{url.Scheme}' for '{url}'.");
|
||||||
|
}
|
||||||
|
|
||||||
|
private static async ValueTask<Stream> ConnectToResolvedAddressAsync(
|
||||||
|
SocketsHttpConnectionContext context,
|
||||||
|
Func<Uri, CancellationToken, Task<IReadOnlyList<IPAddress>>> resolveUrlAddressesAsync,
|
||||||
|
CancellationToken token)
|
||||||
|
{
|
||||||
|
var requestUri = context.InitialRequestMessage.RequestUri ??
|
||||||
|
throw new HttpRequestException("The HTTP request did not contain a target URL.");
|
||||||
|
|
||||||
|
var addresses = await resolveUrlAddressesAsync(requestUri, token);
|
||||||
|
if (addresses.Count == 0)
|
||||||
|
throw new HttpRequestException($"The host '{requestUri.Host}' did not resolve to an IP address.");
|
||||||
|
|
||||||
|
List<SocketException> connectionErrors = [];
|
||||||
|
foreach (var address in addresses.Distinct())
|
||||||
|
{
|
||||||
|
var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp)
|
||||||
|
{
|
||||||
|
NoDelay = true,
|
||||||
|
};
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
await socket.ConnectAsync(new IPEndPoint(address, context.DnsEndPoint.Port), token);
|
||||||
|
return new NetworkStream(socket, ownsSocket: true);
|
||||||
|
}
|
||||||
|
catch (SocketException exception)
|
||||||
|
{
|
||||||
|
connectionErrors.Add(exception);
|
||||||
|
socket.Dispose();
|
||||||
|
}
|
||||||
|
catch
|
||||||
|
{
|
||||||
|
socket.Dispose();
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Exception innerException = connectionErrors.Count == 1
|
||||||
|
? connectionErrors[0]
|
||||||
|
: new AggregateException(connectionErrors);
|
||||||
|
throw new HttpRequestException($"Could not connect to a validated address for '{requestUri.Host}'.", innerException);
|
||||||
|
}
|
||||||
|
|
||||||
private static HttpRequestMessage CreateRequest(Uri url)
|
private static HttpRequestMessage CreateRequest(Uri url)
|
||||||
{
|
{
|
||||||
var request = new HttpRequestMessage(HttpMethod.Get, url);
|
var request = new HttpRequestMessage(HttpMethod.Get, url);
|
||||||
|
|||||||
@ -111,7 +111,7 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger<ReadWebPageTo
|
|||||||
url,
|
url,
|
||||||
token,
|
token,
|
||||||
timeoutSeconds,
|
timeoutSeconds,
|
||||||
async (candidateUrl, validationToken) => await this.ValidateUrlAccessAsync(candidateUrl, allowedPrivateHosts, context.ProviderConfidence, validationToken));
|
async (candidateUrl, validationToken) => await this.ResolveValidatedUrlAddressesAsync(candidateUrl, allowedPrivateHosts, context.ProviderConfidence, validationToken));
|
||||||
}
|
}
|
||||||
catch (OperationCanceledException) when (!token.IsCancellationRequested)
|
catch (OperationCanceledException) when (!token.IsCancellationRequested)
|
||||||
{
|
{
|
||||||
@ -119,6 +119,9 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger<ReadWebPageTo
|
|||||||
}
|
}
|
||||||
catch (HttpRequestException exception)
|
catch (HttpRequestException exception)
|
||||||
{
|
{
|
||||||
|
if (FindBlockedException(exception) is { } blockedException)
|
||||||
|
throw blockedException;
|
||||||
|
|
||||||
throw new InvalidOperationException($"Loading the web page failed: {exception.Message}", exception);
|
throw new InvalidOperationException($"Loading the web page failed: {exception.Message}", exception);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -183,7 +186,24 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger<ReadWebPageTo
|
|||||||
return $"{rawResult[..MAX_TRACE_LENGTH]}...";
|
return $"{rawResult[..MAX_TRACE_LENGTH]}...";
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task ValidateUrlAccessAsync(
|
private static ToolExecutionBlockedException? FindBlockedException(Exception exception)
|
||||||
|
{
|
||||||
|
if (exception is ToolExecutionBlockedException blockedException)
|
||||||
|
return blockedException;
|
||||||
|
|
||||||
|
if (exception is AggregateException aggregateException)
|
||||||
|
{
|
||||||
|
foreach (var innerException in aggregateException.InnerExceptions)
|
||||||
|
{
|
||||||
|
if (FindBlockedException(innerException) is { } innerBlockedException)
|
||||||
|
return innerBlockedException;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return exception.InnerException is null ? null : FindBlockedException(exception.InnerException);
|
||||||
|
}
|
||||||
|
|
||||||
|
private async Task<IReadOnlyList<IPAddress>> ResolveValidatedUrlAddressesAsync(
|
||||||
Uri url,
|
Uri url,
|
||||||
IReadOnlyList<AllowedPrivateHostPattern> allowedPrivateHosts,
|
IReadOnlyList<AllowedPrivateHostPattern> allowedPrivateHosts,
|
||||||
ConfidenceLevel providerConfidence,
|
ConfidenceLevel providerConfidence,
|
||||||
@ -203,13 +223,13 @@ public sealed class ReadWebPageTool(HTMLParser htmlParser, ILogger<ReadWebPageTo
|
|||||||
throw new ToolExecutionBlockedException("Local, link-local, multicast, and unspecified network addresses are not supported.");
|
throw new ToolExecutionBlockedException("Local, link-local, multicast, and unspecified network addresses are not supported.");
|
||||||
|
|
||||||
if (!addresses.Any(IsNonPublicAddress))
|
if (!addresses.Any(IsNonPublicAddress))
|
||||||
return;
|
return addresses;
|
||||||
|
|
||||||
if (!IsAllowedPrivateHost(url.Host, allowedPrivateHosts))
|
if (!IsAllowedPrivateHost(url.Host, allowedPrivateHosts))
|
||||||
throw new ToolExecutionBlockedException("Private or local-network web page URLs are not supported unless their host is explicitly allowed.");
|
throw new ToolExecutionBlockedException("Private or local-network web page URLs are not supported unless their host is explicitly allowed.");
|
||||||
|
|
||||||
if (providerConfidence >= ConfidenceLevel.HIGH)
|
if (providerConfidence >= ConfidenceLevel.HIGH)
|
||||||
return;
|
return addresses;
|
||||||
|
|
||||||
await this.ReportPrivateHostProviderBlockAsync(url, providerConfidence);
|
await this.ReportPrivateHostProviderBlockAsync(url, providerConfidence);
|
||||||
throw new ToolExecutionBlockedException("This private or VPN web page requires a High-confidence provider.");
|
throw new ToolExecutionBlockedException("This private or VPN web page requires a High-confidence provider.");
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user