-
Notifications
You must be signed in to change notification settings - Fork 34
/
wg-mux-server
executable file
·317 lines (269 loc) · 12.6 KB
/
wg-mux-server
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
#!/usr/bin/env python3
import os, sys, logging, contextlib, asyncio, socket, signal
import math, hashlib, secrets, struct, shelve, base64, ipaddress
import inspect, pathlib as pl, subprocess as sp
class WGMuxConfig:
auth_salt = b'wg-mux-1'
mux_attempts = 4
mux_port = 8739
mux_timeout = 5.0
wg_iface = 'wg'
wg_port = 2200
wg_net = '10.215.72.0/24'
n_skip, n_max = 2, None # to reserve e.g. 10.215.72.1/24 for server
ident_db_path = 'wg-mux-ident.db'
class LogMessage:
def __init__(self, fmt, a, k): self.fmt, self.a, self.k = fmt, a, k
def __str__(self): return self.fmt.format(*self.a, **self.k) if self.a or self.k else self.fmt
class LogStyleAdapter(logging.LoggerAdapter):
def __init__(self, logger, extra=None):
super(LogStyleAdapter, self).__init__(logger, extra or {})
def log(self, level, msg, *args, **kws):
if not self.isEnabledFor(level): return
log_kws = {} if 'exc_info' not in kws else dict(exc_info=kws.pop('exc_info'))
msg, kws = self.process(msg, kws)
self.logger._log(level, LogMessage(msg, args, kws), (), **log_kws)
get_logger = lambda name: LogStyleAdapter(logging.getLogger(name))
log = get_logger('mux-server.main')
b64_encode = lambda s: base64.standard_b64encode(s).decode()
b64_decode = lambda s: ( base64.urlsafe_b64decode
if '-' in s or '_' in s else base64.standard_b64decode )(s)
to_bytes = lambda s: s if isinstance(s, bytes) else str(s).encode()
def str_part(s, sep, default=None):
'Examples: str_part("user@host", "<@", "root"), str_part("host:port", ":>")'
c = sep.strip('<>')
if sep.strip(c) == '<': return (default, s) if c not in s else s.split(c, 1)
else: return (s, default) if c not in s else s.rsplit(c, 1)
def sockopt_resolve(prefix, v):
prefix = prefix.upper()
for k in dir(socket):
if not k.startswith(prefix): continue
if getattr(socket, k) == v: return k[len(prefix):]
return v
def bin_pack(fmt, *args):
'Extends struct.pack with "z" for auto-length bytes.'
return struct.pack(fmt.replace('z', '{}s').format(
*(len(s) for s in args if isinstance(s, bytes)) ), *args)
async def aio_task_cancel(task_list):
'Cancel and await a task or a list of such, which can have empty values mixed-in.'
if inspect.isawaitable(task_list): task_list = [task_list]
task_list = list(filter(None, task_list))
for task in task_list:
with contextlib.suppress(asyncio.CancelledError): task.cancel()
for task in task_list:
with contextlib.suppress(asyncio.CancelledError): await task
def retries_within_timeout( tries, timeout,
backoff_func=lambda e,n: ((e**n-1)/e), slack=1e-2 ):
'Return list of delays to make exactly n tires within timeout, with backoff_func.'
a, b = 0, timeout
while True:
m = (a + b) / 2
delays = list(backoff_func(m, n) for n in range(tries))
error = sum(delays) - timeout
if abs(error) < slack: return delays
elif error > 0: b = m
else: a = m
class MuxServerProtocol:
transport = None
def __init__(self, loop):
self.requests = asyncio.Queue(loop=loop)
self.log = get_logger('mux-server.udp')
def connection_made(self, transport):
self.log.debug('Connection made')
self.transport = transport
def datagram_received(self, data, addr):
self.log.debug('Received {:,d}B from {!r}', len(data), addr)
self.requests.put_nowait((data, addr))
# def error_received(self, err):
# self.log.debug('Network error: {}', err)
def connection_lost(self, err):
self.log.debug('Connection lost: {}', err or 'done')
self.requests.put_nowait(None)
class AuthError(Exception): pass
def parse_req(auth_secret, req):
if not req: return
try:
fmt = f'B{req[0]}s32s'
ident_len, ident, wg_pk_peer, salt, mac = struct.unpack(fmt + '16s64s', req)
if ( len(ident) != ident_len or not
(wg_pk_peer and salt and mac) ): raise AuthError('Invalid structure')
mac_chk = hashlib.blake2b(
req[:struct.calcsize(fmt)], key=auth_secret, salt=salt ).digest()
if not secrets.compare_digest(mac, mac_chk): raise AuthError('MAC mismatch')
except (struct.error, AuthError) as err:
log.debug('Failed to parse/auth request value: {}', err)
return
return b64_encode(ident), b64_encode(wg_pk_peer)
def build_res(auth_secret, ident, wg_addr, wg_net, wg_port):
ident_buff = to_bytes(ident)
salt, payload = os.urandom(16), bin_pack(
'>BBHBzz', len(ident_buff), wg_addr.version,
wg_port, wg_net.prefixlen, ident_buff, wg_addr.packed )
mac = hashlib.blake2b(payload, key=auth_secret, salt=salt).digest()
return payload + salt + mac
def ident_repr(ident):
try: ident_t, ident_dec = 'str', b64_decode(ident).decode()
except UnicodeDecodeError: ident_t, ident_dec = 'b64', ident
return f'[{ident_t}] {ident_dec!r}'
async def mux_send(loop, transport, response, addr, delays):
for delay in delays:
transport.sendto(response, addr)
await asyncio.sleep(delay, loop=loop)
async def mux_listen( loop, auth_secret,
sock_af, sock_p, host, port, init_peer, delays ):
responses = dict()
transport, proto = await loop.create_datagram_endpoint(
lambda: MuxServerProtocol(loop), local_addr=(host, port), family=sock_af, proto=sock_p )
try:
while True:
req, req_addr = await proto.requests.get()
ident, wg_pk_peer = parse_req(auth_secret, req)
if not ident: continue
if ident in responses:
if not responses[ident].done(): continue
await responses[ident]
peer_info = init_peer(ident, wg_pk_peer, req_addr)
if not peer_info: continue
wg_addr, wg_net, wg_port = peer_info
response = build_res(auth_secret, ident, wg_addr, wg_net, wg_port)
responses[ident] = loop.create_task(
mux_send(loop, transport, response, req_addr, delays) )
finally:
await aio_task_cancel(responses.values())
transport.close()
def main(args=None, conf=None):
if not conf: conf = WGMuxConfig()
import argparse
parser = argparse.ArgumentParser(
description='Multiplexer for WireGuard connections,'
' assigning each one unique IP(s) according to provided ident-sting.'
' --wg-iface should be pre-configured with listening port,'
' have IP address and in UP state for setup connections to work.')
group = parser.add_argument_group('Bind socket options')
group.add_argument('bind', nargs='?', default='::',
help='Host or address (to be resolved via gai) to listen on.'
' Can include port, which will override -m/--mux-port option, if specified here.'
' Default is to use "::" wildcard IPv4/IPv6 binding.')
group.add_argument('-p', '--mux-port',
default=conf.mux_port, type=int, metavar='port',
help='Local UDP port to listen on for muxer requests from clients (default: %(default)s).'
' Can also be specified in the "bind" argument, which overrides this option.')
group = parser.add_argument_group('Auth/ident options')
group.add_argument('-s', '--auth-secret', metavar='string',
help='Any string to use as symmetric secret'
' to authenticate both sides on --mux-port (default: %(default)s).'
' Must be same for both mux-client and server scripts talking to each other.'
' Will be read from stdin, if not specified.')
group.add_argument('-i', '--ident-db',
default=conf.ident_db_path, metavar='path',
help='Path to db to store all the seen clients to, for persistent port allocation.'
' Default: %(default)s')
group.add_argument('-l', '--ident-list',
action='store_true', help='List stored ident-addr mappings and exit.')
group = parser.add_argument_group('WireGuard options')
group.add_argument('--wg-iface', metavar='iface', default=conf.wg_iface,
help='WireGuard interface name to configure. Default: %(default)s')
group.add_argument('--wg-port', type=int, metavar='port', default=conf.wg_port,
help='WireGuard endpoint port to send to mux-clients. Default: %(default)s')
group.add_argument('--wg-psk', metavar='file',
help='File with base64-encoded WireGuard pre-shared-secret key to use for connection.')
group.add_argument('--wg-net', metavar='ip/mask', default=conf.wg_net,
help='IP/mask network spec to map'
' ident numbers to and send to mux-clients. Default: %(default)s')
group.add_argument('-c', '--wg-cmd', metavar='cmd', default='wg',
help='"wg" command to run, split on spaces.'
' Use e.g. "sudo wg" to have it run via sudo or full path'
' for a special binary with suid/capabilities. Default: %(default)s')
group = parser.add_argument_group('Misc options')
group.add_argument('-r', '--n-skip',
type=int, metavar='n', default=conf.n_skip,
help='Number of first addresses from --wg-net to skip, so that they'
' can be reserved for something else (e.g. .0 + .1 for server itself).'
' Will only affect peers with new ident strings, not existing ones. Default: %(default)s')
group.add_argument('-m', '--n-max',
type=int, metavar='n', default=conf.n_max,
help='Last number of address from --wg-net to hand out.'
' Same idea as with --n-skip, i.e. to have part of range reserved. Default: no limit.')
group.add_argument('-n', '--attempts',
type=int, metavar='n', default=conf.mux_attempts,
help='Number of UDP packets to send from'
' --mux-port in response to clients (to offset packet loss). Default: %(default)s')
group.add_argument('-t', '--timeout',
type=float, metavar='seconds', default=conf.mux_timeout,
help='Negotiation response timeout on --mux-port, in seconds. Default: %(default)ss')
group.add_argument('--dry-run', action='store_true',
help='Do not change WireGuard configuration, only pretend to.')
group.add_argument('-d', '--debug', action='store_true', help='Verbose operation mode.')
opts = parser.parse_args(sys.argv[1:] if args is None else args)
logging.basicConfig(level=logging.DEBUG if opts.debug else logging.WARNING)
wg_cmd = opts.wg_cmd.split()
wg_net = ipaddress.ip_network(opts.wg_net)
tun_a, tun_b = 0, wg_net.num_addresses - 1
if opts.n_skip: tun_a += opts.n_skip
if opts.n_max: tun_b = opts.n_max
log.debug('Using wg network {} [{} - {}]', wg_net, wg_net[tun_a], wg_net[tun_b])
ident_db = shelve.open(str(pl.Path(opts.ident_db).expanduser()), 'c')
if opts.ident_list:
for ident, n in ident_db.items():
print('n={} addr={} :: {}'.format(n, wg_net[n], ident_repr(ident)))
return
if not opts.auth_secret:
log.debug('Reading --auth-secret from stdin (exact value, incl. spaces and newlines).')
with open(sys.stdin.fileno(), 'rb') as src: opts.auth_secret = src.read()
if not opts.auth_secret: parser.error('No --auth-secret specified and none provided on stdin.')
auth_secret = hashlib.blake2s(to_bytes(opts.auth_secret), person=conf.auth_salt).digest()
host, port_mux, family = opts.bind, opts.mux_port, 0
if host.count(':') > 1: host, port_mux = str_part(host, ']:>', port_mux)
else: host, port_mux = str_part(host, ':>', port_mux)
if '[' in host: family = socket.AF_INET6
host, port_mux = host.strip('[]'), int(port_mux)
try:
addrinfo = socket.getaddrinfo( host, str(port_mux),
family=family, type=socket.SOCK_DGRAM, proto=socket.IPPROTO_UDP )
if not addrinfo: raise socket.gaierror(f'No addrinfo for host: {host}')
except (socket.gaierror, socket.error) as err:
parser.error( 'Failed to resolve socket parameters'
' via getaddrinfo: {!r} - {}'.format((host, port_mux), err_fmt(err)) )
sock_af, sock_t, sock_p, _, sock_addr = addrinfo[0]
log.debug(
'Resolved mux host:port {!r}:{!r} to endpoint: {} (family: {}, type: {}, proto: {})',
host, port_mux, sock_addr,
*(sockopt_resolve(pre, n) for pre, n in [
('af_', sock_af), ('sock_', sock_t), ('ipproto_', sock_p) ]) )
host, port_mux = sock_addr[:2]
log.debug(
'Assigning IPs from network: {} ({} addrs(s))',
wg_net, wg_net.num_addresses )
def ns_iter_func():
ns_used = set(ident_db.values())
for n in range(tun_a, tun_b + 1):
if n not in ns_used: yield n
ns_iter = ns_iter_func()
def init_peer(ident, wg_pk_peer, req_addr):
n = ident_db.get(ident)
if n is None:
try: n = next(ns_iter)
except StopIteration:
log.error( 'No more ns to allocate'
' ident: {} (addr={})', ident_repr(ident), req_addr )
return
ident_db[ident] = n
ident_db.sync()
wg_addr = wg_net[n]
log.debug(
'Allocated [n={}, addr={}] for ident: {} (addr={})',
n, wg_addr, ident_repr(ident), req_addr )
cmd = wg_cmd + ['set', opts.wg_iface, 'peer', wg_pk_peer, 'allowed-ips', str(wg_addr)]
if opts.wg_psk: cmd.extend(['preshared-key', opts.wg_psk])
if not opts.dry_run: sp.run(cmd, check=True)
else: log.debug('Config for peer: {}', ' '.join(cmd))
return wg_addr, wg_net, opts.wg_port
retry_delays = retries_within_timeout(opts.attempts, opts.timeout)
with contextlib.closing(asyncio.get_event_loop()) as loop:
muxer = loop.create_task(mux_listen( loop, auth_secret,
sock_af, sock_p, host, port_mux, init_peer, retry_delays ))
for sig in 'INT TERM'.split():
loop.add_signal_handler(getattr(signal, f'SIG{sig}'), muxer.cancel)
try: return loop.run_until_complete(muxer)
except asyncio.CancelledError: return
if __name__ == '__main__': sys.exit(main())