diff --git a/composer.json b/composer.json index 42db14c2..62dea24c 100644 --- a/composer.json +++ b/composer.json @@ -36,7 +36,7 @@ "amphp/hpack": "^3", "amphp/http": "^2", "amphp/pipeline": "^1", - "amphp/socket": "^2", + "amphp/socket": "^2.1", "amphp/sync": "^2", "league/uri": "^6", "league/uri-interfaces": "^2.3", diff --git a/examples/hello-world.php b/examples/hello-world.php index acebf670..8074439e 100644 --- a/examples/hello-world.php +++ b/examples/hello-world.php @@ -6,6 +6,7 @@ use Amp\ByteStream; use Amp\Http\HttpStatus; use Amp\Http\Server\DefaultErrorHandler; +use Amp\Http\Server\Driver\SocketClientFactory; use Amp\Http\Server\Request; use Amp\Http\Server\RequestHandler; use Amp\Http\Server\Response; @@ -32,7 +33,11 @@ $logger->pushHandler($logHandler); $logger->useLoggingLoopDetection(false); -$server = SocketHttpServer::createForBehindProxy($logger); +$server = new SocketHttpServer( + $logger, + new Socket\ResourceServerSocketFactory(), + new SocketClientFactory($logger), +); $server->expose("0.0.0.0:1337"); $server->expose("[::]:1337"); diff --git a/examples/state.php b/examples/state.php index 66304db9..bdd55518 100644 --- a/examples/state.php +++ b/examples/state.php @@ -24,7 +24,7 @@ $logger = new Logger('server'); $logger->pushHandler($logHandler); -$server = SocketHttpServer::createForBehindProxy($logger); +$server = SocketHttpServer::createForDirectAccess($logger); $server->expose("0.0.0.0:1337"); $server->expose("[::]:1337"); diff --git a/src/Driver/ConnectionLimitingClientFactory.php b/src/Driver/ConnectionLimitingClientFactory.php index a7b46cb9..0c680476 100644 --- a/src/Driver/ConnectionLimitingClientFactory.php +++ b/src/Driver/ConnectionLimitingClientFactory.php @@ -15,7 +15,7 @@ final class ConnectionLimitingClientFactory implements ClientFactory public function __construct( private readonly ClientFactory $clientFactory, private readonly PsrLogger $logger, - private readonly int $connectionsPerIpLimit = 10, + private readonly int $connectionsPerIpLimit, ) { } diff --git a/src/Driver/Http2Driver.php b/src/Driver/Http2Driver.php index 07bc6aa8..3b94d995 100644 --- a/src/Driver/Http2Driver.php +++ b/src/Driver/Http2Driver.php @@ -73,6 +73,7 @@ final class Http2Driver extends AbstractHttpDriver implements Http2Processor private int $initialWindowSize = self::DEFAULT_WINDOW_SIZE; + /** @var positive-int */ private int $maxFrameSize = self::DEFAULT_MAX_FRAME_SIZE; private bool $allowsPush; @@ -569,7 +570,6 @@ private function writeBufferedData(int $id): void if ($length > $this->maxFrameSize) { $split = \str_split($stream->buffer, $this->maxFrameSize); - \assert(\is_array($split)); // For Psalm $stream->buffer = \array_pop($split); foreach ($split as $part) { $this->writeFrame($part, Http2Parser::DATA, Http2Parser::NO_FLAG, $id); @@ -626,7 +626,6 @@ private function writeHeaders(string $headers, int $type, int $flags, int $id): // Header frames must be sent as one contiguous block without frames from any other stream being // interleaved between due to HPack. See https://datatracker.ietf.org/doc/html/rfc7540#section-4.3 $split = \str_split($headers, $this->maxFrameSize); - \assert(\is_array($split)); // For Psalm $headers = \array_pop($split); $writeFrame = $this->writeFrame(...); diff --git a/src/Middleware/Forwarded.php b/src/Middleware/Forwarded.php new file mode 100644 index 00000000..804172bb --- /dev/null +++ b/src/Middleware/Forwarded.php @@ -0,0 +1,27 @@ + $fields + */ + public function __construct( + private readonly InternetAddress $for, + private readonly array $fields, + ) { + } + + public function getFor(): InternetAddress + { + return $this->for; + } + + public function getField(string $name): ?string + { + return $this->fields[$name] ?? null; + } +} diff --git a/src/Middleware/ForwardedHeaderType.php b/src/Middleware/ForwardedHeaderType.php new file mode 100644 index 00000000..2fde874e --- /dev/null +++ b/src/Middleware/ForwardedHeaderType.php @@ -0,0 +1,20 @@ + 'forwarded', + self::XForwardedFor => 'x-forwarded-for', + }; + } +} diff --git a/src/Middleware/ForwardedMiddleware.php b/src/Middleware/ForwardedMiddleware.php new file mode 100644 index 00000000..276b9b00 --- /dev/null +++ b/src/Middleware/ForwardedMiddleware.php @@ -0,0 +1,137 @@ + */ + private readonly array $trustedProxies; + + /** @var LocalCache */ + private readonly LocalCache $trustedIps; + + /** + * @param array $trustedProxies Array of IPv4 or IPv6 addresses with an optional subnet mask. + * e.g., '172.18.0.0/24' + * @param positive-int $cacheSize + */ + public function __construct( + private readonly ForwardedHeaderType $headerType, + array $trustedProxies, + int $cacheSize = 1000, + ) { + $this->trustedProxies = \array_map( + static fn (string $ip) => new CidrMatcher($ip), + \array_values($trustedProxies), + ); + + $this->trustedIps = new LocalCache($cacheSize); + } + + public function handleRequest(Request $request, RequestHandler $requestHandler): Response + { + $clientAddress = $request->getClient()->getRemoteAddress(); + + if ($clientAddress instanceof InternetAddress && $this->isTrustedProxy($clientAddress)) { + $request->setAttribute(Forwarded::class, match ($this->headerType) { + ForwardedHeaderType::Forwarded => $this->getForwarded($request), + ForwardedHeaderType::XForwardedFor => $this->getForwardedFor($request), + }); + } + + return $requestHandler->handleRequest($request); + } + + private function isTrustedProxy(InternetAddress $address): bool + { + $ip = $address->getAddress(); + $trusted = $this->trustedIps->get($ip); + if ($trusted !== null) { + return $trusted; + } + + $trusted = false; + foreach ($this->trustedProxies as $matcher) { + if ($matcher->match($ip)) { + $trusted = true; + break; + } + } + + $this->trustedIps->set($ip, $trusted); + + return $trusted; + } + + private function getForwarded(Request $request): ?Forwarded + { + $headers = Http\parseMultipleHeaderFields($request, 'forwarded'); + if (!$headers) { + return null; + } + + foreach (\array_reverse($headers) as $header) { + $for = $header['for'] ?? null; + if ($for === null) { + continue; + } + + $address = InternetAddress::tryFromString($this->addPortIfMissing($for)); + if (!$address || $this->isTrustedProxy($address)) { + continue; + } + + return new Forwarded($address, $header); + } + + return null; + } + + private function addPortIfMissing(string $address): string + { + if (!\str_contains($address, ':') || \str_ends_with($address, ']')) { + $address .= ':0'; + } + + return $address; + } + + private function getForwardedFor(Request $request): ?Forwarded + { + $forwardedFor = Http\splitHeader($request, 'x-forwarded-for'); + if (!$forwardedFor) { + return null; + } + + $forwardedFor = \array_map(static function (string $ip): string { + if (\str_contains($ip, ':')) { + return '[' . \trim($ip, '[]') . ']:0'; + } + + return $ip . ':0'; + }, $forwardedFor); + + /** @var InternetAddress[] $forwardedFor */ + $forwardedFor = \array_filter(\array_map(InternetAddress::tryFromString(...), $forwardedFor)); + foreach (\array_reverse($forwardedFor) as $for) { + if (!$this->isTrustedProxy($for)) { + return new Forwarded($for, [ + 'for' => $for->getAddress(), + 'host' => $request->getHeader('x-forwarded-host'), + 'proto' => $request->getHeader('x-forwarded-proto'), + ]); + } + } + + return null; + } +} diff --git a/src/SocketHttpServer.php b/src/SocketHttpServer.php index e22fcb4b..f19680aa 100644 --- a/src/SocketHttpServer.php +++ b/src/SocketHttpServer.php @@ -13,6 +13,8 @@ use Amp\Http\Server\Driver\SocketClientFactory; use Amp\Http\Server\Middleware\AllowedMethodsMiddleware; use Amp\Http\Server\Middleware\CompressionMiddleware; +use Amp\Http\Server\Middleware\ForwardedHeaderType; +use Amp\Http\Server\Middleware\ForwardedMiddleware; use Amp\Socket\BindContext; use Amp\Socket\ResourceServerSocketFactory; use Amp\Socket\ServerSocket; @@ -52,14 +54,13 @@ final class SocketHttpServer implements HttpServer /** * Creates an instance appropriate for direct access by the public. * - * @param CompressionMiddleware|null $compressionMiddleware Use null to disable compression. * @param positive-int $connectionLimit Default is {@see self::DEFAULT_CONNECTION_LIMIT}. * @param positive-int $connectionLimitPerIp Default is {@see self::DEFAULT_CONNECTIONS_PER_IP_LIMIT}. * @param list|null $allowedMethods Use null to disable request method filtering. */ public static function createForDirectAccess( PsrLogger $logger, - ?CompressionMiddleware $compressionMiddleware = new CompressionMiddleware(), + bool $enableCompression = true, int $connectionLimit = self::DEFAULT_CONNECTION_LIMIT, int $connectionLimitPerIp = self::DEFAULT_CONNECTIONS_PER_IP_LIMIT, ?array $allowedMethods = AllowedMethodsMiddleware::DEFAULT_ALLOWED_METHODS, @@ -80,11 +81,16 @@ public static function createForDirectAccess( $connectionLimitPerIp, )); + $middleware = []; + if ($enableCompression && $compressionMiddleware = self::createCompressionMiddleware($logger)) { + $middleware[] = $compressionMiddleware; + } + return new self( $logger, $serverSocketFactory, $clientFactory, - $compressionMiddleware, + $middleware, $allowedMethods, $httpDriverFactory, ); @@ -99,26 +105,54 @@ public static function createForDirectAccess( */ public static function createForBehindProxy( PsrLogger $logger, - ?array $allowedMethods = null, + ForwardedHeaderType $headerType, + array $trustedProxies, + bool $enableCompression = true, + ?array $allowedMethods = AllowedMethodsMiddleware::DEFAULT_ALLOWED_METHODS, ?HttpDriverFactory $httpDriverFactory = null, ): self { + $middleware = []; + + $middleware[] = new ForwardedMiddleware($headerType, $trustedProxies); + + if ($enableCompression && $compressionMiddleware = self::createCompressionMiddleware($logger)) { + $middleware[] = $compressionMiddleware; + } + return new self( $logger, new ResourceServerSocketFactory(), new SocketClientFactory($logger), - allowedMethods: $allowedMethods, - httpDriverFactory: $httpDriverFactory, + $middleware, + $allowedMethods, + $httpDriverFactory, ); } + private static function createCompressionMiddleware(PsrLogger $logger): ?CompressionMiddleware + { + if (!\extension_loaded('zlib')) { + $logger->warning( + 'The zlib extension is not loaded which prevents using compression. ' . + 'Either activate the zlib extension or set $enableCompression to false' + ); + + return null; + } + + return new CompressionMiddleware(); + } + /** + * @param array $middleware Default middlewares. You may also use {@see Middleware\stack()} before + * passing the {@see RequestHandler} to {@see self::start()}. * @param list|null $allowedMethods Use null to disable request method filtering. */ public function __construct( private readonly PsrLogger $logger, private readonly ServerSocketFactory $serverSocketFactory, private readonly ClientFactory $clientFactory, - private readonly ?CompressionMiddleware $compressionMiddleware = null, + private readonly array $middleware = [], private readonly ?array $allowedMethods = AllowedMethodsMiddleware::DEFAULT_ALLOWED_METHODS, ?HttpDriverFactory $httpDriverFactory = null, ) { @@ -195,17 +229,7 @@ public function start(RequestHandler $requestHandler, ErrorHandler $errorHandler $this->logger->warning("The 'xdebug' extension is loaded, which has a major impact on performance."); } - if ($this->compressionMiddleware) { - if (!\extension_loaded('zlib')) { - $this->logger->warning( - "The zlib extension is not loaded which prevents using compression. " . - "Either activate the zlib extension or disable compression in the server's options." - ); - } else { - $this->logger->notice('Response compression enabled.'); - $requestHandler = Middleware\stack($requestHandler, $this->compressionMiddleware); - } - } + $requestHandler = Middleware\stack($requestHandler, ...$this->middleware); if ($this->allowedMethods !== null) { $this->logger->notice(\sprintf( @@ -243,6 +267,7 @@ public function start(RequestHandler $requestHandler, ErrorHandler $errorHandler $this->httpDriverFactory->getApplicationLayerProtocols(), ); + /** @psalm-suppress PropertyTypeCoercion */ $this->servers[] = $this->serverSocketFactory->listen( $address, $bindContext?->withTlsContext($tlsContext), @@ -258,7 +283,8 @@ public function start(RequestHandler $requestHandler, ErrorHandler $errorHandler $this->logger->info("Listening on {$scheme}://{$serverName}/"); - EventLoop::queue($this->accept(...), $server, $requestHandler, $errorHandler); + // Using short-closure to avoid Psalm bug when using a first-class callable here. + EventLoop::queue(fn () => $this->accept($server, $requestHandler, $errorHandler)); } } catch (\Throwable $exception) { try { diff --git a/test/IntegrationTest.php b/test/IntegrationTest.php index b498d859..a4c163fd 100644 --- a/test/IntegrationTest.php +++ b/test/IntegrationTest.php @@ -117,7 +117,8 @@ public function testStreamRequest(): void }); $response = $this->httpClient->request(new ClientRequest( - $this->getAuthority() . "/foo", 'POST', + $this->getAuthority() . "/foo", + 'POST', StreamedContent::fromStream(new ReadableIterableStream($queue->pipe())), )); diff --git a/test/Middleware/ForwardedMiddlewareTest.php b/test/Middleware/ForwardedMiddlewareTest.php new file mode 100644 index 00000000..2b930317 --- /dev/null +++ b/test/Middleware/ForwardedMiddlewareTest.php @@ -0,0 +1,144 @@ +createMock(Client::class); + $client->method('getRemoteAddress') + ->willReturn(new InternetAddress($address, 12345)); + + return $client; + } + + /** + * @param \Closure(Forwarded|null):void $verifier + */ + private function verifyUsing(\Closure $verifier): RequestHandler + { + return new class($verifier) implements RequestHandler { + public function __construct(private readonly \Closure $verifier) + { + } + + public function handleRequest(Request $request): Response + { + ($this->verifier)($request->getAttribute(Forwarded::class)); + return new Response(); + } + }; + } + + public function provideForwardedHeaders(): iterable + { + yield [ + ForwardedHeaderType::Forwarded, + 'For="[2001:db8:cafe::17]:4711"', + new InternetAddress('2001:db8:cafe::17', 4711), + [ + 'for' => '[2001:db8:cafe::17]:4711', + ], + ]; + + yield [ + ForwardedHeaderType::Forwarded, + 'for="[2001:db8:cafe::17]";proto=https;secret=test;by=172.18.0.9', + new InternetAddress('2001:db8:cafe::17', 0), + [ + 'for' => '[2001:db8:cafe::17]', + 'proto' => 'https', + 'secret' => 'test', + 'by' => '172.18.0.9', + ], + ]; + + yield [ + ForwardedHeaderType::Forwarded, + 'for=192.0.2.60;proto=http;by=203.0.113.43', + new InternetAddress('192.0.2.60', 0), + [ + 'for' => '192.0.2.60', + 'proto' => 'http', + 'by' => '203.0.113.43', + ], + ]; + + yield [ + ForwardedHeaderType::Forwarded, + 'for=192.0.2.43, for=198.51.100.17', + new InternetAddress('198.51.100.17', 0), + [ + 'for' => '198.51.100.17', + ], + ]; + + yield [ + ForwardedHeaderType::Forwarded, + 'for="2001:db8:cafe::17"', + null, + ]; + + yield [ + ForwardedHeaderType::XForwardedFor, + '2001:db8:85a3:8d3:1319:8a2e:370:7348', + new InternetAddress('2001:db8:85a3:8d3:1319:8a2e:370:7348', 0), + ]; + + yield [ + ForwardedHeaderType::XForwardedFor, + '203.0.113.195,2001:db8:85a3:8d3:1319:8a2e:370:7348,150.172.238.178', + new InternetAddress('150.172.238.178', 0), + ]; + + yield [ + ForwardedHeaderType::XForwardedFor, + '2001:db8:85a3:8d3:1319:8a2e:370', + null, + ]; + } + + /** + * @dataProvider provideForwardedHeaders + */ + public function testForwarded( + ForwardedHeaderType $type, + string $headerValue, + ?InternetAddress $address, + array $fields = [], + ): void { + $middleware = new ForwardedMiddleware($type, ['172.18.0.0/24']); + + $request = new Request($this->createClient('172.18.0.5'), 'GET', $this->createMock(PsrUri::class)); + $request->setHeader($type->getHeaderName(), $headerValue); + + $middleware->handleRequest($request, $this->verifyUsing(function (?Forwarded $forwarded) use ( + $address, + $fields + ): void { + self::assertSame($address?->getAddress(), $forwarded?->getFor()->getAddress()); + self::assertSame($address?->getPort(), $forwarded?->getFor()->getPort()); + + if (!$forwarded) { + return; + } + + foreach ($fields as $field => $expected) { + self::assertSame($expected, $forwarded->getField($field)); + } + })); + } +} diff --git a/test/test-server.php b/test/test-server.php index 6f1e5533..b77e39de 100644 --- a/test/test-server.php +++ b/test/test-server.php @@ -33,7 +33,7 @@ $logger = new Logger('server'); $logger->pushHandler($logHandler); -$server = new SocketHttpServer($logger); +$server = SocketHttpServer::createForDirectAccess($logger); $server->expose(new Socket\InternetAddress("0.0.0.0", 1338), $context); $server->expose(new Socket\InternetAddress("[::]", 1338), $context); @@ -49,7 +49,7 @@ } return new Response(HttpStatus::OK, [ - "content-type" => "text/plain; charset=utf-8", + "content-type" => "text/plain; charset=utf-8", ], "Hello, World!"); }), new DefaultErrorHandler());