Skip to content

Commit

Permalink
Add prepend_length to Message.to_wire(). (#1001)
Browse files Browse the repository at this point in the history
If a caller passes prepend_length=True, the wire format will include the
2 byte encoded message length before the message itself.  This is useful
for callers planning to send the message over TCP, DoT, and DoQ.
  • Loading branch information
bwelling authored Nov 3, 2023
1 parent defa13e commit 2fdf696
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
12 changes: 5 additions & 7 deletions dns/asyncquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,12 @@ async def send_tcp(
"""

if isinstance(what, dns.message.Message):
wire = what.to_wire()
tcpmsg = what.to_wire(prepend_length=True)
else:
wire = what
l = len(wire)
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = struct.pack("!H", l) + wire
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = len(what).to_bytes(2, 'big') + what
sent_time = time.time()
await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
return (len(tcpmsg), sent_time)
Expand Down
10 changes: 9 additions & 1 deletion dns/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ def to_wire(
max_size: int = 0,
multi: bool = False,
tsig_ctx: Optional[Any] = None,
prepend_length : bool = False,
**kw: Dict[str, Any],
) -> bytes:
"""Return a string containing the message in DNS compressed wire
Expand All @@ -549,6 +550,10 @@ def to_wire(
*tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the
ongoing TSIG context, used when signing zone transfers.
*prepend_length", a ``bool``, should be set to ``True`` if the caller
wants the message length prepended to the message itself. This is
useful for messages sent over TCP, TLS (DoT), or QUIC (DoQ).
Raises ``dns.exception.TooBig`` if *max_size* was exceeded.
Returns a ``bytes``.
Expand Down Expand Up @@ -598,7 +603,10 @@ def to_wire(
r.write_header()
if multi:
self.tsig_ctx = ctx
return r.get_wire()
wire = r.get_wire()
if prepend_length:
wire = len(wire).to_bytes(2, 'big') + wire
return wire

@staticmethod
def _make_tsig(
Expand Down
12 changes: 5 additions & 7 deletions dns/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,14 +864,12 @@ def send_tcp(
"""

if isinstance(what, dns.message.Message):
wire = what.to_wire()
tcpmsg = what.to_wire(prepend_length=True)
else:
wire = what
l = len(wire)
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = struct.pack("!H", l) + wire
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = len(what).to_bytes(2, 'big') + what
sent_time = time.time()
_net_write(sock, tcpmsg, expiration)
return (len(tcpmsg), sent_time)
Expand Down

0 comments on commit 2fdf696

Please sign in to comment.