Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DNS client used by server discovery generic #174

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mtop-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ pub trait Selector {
///
/// See https://en.wikipedia.org/wiki/Rendezvous_hashing
#[derive(Debug)]
pub struct SelectorRendezvous {
pub struct RendezvousSelector {
servers: Vec<Server>,
}

impl SelectorRendezvous {
impl RendezvousSelector {
/// Create a new instance with the provided initial server list.
pub fn new(servers: Vec<Server>) -> Self {
Self { servers }
Expand All @@ -82,7 +82,7 @@ impl SelectorRendezvous {
}
}

impl Selector for SelectorRendezvous {
impl Selector for RendezvousSelector {
async fn servers(&self) -> Vec<Server> {
self.servers.clone()
}
Expand Down Expand Up @@ -223,7 +223,7 @@ impl Default for MemcachedClientConfig {
/// Memcached client that operates on multiple servers, pooling connections
/// to them, and sharding keys via a `Selector` implementation.
#[derive(Debug)]
pub struct MemcachedClient<S = SelectorRendezvous, F = TcpFactory>
pub struct MemcachedClient<S = RendezvousSelector, F = TcpFactory>
where
S: Selector + Send + Sync + 'static,
F: ClientFactory<Server, Memcached> + Send + Sync + 'static,
Expand Down
80 changes: 31 additions & 49 deletions mtop-client/src/discovery.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::core::MtopError;
use crate::dns::{DnsClient, Message, Name, Record, RecordClass, RecordData, RecordType};
use crate::dns::{DefaultDnsClient, DnsClient, Name, Record, RecordClass, RecordData, RecordType};
use rustls_pki_types::ServerName;
use std::cmp::Ordering;
use std::fmt;
Expand Down Expand Up @@ -100,28 +100,23 @@ impl fmt::Display for Server {
}
}

/// Trait to represent our DNS client for easier testing
trait AsyncDnsClient {
async fn resolve(&self, name: Name, rtype: RecordType, rclass: RecordClass) -> Result<Message, MtopError>;
}

impl AsyncDnsClient for &DnsClient {
async fn resolve(&self, name: Name, rtype: RecordType, rclass: RecordClass) -> Result<Message, MtopError> {
DnsClient::resolve(self, name, rtype, rclass).await
}
}

/// Service discovery implementation for finding Memcached servers using DNS.
///
/// Different types of DNS records and different behaviors are used based on the
/// presence of specific prefixes for hostnames. See `resolve_by_proto` for details.
#[derive(Debug)]
pub struct DiscoveryDefault {
client: DnsClient,
pub struct Discovery<C = DefaultDnsClient>
where
C: DnsClient + Send + Sync + 'static,
{
client: C,
}

impl DiscoveryDefault {
pub fn new(client: DnsClient) -> Self {
impl<C> Discovery<C>
where
C: DnsClient + Send + Sync + 'static,
{
pub fn new(client: C) -> Self {
Self { client }
}

Expand All @@ -139,49 +134,36 @@ impl DiscoveryDefault {
/// * No prefix with a non-IP address will use the host as a Memcached server.
/// Resolution of the host to an IP address will happen at connection time using the
/// system resolver.
pub async fn resolve(&self, name: &str) -> Result<Vec<Server>, MtopError> {
Self::resolve_by_proto(&self.client, name).await
}

async fn resolve_by_proto<C>(client: C, name: &str) -> Result<Vec<Server>, MtopError>
where
C: AsyncDnsClient,
{
pub async fn resolve_by_proto(&self, name: &str) -> Result<Vec<Server>, MtopError> {
if name.starts_with(DNS_A_PREFIX) {
Ok(Self::resolve_a_aaaa(client, name.trim_start_matches(DNS_A_PREFIX)).await?)
Ok(self.resolve_a_aaaa(name.trim_start_matches(DNS_A_PREFIX)).await?)
} else if name.starts_with(DNS_SRV_PREFIX) {
Ok(Self::resolve_srv(client, name.trim_start_matches(DNS_SRV_PREFIX)).await?)
Ok(self.resolve_srv(name.trim_start_matches(DNS_SRV_PREFIX)).await?)
} else if let Ok(addr) = name.parse::<SocketAddr>() {
Ok(Self::resolv_socket_addr(name, addr)?)
} else {
Ok(Self::resolv_bare_host(name)?)
}
}

async fn resolve_srv<C>(client: C, name: &str) -> Result<Vec<Server>, MtopError>
where
C: AsyncDnsClient,
{
async fn resolve_srv(&self, name: &str) -> Result<Vec<Server>, MtopError> {
let (host, port) = Self::host_and_port(name)?;
let server_name = Self::server_name(host)?;
let name = host.parse()?;

let res = client.resolve(name, RecordType::SRV, RecordClass::INET).await?;
let res = self.client.resolve(name, RecordType::SRV, RecordClass::INET).await?;
Ok(Self::servers_from_answers(port, &server_name, res.answers()))
}

async fn resolve_a_aaaa<C>(client: C, name: &str) -> Result<Vec<Server>, MtopError>
where
C: AsyncDnsClient,
{
async fn resolve_a_aaaa(&self, name: &str) -> Result<Vec<Server>, MtopError> {
let (host, port) = Self::host_and_port(name)?;
let server_name = Self::server_name(host)?;
let name: Name = host.parse()?;

let res = client.resolve(name.clone(), RecordType::A, RecordClass::INET).await?;
let res = self.client.resolve(name.clone(), RecordType::A, RecordClass::INET).await?;
let mut out = Self::servers_from_answers(port, &server_name, res.answers());

let res = client.resolve(name, RecordType::AAAA, RecordClass::INET).await?;
let res = self.client.resolve(name, RecordType::AAAA, RecordClass::INET).await?;
out.extend(Self::servers_from_answers(port, &server_name, res.answers()));

Ok(out)
Expand Down Expand Up @@ -248,11 +230,11 @@ impl DiscoveryDefault {

#[cfg(test)]
mod test {
use super::{AsyncDnsClient, DiscoveryDefault, Server, ServerID};
use super::{Discovery, Server, ServerID};
use crate::core::MtopError;
use crate::dns::{
Flags, Message, MessageId, Name, Question, Record, RecordClass, RecordData, RecordDataA, RecordDataAAAA,
RecordDataSRV, RecordType,
DnsClient, Flags, Message, MessageId, Name, Question, Record, RecordClass, RecordData, RecordDataA,
RecordDataAAAA, RecordDataSRV, RecordType,
};
use rustls_pki_types::ServerName;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
Expand Down Expand Up @@ -340,7 +322,7 @@ mod test {
}
}

impl AsyncDnsClient for &MockDnsClient {
impl DnsClient for MockDnsClient {
async fn resolve(&self, _name: Name, _rtype: RecordType, _rclass: RecordClass) -> Result<Message, MtopError> {
let mut responses = self.responses.lock().await;
let res = responses.pop().unwrap();
Expand Down Expand Up @@ -385,9 +367,8 @@ mod test {
);

let client = MockDnsClient::new(vec![response_a, response_aaaa]);
let servers = DiscoveryDefault::resolve_by_proto(&client, "dns+example.com:11211")
.await
.unwrap();
let discovery = Discovery::new(client);
let servers = discovery.resolve_by_proto("dns+example.com:11211").await.unwrap();
let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();

let id_a = ServerID::from(("10.1.1.1", 11211));
Expand Down Expand Up @@ -430,9 +411,8 @@ mod test {
);

let client = MockDnsClient::new(vec![response]);
let servers = DiscoveryDefault::resolve_by_proto(&client, "dnssrv+_cache.example.com:11211")
.await
.unwrap();
let discovery = Discovery::new(client);
let servers = discovery.resolve_by_proto("dnssrv+_cache.example.com:11211").await.unwrap();
let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();

let id1 = ServerID::from(("cache01.example.com.", 11211));
Expand All @@ -448,7 +428,8 @@ mod test {
let addr: SocketAddr = "127.0.0.2:11211".parse().unwrap();

let client = MockDnsClient::new(vec![]);
let servers = DiscoveryDefault::resolve_by_proto(&client, name).await.unwrap();
let discovery = Discovery::new(client);
let servers = discovery.resolve_by_proto(name).await.unwrap();
let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();

let id = ServerID::from(addr);
Expand All @@ -460,7 +441,8 @@ mod test {
let name = "localhost:11211";

let client = MockDnsClient::new(vec![]);
let servers = DiscoveryDefault::resolve_by_proto(&client, name).await.unwrap();
let discovery = Discovery::new(client);
let servers = discovery.resolve_by_proto(name).await.unwrap();
let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();

let id = ServerID::from(("localhost", 11211));
Expand Down
81 changes: 51 additions & 30 deletions mtop-client/src/dns/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::net::tcp_connect;
use crate::pool::{ClientFactory, ClientPool, ClientPoolConfig};
use crate::timeout::Timeout;
use std::fmt::{self, Formatter};
use std::future::Future;
use std::io::{self, Cursor, Error};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::pin::Pin;
Expand Down Expand Up @@ -56,15 +57,38 @@ impl Default for DnsClientConfig {
}
}

/// Client for performing DNS queries and returning the results.
///
/// There is currently only a single non-test implementation because this
/// trait exists to make testing consumers easier.
pub trait DnsClient {
fn resolve(
&self,
name: Name,
rtype: RecordType,
rclass: RecordClass,
) -> impl Future<Output = Result<Message, MtopError>>;
}

/// Implementation of a `DnsClient` that uses UDP with TCP fallback.
///
/// Supports nameserver rotation, retries, timeouts, and pooling of client
/// connections. Names are assumed to already be fully qualified, meaning
/// that they are not combined with a search domain.
///
/// Timeouts are handled by the client itself and so callers should _not_
/// add a timeout on the `resolve` method. Note that timeouts are per-network
/// operation. This means that a single call to `resolve` make take longer
/// than the timeout since failed network operations are retried.
#[derive(Debug)]
pub struct DnsClient {
pub struct DefaultDnsClient {
config: DnsClientConfig,
server_idx: AtomicUsize,
udp_pool: ClientPool<SocketAddr, UdpClient, UdpFactory>,
tcp_pool: ClientPool<SocketAddr, TcpClient, TcpFactory>,
}

impl DnsClient {
impl DefaultDnsClient {
/// Create a new DnsClient that will resolve names using UDP or TCP connections
/// and behavior based on a resolv.conf configuration file.
pub fn new(config: DnsClientConfig) -> Self {
Expand All @@ -85,34 +109,6 @@ impl DnsClient {
tcp_pool: ClientPool::new(tcp_config, TcpFactory),
}
}

/// Perform a DNS lookup with the configured nameservers.
///
/// Timeouts and network errors will result in up to one additional attempt
/// to perform a DNS lookup when using the default configuration.
pub async fn resolve(&self, name: Name, rtype: RecordType, rclass: RecordClass) -> Result<Message, MtopError> {
let full = name.to_fqdn();
let id = MessageId::random();
let flags = Flags::default().set_recursion_desired();
let question = Question::new(full, rtype).set_qclass(rclass);
let message = Message::new(id, flags).add_question(question);

let mut attempt = 0;
loop {
match self.exchange(&message, usize::from(attempt)).await {
Ok(v) => return Ok(v),
Err(e) => {
if attempt + 1 >= self.config.attempts {
return Err(e);
}

tracing::debug!(message = "retrying failed query", attempt = attempt + 1, max_attempts = self.config.attempts, err = %e);
attempt += 1;
}
}
}
}

async fn exchange(&self, msg: &Message, attempt: usize) -> Result<Message, MtopError> {
let server = self.nameserver(attempt);

Expand Down Expand Up @@ -159,6 +155,31 @@ impl DnsClient {
}
}

impl DnsClient for DefaultDnsClient {
async fn resolve(&self, name: Name, rtype: RecordType, rclass: RecordClass) -> Result<Message, MtopError> {
let full = name.to_fqdn();
let id = MessageId::random();
let flags = Flags::default().set_recursion_desired();
let question = Question::new(full, rtype).set_qclass(rclass);
let message = Message::new(id, flags).add_question(question);

let mut attempt = 0;
loop {
match self.exchange(&message, usize::from(attempt)).await {
Ok(v) => return Ok(v),
Err(e) => {
if attempt + 1 >= self.config.attempts {
return Err(e);
}

tracing::debug!(message = "retrying failed query", attempt = attempt + 1, max_attempts = self.config.attempts, err = %e);
attempt += 1;
}
}
}
}
}

/// Client for sending and receiving DNS messages over read and write streams,
/// usually a TCP connection. Messages are sent with a two byte prefix that
/// indicates the size of the message. Responses are expected to have the same
Expand Down
2 changes: 1 addition & 1 deletion mtop-client/src/dns/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod name;
mod rdata;
mod resolv;

pub use crate::dns::client::{DnsClient, DnsClientConfig};
pub use crate::dns::client::{DefaultDnsClient, DnsClient, DnsClientConfig};
pub use crate::dns::core::{RecordClass, RecordType};
pub use crate::dns::message::{Flags, Message, MessageId, Operation, Question, Record, ResponseCode};
pub use crate::dns::name::Name;
Expand Down
4 changes: 2 additions & 2 deletions mtop-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ mod pool;
mod timeout;

pub use crate::client::{
MemcachedClient, MemcachedClientConfig, Selector, SelectorRendezvous, ServersResponse, TcpFactory, ValuesResponse,
MemcachedClient, MemcachedClientConfig, RendezvousSelector, Selector, ServersResponse, TcpFactory, ValuesResponse,
};
pub use crate::core::{
ErrorKind, Key, Memcached, Meta, MtopError, ProtocolError, ProtocolErrorKind, Slab, SlabItem, SlabItems, Slabs,
Stats, Value,
};
pub use crate::discovery::{DiscoveryDefault, Server, ServerID};
pub use crate::discovery::{Discovery, Server, ServerID};
pub use crate::net::TlsConfig;
pub use crate::pool::{ClientFactory, PooledClient};
pub use crate::timeout::{Timed, Timeout};
2 changes: 1 addition & 1 deletion mtop/src/bin/dns.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use clap::{Args, Parser, Subcommand, ValueHint};
use mtop::ping::{Bundle, DnsPinger};
use mtop::{profile, sig};
use mtop_client::dns::{Flags, Message, MessageId, Name, Question, Record, RecordClass, RecordType};
use mtop_client::dns::{DnsClient, Flags, Message, MessageId, Name, Question, Record, RecordClass, RecordType};
use std::fmt::Write;
use std::io::Cursor;
use std::net::SocketAddr;
Expand Down
Loading