diff --git a/README.md b/README.md index 4a172cd..a7d08fe 100644 --- a/README.md +++ b/README.md @@ -223,5 +223,5 @@ Steps for releasing new versions of `mtop` are described below. * Update local `master` from Github remote. Make sure to build once with updated versions to update `Cargo.lock`. * Create but do not push a tag of the format `v1.2.3` * Run `cargo package` and `cargo publish` for the `mtop-client` crate. -* Run `cargo pacakge` and `cargo publish` for the `mtop` crate. +* Run `cargo package` and `cargo publish` for the `mtop` crate. * Push tags to all remotes `git push --tags origin`, `git push --tags github` diff --git a/mtop-client/src/discovery.rs b/mtop-client/src/discovery.rs index 7c7d768..6bcd702 100644 --- a/mtop-client/src/discovery.rs +++ b/mtop-client/src/discovery.rs @@ -1,5 +1,5 @@ use crate::core::MtopError; -use crate::dns::{DnsClient, Name, Record, RecordClass, RecordData, RecordType}; +use crate::dns::{DnsClient, Message, Name, Record, RecordClass, RecordData, RecordType}; use std::cmp::Ordering; use std::fmt; use std::net::{IpAddr, SocketAddr}; @@ -100,6 +100,16 @@ impl fmt::Display for Server { } } +trait AsyncDnsClient { + async fn resolve(&self, name: Name, rtype: RecordType, rclass: RecordClass) -> Result; +} + +impl AsyncDnsClient for &DnsClient { + async fn resolve(&self, name: Name, rtype: RecordType, rclass: RecordClass) -> Result { + DnsClient::resolve(self, name, rtype, rclass).await + } +} + #[derive(Debug, Clone)] pub struct DiscoveryDefault { client: DnsClient, @@ -112,46 +122,52 @@ impl DiscoveryDefault { pub async fn resolve_by_proto(&self, name: &str) -> Result, MtopError> { if name.starts_with(DNS_A_PREFIX) { - Ok(self.resolve_a_aaaa(name.trim_start_matches(DNS_A_PREFIX)).await?) + Ok(Self::resolve_a_aaaa(&self.client, name.trim_start_matches(DNS_A_PREFIX)).await?) } else if name.starts_with(DNS_SRV_PREFIX) { - Ok(self.resolve_srv(name.trim_start_matches(DNS_SRV_PREFIX)).await?) + Ok(Self::resolve_srv(&self.client, name.trim_start_matches(DNS_SRV_PREFIX)).await?) } else if let Ok(addr) = name.parse::() { - Ok(self.resolv_socket_addr(name, addr)?) + Ok(Self::resolv_socket_addr(name, addr)?) } else { - Ok(self.resolve_a_aaaa(name).await?.pop().into_iter().collect()) + Ok(Self::resolve_a_aaaa(&self.client, name).await?.pop().into_iter().collect()) } } - async fn resolve_srv(&self, name: &str) -> Result, MtopError> { + async fn resolve_srv(client: C, name: &str) -> Result, MtopError> + where + C: AsyncDnsClient, + { let (host, port) = Self::host_and_port(name)?; let server_name = Self::server_name(host)?; let name = host.parse()?; - let res = self.client.resolve(name, RecordType::SRV, RecordClass::INET).await?; - Ok(self.servers_from_answers(port, &server_name, res.answers())) + let res = client.resolve(name, RecordType::SRV, RecordClass::INET).await?; + Ok(Self::servers_from_answers(port, &server_name, res.answers())) } - async fn resolve_a_aaaa(&self, name: &str) -> Result, MtopError> { + async fn resolve_a_aaaa(client: C, name: &str) -> Result, MtopError> + where + C: AsyncDnsClient, + { let (host, port) = Self::host_and_port(name)?; let server_name = Self::server_name(host)?; let name: Name = host.parse()?; - let res = self.client.resolve(name.clone(), RecordType::A, RecordClass::INET).await?; - let mut out = self.servers_from_answers(port, &server_name, res.answers()); + let res = client.resolve(name.clone(), RecordType::A, RecordClass::INET).await?; + let mut out = Self::servers_from_answers(port, &server_name, res.answers()); - let res = self.client.resolve(name, RecordType::AAAA, RecordClass::INET).await?; - out.extend(self.servers_from_answers(port, &server_name, res.answers())); + let res = client.resolve(name, RecordType::AAAA, RecordClass::INET).await?; + out.extend(Self::servers_from_answers(port, &server_name, res.answers())); Ok(out) } - fn resolv_socket_addr(&self, name: &str, addr: SocketAddr) -> Result, MtopError> { + fn resolv_socket_addr(name: &str, addr: SocketAddr) -> Result, MtopError> { let (host, _port) = Self::host_and_port(name)?; let server_name = Self::server_name(host)?; Ok(vec![Server::new(ServerID::from(addr), server_name)]) } - fn servers_from_answers(&self, port: u16, server_name: &ServerName, answers: &[Record]) -> Vec { + fn servers_from_answers(port: u16, server_name: &ServerName, answers: &[Record]) -> Vec { let mut out = Vec::new(); for answer in answers { @@ -200,8 +216,15 @@ impl DiscoveryDefault { #[cfg(test)] mod test { - use super::{Server, ServerID}; + use super::{AsyncDnsClient, DiscoveryDefault, Server, ServerID}; + use crate::core::MtopError; + use crate::dns::{ + Flags, Message, MessageId, Name, Question, Record, RecordClass, RecordData, RecordDataA, RecordDataAAAA, + RecordDataSRV, RecordType, + }; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; + use std::str::FromStr; + use tokio::sync::Mutex; use webpki::types::ServerName; #[test] @@ -272,4 +295,127 @@ mod test { let server = Server::new(id, name); assert_eq!("cache01.example.com:11211", server.address()); } + + struct MockDnsClient { + responses: Mutex>, + } + + impl MockDnsClient { + fn new(responses: Vec) -> Self { + Self { + responses: Mutex::new(responses), + } + } + } + + impl AsyncDnsClient for &MockDnsClient { + async fn resolve(&self, _name: Name, _rtype: RecordType, _rclass: RecordClass) -> Result { + let mut responses = self.responses.lock().await; + let res = responses.pop().unwrap(); + Ok(res) + } + } + + fn response_with_answers(rtype: RecordType, records: Vec) -> Message { + let flags = Flags::default().set_recursion_desired().set_recursion_available(); + let mut message = Message::new(MessageId::random(), flags) + .add_question(Question::new(Name::from_str("example.com.").unwrap(), rtype)); + + for r in records { + message = message.add_answer(r); + } + + message + } + + #[tokio::test] + async fn test_dns_client_resolve_a_aaaa() { + let response_a = response_with_answers( + RecordType::A, + vec![Record::new( + Name::from_str("example.com.").unwrap(), + RecordType::A, + RecordClass::INET, + 300, + RecordData::A(RecordDataA::new(Ipv4Addr::new(10, 1, 1, 1))), + )], + ); + + let response_aaaa = response_with_answers( + RecordType::AAAA, + vec![Record::new( + Name::from_str("example.com.").unwrap(), + RecordType::AAAA, + RecordClass::INET, + 300, + RecordData::AAAA(RecordDataAAAA::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))), + )], + ); + + let client = MockDnsClient::new(vec![response_a, response_aaaa]); + let servers = DiscoveryDefault::resolve_a_aaaa(&client, "example.com:11211").await.unwrap(); + let ids = servers.iter().map(|s| s.id()).collect::>(); + + let id_a = ServerID::from(("10.1.1.1", 11211)); + let id_aaaa = ServerID::from(("[::1]", 11211)); + + assert!(ids.contains(&id_a), "expected {:?} to contain {:?}", ids, id_a); + assert!(ids.contains(&id_aaaa), "expected {:?} to contain {:?}", ids, id_aaaa); + } + + #[tokio::test] + async fn test_dns_client_resolve_srv() { + let response = response_with_answers( + RecordType::SRV, + vec![ + Record::new( + Name::from_str("_cache.example.com.").unwrap(), + RecordType::SRV, + RecordClass::INET, + 300, + RecordData::SRV(RecordDataSRV::new( + 100, + 10, + 11211, + Name::from_str("cache01.example.com.").unwrap(), + )), + ), + Record::new( + Name::from_str("_cache.example.com.").unwrap(), + RecordType::SRV, + RecordClass::INET, + 300, + RecordData::SRV(RecordDataSRV::new( + 100, + 10, + 11211, + Name::from_str("cache02.example.com.").unwrap(), + )), + ), + ], + ); + + let client = MockDnsClient::new(vec![response]); + let servers = DiscoveryDefault::resolve_srv(&client, "_cache.example.com:11211") + .await + .unwrap(); + let ids = servers.iter().map(|s| s.id()).collect::>(); + + let id1 = ServerID::from(("cache01.example.com.", 11211)); + let id2 = ServerID::from(("cache02.example.com.", 11211)); + + assert!(ids.contains(&id1), "expected {:?} to contain {:?}", ids, id1); + assert!(ids.contains(&id2), "expected {:?} to contain {:?}", ids, id2); + } + + #[test] + fn test_dns_client_resolve_socket_addr() { + let name = "127.0.0.2:11211"; + let addr = "127.0.0.2:11211".parse().unwrap(); + let servers = DiscoveryDefault::resolv_socket_addr(name, addr).unwrap(); + let ids = servers.iter().map(|s| s.id()).collect::>(); + + let id = ServerID::from(addr); + assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id); + } } diff --git a/mtop-client/src/dns/rdata.rs b/mtop-client/src/dns/rdata.rs index 08cda5c..86cef11 100644 --- a/mtop-client/src/dns/rdata.rs +++ b/mtop-client/src/dns/rdata.rs @@ -897,7 +897,7 @@ mod test { // to store the length of the segment. In reality, we need 256 bytes for each segment // so having 256 bytes * 256 segments should be an error. let segment = "a".repeat(255); - let data: Vec = (0..256).into_iter().map(|_| segment.clone()).collect(); + let data: Vec = (0..256).map(|_| segment.clone()).collect(); let res = RecordDataTXT::new(data); assert!(res.is_err()); @@ -906,7 +906,7 @@ mod test { #[test] fn test_record_data_txt_new_success() { let segment = "a".repeat(255); - let data: Vec = (0..255).into_iter().map(|_| segment.clone()).collect(); + let data: Vec = (0..255).map(|_| segment.clone()).collect(); let txt = RecordDataTXT::new(data).unwrap(); assert_eq!(65280, txt.size());