diff --git a/Cargo.lock b/Cargo.lock index f14ec9f..4122a75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -404,6 +404,7 @@ dependencies = [ "rand", "rand_distr", "ratatui", + "rustls-webpki", "tokio", "tracing", "tracing-subscriber", diff --git a/mtop-client/src/client.rs b/mtop-client/src/client.rs index bf65956..4061154 100644 --- a/mtop-client/src/client.rs +++ b/mtop-client/src/client.rs @@ -1,5 +1,6 @@ use crate::core::{Key, Meta, MtopError, SlabItems, Slabs, Stats, Value}; -use crate::pool::{MemcachedPool, PooledMemcached, Server, ServerID}; +use crate::discovery::{Server, ServerID}; +use crate::pool::{MemcachedPool, PooledMemcached}; use std::collections::hash_map::DefaultHasher; use std::collections::HashMap; use std::hash::Hasher; diff --git a/mtop-client/src/discovery.rs b/mtop-client/src/discovery.rs new file mode 100644 index 0000000..662455c --- /dev/null +++ b/mtop-client/src/discovery.rs @@ -0,0 +1,244 @@ +use crate::core::MtopError; +use crate::dns::{DnsClient, RecordData}; +use std::cmp::Ordering; +use std::fmt; +use std::net::{IpAddr, SocketAddr}; +use webpki::types::ServerName; + +const DNS_A_PREFIX: &str = "dns+"; +const DNS_SRV_PREFIX: &str = "dnssrv+"; + +/// Unique ID for a server in a Memcached cluster. +#[derive(Debug, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] +#[repr(transparent)] +pub struct ServerID(String); + +impl ServerID { + fn from_host_port(host: S, port: u16) -> Self + where + S: AsRef, + { + let host = host.as_ref(); + if let Ok(ip) = host.parse::() { + Self(SocketAddr::from((ip, port)).to_string()) + } else { + Self(format!("{}:{}", host, port)) + } + } +} + +impl From<(&str, u16)> for ServerID { + fn from(value: (&str, u16)) -> Self { + Self::from_host_port(value.0, value.1) + } +} + +impl From<(String, u16)> for ServerID { + fn from(value: (String, u16)) -> Self { + Self::from_host_port(value.0, value.1) + } +} + +impl From for ServerID { + fn from(value: SocketAddr) -> Self { + Self(value.to_string()) + } +} + +impl fmt::Display for ServerID { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl AsRef for ServerID { + fn as_ref(&self) -> &str { + &self.0 + } +} + +/// An individual server that is part of a Memcached cluster. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct Server { + repr: ServerRepr, +} + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +enum ServerRepr { + Resolved(ServerID, ServerName<'static>, SocketAddr), + Unresolved(ServerID, ServerName<'static>), +} + +impl Server { + pub fn from_id(id: ServerID, name: ServerName<'static>) -> Self { + Self { + repr: ServerRepr::Unresolved(id, name), + } + } + + pub fn from_addr(addr: SocketAddr, name: ServerName<'static>) -> Self { + Self { + repr: ServerRepr::Resolved(ServerID::from(addr), name, addr), + } + } + + pub fn id(&self) -> ServerID { + match &self.repr { + ServerRepr::Resolved(id, _, _) => id.clone(), + ServerRepr::Unresolved(id, _) => id.clone(), + } + } + + pub fn server_name(&self) -> ServerName<'static> { + match &self.repr { + ServerRepr::Resolved(_, name, _) => name.clone(), + ServerRepr::Unresolved(_, name) => name.clone(), + } + } + + pub fn address(&self) -> String { + match &self.repr { + ServerRepr::Resolved(_, _, addr) => addr.to_string(), + ServerRepr::Unresolved(id, _) => id.to_string(), + } + } +} + +impl PartialOrd for Server { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Server { + fn cmp(&self, other: &Self) -> Ordering { + self.id().cmp(&other.id()) + } +} + +impl fmt::Display for Server { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.id()) + } +} + +#[derive(Debug, Clone)] +pub struct DiscoveryDefault { + client: DnsClient, +} + +impl DiscoveryDefault { + pub fn new(client: DnsClient) -> Self { + Self { client } + } + + pub async fn resolve_by_proto(&self, name: &str) -> Result, MtopError> { + if name.starts_with(DNS_A_PREFIX) { + Ok(self.resolve_a(name.trim_start_matches(DNS_A_PREFIX)).await?) + } else if name.starts_with(DNS_SRV_PREFIX) { + Ok(self.resolve_srv(name.trim_start_matches(DNS_SRV_PREFIX)).await?) + } else { + Ok(self.resolve_a(name).await?.pop().into_iter().collect()) + } + } + + async fn resolve_srv(&self, name: &str) -> Result, MtopError> { + let server_name = Self::server_name(name)?; + let (host_name, port) = Self::host_and_port(name)?; + let mut out = Vec::new(); + + let res = self.client.resolve_srv(host_name).await?; + for a in res.answers() { + let target = if let RecordData::SRV(srv) = a.rdata() { + srv.target().to_string() + } else { + tracing::warn!(message = "unexpected record data for answer", name = host_name, answer = ?a); + continue; + }; + let server_id = ServerID::from((target, port)); + let server = Server::from_id(server_id, server_name.clone()); + out.push(server); + } + + Ok(out) + } + + async fn resolve_a(&self, name: &str) -> Result, MtopError> { + let server_name = Self::server_name(name)?; + + let mut out = Vec::new(); + for addr in tokio::net::lookup_host(name).await? { + out.push(Server::from_addr(addr, server_name.clone())); + } + + Ok(out) + } + + fn host_and_port(name: &str) -> Result<(&str, u16), MtopError> { + name.rsplit_once(':') + .ok_or_else(|| { + MtopError::configuration(format!( + "invalid server name '{}', must be of the form 'host:port'", + name + )) + }) + // IPv6 addresses use brackets around them to disambiguate them from a port number. + // Since we're parsing the host and port, strip the brackets because they aren't + // needed or valid to include in a TLS ServerName. + .map(|(hostname, port)| (hostname.trim_start_matches('[').trim_end_matches(']'), port)) + .and_then(|(hostname, port)| { + port.parse().map(|p| (hostname, p)).map_err(|e| { + MtopError::configuration_cause(format!("unable to parse port number from '{}'", name), e) + }) + }) + } + + fn server_name(name: &str) -> Result, MtopError> { + Self::host_and_port(name).and_then(|(host, _)| { + ServerName::try_from(host) + .map(|s| s.to_owned()) + .map_err(|e| MtopError::configuration_cause(format!("invalid server name '{}'", host), e)) + }) + } +} + +#[cfg(test)] +mod test { + use super::ServerID; + use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; + + #[test] + fn test_server_id_from_ipv4_addr() { + let addr = SocketAddr::from((Ipv4Addr::new(127, 1, 1, 1), 11211)); + let id = ServerID::from(addr); + assert_eq!("127.1.1.1:11211", id.to_string()); + } + + #[test] + fn test_server_id_from_ipv6_addr() { + let addr = SocketAddr::from((Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 11211)); + let id = ServerID::from(addr); + assert_eq!("[::1]:11211", id.to_string()); + } + + #[test] + fn test_server_id_from_ipv4_pair() { + let pair = ("10.1.1.22", 11212); + let id = ServerID::from(pair); + assert_eq!("10.1.1.22:11212", id.to_string()); + } + + #[test] + fn test_server_id_from_ipv6_pair() { + let pair = ("::1", 11212); + let id = ServerID::from(pair); + assert_eq!("[::1]:11212", id.to_string()); + } + + #[test] + fn test_server_id_from_host_pair() { + let pair = ("cache.example.com", 11211); + let id = ServerID::from(pair); + assert_eq!("cache.example.com:11211", id.to_string()); + } +} diff --git a/mtop-client/src/dns/client.rs b/mtop-client/src/dns/client.rs new file mode 100644 index 0000000..66e80de --- /dev/null +++ b/mtop-client/src/dns/client.rs @@ -0,0 +1,62 @@ +use crate::core::MtopError; +use crate::dns::core::RecordType; +use crate::dns::message::{Flags, Message, MessageId, Question}; +use crate::dns::name::Name; +use std::io::Cursor; +use std::net::SocketAddr; +use std::str::FromStr; +use tokio::net::UdpSocket; + +const DEFAULT_RECV_BUF: usize = 512; + +#[derive(Debug, Clone)] +pub struct DnsClient { + local: SocketAddr, + server: SocketAddr, +} + +impl DnsClient { + pub fn new(local: SocketAddr, server: SocketAddr) -> Self { + Self { local, server } + } + + pub async fn exchange(&self, msg: &Message) -> Result { + let id = msg.id(); + let sock = self.connect_udp().await?; + self.send_udp(&sock, msg).await?; + self.recv_udp(&sock, id).await + } + + pub async fn resolve_srv(&self, name: &str) -> Result { + let n = Name::from_str(name)?; + let id = MessageId::random(); + let flags = Flags::default().set_recursion_desired(); + let msg = Message::new(id, flags).add_question(Question::new(n, RecordType::SRV)); + + self.exchange(&msg).await + } + + async fn connect_udp(&self) -> Result { + let sock = UdpSocket::bind(&self.local).await?; + sock.connect(&self.server).await?; + Ok(sock) + } + + async fn send_udp(&self, socket: &UdpSocket, msg: &Message) -> Result<(), MtopError> { + let mut buf = Vec::new(); + msg.write_network_bytes(&mut buf)?; + Ok(socket.send(&buf).await.map(|_| ())?) + } + + async fn recv_udp(&self, socket: &UdpSocket, id: MessageId) -> Result { + let mut buf = vec![0_u8; DEFAULT_RECV_BUF]; + loop { + let n = socket.recv(&mut buf).await?; + let cur = Cursor::new(&buf[0..n]); + let msg = Message::read_network_bytes(cur)?; + if msg.id() == id { + return Ok(msg); + } + } + } +} diff --git a/mtop-client/src/dns/message.rs b/mtop-client/src/dns/message.rs index 2b3c20a..d369e00 100644 --- a/mtop-client/src/dns/message.rs +++ b/mtop-client/src/dns/message.rs @@ -4,12 +4,40 @@ use crate::dns::name::Name; use crate::dns::rdata::RecordData; use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; -use std::fmt::Debug; +use std::fmt::{Debug, Formatter}; use std::io::Seek; +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[repr(transparent)] +pub struct MessageId(u16); + +impl MessageId { + pub fn random() -> Self { + Self(rand::random()) + } +} + +impl From for MessageId { + fn from(value: u16) -> Self { + Self(value) + } +} + +impl From for u16 { + fn from(value: MessageId) -> Self { + value.0 + } +} + +impl fmt::Display for MessageId { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + #[derive(Debug, Clone, Eq, PartialEq)] pub struct Message { - id: u16, + id: MessageId, flags: Flags, questions: Vec, answers: Vec, @@ -18,7 +46,7 @@ pub struct Message { } impl Message { - pub fn new(id: u16, flags: Flags) -> Self { + pub fn new(id: MessageId, flags: Flags) -> Self { Self { id, flags, @@ -29,11 +57,11 @@ impl Message { } } - pub fn id(&self) -> u16 { + pub fn id(&self) -> MessageId { self.id } - pub fn set_id(mut self, id: u16) -> Self { + pub fn set_id(mut self, id: MessageId) -> Self { self.id = id; self } @@ -164,7 +192,7 @@ impl Message { #[derive(Debug, Clone, Eq, PartialEq)] struct Header { - id: u16, + id: MessageId, flags: Flags, num_questions: u16, num_answers: u16, @@ -177,7 +205,7 @@ impl Header { where T: WriteBytesExt, { - buf.write_u16::(self.id)?; + buf.write_u16::(self.id.into())?; buf.write_u16::(self.flags.as_u16())?; buf.write_u16::(self.num_questions)?; buf.write_u16::(self.num_answers)?; @@ -189,7 +217,7 @@ impl Header { where T: ReadBytesExt, { - let id = buf.read_u16::()?; + let id = MessageId::from(buf.read_u16::()?); let flags = Flags::try_from(buf.read_u16::()?)?; let num_questions = buf.read_u16::()?; let num_answers = buf.read_u16::()?; @@ -543,7 +571,7 @@ impl Record { #[cfg(test)] mod test { - use super::{Flags, Header, Message, Operation, Question, Record, ResponseCode}; + use super::{Flags, Header, Message, MessageId, Operation, Question, Record, ResponseCode}; use crate::dns::core::{RecordClass, RecordType}; use crate::dns::name::Name; use crate::dns::rdata::{RecordData, RecordDataA, RecordDataSRV}; @@ -578,7 +606,7 @@ mod test { ); let message = Message::new( - 65333, Flags::default() + MessageId::from(65333), Flags::default() .set_response() .set_op_code(Operation::Query) .set_response_code(ResponseCode::NoError)) @@ -722,7 +750,7 @@ mod test { ]); let message = Message::read_network_bytes(cur).unwrap(); - assert_eq!(65333, message.id()); + assert_eq!(MessageId::from(65333), message.id()); assert_eq!( Flags::default() .set_response() @@ -768,7 +796,7 @@ mod test { #[test] fn test_header_write_network_bytes() { let h = Header { - id: 65333, + id: MessageId::from(65333), flags: Flags::default().set_recursion_desired(), num_questions: 1, num_answers: 2, @@ -805,7 +833,7 @@ mod test { ]); let h = Header::read_network_bytes(cur).unwrap(); - assert_eq!(65333, h.id); + assert_eq!(MessageId::from(65333), h.id); assert_eq!(Flags::default().set_recursion_desired(), h.flags); assert_eq!(1, h.num_questions); assert_eq!(2, h.num_answers); diff --git a/mtop-client/src/dns/mod.rs b/mtop-client/src/dns/mod.rs index 3d300a2..d8502bb 100644 --- a/mtop-client/src/dns/mod.rs +++ b/mtop-client/src/dns/mod.rs @@ -1,40 +1,14 @@ +mod client; mod core; mod message; mod name; mod rdata; +pub use crate::dns::client::DnsClient; pub use crate::dns::core::{RecordClass, RecordType}; -pub use crate::dns::message::{Flags, Message, Operation, Question, Record, ResponseCode}; +pub use crate::dns::message::{Flags, Message, MessageId, Operation, Question, Record, ResponseCode}; pub use crate::dns::name::Name; pub use crate::dns::rdata::{ RecordData, RecordDataA, RecordDataAAAA, RecordDataCNAME, RecordDataNS, RecordDataSOA, RecordDataSRV, RecordDataTXT, RecordDataUnknown, }; - -use crate::core::MtopError; -use std::io::Cursor; -use tokio::net::UdpSocket; - -const DEFAULT_RECV_BUF: usize = 512; - -pub fn id() -> u16 { - rand::random() -} - -pub async fn send(sock: &UdpSocket, msg: &Message) -> Result<(), MtopError> { - let mut buf = Vec::new(); - msg.write_network_bytes(&mut buf)?; - Ok(sock.send(&buf).await.map(|_| ())?) -} - -pub async fn recv(sock: &UdpSocket, id: u16) -> Result { - let mut buf = vec![0_u8; DEFAULT_RECV_BUF]; - loop { - let n = sock.recv(&mut buf).await?; - let cur = Cursor::new(&buf[0..n]); - let msg = Message::read_network_bytes(cur)?; - if msg.id() == id { - return Ok(msg); - } - } -} diff --git a/mtop-client/src/lib.rs b/mtop-client/src/lib.rs index db57ddc..8f6cb9f 100644 --- a/mtop-client/src/lib.rs +++ b/mtop-client/src/lib.rs @@ -1,14 +1,15 @@ mod client; mod core; +mod discovery; +pub mod dns; mod pool; mod timeout; -pub mod dns; - pub use crate::client::{MemcachedClient, SelectorRendezvous, ServersResponse, ValuesResponse}; pub use crate::core::{ ErrorKind, Key, Memcached, Meta, MtopError, ProtocolError, ProtocolErrorKind, Slab, SlabItem, SlabItems, Slabs, Stats, Value, }; -pub use crate::pool::{DiscoveryDefault, MemcachedPool, PoolConfig, PooledMemcached, Server, ServerID, TLSConfig}; +pub use crate::discovery::{DiscoveryDefault, Server, ServerID}; +pub use crate::pool::{MemcachedPool, PoolConfig, PooledMemcached, TLSConfig}; pub use crate::timeout::{Timed, Timeout}; diff --git a/mtop-client/src/pool.rs b/mtop-client/src/pool.rs index 48f1f66..3b2b024 100644 --- a/mtop-client/src/pool.rs +++ b/mtop-client/src/pool.rs @@ -1,11 +1,10 @@ use crate::core::{Memcached, MtopError}; -use std::cmp::Ordering; +use crate::discovery::Server; use std::collections::HashMap; use std::fmt::{self, Debug}; use std::fs::File; use std::future::Future; use std::io::{self, BufReader}; -use std::net::SocketAddr; use std::ops::{Deref, DerefMut}; use std::path::PathBuf; use std::sync::Arc; @@ -16,117 +15,6 @@ use tokio_rustls::rustls::{ClientConfig, RootCertStore}; use tokio_rustls::TlsConnector; use webpki::types::{CertificateDer, PrivateKeyDer, ServerName}; -const DNS_HOST_PREFIX: &str = "dns+"; - -/// Unique ID for a server in a Memcached cluster. -#[derive(Debug, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] -#[repr(transparent)] -pub struct ServerID(String); - -impl From for ServerID { - fn from(value: SocketAddr) -> Self { - Self(value.to_string()) - } -} - -impl fmt::Display for ServerID { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl AsRef for ServerID { - fn as_ref(&self) -> &str { - &self.0 - } -} - -/// An individual server that is part of a Memcached cluster. -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct Server { - id: ServerID, - addr: SocketAddr, - name: ServerName<'static>, -} - -impl Server { - pub fn from(addr: SocketAddr, name: ServerName<'static>) -> Self { - Self { - id: ServerID::from(addr), - addr, - name, - } - } - - pub fn id(&self) -> ServerID { - self.id.clone() - } - - pub fn addr(&self) -> SocketAddr { - self.addr - } - - pub fn server_name(&self) -> ServerName<'static> { - self.name.clone() - } -} - -impl PartialOrd for Server { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for Server { - fn cmp(&self, other: &Self) -> Ordering { - self.id.cmp(&other.id) - } -} - -impl fmt::Display for Server { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.id) - } -} - -fn host_to_server_name(host: &str) -> Result, MtopError> { - ServerName::try_from(host) - .map(|s| s.to_owned()) - .map_err(|e| MtopError::configuration_cause(format!("invalid server name '{}'", host), e)) -} - -#[derive(Debug, Default, Clone)] -pub struct DiscoveryDefault; - -impl DiscoveryDefault { - pub async fn resolve_by_proto(&self, name: &str) -> Result, MtopError> { - if name.starts_with(DNS_HOST_PREFIX) { - Ok(self.resolve_a(name.trim_start_matches(DNS_HOST_PREFIX)).await?) - } else { - Ok(self.resolve_a(name).await?.pop().into_iter().collect()) - } - } - - async fn resolve_a(&self, name: &str) -> Result, MtopError> { - // Names must be of the form hostname:port. The hostname can be an IP address or - // an actual DNS name. We trim leading and trailing brackets from the hostname portion - // since these are used to disambiguate IPv6 addresses from the port number but - // aren't allowed for _just_ an IP address. - let server_name = name - .rsplit_once(':') - .ok_or_else(|| MtopError::configuration(format!("invalid server name '{}'", name))) - .map(|(hostname, _)| hostname.trim_start_matches('[').trim_end_matches(']')) - .and_then(host_to_server_name)?; - - let mut out = Vec::new(); - for addr in tokio::net::lookup_host(name).await? { - out.push(Server::from(addr, server_name.clone())); - } - - Ok(out) - } -} - #[derive(Debug)] pub struct PooledMemcached { inner: Memcached, @@ -172,25 +60,18 @@ pub struct TLSConfig { pub ca_path: Option, pub cert_path: Option, pub key_path: Option, - pub server_name: Option, + pub server_name: Option>, } #[derive(Debug)] pub struct MemcachedPool { connections: Mutex>>, client_config: Option>, - server_name: Option>, config: PoolConfig, } impl MemcachedPool { pub async fn new(handle: Handle, config: PoolConfig) -> Result { - let server_name = if let Some(s) = &config.tls.server_name { - Some(host_to_server_name(s)?) - } else { - None - }; - let client_config = if config.tls.enabled { Some(Arc::new(Self::client_config(handle, &config.tls).await?)) } else { @@ -200,7 +81,6 @@ impl MemcachedPool { Ok(MemcachedPool { connections: Mutex::new(HashMap::new()), client_config, - server_name, config, }) } @@ -299,11 +179,16 @@ impl MemcachedPool { async fn connect(&self, server: &Server) -> Result { if let Some(cfg) = &self.client_config { - let name = self.server_name.clone().unwrap_or_else(|| server.server_name()); + let name = self + .config + .tls + .server_name + .clone() + .unwrap_or_else(|| server.server_name()); tracing::debug!(message = "using server name for TLS validation", server_name = ?name); - tls_connect(server.addr(), name, cfg.clone()).await + tls_connect(server.address(), name, cfg.clone()).await } else { - plain_connect(server.addr()).await + plain_connect(server.address()).await } } @@ -407,7 +292,7 @@ mod test { async fn test_get_new_connection() { let cfg = PoolConfig::default(); let pool = MemcachedPool::new(Handle::current(), cfg).await.unwrap(); - let server = Server::from( + let server = Server::from_addr( "127.0.0.1:11211".parse().unwrap(), ServerName::try_from("localhost").unwrap().to_owned(), ); @@ -428,7 +313,7 @@ mod test { async fn test_get_existing_connection() { let cfg = PoolConfig::default(); let pool = MemcachedPool::new(Handle::current(), cfg).await.unwrap(); - let server = Server::from( + let server = Server::from_addr( "127.0.0.1:11211".parse().unwrap(), ServerName::try_from("localhost").unwrap().to_owned(), ); @@ -454,7 +339,7 @@ mod test { }; let pool = MemcachedPool::new(Handle::current(), cfg).await.unwrap(); - let server = Server::from( + let server = Server::from_addr( "127.0.0.1:11211".parse().unwrap(), ServerName::try_from("localhost").unwrap().to_owned(), ); @@ -478,7 +363,7 @@ mod test { let cfg = PoolConfig::default(); let pool = MemcachedPool::new(Handle::current(), cfg).await.unwrap(); - let server = Server::from( + let server = Server::from_addr( "127.0.0.1:11211".parse().unwrap(), ServerName::try_from("localhost").unwrap().to_owned(), ); diff --git a/mtop/Cargo.toml b/mtop/Cargo.toml index 2374464..8772a44 100644 --- a/mtop/Cargo.toml +++ b/mtop/Cargo.toml @@ -16,6 +16,7 @@ crossterm = "0.27.0" mtop-client = { path = "../mtop-client", version = "0.9.0" } rand = "0.8.5" rand_distr = "0.4.3" +rustls-webpki = "0.102.0" ratatui = "0.26.0" tokio = { version = "1.14.0", features = ["full"] } tracing = "0.1.11" diff --git a/mtop/src/bin/dns.rs b/mtop/src/bin/dns.rs index 75e5911..74b3331 100644 --- a/mtop/src/bin/dns.rs +++ b/mtop/src/bin/dns.rs @@ -1,18 +1,17 @@ use clap::{Args, Parser, Subcommand}; -use mtop_client::dns::{Flags, Message, Name, Question, Record, RecordClass, RecordType}; -use mtop_client::MtopError; +use mtop_client::dns::{DnsClient, Flags, Message, MessageId, Name, Question, Record, RecordClass, RecordType}; use mtop_client::Timeout; use std::fmt::Write; use std::io::Cursor; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::process::ExitCode; use std::str::FromStr; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::ToSocketAddrs; -use tokio::net::UdpSocket; use tracing::Level; -const DEFAULT_HOST: &str = "127.0.0.1:53"; +const DEFAULT_DNS_LOCAL: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); +const DEFAULT_DNS_SERVER: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 53); const DEFAULT_TIMEOUT_SECS: u64 = 5; const DEFAULT_RECORD_TYPE: RecordType = RecordType::A; const DEFAULT_RECORD_CLASS: RecordClass = RecordClass::INET; @@ -35,9 +34,13 @@ enum Action { /// Perform a DNS query and display the result as dig-like text output. #[derive(Debug, Args)] struct QueryCommand { - /// Server to make the request to in the form 'hostname:port' - #[arg(long, default_value_t = DEFAULT_HOST.to_owned())] - server: String, + /// Local address for DNS requests for service discovery in the form 'address:port' + #[arg(long, default_value_t = DEFAULT_DNS_LOCAL)] + dns_local: SocketAddr, + + /// DNS server for service discovery in the form 'address:port' + #[arg(long, default_value_t = DEFAULT_DNS_SERVER)] + dns_server: SocketAddr, /// Timeout for making requests to a DNS server, in seconds. #[arg(long, default_value_t = DEFAULT_TIMEOUT_SECS)] @@ -96,25 +99,8 @@ async fn main() -> ExitCode { } } -async fn connect(server: A) -> Result -where - A: ToSocketAddrs, -{ - let socket = UdpSocket::bind("0.0.0.0:0").await?; - socket.connect(server).await?; - Ok(socket) -} - async fn run_query(cmd: &QueryCommand) -> ExitCode { let timeout = Duration::from_secs(cmd.timeout_secs); - let socket = match connect(&cmd.server).timeout(timeout, "socket.connect").await { - Ok(s) => s, - Err(e) => { - tracing::error!(message = "unable to open UDP socket", "server" = cmd.server, err = %e); - return ExitCode::FAILURE; - } - }; - let name = match Name::from_str(&cmd.name) { Ok(n) => n, Err(e) => { @@ -123,25 +109,15 @@ async fn run_query(cmd: &QueryCommand) -> ExitCode { } }; - let id = mtop_client::dns::id(); + let id = MessageId::random(); let msg = Message::new(id, Flags::default().set_query().set_recursion_desired()) .add_question(Question::new(name, cmd.rtype).set_qclass(cmd.rclass)); - if let Err(e) = mtop_client::dns::send(&socket, &msg) - .timeout(timeout, "socket.send") - .await - { - tracing::error!(message = "unable to send message", "server" = cmd.server, err = %e); - return ExitCode::FAILURE; - } - - let response = match mtop_client::dns::recv(&socket, id) - .timeout(timeout, "socket.recv") - .await - { + let client = DnsClient::new(cmd.dns_local, cmd.dns_server); + let response = match client.exchange(&msg).timeout(timeout, "client.exchange").await { Ok(r) => r, Err(e) => { - tracing::error!(message = "unable to receive message", "server" = cmd.server, err = %e); + tracing::error!(message = "unable to exchange message", "server" = %cmd.dns_server, err = %e); return ExitCode::FAILURE; } }; @@ -187,7 +163,7 @@ async fn run_write(cmd: &WriteCommand) -> ExitCode { } }; - let id = rand::random(); + let id = MessageId::random(); let msg = Message::new(id, Flags::default().set_query().set_recursion_desired()) .add_question(Question::new(name, cmd.rtype).set_qclass(cmd.rclass)); diff --git a/mtop/src/bin/mc.rs b/mtop/src/bin/mc.rs index a2604c0..ed7ae90 100644 --- a/mtop/src/bin/mc.rs +++ b/mtop/src/bin/mc.rs @@ -2,12 +2,14 @@ use clap::{Args, Parser, Subcommand, ValueHint}; use mtop::bench::{Bencher, Percent, Summary}; use mtop::check::{Checker, TimingBundle}; use mtop::profile::Profiler; +use mtop_client::dns::DnsClient; use mtop_client::{ DiscoveryDefault, MemcachedClient, MemcachedPool, Meta, MtopError, PoolConfig, SelectorRendezvous, Server, TLSConfig, Timeout, Value, }; use std::fs::File; use std::io::Write; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; use std::process::ExitCode; use std::time::Duration; @@ -15,7 +17,10 @@ use std::{env, io}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}; use tokio::runtime::Handle; use tracing::{Instrument, Level}; +use webpki::types::{InvalidDnsNameError, ServerName}; +const DEFAULT_DNS_LOCAL: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); +const DEFAULT_DNS_SERVER: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 53); const DEFAULT_LOG_LEVEL: Level = Level::INFO; const DEFAULT_HOST: &str = "localhost:11211"; const DEFAULT_TIMEOUT_SECS: u64 = 30; @@ -26,10 +31,18 @@ const DEFAULT_CONNECTIONS_PER_HOST: u64 = 4; #[command(name = "mc", version = clap::crate_version!())] struct McConfig { /// Logging verbosity. Allowed values are 'trace', 'debug', 'info', 'warn', and 'error' - /// (case insensitive). + /// (case-insensitive). #[arg(long, default_value_t = DEFAULT_LOG_LEVEL)] log_level: Level, + /// Local address for DNS requests for service discovery in the form 'address:port' + #[arg(long, default_value_t = DEFAULT_DNS_LOCAL)] + dns_local: SocketAddr, + + /// DNS server for service discovery in the form 'address:port' + #[arg(long, default_value_t = DEFAULT_DNS_SERVER)] + dns_server: SocketAddr, + /// Memcached host to connect to in the form 'hostname:port'. #[arg(long, default_value_t = DEFAULT_HOST.to_owned(), value_hint = ValueHint::Hostname)] host: String, @@ -58,8 +71,8 @@ struct McConfig { /// Optional server name to use for validating the server certificate. If not set, the /// hostname of the server is used for checking that the certificate matches the server. - #[arg(long)] - tls_server_name: Option, + #[arg(long, value_parser = parse_server_name)] + tls_server_name: Option>, /// Optional client certificate to use to authenticate with the Memcached server. Note that /// this may or may not be required based on how the Memcached server is configured. @@ -75,6 +88,10 @@ struct McConfig { mode: Action, } +fn parse_server_name(s: &str) -> Result, InvalidDnsNameError> { + ServerName::try_from(s).map(|n| n.to_owned()) +} + #[derive(Debug, Subcommand)] enum Action { Add(AddCommand), @@ -280,7 +297,7 @@ async fn main() -> ExitCode { tracing::subscriber::set_global_default(console_subscriber).expect("failed to initialize console logging"); let timeout = Duration::from_secs(opts.timeout_secs); - let resolver = DiscoveryDefault; + let resolver = DiscoveryDefault::new(DnsClient::new(opts.dns_local, opts.dns_server)); let servers = match resolver .resolve_by_proto(&opts.host) .timeout(timeout, "resolver.resolve_by_proto") diff --git a/mtop/src/bin/mtop.rs b/mtop/src/bin/mtop.rs index fbad4bf..2339f9a 100644 --- a/mtop/src/bin/mtop.rs +++ b/mtop/src/bin/mtop.rs @@ -1,11 +1,13 @@ use clap::{Parser, ValueHint}; use mtop::queue::{BlockingStatsQueue, Host, StatsQueue}; use mtop::ui::{Theme, TAILWIND}; +use mtop_client::dns::DnsClient; use mtop_client::{ DiscoveryDefault, MemcachedClient, MemcachedPool, MtopError, PoolConfig, SelectorRendezvous, Server, TLSConfig, Timeout, }; use std::env; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::PathBuf; use std::process::ExitCode; use std::sync::Arc; @@ -14,9 +16,12 @@ use tokio::runtime::Handle; use tokio::task; use tracing::instrument::WithSubscriber; use tracing::{Instrument, Level}; +use webpki::types::{InvalidDnsNameError, ServerName}; -const DEFAULT_THEME: Theme = TAILWIND; +const DEFAULT_DNS_LOCAL: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); +const DEFAULT_DNS_SERVER: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 53); const DEFAULT_LOG_LEVEL: Level = Level::INFO; +const DEFAULT_THEME: Theme = TAILWIND; // Update interval of more than a second to minimize the chance that stats returned by the // memcached server have the exact same "time" value (which has one-second granularity). const DEFAULT_STATS_INTERVAL: Duration = Duration::from_millis(1073); @@ -28,10 +33,18 @@ const NUM_MEASUREMENTS: usize = 10; #[command(name = "mtop", version = clap::crate_version!())] struct MtopConfig { /// Logging verbosity. Allowed values are 'trace', 'debug', 'info', 'warn', and 'error' - /// (case insensitive). + /// (case-insensitive). #[arg(long, default_value_t = DEFAULT_LOG_LEVEL)] log_level: Level, + /// Local address for DNS requests for service discovery in the form 'address:port' + #[arg(long, default_value_t = DEFAULT_DNS_LOCAL)] + dns_local: SocketAddr, + + /// DNS server for service discovery in the form 'address:port' + #[arg(long, default_value_t = DEFAULT_DNS_SERVER)] + dns_server: SocketAddr, + /// Timeout for connecting to Memcached and fetching statistics, in seconds. #[arg(long, default_value_t = DEFAULT_TIMEOUT_SECS)] timeout_secs: u64, @@ -56,8 +69,8 @@ struct MtopConfig { /// Optional server name to use for validating the server certificate. If not set, the /// hostname of the server is used for checking that the certificate matches the server. - #[arg(long)] - tls_server_name: Option, + #[arg(long, value_parser = parse_server_name)] + tls_server_name: Option>, /// Optional client certificate to use to authenticate with the Memcached server. Note that /// this may or may not be required based on how the Memcached server is configured. @@ -70,13 +83,21 @@ struct MtopConfig { tls_key: Option, /// Memcached hosts to connect to in the form 'hostname:port'. Must be specified at least - /// once and may be used multiple times (separated by spaces). Hostnames may be prefixed by - /// the string 'dns+'. When prefixed, the hostname will be resolved to A or AAAA records and - /// each IP address will be connected to at the provided port. + /// once and may be used multiple times (separated by spaces). + /// + /// Hostnames may be prefixed by the strings 'dns+' or 'dnssrv+'. When prefixed with 'dns+', + /// the hostname will be resolved to A or AAAA records and each IP address will be connected + /// to at the provided port. When prefixed with 'dnssrv+', the hostname will be resolved to + /// SRV records and the target of each record will be connected to at the provided port. Note + /// that the port from the SRV record is ignored. #[arg(required = true, value_hint = ValueHint::Hostname)] hosts: Vec, } +fn parse_server_name(s: &str) -> Result, InvalidDnsNameError> { + ServerName::try_from(s).map(|n| n.to_owned()) +} + fn default_log_file() -> PathBuf { env::temp_dir().join("mtop").join("mtop.log") } @@ -101,7 +122,7 @@ async fn main() -> ExitCode { let timeout = Duration::from_secs(opts.timeout_secs); let measurements = Arc::new(StatsQueue::new(NUM_MEASUREMENTS)); - let resolver = DiscoveryDefault; + let resolver = DiscoveryDefault::new(DnsClient::new(opts.dns_local, opts.dns_server)); let servers = match expand_hosts(&opts.hosts, &resolver, timeout).await { Ok(v) => v, @@ -111,6 +132,11 @@ async fn main() -> ExitCode { } }; + if servers.is_empty() { + tracing::error!(message = "resolving host names did not return any results", hosts = ?opts.hosts); + return ExitCode::FAILURE; + } + let client = match new_client(&opts, &servers).await { Ok(v) => v, Err(e) => { diff --git a/mtop/src/check.rs b/mtop/src/check.rs index 4787579..26d39e3 100644 --- a/mtop/src/check.rs +++ b/mtop/src/check.rs @@ -1,4 +1,4 @@ -use mtop_client::{DiscoveryDefault, Key, MemcachedClient}; +use mtop_client::{DiscoveryDefault, Key, MemcachedClient, Timeout}; use std::cmp; use std::time::{Duration, Instant}; use tokio::time; @@ -58,43 +58,39 @@ impl<'a> Checker<'a> { let val = VALUE.to_vec(); let dns_start = Instant::now(); - let server = match time::timeout(self.timeout, self.resolver.resolve_by_proto(host)) + let server = match self + .resolver + .resolve_by_proto(host) + .timeout(self.timeout, "resolver.resolve_by_proto") .await - .map(|r| r.map(|mut v| v.pop())) + .map(|mut v| v.pop()) { - Ok(Ok(Some(s))) => s, - Ok(Ok(None)) => { + Ok(Some(s)) => s, + Ok(None) => { tracing::warn!(message = "no addresses for host", host = host); failures.total += 1; failures.dns += 1; continue; } - Ok(Err(e)) => { + Err(e) => { tracing::warn!(message = "failed to resolve host", host = host, err = %e); failures.total += 1; failures.dns += 1; continue; } - Err(_) => { - tracing::warn!(message = "timeout resolving host", host = host); - failures.total += 1; - failures.dns += 1; - continue; - } }; let dns_time = dns_start.elapsed(); let conn_start = Instant::now(); - let mut conn = match time::timeout(self.timeout, self.client.raw_open(&server)).await { - Ok(Ok(v)) => v, - Ok(Err(e)) => { - tracing::warn!(message = "failed to connect to host", host = host, addr = %server.addr(), err = %e); - failures.total += 1; - failures.connections += 1; - continue; - } - Err(_) => { - tracing::warn!(message = "timeout connecting to host", host = host, addr = %server.addr()); + let mut conn = match self + .client + .raw_open(&server) + .timeout(self.timeout, "client.raw_open") + .await + { + Ok(v) => v, + Err(e) => { + tracing::warn!(message = "failed to connect to host", host = host, addr = %server.address(), err = %e); failures.total += 1; failures.connections += 1; continue; @@ -103,16 +99,14 @@ impl<'a> Checker<'a> { let conn_time = conn_start.elapsed(); let set_start = Instant::now(); - match time::timeout(self.timeout, conn.set(&key, 0, 60, &val)).await { - Ok(Ok(_)) => {} - Ok(Err(e)) => { - tracing::warn!(message = "failed to set key", host = host, addr = %server.addr(), err = %e); - failures.total += 1; - failures.sets += 1; - continue; - } - Err(_) => { - tracing::warn!(message = "timeout setting key", host = host, addr = %server.addr()); + match conn + .set(&key, 0, 60, &val) + .timeout(self.timeout, "connection.set") + .await + { + Ok(_) => {} + Err(e) => { + tracing::warn!(message = "failed to set key", host = host, addr = %server.address(), err = %e); failures.total += 1; failures.sets += 1; continue; @@ -121,16 +115,10 @@ impl<'a> Checker<'a> { let set_time = set_start.elapsed(); let get_start = Instant::now(); - match time::timeout(self.timeout, conn.get(&[key])).await { - Ok(Ok(_)) => {} - Ok(Err(e)) => { - tracing::warn!(message = "failed to get key", host = host, addr = %server.addr(), err = %e); - failures.total += 1; - failures.gets += 1; - continue; - } - Err(_) => { - tracing::warn!(message = "timeout getting key", host = host, addr = %server.addr()); + match conn.get(&[key]).timeout(self.timeout, "connection.get").await { + Ok(_) => {} + Err(e) => { + tracing::warn!(message = "failed to get key", host = host, addr = %server.address(), err = %e); failures.total += 1; failures.gets += 1; continue;