From 2fdf696ceebffcc99007798bad1dec135b464756 Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Fri, 3 Nov 2023 08:12:37 -0700 Subject: [PATCH] Add prepend_length to Message.to_wire(). (#1001) If a caller passes prepend_length=True, the wire format will include the 2 byte encoded message length before the message itself. This is useful for callers planning to send the message over TCP, DoT, and DoQ. --- dns/asyncquery.py | 12 +++++------- dns/message.py | 10 +++++++++- dns/query.py | 12 +++++------- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/dns/asyncquery.py b/dns/asyncquery.py index ecf9c1a5f..13d317d14 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -292,14 +292,12 @@ async def send_tcp( """ if isinstance(what, dns.message.Message): - wire = what.to_wire() + tcpmsg = what.to_wire(prepend_length=True) else: - wire = what - l = len(wire) - # copying the wire into tcpmsg is inefficient, but lets us - # avoid writev() or doing a short write that would get pushed - # onto the net - tcpmsg = struct.pack("!H", l) + wire + # copying the wire into tcpmsg is inefficient, but lets us + # avoid writev() or doing a short write that would get pushed + # onto the net + tcpmsg = len(what).to_bytes(2, 'big') + what sent_time = time.time() await sock.sendall(tcpmsg, _timeout(expiration, sent_time)) return (len(tcpmsg), sent_time) diff --git a/dns/message.py b/dns/message.py index 70049efd3..54611d6fc 100644 --- a/dns/message.py +++ b/dns/message.py @@ -527,6 +527,7 @@ def to_wire( max_size: int = 0, multi: bool = False, tsig_ctx: Optional[Any] = None, + prepend_length : bool = False, **kw: Dict[str, Any], ) -> bytes: """Return a string containing the message in DNS compressed wire @@ -549,6 +550,10 @@ def to_wire( *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the ongoing TSIG context, used when signing zone transfers. + *prepend_length", a ``bool``, should be set to ``True`` if the caller + wants the message length prepended to the message itself. This is + useful for messages sent over TCP, TLS (DoT), or QUIC (DoQ). + Raises ``dns.exception.TooBig`` if *max_size* was exceeded. Returns a ``bytes``. @@ -598,7 +603,10 @@ def to_wire( r.write_header() if multi: self.tsig_ctx = ctx - return r.get_wire() + wire = r.get_wire() + if prepend_length: + wire = len(wire).to_bytes(2, 'big') + wire + return wire @staticmethod def _make_tsig( diff --git a/dns/query.py b/dns/query.py index 0d7112515..a5fc6023d 100644 --- a/dns/query.py +++ b/dns/query.py @@ -864,14 +864,12 @@ def send_tcp( """ if isinstance(what, dns.message.Message): - wire = what.to_wire() + tcpmsg = what.to_wire(prepend_length=True) else: - wire = what - l = len(wire) - # copying the wire into tcpmsg is inefficient, but lets us - # avoid writev() or doing a short write that would get pushed - # onto the net - tcpmsg = struct.pack("!H", l) + wire + # copying the wire into tcpmsg is inefficient, but lets us + # avoid writev() or doing a short write that would get pushed + # onto the net + tcpmsg = len(what).to_bytes(2, 'big') + what sent_time = time.time() _net_write(sock, tcpmsg, expiration) return (len(tcpmsg), sent_time)