Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally restrict the range of ephemeral ports #63

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/aioice/ice.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import socket
import threading
from itertools import count
from typing import Dict, List, Optional, Set, Text, Tuple, Union, cast
from typing import Dict, Iterable, List, Optional, Set, Text, Tuple, Union, cast

import netifaces

from . import mdns, stun, turn
from .candidate import Candidate, candidate_foundation, candidate_priority
from .utils import random_string
from .utils import create_datagram_endpoint, random_string

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -297,6 +297,7 @@ class Connection:
:param use_ipv4: Whether to use IPv4 candidates.
:param use_ipv6: Whether to use IPv6 candidates.
:param transport_policy: Transport policy.
:param ephemeral_ports: Set of allowed ephemeral local ports to bind to.
"""

def __init__(
Expand All @@ -312,6 +313,7 @@ def __init__(
use_ipv4: bool = True,
use_ipv6: bool = True,
transport_policy: TransportPolicy = TransportPolicy.ALL,
ephemeral_ports: Optional[Iterable[int]] = None,
) -> None:
self.ice_controlling = ice_controlling
#: Local username, automatically set to a random value.
Expand Down Expand Up @@ -357,6 +359,7 @@ def __init__(
self._tie_breaker = secrets.randbits(64)
self._use_ipv4 = use_ipv4
self._use_ipv6 = use_ipv6
self._ephemeral_ports = ephemeral_ports

if (
stun_server is None
Expand Down Expand Up @@ -876,16 +879,14 @@ async def get_component_candidates(
self, component: int, addresses: List[str], timeout: int = 5
) -> List[Candidate]:
candidates = []
loop = asyncio.get_event_loop()

# gather host candidates
host_protocols = []
for address in addresses:
# create transport
try:
transport, protocol = await loop.create_datagram_endpoint(
lambda: StunProtocol(self), local_addr=(address, 0)
)
transport, protocol = await create_datagram_endpoint(
lambda: StunProtocol(self), local_address=address, local_ports=self._ephemeral_ports)
sock = transport.get_extra_info("socket")
if sock is not None:
sock.setsockopt(
Expand Down
35 changes: 35 additions & 0 deletions src/aioice/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import os
import random
import secrets
import string
from typing import Iterable, Optional, Tuple


def random_string(length: int) -> str:
Expand All @@ -10,3 +13,35 @@ def random_string(length: int) -> str:

def random_transaction_id() -> bytes:
return os.urandom(12)


async def create_datagram_endpoint(protocol_factory,
remote_addr: Tuple[str, int] = None,
local_address: str = None,
local_ports: Optional[Iterable[int]] = None,
):
"""
Asynchronousley create a datagram endpoint.

:param protocol_factory: Callable returning a protocol instance.
:param remote_addr: Remote address and port.
:param local_address: Local address to bind to.
:param local_ports: Set of allowed local ports to bind to.
"""
if local_ports is not None:
ports = list(local_ports)
random.shuffle(ports)
else:
ports = (0,)
loop = asyncio.get_event_loop()
for port in ports:
try:
transport, protocol = await loop.create_datagram_endpoint(
protocol_factory, remote_addr=remote_addr, local_addr=(local_address, port)
)
return transport, protocol
except OSError as exc:
if port == ports[-1]:
# this was the last port, give up
raise exc
raise ValueError("local_ports must not be empty")
57 changes: 57 additions & 0 deletions tests/test_ice.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import functools
import os
import random
import socket
import unittest
from unittest import mock
Expand Down Expand Up @@ -1249,6 +1250,62 @@ async def test_repr(self):
conn._id = 1
self.assertEqual(repr(conn), "Connection(1)")

@asynctest
async def test_connection_ephemeral_ports(self):
addresses = ["127.0.0.1"]

# Let the OS pick a random port - should always yield a candidate
conn1 = ice.Connection(ice_controlling=True)
c = await conn1.get_component_candidates(0, addresses)
self.assertTrue(c[0].port >= 1 and c[0].port <= 65535)

# Try opening a new connection with the same port - should never yield candidates
conn2 = ice.Connection(ice_controlling=True, ephemeral_ports=[c[0].port])
c = await conn2.get_component_candidates(0, addresses)
self.assertEqual(len(c), 0) # port already in use, no candidates
await conn1.close()

# Empty set of ports - illegal argument
conn3 = ice.Connection(ice_controlling=True, ephemeral_ports=[])
with self.assertRaises(ValueError):
await conn3.get_component_candidates(0, addresses)

# Range of 100 ports
lower = random.randint(1024, 65536 - 100)
upper = lower + 100
ports = set(range(lower, upper)) - set([5353])

# Exhaust the range of ports - should always yield candidates
conns = []
for i in range(0, len(ports)):
conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports)
c = await conn.get_component_candidates(i, addresses)
if c:
self.assertTrue(c[0].port >= lower and c[0].port < upper)
conns.append(conn)
self.assertGreaterEqual(len(conns), len(ports) - 1) # account for at most 1 port in use by another process

# Open one more connection from the same range - should never yield candidates
conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports)
c = await conn.get_component_candidates(0, addresses)
self.assertEqual(len(c), 0) # all ports are exhausted, no candidates

# Close one connection and try again - should always yield a candidate
await conns.pop().close()
conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports)
c = await conn.get_component_candidates(0, addresses)
self.assertTrue(c[0].port >= lower and c[0].port < upper)
await conn.close()

# cleanup
for conn in conns:
await conn.close()

# Bind to wildcard local address - should always yield a candidate
conn = ice.Connection(ice_controlling=True)
c = await conn.get_component_candidates(0, [None])
self.assertTrue(c[0].port >= 1 and c[0].port <= 65535)
await conn.close()

class StunProtocolTest(unittest.TestCase):
@asynctest
Expand Down