From f9ade330d54b0d698d85db2a919d83b3f0b457f9 Mon Sep 17 00:00:00 2001 From: Bastien Gandouet Date: Thu, 11 Jul 2024 17:50:10 +0200 Subject: [PATCH] Add trusted networks --- tests/middleware/test_proxy_headers.py | 40 ++++++++++++++++++++++++++ uvicorn/middleware/proxy_headers.py | 35 ++++++++++++++++++++-- 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/tests/middleware/test_proxy_headers.py b/tests/middleware/test_proxy_headers.py index 81e559944..0fed945f1 100644 --- a/tests/middleware/test_proxy_headers.py +++ b/tests/middleware/test_proxy_headers.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ipaddress from typing import TYPE_CHECKING import httpx @@ -58,6 +59,45 @@ async def test_proxy_headers_trusted_hosts(trusted_hosts: list[str] | str, respo assert response.text == response_text +@pytest.mark.anyio +@pytest.mark.parametrize( + ("trusted_networks", "response_text"), + [ + ([ipaddress.IPv4Network("192.168.0.0/24")], "Remote: https://10.0.2.1:0"), + ( + [ + ipaddress.IPv4Network("192.168.0.0/24"), + ipaddress.IPv4Network("10.0.0.0/16"), + ipaddress.IPv6Network("2001:db8::/64"), + ], + "Remote: https://1.2.3.4:0", + ), + ( + [ + ipaddress.IPv4Network("192.168.0.0/24"), + ipaddress.IPv4Network("10.0.0.0/16"), + ], + "Remote: https://2001:db8::1:0", + ), + ], +) +async def test_proxy_headers_trusted_networks( + trusted_networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network], + response_text: str, +) -> None: + app_with_middleware = ProxyHeadersMiddleware(app, trusted_networks=trusted_networks) + transport = httpx.ASGITransport(app=app_with_middleware) # type: ignore + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + headers = { + "X-Forwarded-Proto": "https", + "X-Forwarded-For": "1.2.3.4, 2001:db8::1, 10.0.2.1, 192.168.0.2", + } + response = await client.get("/", headers=headers) + + assert response.status_code == 200 + assert response.text == response_text + + @pytest.mark.anyio @pytest.mark.parametrize( ("trusted_hosts", "response_text"), diff --git a/uvicorn/middleware/proxy_headers.py b/uvicorn/middleware/proxy_headers.py index 45d5518ce..88a24d512 100644 --- a/uvicorn/middleware/proxy_headers.py +++ b/uvicorn/middleware/proxy_headers.py @@ -11,9 +11,34 @@ from __future__ import annotations +import ipaddress from typing import Union, cast -from uvicorn._types import ASGI3Application, ASGIReceiveCallable, ASGISendCallable, HTTPScope, Scope, WebSocketScope +from uvicorn._types import ( + ASGI3Application, + ASGIReceiveCallable, + ASGISendCallable, + HTTPScope, + Scope, + WebSocketScope, +) + + +def _address_to_network( + host: str, +) -> ipaddress.IPv4Network | ipaddress.IPv6Network: + address = ipaddress.ip_address(host) + return ipaddress.ip_network(int(address)) + + +def _networks_contain_address( + networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network], + address: ipaddress.IPv4Address | ipaddress.IPv6Address, +) -> bool: + for network in networks: + if address in network: + return True + return False class ProxyHeadersMiddleware: @@ -21,6 +46,7 @@ def __init__( self, app: ASGI3Application, trusted_hosts: list[str] | str = "127.0.0.1", + trusted_networks: (list[ipaddress.IPv4Network | ipaddress.IPv6Network] | None) = None, ) -> None: self.app = app if isinstance(trusted_hosts, str): @@ -29,12 +55,17 @@ def __init__( self.trusted_hosts = set(trusted_hosts) self.always_trust = "*" in self.trusted_hosts + self.trusted_networks = trusted_networks or [] + + if not self.always_trust: + self.trusted_networks += [_address_to_network(host) for host in self.trusted_hosts] + def get_trusted_client_host(self, x_forwarded_for_hosts: list[str]) -> str | None: if self.always_trust: return x_forwarded_for_hosts[0] for host in reversed(x_forwarded_for_hosts): - if host not in self.trusted_hosts: + if not _networks_contain_address(self.trusted_networks, ipaddress.ip_address(host)): return host return None