Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deduplicate servers by ID when doing service discovery resolution #185

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 58 additions & 10 deletions mtop-client/src/discovery.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::core::MtopError;
use crate::dns::{DefaultDnsClient, DnsClient, Name, Record, RecordClass, RecordData, RecordType};
use crate::dns::{DefaultDnsClient, DnsClient, Message, Name, RecordClass, RecordData, RecordType};
use rustls_pki_types::ServerName;
use std::cmp::Ordering;
use std::collections::HashSet;
use std::fmt;
use std::net::{IpAddr, SocketAddr};

Expand Down Expand Up @@ -152,7 +153,7 @@ where
let name = host.parse()?;

let res = self.client.resolve(name, RecordType::SRV, RecordClass::INET).await?;
Ok(Self::servers_from_answers(port, &server_name, res.answers()))
Ok(Self::servers_from_answers(port, &server_name, &res))
}

async fn resolve_a_aaaa(&self, name: &str) -> Result<Vec<Server>, MtopError> {
Expand All @@ -161,10 +162,10 @@ where
let name: Name = host.parse()?;

let res = self.client.resolve(name.clone(), RecordType::A, RecordClass::INET).await?;
let mut out = Self::servers_from_answers(port, &server_name, res.answers());
let mut out = Self::servers_from_answers(port, &server_name, &res);

let res = self.client.resolve(name, RecordType::AAAA, RecordClass::INET).await?;
out.extend(Self::servers_from_answers(port, &server_name, res.answers()));
out.extend(Self::servers_from_answers(port, &server_name, &res));

Ok(out)
}
Expand All @@ -181,10 +182,10 @@ where
Ok(vec![Server::new(ServerID::from((host, port)), server_name)])
}

fn servers_from_answers(port: u16, server_name: &ServerName, answers: &[Record]) -> Vec<Server> {
let mut out = Vec::new();
fn servers_from_answers(port: u16, server_name: &ServerName, message: &Message) -> Vec<Server> {
let mut ids = HashSet::new();

for answer in answers {
for answer in message.answers() {
let server_id = match answer.rdata() {
RecordData::A(data) => ServerID::from(SocketAddr::new(IpAddr::V4(data.addr()), port)),
RecordData::AAAA(data) => ServerID::from(SocketAddr::new(IpAddr::V6(data.addr()), port)),
Expand All @@ -195,11 +196,14 @@ where
}
};

let server = Server::new(server_id, server_name.to_owned());
out.push(server);
// Insert IDs into a HashSet to deduplicate them. We can potentially end up with duplicates
// when a SRV query returns multiple answers per hostname (such as when each host has more
// than a single port). Because we ignore the port number from the SRV answer we need to
// deduplicate here.
ids.insert(server_id);
}

out
ids.into_iter().map(|id| Server::new(id, server_name.to_owned())).collect()
}

fn host_and_port(name: &str) -> Result<(&str, u16), MtopError> {
Expand Down Expand Up @@ -422,6 +426,48 @@ mod test {
assert!(ids.contains(&id2), "expected {:?} to contain {:?}", ids, id2);
}

#[tokio::test]
async fn test_dns_client_resolve_srv_dupes() {
let response = response_with_answers(
RecordType::SRV,
vec![
Record::new(
Name::from_str("_cache.example.com.").unwrap(),
RecordType::SRV,
RecordClass::INET,
300,
RecordData::SRV(RecordDataSRV::new(
100,
10,
11211,
Name::from_str("cache01.example.com.").unwrap(),
)),
),
Record::new(
Name::from_str("_cache.example.com.").unwrap(),
RecordType::SRV,
RecordClass::INET,
300,
RecordData::SRV(RecordDataSRV::new(
100,
10,
9105,
Name::from_str("cache01.example.com.").unwrap(),
)),
),
],
);

let client = MockDnsClient::new(vec![response]);
let discovery = Discovery::new(client);
let servers = discovery.resolve_by_proto("dnssrv+_cache.example.com:11211").await.unwrap();
let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();

let id = ServerID::from(("cache01.example.com.", 11211));

assert_eq!(ids, vec![id]);
}

#[tokio::test]
async fn test_dns_client_resolve_socket_addr() {
let name = "127.0.0.2:11211";
Expand All @@ -433,6 +479,7 @@ mod test {
let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();

let id = ServerID::from(addr);

assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id);
}

Expand All @@ -446,6 +493,7 @@ mod test {
let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();

let id = ServerID::from(("localhost", 11211));

assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id);
}
}