Skip to content

Commit

Permalink
Merge pull request #134 from 56quarters/resolv
Browse files Browse the repository at this point in the history
Read DNS settings from a resolv.conf file
  • Loading branch information
56quarters authored May 17, 2024
2 parents 22ba6b7 + afbfa60 commit 459328e
Show file tree
Hide file tree
Showing 15 changed files with 762 additions and 182 deletions.
13 changes: 3 additions & 10 deletions mtop-client/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,7 @@ impl TryFrom<&HashMap<String, String>> for Slabs {
// $active_slabs + 1.
let mut ids = BTreeSet::new();
for k in value.keys() {
let key_id: Option<u64> = k
.split_once(':')
.map(|(raw, _rest)| raw)
.and_then(|raw| raw.parse().ok());
let key_id: Option<u64> = k.split_once(':').map(|(raw, _rest)| raw).and_then(|raw| raw.parse().ok());

if let Some(id) = key_id {
ids.insert(id);
Expand Down Expand Up @@ -1287,9 +1284,7 @@ mod test {
macro_rules! test_store_command_success {
($method:ident, $verb:expr) => {
let (mut rx, mut client) = client!("STORED\r\n");
let res = client
.$method(&Key::one("test").unwrap(), 0, 300, "val".as_bytes())
.await;
let res = client.$method(&Key::one("test").unwrap(), 0, 300, "val".as_bytes()).await;

assert!(res.is_ok());
let bytes = rx.recv().await.unwrap();
Expand All @@ -1301,9 +1296,7 @@ mod test {
macro_rules! test_store_command_error {
($method:ident, $verb:expr) => {
let (mut rx, mut client) = client!("NOT_STORED\r\n");
let res = client
.$method(&Key::one("test").unwrap(), 0, 300, "val".as_bytes())
.await;
let res = client.$method(&Key::one("test").unwrap(), 0, 300, "val".as_bytes()).await;

assert!(res.is_err());
let err = res.unwrap_err();
Expand Down
149 changes: 86 additions & 63 deletions mtop-client/src/discovery.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::core::MtopError;
use crate::dns::{DnsClient, RecordData};
use crate::dns::{DnsClient, Name, Record, RecordClass, RecordData, RecordType};
use std::cmp::Ordering;
use std::fmt;
use std::net::{IpAddr, SocketAddr};
Expand Down Expand Up @@ -60,47 +60,25 @@ impl AsRef<str> for ServerID {
/// An individual server that is part of a Memcached cluster.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct Server {
repr: ServerRepr,
}

#[derive(Debug, Clone, Eq, PartialEq, Hash)]
enum ServerRepr {
Resolved(ServerID, ServerName<'static>, SocketAddr),
Unresolved(ServerID, ServerName<'static>),
id: ServerID,
name: ServerName<'static>,
}

impl Server {
pub fn from_id(id: ServerID, name: ServerName<'static>) -> Self {
Self {
repr: ServerRepr::Unresolved(id, name),
}
}

pub fn from_addr(addr: SocketAddr, name: ServerName<'static>) -> Self {
Self {
repr: ServerRepr::Resolved(ServerID::from(addr), name, addr),
}
pub fn new(id: ServerID, name: ServerName<'static>) -> Self {
Self { id, name }
}

pub fn id(&self) -> ServerID {
match &self.repr {
ServerRepr::Resolved(id, _, _) => id.clone(),
ServerRepr::Unresolved(id, _) => id.clone(),
}
self.id.clone()
}

pub fn server_name(&self) -> ServerName<'static> {
match &self.repr {
ServerRepr::Resolved(_, name, _) => name.clone(),
ServerRepr::Unresolved(_, name) => name.clone(),
}
self.name.clone()
}

pub fn address(&self) -> String {
match &self.repr {
ServerRepr::Resolved(_, _, addr) => addr.to_string(),
ServerRepr::Unresolved(id, _) => id.to_string(),
}
self.id.to_string()
}
}

Expand Down Expand Up @@ -134,44 +112,56 @@ impl DiscoveryDefault {

pub async fn resolve_by_proto(&self, name: &str) -> Result<Vec<Server>, MtopError> {
if name.starts_with(DNS_A_PREFIX) {
Ok(self.resolve_a(name.trim_start_matches(DNS_A_PREFIX)).await?)
Ok(self.resolve_a_aaaa(name.trim_start_matches(DNS_A_PREFIX)).await?)
} else if name.starts_with(DNS_SRV_PREFIX) {
Ok(self.resolve_srv(name.trim_start_matches(DNS_SRV_PREFIX)).await?)
} else {
Ok(self.resolve_a(name).await?.pop().into_iter().collect())
Ok(self.resolve_a_aaaa(name).await?.pop().into_iter().collect())
}
}

async fn resolve_srv(&self, name: &str) -> Result<Vec<Server>, MtopError> {
let server_name = Self::server_name(name)?;
let (host_name, port) = Self::host_and_port(name)?;
let mut out = Vec::new();
let (host, port) = Self::host_and_port(name)?;
let server_name = Self::server_name(host)?;
let name = host.parse()?;

let res = self.client.resolve_srv(host_name).await?;
for a in res.answers() {
let target = if let RecordData::SRV(srv) = a.rdata() {
srv.target().to_string()
} else {
tracing::warn!(message = "unexpected record data for answer", name = host_name, answer = ?a);
continue;
};
let server_id = ServerID::from((target, port));
let server = Server::from_id(server_id, server_name.clone());
out.push(server);
}
let res = self.client.resolve(name, RecordType::SRV, RecordClass::INET).await?;
Ok(self.servers_from_answers(port, &server_name, res.answers()))
}

async fn resolve_a_aaaa(&self, name: &str) -> Result<Vec<Server>, MtopError> {
let (host, port) = Self::host_and_port(name)?;
let server_name = Self::server_name(host)?;
let name: Name = host.parse()?;

let res = self.client.resolve(name.clone(), RecordType::A, RecordClass::INET).await?;
let mut out = self.servers_from_answers(port, &server_name, res.answers());

let res = self.client.resolve(name, RecordType::AAAA, RecordClass::INET).await?;
out.extend(self.servers_from_answers(port, &server_name, res.answers()));

Ok(out)
}

async fn resolve_a(&self, name: &str) -> Result<Vec<Server>, MtopError> {
let server_name = Self::server_name(name)?;

fn servers_from_answers(&self, port: u16, server_name: &ServerName, answers: &[Record]) -> Vec<Server> {
let mut out = Vec::new();
for addr in tokio::net::lookup_host(name).await? {
out.push(Server::from_addr(addr, server_name.clone()));

for answer in answers {
let server_id = match answer.rdata() {
RecordData::A(data) => ServerID::from(SocketAddr::new(IpAddr::V4(data.addr()), port)),
RecordData::AAAA(data) => ServerID::from(SocketAddr::new(IpAddr::V6(data.addr()), port)),
RecordData::SRV(data) => ServerID::from((data.target().to_string(), port)),
_ => {
tracing::warn!(message = "unexpected record data for answer", answer = ?answer);
continue;
}
};

let server = Server::new(server_id, server_name.to_owned());
out.push(server);
}

Ok(out)
out
}

fn host_and_port(name: &str) -> Result<(&str, u16), MtopError> {
Expand All @@ -185,27 +175,26 @@ impl DiscoveryDefault {
// IPv6 addresses use brackets around them to disambiguate them from a port number.
// Since we're parsing the host and port, strip the brackets because they aren't
// needed or valid to include in a TLS ServerName.
.map(|(hostname, port)| (hostname.trim_start_matches('[').trim_end_matches(']'), port))
.and_then(|(hostname, port)| {
port.parse().map(|p| (hostname, p)).map_err(|e| {
.map(|(host, port)| (host.trim_start_matches('[').trim_end_matches(']'), port))
.and_then(|(host, port)| {
port.parse().map(|p| (host, p)).map_err(|e| {
MtopError::configuration_cause(format!("unable to parse port number from '{}'", name), e)
})
})
}

fn server_name(name: &str) -> Result<ServerName<'static>, MtopError> {
Self::host_and_port(name).and_then(|(host, _)| {
ServerName::try_from(host)
.map(|s| s.to_owned())
.map_err(|e| MtopError::configuration_cause(format!("invalid server name '{}'", host), e))
})
fn server_name(host: &str) -> Result<ServerName<'static>, MtopError> {
ServerName::try_from(host)
.map(|s| s.to_owned())
.map_err(|e| MtopError::configuration_cause(format!("invalid server name '{}'", host), e))
}
}

#[cfg(test)]
mod test {
use super::ServerID;
use super::{Server, ServerID};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use webpki::types::ServerName;

#[test]
fn test_server_id_from_ipv4_addr() {
Expand Down Expand Up @@ -241,4 +230,38 @@ mod test {
let id = ServerID::from(pair);
assert_eq!("cache.example.com:11211", id.to_string());
}

#[test]
fn test_server_resolved_id() {
let addr = SocketAddr::from(([127, 0, 0, 1], 11211));
let id = ServerID::from(addr);
let name = ServerName::try_from("cache.example.com").unwrap();
let server = Server::new(id, name);
assert_eq!("127.0.0.1:11211", server.id().to_string());
}

#[test]
fn test_server_resolved_address() {
let addr = SocketAddr::from(([127, 0, 0, 1], 11211));
let id = ServerID::from(addr);
let name = ServerName::try_from("cache.example.com").unwrap();
let server = Server::new(id, name);
assert_eq!("127.0.0.1:11211", server.address());
}

#[test]
fn test_server_unresolved_id() {
let id = ServerID::from(("cache01.example.com", 11211));
let name = ServerName::try_from("cache.example.com").unwrap();
let server = Server::new(id, name);
assert_eq!("cache01.example.com:11211", server.id().to_string());
}

#[test]
fn test_server_unresolved_address() {
let id = ServerID::from(("cache01.example.com", 11211));
let name = ServerName::try_from("cache.example.com").unwrap();
let server = Server::new(id, name);
assert_eq!("cache01.example.com:11211", server.address());
}
}
93 changes: 74 additions & 19 deletions mtop-client/src/dns/client.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,79 @@
use crate::core::MtopError;
use crate::dns::core::RecordType;
use crate::dns::core::{RecordClass, RecordType};
use crate::dns::message::{Flags, Message, MessageId, Question};
use crate::dns::name::Name;
use crate::dns::resolv::ResolvConf;
use crate::timeout::Timeout;
use std::io::Cursor;
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::net::UdpSocket;

const DEFAULT_RECV_BUF: usize = 512;

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct DnsClient {
local: SocketAddr,
server: SocketAddr,
config: ResolvConf,
server: AtomicUsize,
}

impl DnsClient {
pub fn new(local: SocketAddr, server: SocketAddr) -> Self {
Self { local, server }
}

pub async fn exchange(&self, msg: &Message) -> Result<Message, MtopError> {
let id = msg.id();
let sock = self.connect_udp().await?;
self.send_udp(&sock, msg).await?;
self.recv_udp(&sock, id).await
/// Create a new DnsClient that will use a local address to open UDP or TCP
/// connections and behavior based on a resolv.conf configuration file.
pub fn new(local: SocketAddr, config: ResolvConf) -> Self {
Self {
local,
config,
server: AtomicUsize::new(0),
}
}

pub async fn resolve_srv(&self, name: &str) -> Result<Message, MtopError> {
let n = Name::from_str(name)?;
/// Perform a DNS lookup with the configured nameservers.
///
/// Timeouts and network errors will result in up to one additional attempt
/// to perform a DNS lookup when using the default configuration.
pub async fn resolve(&self, name: Name, rtype: RecordType, rclass: RecordClass) -> Result<Message, MtopError> {
let full = name.to_fqdn();
let id = MessageId::random();
let flags = Flags::default().set_recursion_desired();
let msg = Message::new(id, flags).add_question(Question::new(n, RecordType::SRV));
let question = Question::new(full, rtype).set_qclass(rclass);
let message = Message::new(id, flags).add_question(question);

let mut attempt = 0;
loop {
match self.exchange(&message, usize::from(attempt)).await {
Ok(v) => return Ok(v),
Err(e) => {
if attempt + 1 >= self.config.options.attempts {
return Err(e);
}

self.exchange(&msg).await
tracing::debug!(message = "retrying failed query", attempt = attempt + 1, max_attempts = self.config.options.attempts, err = %e);
attempt += 1;
}
}
}
}
async fn exchange(&self, msg: &Message, attempt: usize) -> Result<Message, MtopError> {
let id = msg.id();
let server = self.nameserver(attempt);

async fn connect_udp(&self) -> Result<UdpSocket, MtopError> {
// Wrap creating the socket, sending, and receiving in an async block
// so that we can apply a single timeout to all operations and ensure
// we have access to the nameserver to make the timeout message useful.
async {
let sock = self.connect_udp(server).await?;
self.send_udp(&sock, msg).await?;
self.recv_udp(&sock, id).await
}
.timeout(self.config.options.timeout, format!("client.exchange {}", server))
.await
}

async fn connect_udp(&self, server: SocketAddr) -> Result<UdpSocket, MtopError> {
let sock = UdpSocket::bind(&self.local).await?;
sock.connect(&self.server).await?;
sock.connect(server).await?;
Ok(sock)
}

Expand All @@ -59,4 +94,24 @@ impl DnsClient {
}
}
}

fn nameserver(&self, attempt: usize) -> SocketAddr {
let idx = if self.config.options.rotate {
self.server.fetch_add(1, Ordering::Relaxed)
} else {
attempt
};

self.config.nameservers[idx % self.config.nameservers.len()]
}
}

impl Clone for DnsClient {
fn clone(&self) -> Self {
Self {
local: self.local,
config: self.config.clone(),
server: AtomicUsize::new(0),
}
}
}
Loading

0 comments on commit 459328e

Please sign in to comment.