Skip to content

Commit

Permalink
ForwardedMiddleware (#347)
Browse files Browse the repository at this point in the history
  • Loading branch information
trowski authored Apr 22, 2023
1 parent d416c11 commit b86f0d4
Show file tree
Hide file tree
Showing 12 changed files with 387 additions and 28 deletions.
2 changes: 1 addition & 1 deletion composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"amphp/hpack": "^3",
"amphp/http": "^2",
"amphp/pipeline": "^1",
"amphp/socket": "^2",
"amphp/socket": "^2.1",
"amphp/sync": "^2",
"league/uri": "^6",
"league/uri-interfaces": "^2.3",
Expand Down
7 changes: 6 additions & 1 deletion examples/hello-world.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use Amp\ByteStream;
use Amp\Http\HttpStatus;
use Amp\Http\Server\DefaultErrorHandler;
use Amp\Http\Server\Driver\SocketClientFactory;
use Amp\Http\Server\Request;
use Amp\Http\Server\RequestHandler;
use Amp\Http\Server\Response;
Expand All @@ -32,7 +33,11 @@
$logger->pushHandler($logHandler);
$logger->useLoggingLoopDetection(false);

$server = SocketHttpServer::createForBehindProxy($logger);
$server = new SocketHttpServer(
$logger,
new Socket\ResourceServerSocketFactory(),
new SocketClientFactory($logger),
);

$server->expose("0.0.0.0:1337");
$server->expose("[::]:1337");
Expand Down
2 changes: 1 addition & 1 deletion examples/state.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
$logger = new Logger('server');
$logger->pushHandler($logHandler);

$server = SocketHttpServer::createForBehindProxy($logger);
$server = SocketHttpServer::createForDirectAccess($logger);

$server->expose("0.0.0.0:1337");
$server->expose("[::]:1337");
Expand Down
2 changes: 1 addition & 1 deletion src/Driver/ConnectionLimitingClientFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ final class ConnectionLimitingClientFactory implements ClientFactory
public function __construct(
private readonly ClientFactory $clientFactory,
private readonly PsrLogger $logger,
private readonly int $connectionsPerIpLimit = 10,
private readonly int $connectionsPerIpLimit,
) {
}

Expand Down
3 changes: 1 addition & 2 deletions src/Driver/Http2Driver.php
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ final class Http2Driver extends AbstractHttpDriver implements Http2Processor

private int $initialWindowSize = self::DEFAULT_WINDOW_SIZE;

/** @var positive-int */
private int $maxFrameSize = self::DEFAULT_MAX_FRAME_SIZE;

private bool $allowsPush;
Expand Down Expand Up @@ -569,7 +570,6 @@ private function writeBufferedData(int $id): void

if ($length > $this->maxFrameSize) {
$split = \str_split($stream->buffer, $this->maxFrameSize);
\assert(\is_array($split)); // For Psalm
$stream->buffer = \array_pop($split);
foreach ($split as $part) {
$this->writeFrame($part, Http2Parser::DATA, Http2Parser::NO_FLAG, $id);
Expand Down Expand Up @@ -626,7 +626,6 @@ private function writeHeaders(string $headers, int $type, int $flags, int $id):
// Header frames must be sent as one contiguous block without frames from any other stream being
// interleaved between due to HPack. See https://datatracker.ietf.org/doc/html/rfc7540#section-4.3
$split = \str_split($headers, $this->maxFrameSize);
\assert(\is_array($split)); // For Psalm
$headers = \array_pop($split);

$writeFrame = $this->writeFrame(...);
Expand Down
27 changes: 27 additions & 0 deletions src/Middleware/Forwarded.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
<?php declare(strict_types=1);

namespace Amp\Http\Server\Middleware;

use Amp\Socket\InternetAddress;

final class Forwarded
{
/**
* @param array<non-empty-string, string|null> $fields
*/
public function __construct(
private readonly InternetAddress $for,
private readonly array $fields,
) {
}

public function getFor(): InternetAddress
{
return $this->for;
}

public function getField(string $name): ?string
{
return $this->fields[$name] ?? null;
}
}
20 changes: 20 additions & 0 deletions src/Middleware/ForwardedHeaderType.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<?php declare(strict_types=1);

namespace Amp\Http\Server\Middleware;

enum ForwardedHeaderType
{
case Forwarded;
case XForwardedFor;

/**
* @return non-empty-string
*/
public function getHeaderName(): string
{
return match ($this) {
self::Forwarded => 'forwarded',
self::XForwardedFor => 'x-forwarded-for',
};
}
}
137 changes: 137 additions & 0 deletions src/Middleware/ForwardedMiddleware.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
<?php declare(strict_types=1);

namespace Amp\Http\Server\Middleware;

use Amp\Cache\LocalCache;
use Amp\Http;
use Amp\Http\Server\Middleware;
use Amp\Http\Server\Request;
use Amp\Http\Server\RequestHandler;
use Amp\Http\Server\Response;
use Amp\Socket\CidrMatcher;
use Amp\Socket\InternetAddress;

final class ForwardedMiddleware implements Middleware
{
/** @var list<CidrMatcher> */
private readonly array $trustedProxies;

/** @var LocalCache<bool> */
private readonly LocalCache $trustedIps;

/**
* @param array<non-empty-string> $trustedProxies Array of IPv4 or IPv6 addresses with an optional subnet mask.
* e.g., '172.18.0.0/24'
* @param positive-int $cacheSize
*/
public function __construct(
private readonly ForwardedHeaderType $headerType,
array $trustedProxies,
int $cacheSize = 1000,
) {
$this->trustedProxies = \array_map(
static fn (string $ip) => new CidrMatcher($ip),
\array_values($trustedProxies),
);

$this->trustedIps = new LocalCache($cacheSize);
}

public function handleRequest(Request $request, RequestHandler $requestHandler): Response
{
$clientAddress = $request->getClient()->getRemoteAddress();

if ($clientAddress instanceof InternetAddress && $this->isTrustedProxy($clientAddress)) {
$request->setAttribute(Forwarded::class, match ($this->headerType) {
ForwardedHeaderType::Forwarded => $this->getForwarded($request),
ForwardedHeaderType::XForwardedFor => $this->getForwardedFor($request),
});
}

return $requestHandler->handleRequest($request);
}

private function isTrustedProxy(InternetAddress $address): bool
{
$ip = $address->getAddress();
$trusted = $this->trustedIps->get($ip);
if ($trusted !== null) {
return $trusted;
}

$trusted = false;
foreach ($this->trustedProxies as $matcher) {
if ($matcher->match($ip)) {
$trusted = true;
break;
}
}

$this->trustedIps->set($ip, $trusted);

return $trusted;
}

private function getForwarded(Request $request): ?Forwarded
{
$headers = Http\parseMultipleHeaderFields($request, 'forwarded');
if (!$headers) {
return null;
}

foreach (\array_reverse($headers) as $header) {
$for = $header['for'] ?? null;
if ($for === null) {
continue;
}

$address = InternetAddress::tryFromString($this->addPortIfMissing($for));
if (!$address || $this->isTrustedProxy($address)) {
continue;
}

return new Forwarded($address, $header);
}

return null;
}

private function addPortIfMissing(string $address): string
{
if (!\str_contains($address, ':') || \str_ends_with($address, ']')) {
$address .= ':0';
}

return $address;
}

private function getForwardedFor(Request $request): ?Forwarded
{
$forwardedFor = Http\splitHeader($request, 'x-forwarded-for');
if (!$forwardedFor) {
return null;
}

$forwardedFor = \array_map(static function (string $ip): string {
if (\str_contains($ip, ':')) {
return '[' . \trim($ip, '[]') . ']:0';
}

return $ip . ':0';
}, $forwardedFor);

/** @var InternetAddress[] $forwardedFor */
$forwardedFor = \array_filter(\array_map(InternetAddress::tryFromString(...), $forwardedFor));
foreach (\array_reverse($forwardedFor) as $for) {
if (!$this->isTrustedProxy($for)) {
return new Forwarded($for, [
'for' => $for->getAddress(),
'host' => $request->getHeader('x-forwarded-host'),
'proto' => $request->getHeader('x-forwarded-proto'),
]);
}
}

return null;
}
}
64 changes: 45 additions & 19 deletions src/SocketHttpServer.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
use Amp\Http\Server\Driver\SocketClientFactory;
use Amp\Http\Server\Middleware\AllowedMethodsMiddleware;
use Amp\Http\Server\Middleware\CompressionMiddleware;
use Amp\Http\Server\Middleware\ForwardedHeaderType;
use Amp\Http\Server\Middleware\ForwardedMiddleware;
use Amp\Socket\BindContext;
use Amp\Socket\ResourceServerSocketFactory;
use Amp\Socket\ServerSocket;
Expand Down Expand Up @@ -52,14 +54,13 @@ final class SocketHttpServer implements HttpServer
/**
* Creates an instance appropriate for direct access by the public.
*
* @param CompressionMiddleware|null $compressionMiddleware Use null to disable compression.
* @param positive-int $connectionLimit Default is {@see self::DEFAULT_CONNECTION_LIMIT}.
* @param positive-int $connectionLimitPerIp Default is {@see self::DEFAULT_CONNECTIONS_PER_IP_LIMIT}.
* @param list<non-empty-string>|null $allowedMethods Use null to disable request method filtering.
*/
public static function createForDirectAccess(
PsrLogger $logger,
?CompressionMiddleware $compressionMiddleware = new CompressionMiddleware(),
bool $enableCompression = true,
int $connectionLimit = self::DEFAULT_CONNECTION_LIMIT,
int $connectionLimitPerIp = self::DEFAULT_CONNECTIONS_PER_IP_LIMIT,
?array $allowedMethods = AllowedMethodsMiddleware::DEFAULT_ALLOWED_METHODS,
Expand All @@ -80,11 +81,16 @@ public static function createForDirectAccess(
$connectionLimitPerIp,
));

$middleware = [];
if ($enableCompression && $compressionMiddleware = self::createCompressionMiddleware($logger)) {
$middleware[] = $compressionMiddleware;
}

return new self(
$logger,
$serverSocketFactory,
$clientFactory,
$compressionMiddleware,
$middleware,
$allowedMethods,
$httpDriverFactory,
);
Expand All @@ -99,26 +105,54 @@ public static function createForDirectAccess(
*/
public static function createForBehindProxy(
PsrLogger $logger,
?array $allowedMethods = null,
ForwardedHeaderType $headerType,
array $trustedProxies,
bool $enableCompression = true,
?array $allowedMethods = AllowedMethodsMiddleware::DEFAULT_ALLOWED_METHODS,
?HttpDriverFactory $httpDriverFactory = null,
): self {
$middleware = [];

$middleware[] = new ForwardedMiddleware($headerType, $trustedProxies);

if ($enableCompression && $compressionMiddleware = self::createCompressionMiddleware($logger)) {
$middleware[] = $compressionMiddleware;
}

return new self(
$logger,
new ResourceServerSocketFactory(),
new SocketClientFactory($logger),
allowedMethods: $allowedMethods,
httpDriverFactory: $httpDriverFactory,
$middleware,
$allowedMethods,
$httpDriverFactory,
);
}

private static function createCompressionMiddleware(PsrLogger $logger): ?CompressionMiddleware
{
if (!\extension_loaded('zlib')) {
$logger->warning(
'The zlib extension is not loaded which prevents using compression. ' .
'Either activate the zlib extension or set $enableCompression to false'
);

return null;
}

return new CompressionMiddleware();
}

/**
* @param array<Middleware> $middleware Default middlewares. You may also use {@see Middleware\stack()} before
* passing the {@see RequestHandler} to {@see self::start()}.
* @param list<non-empty-string>|null $allowedMethods Use null to disable request method filtering.
*/
public function __construct(
private readonly PsrLogger $logger,
private readonly ServerSocketFactory $serverSocketFactory,
private readonly ClientFactory $clientFactory,
private readonly ?CompressionMiddleware $compressionMiddleware = null,
private readonly array $middleware = [],
private readonly ?array $allowedMethods = AllowedMethodsMiddleware::DEFAULT_ALLOWED_METHODS,
?HttpDriverFactory $httpDriverFactory = null,
) {
Expand Down Expand Up @@ -195,17 +229,7 @@ public function start(RequestHandler $requestHandler, ErrorHandler $errorHandler
$this->logger->warning("The 'xdebug' extension is loaded, which has a major impact on performance.");
}

if ($this->compressionMiddleware) {
if (!\extension_loaded('zlib')) {
$this->logger->warning(
"The zlib extension is not loaded which prevents using compression. " .
"Either activate the zlib extension or disable compression in the server's options."
);
} else {
$this->logger->notice('Response compression enabled.');
$requestHandler = Middleware\stack($requestHandler, $this->compressionMiddleware);
}
}
$requestHandler = Middleware\stack($requestHandler, ...$this->middleware);

if ($this->allowedMethods !== null) {
$this->logger->notice(\sprintf(
Expand Down Expand Up @@ -243,6 +267,7 @@ public function start(RequestHandler $requestHandler, ErrorHandler $errorHandler
$this->httpDriverFactory->getApplicationLayerProtocols(),
);

/** @psalm-suppress PropertyTypeCoercion */
$this->servers[] = $this->serverSocketFactory->listen(
$address,
$bindContext?->withTlsContext($tlsContext),
Expand All @@ -258,7 +283,8 @@ public function start(RequestHandler $requestHandler, ErrorHandler $errorHandler

$this->logger->info("Listening on {$scheme}://{$serverName}/");

EventLoop::queue($this->accept(...), $server, $requestHandler, $errorHandler);
// Using short-closure to avoid Psalm bug when using a first-class callable here.
EventLoop::queue(fn () => $this->accept($server, $requestHandler, $errorHandler));
}
} catch (\Throwable $exception) {
try {
Expand Down
Loading

0 comments on commit b86f0d4

Please sign in to comment.