From b0f060f714b5e1c46675e7febe71b9304d90e47a Mon Sep 17 00:00:00 2001 From: Johan Westerlund Date: Wed, 26 Oct 2022 14:57:56 +0200 Subject: [PATCH] Optionally restrict the range of ephemeral ports --- src/aioice/ice.py | 13 ++++++----- src/aioice/utils.py | 35 ++++++++++++++++++++++++++++++ tests/test_ice.py | 53 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 6 deletions(-) diff --git a/src/aioice/ice.py b/src/aioice/ice.py index f33a8dc..1dee7ec 100644 --- a/src/aioice/ice.py +++ b/src/aioice/ice.py @@ -8,13 +8,13 @@ import socket import threading from itertools import count -from typing import Dict, List, Optional, Set, Text, Tuple, Union, cast +from typing import Dict, Iterable, List, Optional, Set, Text, Tuple, Union, cast import netifaces from . import mdns, stun, turn from .candidate import Candidate, candidate_foundation, candidate_priority -from .utils import random_string +from .utils import create_datagram_endpoint, random_string logger = logging.getLogger(__name__) @@ -282,6 +282,7 @@ class Connection: :param turn_transport: The transport for TURN server, `"udp"` or `"tcp"`. :param use_ipv4: Whether to use IPv4 candidates. :param use_ipv6: Whether to use IPv6 candidates. + :param ephemeral_ports: Set of allowed ephemeral local ports to bind to. """ def __init__( @@ -296,6 +297,7 @@ def __init__( turn_transport: str = "udp", use_ipv4: bool = True, use_ipv6: bool = True, + ephemeral_ports: Optional[Iterable[int]] = None, ) -> None: self.ice_controlling = ice_controlling #: Local username, automatically set to a random value. @@ -340,6 +342,7 @@ def __init__( self._tie_breaker = secrets.randbits(64) self._use_ipv4 = use_ipv4 self._use_ipv6 = use_ipv6 + self._ephemeral_ports = ephemeral_ports @property def local_candidates(self) -> List[Candidate]: @@ -847,16 +850,14 @@ async def get_component_candidates( self, component: int, addresses: List[str], timeout: int = 5 ) -> List[Candidate]: candidates = [] - loop = asyncio.get_event_loop() # gather host candidates host_protocols = [] for address in addresses: # create transport try: - transport, protocol = await loop.create_datagram_endpoint( - lambda: StunProtocol(self), local_addr=(address, 0) - ) + transport, protocol = await create_datagram_endpoint( + lambda: StunProtocol(self), local_address=address, ephemeral_ports=self._ephemeral_ports) sock = transport.get_extra_info("socket") if sock is not None: sock.setsockopt( diff --git a/src/aioice/utils.py b/src/aioice/utils.py index a292edf..09e4f45 100644 --- a/src/aioice/utils.py +++ b/src/aioice/utils.py @@ -1,6 +1,9 @@ +import asyncio import os +import random import secrets import string +from typing import Iterable, Optional, Tuple def random_string(length: int) -> str: @@ -10,3 +13,35 @@ def random_string(length: int) -> str: def random_transaction_id() -> bytes: return os.urandom(12) + + +async def create_datagram_endpoint(protocol_factory, + remote_addr: Tuple[str, int] = None, + local_address: str = None, + ephemeral_ports: Optional[Iterable[int]] = None, +): + """ + Asynchronousley create a datagram endpoint. + + :param protocol_factory: Callable returning a protocol instance. + :param remote_addr: Remote address and port. + :param local_address: Local address to bind to. + :param ephemeral_ports: Set of allowed local ephemeral ports to bind to. + """ + if ephemeral_ports is not None: + ports = list(ephemeral_ports) + random.shuffle(ports) + else: + ports = (0,) + loop = asyncio.get_event_loop() + for port in ports: + try: + transport, protocol = await loop.create_datagram_endpoint( + protocol_factory, remote_addr=remote_addr, local_addr=(local_address, port) + ) + return transport, protocol + except OSError as exc: + if port == ports[-1]: + # this was the last port, give up + raise exc + raise ValueError("ephemeral_ports must not be empty") diff --git a/tests/test_ice.py b/tests/test_ice.py index fb5e308..07f8d3b 100644 --- a/tests/test_ice.py +++ b/tests/test_ice.py @@ -1,6 +1,7 @@ import asyncio import functools import os +import random import socket import unittest from unittest import mock @@ -1161,6 +1162,58 @@ async def test_repr(self): conn._id = 1 self.assertEqual(repr(conn), "Connection(1)") + @asynctest + async def test_connection_ephemeral_ports(self): + addresses = ["127.0.0.1"] + + # Let the OS pick a random port - should never fail + conn1 = ice.Connection(ice_controlling=True) + c = await conn1.get_component_candidates(0, addresses) + self.assertTrue(c[0].port >= 1 and c[0].port <= 65535) + + # Try opening a new connection with the same port - should always fail + conn2 = ice.Connection(ice_controlling=True, ephemeral_ports=[c[0].port]) + c = await conn2.get_component_candidates(0, addresses) + self.assertTrue(len(c) == 0) # port already in use, no candidates + await conn1.close() + + # Empty set of ports - illegal argument + conn3 = ice.Connection(ice_controlling=True, ephemeral_ports=[]) + with self.assertRaises(ValueError) as exc: + await conn3.get_component_candidates(0, addresses) + + # Range of ports + lower = random.randint(1024, 65536 - 100) + upper = lower + 100 + ports = list(range(lower, upper)) + + conn4 = ice.Connection(ice_controlling=True, ephemeral_ports=ports) + c = await conn4.get_component_candidates(0, addresses) + self.assertTrue(c[0].port >= lower and c[0].port < upper) + + # Exhaust the range of ports + conns = [] + for i in range(1, len(ports)): + conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports) + c = await conn.get_component_candidates(i, addresses) + self.assertTrue(c[0].port >= lower and c[0].port < lower + len(ports)) + conns.append(conn) + + # Open one more connection from the same range - should always fail + conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports) + c = await conn.get_component_candidates(0, addresses) + self.assertTrue(len(c) == 0) # all ports are already in use, no candidates + + # Close one connection and try again - should never fail + await conn4.close() + conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports) + c = await conn.get_component_candidates(0, addresses) + self.assertTrue(c[0].port >= lower and c[0].port < lower + len(ports)) + + # cleanup + await conn.close() + for conn in conns: + await conn.close() class StunProtocolTest(unittest.TestCase): @asynctest