From b6563b717213c10e2cc3672d49fb55d05286609f Mon Sep 17 00:00:00 2001 From: Nick Pillitteri Date: Fri, 25 Oct 2024 19:46:14 -0400 Subject: [PATCH] dns: Fix an issue where a buffer was incorrectly reused Fix an issue with the TCP DNS client where a buffer was not cleared between reuse resulting in TCP connections only being used for a single request and discarded (due to corrupt messages). --- mtop-client/src/dns/client.rs | 71 ++++++++++++++++++++++++++++++++++ mtop-client/src/dns/core.rs | 12 ++++++ mtop-client/src/dns/message.rs | 42 +++++++++++++++++++- 3 files changed, 123 insertions(+), 2 deletions(-) diff --git a/mtop-client/src/dns/client.rs b/mtop-client/src/dns/client.rs index 22c879e..c3911cb 100644 --- a/mtop-client/src/dns/client.rs +++ b/mtop-client/src/dns/client.rs @@ -201,6 +201,8 @@ pub struct TcpConnection { read: BufReader>, write: BufWriter>, buffer: Vec, + bytes_read: AtomicUsize, + bytes_written: AtomicUsize, } impl TcpConnection { @@ -213,17 +215,23 @@ impl TcpConnection { read: BufReader::new(Box::new(read)), write: BufWriter::new(Box::new(write)), buffer: Vec::with_capacity(size), + bytes_read: AtomicUsize::new(0), + bytes_written: AtomicUsize::new(0), } } pub async fn exchange(&mut self, msg: &Message) -> Result { // Write the message to a local buffer and then send it, prefixed // with the size of the message. + self.buffer.clear(); msg.write_network_bytes(&mut self.buffer)?; self.write.write_u16(self.buffer.len() as u16).await?; self.write.write_all(&self.buffer).await?; self.write.flush().await?; + // Increment total bytes written including the request size prefix. + self.bytes_written.fetch_add(self.buffer.len() + 2, Ordering::Relaxed); + // Read the prefixed size of the response in big-endian (network) // order and then read exactly that many bytes into our buffer. let sz = self.read.read_u16().await?; @@ -231,6 +239,9 @@ impl TcpConnection { self.buffer.resize(usize::from(sz), 0); self.read.read_exact(&mut self.buffer).await?; + // Increment total bytes read including the response size prefix. + self.bytes_read.fetch_add(self.buffer.len() + 2, Ordering::Relaxed); + let mut cur = Cursor::new(&self.buffer); let res = Message::read_network_bytes(&mut cur)?; if res.id() != msg.id() { @@ -243,6 +254,14 @@ impl TcpConnection { Ok(res) } } + + pub fn bytes_written(&self) -> usize { + self.bytes_written.load(Ordering::Relaxed) + } + + pub fn bytes_read(&self) -> usize { + self.bytes_read.load(Ordering::Relaxed) + } } impl fmt::Debug for TcpConnection { @@ -486,6 +505,58 @@ mod test { assert_eq!(ErrorKind::Runtime, err.kind()); } + #[tokio::test] + async fn test_tcp_client_multiple_messages() { + let write = Vec::new(); + let read = { + let response1 = new_message_bytes(123, true); + let response2 = new_message_bytes(456, true); + let mut bytes = Vec::new(); + bytes.extend(response1); + bytes.extend(response2); + Cursor::new(bytes) + }; + + let mut client = TcpConnection::new(read, write, 512); + + let question = Question::new(Name::from_str("example.com.").unwrap(), RecordType::A); + let message1 = + Message::new(MessageId::from(123), Flags::default().set_recursion_desired()).add_question(question.clone()); + let message2 = + Message::new(MessageId::from(456), Flags::default().set_recursion_desired()).add_question(question.clone()); + + let res1 = client.exchange(&message1).await.unwrap(); + assert_eq!(message1.id(), res1.id()); + assert_eq!(message1.questions()[0], res1.questions()[0]); + assert_eq!( + Record::new( + Name::from_str("example.com.").unwrap(), + RecordType::A, + RecordClass::INET, + 60, + RecordData::A(RecordDataA::new(Ipv4Addr::new(127, 0, 0, 100))), + ), + res1.answers()[0] + ); + + let res2 = client.exchange(&message2).await.unwrap(); + assert_eq!(message2.id(), res2.id()); + assert_eq!(message2.questions()[0], res2.questions()[0]); + assert_eq!( + Record::new( + Name::from_str("example.com.").unwrap(), + RecordType::A, + RecordClass::INET, + 60, + RecordData::A(RecordDataA::new(Ipv4Addr::new(127, 0, 0, 100))), + ), + res2.answers()[0] + ); + + let expected_bytes = message1.size() + message2.size() + 2 + 2; + assert_eq!(expected_bytes, client.bytes_written()); + } + #[tokio::test] async fn test_tcp_client_success() { let write = Vec::new(); diff --git a/mtop-client/src/dns/core.rs b/mtop-client/src/dns/core.rs index a893211..867597f 100644 --- a/mtop-client/src/dns/core.rs +++ b/mtop-client/src/dns/core.rs @@ -17,6 +17,12 @@ pub enum RecordType { Unknown(u16), } +impl RecordType { + pub fn size(&self) -> usize { + 2 + } +} + impl From for RecordType { fn from(value: u16) -> Self { match value { @@ -95,6 +101,12 @@ pub enum RecordClass { Unknown(u16), } +impl RecordClass { + pub fn size(&self) -> usize { + 2 + } +} + impl From for RecordClass { fn from(value: u16) -> Self { match value { diff --git a/mtop-client/src/dns/message.rs b/mtop-client/src/dns/message.rs index d63c7f1..10c61c2 100644 --- a/mtop-client/src/dns/message.rs +++ b/mtop-client/src/dns/message.rs @@ -15,6 +15,10 @@ impl MessageId { pub fn random() -> Self { Self(rand::random()) } + + pub fn size(&self) -> usize { + 2 + } } impl From for MessageId { @@ -57,6 +61,16 @@ impl Message { } } + pub fn size(&self) -> usize { + self.id.size() + + self.flags.size() + + (2 * 4) // lengths of questions, answers, authority, extra + + self.questions.iter().map(|q| q.size()).sum::() + + self.answers.iter().map(|r| r.size()).sum::() + + self.authority.iter().map(|r| r.size()).sum::() + + self.extra.iter().map(|r| r.size()).sum::() + } + pub fn id(&self) -> MessageId { self.id } @@ -246,6 +260,10 @@ impl Flags { const OFFSET_RA: usize = 7; const OFFSET_RC: usize = 0; + pub fn size(&self) -> usize { + 2 + } + pub fn is_query(&self) -> bool { !(self.0 & Self::MASK_QR) > 0 } @@ -445,6 +463,10 @@ impl Question { } } + pub fn size(&self) -> usize { + self.name.size() + self.qtype.size() + self.qclass.size() + } + pub fn set_qclass(mut self, qclass: RecordClass) -> Self { self.qclass = qclass; self @@ -502,6 +524,15 @@ impl Record { } } + pub fn size(&self) -> usize { + self.name.size() + + self.rtype.size() + + self.rclass.size() + + 4 // ttl + + 2 // rdata length + + self.rdata.size() + } + pub fn name(&self) -> &Name { &self.name } @@ -875,6 +906,7 @@ mod test { #[test] fn test_question_write_network_bytes() { let q = Question::new(Name::from_str("example.com.").unwrap(), RecordType::AAAA); + let size = q.size(); let mut cur = Cursor::new(Vec::new()); q.write_network_bytes(&mut cur).unwrap(); let buf = cur.into_inner(); @@ -891,6 +923,7 @@ mod test { ], buf, ); + assert_eq!(size, buf.len()); } #[rustfmt::skip] @@ -906,10 +939,12 @@ mod test { 0, 1, // INET class ]); + let size = cur.get_ref().len(); let q = Question::read_network_bytes(cur).unwrap(); assert_eq!("example.com.", q.name().to_string()); assert_eq!(RecordType::AAAA, q.qtype()); assert_eq!(RecordClass::INET, q.qclass()); + assert_eq!(size, q.size()); } #[rustfmt::skip] @@ -922,6 +957,7 @@ mod test { 300, RecordData::A(RecordDataA::new(Ipv4Addr::new(127, 0, 0, 100))), ); + let size = rr.size(); let mut cur = Cursor::new(Vec::new()); rr.write_network_bytes(&mut cur).unwrap(); let buf = cur.into_inner(); @@ -942,7 +978,8 @@ mod test { 127, 0, 0, 100, // rdata, A address ], buf, - ) + ); + assert_eq!(size, buf.len()); } #[rustfmt::skip] @@ -963,16 +1000,17 @@ mod test { 127, 0, 0, 100, // rdata, A address ]); + let size = cur.get_ref().len(); let rr = Record::read_network_bytes(cur).unwrap(); assert_eq!("www.example.com.", rr.name().to_string()); assert_eq!(RecordType::A, rr.rtype()); assert_eq!(RecordClass::INET, rr.rclass()); assert_eq!(300, rr.ttl()); - if let RecordData::A(rd) = rr.rdata() { assert_eq!(Ipv4Addr::new(127, 0, 0, 100), rd.addr()); } else { panic!("unexpected rdata type: {:?}", rr.rdata()); } + assert_eq!(size, rr.size()); } }