Skip to content

Commit

Permalink
Merge pull request #146 from 56quarters/more-tests
Browse files Browse the repository at this point in the history
Add unit tests for some discovery logic
  • Loading branch information
56quarters authored Jun 22, 2024
2 parents 11dadb6 + aa8435e commit 10c6e77
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,5 +223,5 @@ Steps for releasing new versions of `mtop` are described below.
* Update local `master` from Github remote. Make sure to build once with updated versions to update `Cargo.lock`.
* Create but do not push a tag of the format `v1.2.3`
* Run `cargo package` and `cargo publish` for the `mtop-client` crate.
* Run `cargo pacakge` and `cargo publish` for the `mtop` crate.
* Run `cargo package` and `cargo publish` for the `mtop` crate.
* Push tags to all remotes `git push --tags origin`, `git push --tags github`
178 changes: 162 additions & 16 deletions mtop-client/src/discovery.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::core::MtopError;
use crate::dns::{DnsClient, Name, Record, RecordClass, RecordData, RecordType};
use crate::dns::{DnsClient, Message, Name, Record, RecordClass, RecordData, RecordType};
use std::cmp::Ordering;
use std::fmt;
use std::net::{IpAddr, SocketAddr};
Expand Down Expand Up @@ -100,6 +100,16 @@ impl fmt::Display for Server {
}
}

trait AsyncDnsClient {
async fn resolve(&self, name: Name, rtype: RecordType, rclass: RecordClass) -> Result<Message, MtopError>;
}

impl AsyncDnsClient for &DnsClient {
async fn resolve(&self, name: Name, rtype: RecordType, rclass: RecordClass) -> Result<Message, MtopError> {
DnsClient::resolve(self, name, rtype, rclass).await
}
}

#[derive(Debug, Clone)]
pub struct DiscoveryDefault {
client: DnsClient,
Expand All @@ -112,46 +122,52 @@ impl DiscoveryDefault {

pub async fn resolve_by_proto(&self, name: &str) -> Result<Vec<Server>, MtopError> {
if name.starts_with(DNS_A_PREFIX) {
Ok(self.resolve_a_aaaa(name.trim_start_matches(DNS_A_PREFIX)).await?)
Ok(Self::resolve_a_aaaa(&self.client, name.trim_start_matches(DNS_A_PREFIX)).await?)
} else if name.starts_with(DNS_SRV_PREFIX) {
Ok(self.resolve_srv(name.trim_start_matches(DNS_SRV_PREFIX)).await?)
Ok(Self::resolve_srv(&self.client, name.trim_start_matches(DNS_SRV_PREFIX)).await?)
} else if let Ok(addr) = name.parse::<SocketAddr>() {
Ok(self.resolv_socket_addr(name, addr)?)
Ok(Self::resolv_socket_addr(name, addr)?)
} else {
Ok(self.resolve_a_aaaa(name).await?.pop().into_iter().collect())
Ok(Self::resolve_a_aaaa(&self.client, name).await?.pop().into_iter().collect())
}
}

async fn resolve_srv(&self, name: &str) -> Result<Vec<Server>, MtopError> {
async fn resolve_srv<C>(client: C, name: &str) -> Result<Vec<Server>, MtopError>
where
C: AsyncDnsClient,
{
let (host, port) = Self::host_and_port(name)?;
let server_name = Self::server_name(host)?;
let name = host.parse()?;

let res = self.client.resolve(name, RecordType::SRV, RecordClass::INET).await?;
Ok(self.servers_from_answers(port, &server_name, res.answers()))
let res = client.resolve(name, RecordType::SRV, RecordClass::INET).await?;
Ok(Self::servers_from_answers(port, &server_name, res.answers()))
}

async fn resolve_a_aaaa(&self, name: &str) -> Result<Vec<Server>, MtopError> {
async fn resolve_a_aaaa<C>(client: C, name: &str) -> Result<Vec<Server>, MtopError>
where
C: AsyncDnsClient,
{
let (host, port) = Self::host_and_port(name)?;
let server_name = Self::server_name(host)?;
let name: Name = host.parse()?;

let res = self.client.resolve(name.clone(), RecordType::A, RecordClass::INET).await?;
let mut out = self.servers_from_answers(port, &server_name, res.answers());
let res = client.resolve(name.clone(), RecordType::A, RecordClass::INET).await?;
let mut out = Self::servers_from_answers(port, &server_name, res.answers());

let res = self.client.resolve(name, RecordType::AAAA, RecordClass::INET).await?;
out.extend(self.servers_from_answers(port, &server_name, res.answers()));
let res = client.resolve(name, RecordType::AAAA, RecordClass::INET).await?;
out.extend(Self::servers_from_answers(port, &server_name, res.answers()));

Ok(out)
}

fn resolv_socket_addr(&self, name: &str, addr: SocketAddr) -> Result<Vec<Server>, MtopError> {
fn resolv_socket_addr(name: &str, addr: SocketAddr) -> Result<Vec<Server>, MtopError> {
let (host, _port) = Self::host_and_port(name)?;
let server_name = Self::server_name(host)?;
Ok(vec![Server::new(ServerID::from(addr), server_name)])
}

fn servers_from_answers(&self, port: u16, server_name: &ServerName, answers: &[Record]) -> Vec<Server> {
fn servers_from_answers(port: u16, server_name: &ServerName, answers: &[Record]) -> Vec<Server> {
let mut out = Vec::new();

for answer in answers {
Expand Down Expand Up @@ -200,8 +216,15 @@ impl DiscoveryDefault {

#[cfg(test)]
mod test {
use super::{Server, ServerID};
use super::{AsyncDnsClient, DiscoveryDefault, Server, ServerID};
use crate::core::MtopError;
use crate::dns::{
Flags, Message, MessageId, Name, Question, Record, RecordClass, RecordData, RecordDataA, RecordDataAAAA,
RecordDataSRV, RecordType,
};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use tokio::sync::Mutex;
use webpki::types::ServerName;

#[test]
Expand Down Expand Up @@ -272,4 +295,127 @@ mod test {
let server = Server::new(id, name);
assert_eq!("cache01.example.com:11211", server.address());
}

struct MockDnsClient {
responses: Mutex<Vec<Message>>,
}

impl MockDnsClient {
fn new(responses: Vec<Message>) -> Self {
Self {
responses: Mutex::new(responses),
}
}
}

impl AsyncDnsClient for &MockDnsClient {
async fn resolve(&self, _name: Name, _rtype: RecordType, _rclass: RecordClass) -> Result<Message, MtopError> {
let mut responses = self.responses.lock().await;
let res = responses.pop().unwrap();
Ok(res)
}
}

fn response_with_answers(rtype: RecordType, records: Vec<Record>) -> Message {
let flags = Flags::default().set_recursion_desired().set_recursion_available();
let mut message = Message::new(MessageId::random(), flags)
.add_question(Question::new(Name::from_str("example.com.").unwrap(), rtype));

for r in records {
message = message.add_answer(r);
}

message
}

#[tokio::test]
async fn test_dns_client_resolve_a_aaaa() {
let response_a = response_with_answers(
RecordType::A,
vec![Record::new(
Name::from_str("example.com.").unwrap(),
RecordType::A,
RecordClass::INET,
300,
RecordData::A(RecordDataA::new(Ipv4Addr::new(10, 1, 1, 1))),
)],
);

let response_aaaa = response_with_answers(
RecordType::AAAA,
vec![Record::new(
Name::from_str("example.com.").unwrap(),
RecordType::AAAA,
RecordClass::INET,
300,
RecordData::AAAA(RecordDataAAAA::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))),
)],
);

let client = MockDnsClient::new(vec![response_a, response_aaaa]);
let servers = DiscoveryDefault::resolve_a_aaaa(&client, "example.com:11211").await.unwrap();
let ids = servers.iter().map(|s| s.id()).collect::<Vec<_>>();

let id_a = ServerID::from(("10.1.1.1", 11211));
let id_aaaa = ServerID::from(("[::1]", 11211));

assert!(ids.contains(&id_a), "expected {:?} to contain {:?}", ids, id_a);
assert!(ids.contains(&id_aaaa), "expected {:?} to contain {:?}", ids, id_aaaa);
}

#[tokio::test]
async fn test_dns_client_resolve_srv() {
let response = response_with_answers(
RecordType::SRV,
vec![
Record::new(
Name::from_str("_cache.example.com.").unwrap(),
RecordType::SRV,
RecordClass::INET,
300,
RecordData::SRV(RecordDataSRV::new(
100,
10,
11211,
Name::from_str("cache01.example.com.").unwrap(),
)),
),
Record::new(
Name::from_str("_cache.example.com.").unwrap(),
RecordType::SRV,
RecordClass::INET,
300,
RecordData::SRV(RecordDataSRV::new(
100,
10,
11211,
Name::from_str("cache02.example.com.").unwrap(),
)),
),
],
);

let client = MockDnsClient::new(vec![response]);
let servers = DiscoveryDefault::resolve_srv(&client, "_cache.example.com:11211")
.await
.unwrap();
let ids = servers.iter().map(|s| s.id()).collect::<Vec<_>>();

let id1 = ServerID::from(("cache01.example.com.", 11211));
let id2 = ServerID::from(("cache02.example.com.", 11211));

assert!(ids.contains(&id1), "expected {:?} to contain {:?}", ids, id1);
assert!(ids.contains(&id2), "expected {:?} to contain {:?}", ids, id2);
}

#[test]
fn test_dns_client_resolve_socket_addr() {
let name = "127.0.0.2:11211";
let addr = "127.0.0.2:11211".parse().unwrap();
let servers = DiscoveryDefault::resolv_socket_addr(name, addr).unwrap();
let ids = servers.iter().map(|s| s.id()).collect::<Vec<_>>();

let id = ServerID::from(addr);
assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id);
}
}
4 changes: 2 additions & 2 deletions mtop-client/src/dns/rdata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ mod test {
// to store the length of the segment. In reality, we need 256 bytes for each segment
// so having 256 bytes * 256 segments should be an error.
let segment = "a".repeat(255);
let data: Vec<String> = (0..256).into_iter().map(|_| segment.clone()).collect();
let data: Vec<String> = (0..256).map(|_| segment.clone()).collect();
let res = RecordDataTXT::new(data);

assert!(res.is_err());
Expand All @@ -906,7 +906,7 @@ mod test {
#[test]
fn test_record_data_txt_new_success() {
let segment = "a".repeat(255);
let data: Vec<String> = (0..255).into_iter().map(|_| segment.clone()).collect();
let data: Vec<String> = (0..255).map(|_| segment.clone()).collect();
let txt = RecordDataTXT::new(data).unwrap();

assert_eq!(65280, txt.size());
Expand Down

0 comments on commit 10c6e77

Please sign in to comment.