Skip to content

Commit

Permalink
Merge pull request #133 from 56quarters/srv
Browse files Browse the repository at this point in the history
Allow DNS SRV records to be use to discover cache servers
  • Loading branch information
56quarters authored Apr 6, 2024
2 parents 6ba89cc + 4fd1377 commit 22ba6b7
Show file tree
Hide file tree
Showing 13 changed files with 473 additions and 269 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion mtop-client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::core::{Key, Meta, MtopError, SlabItems, Slabs, Stats, Value};
use crate::pool::{MemcachedPool, PooledMemcached, Server, ServerID};
use crate::discovery::{Server, ServerID};
use crate::pool::{MemcachedPool, PooledMemcached};
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::Hasher;
Expand Down
244 changes: 244 additions & 0 deletions mtop-client/src/discovery.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
use crate::core::MtopError;
use crate::dns::{DnsClient, RecordData};
use std::cmp::Ordering;
use std::fmt;
use std::net::{IpAddr, SocketAddr};
use webpki::types::ServerName;

const DNS_A_PREFIX: &str = "dns+";
const DNS_SRV_PREFIX: &str = "dnssrv+";

/// Unique ID for a server in a Memcached cluster.
#[derive(Debug, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
#[repr(transparent)]
pub struct ServerID(String);

impl ServerID {
fn from_host_port<S>(host: S, port: u16) -> Self
where
S: AsRef<str>,
{
let host = host.as_ref();
if let Ok(ip) = host.parse::<IpAddr>() {
Self(SocketAddr::from((ip, port)).to_string())
} else {
Self(format!("{}:{}", host, port))
}
}
}

impl From<(&str, u16)> for ServerID {
fn from(value: (&str, u16)) -> Self {
Self::from_host_port(value.0, value.1)
}
}

impl From<(String, u16)> for ServerID {
fn from(value: (String, u16)) -> Self {
Self::from_host_port(value.0, value.1)
}
}

impl From<SocketAddr> for ServerID {
fn from(value: SocketAddr) -> Self {
Self(value.to_string())
}
}

impl fmt::Display for ServerID {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}

impl AsRef<str> for ServerID {
fn as_ref(&self) -> &str {
&self.0
}
}

/// An individual server that is part of a Memcached cluster.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct Server {
repr: ServerRepr,
}

#[derive(Debug, Clone, Eq, PartialEq, Hash)]
enum ServerRepr {
Resolved(ServerID, ServerName<'static>, SocketAddr),
Unresolved(ServerID, ServerName<'static>),
}

impl Server {
pub fn from_id(id: ServerID, name: ServerName<'static>) -> Self {
Self {
repr: ServerRepr::Unresolved(id, name),
}
}

pub fn from_addr(addr: SocketAddr, name: ServerName<'static>) -> Self {
Self {
repr: ServerRepr::Resolved(ServerID::from(addr), name, addr),
}
}

pub fn id(&self) -> ServerID {
match &self.repr {
ServerRepr::Resolved(id, _, _) => id.clone(),
ServerRepr::Unresolved(id, _) => id.clone(),
}
}

pub fn server_name(&self) -> ServerName<'static> {
match &self.repr {
ServerRepr::Resolved(_, name, _) => name.clone(),
ServerRepr::Unresolved(_, name) => name.clone(),
}
}

pub fn address(&self) -> String {
match &self.repr {
ServerRepr::Resolved(_, _, addr) => addr.to_string(),
ServerRepr::Unresolved(id, _) => id.to_string(),
}
}
}

impl PartialOrd for Server {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl Ord for Server {
fn cmp(&self, other: &Self) -> Ordering {
self.id().cmp(&other.id())
}
}

impl fmt::Display for Server {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.id())
}
}

#[derive(Debug, Clone)]
pub struct DiscoveryDefault {
client: DnsClient,
}

impl DiscoveryDefault {
pub fn new(client: DnsClient) -> Self {
Self { client }
}

pub async fn resolve_by_proto(&self, name: &str) -> Result<Vec<Server>, MtopError> {
if name.starts_with(DNS_A_PREFIX) {
Ok(self.resolve_a(name.trim_start_matches(DNS_A_PREFIX)).await?)
} else if name.starts_with(DNS_SRV_PREFIX) {
Ok(self.resolve_srv(name.trim_start_matches(DNS_SRV_PREFIX)).await?)
} else {
Ok(self.resolve_a(name).await?.pop().into_iter().collect())
}
}

async fn resolve_srv(&self, name: &str) -> Result<Vec<Server>, MtopError> {
let server_name = Self::server_name(name)?;
let (host_name, port) = Self::host_and_port(name)?;
let mut out = Vec::new();

let res = self.client.resolve_srv(host_name).await?;
for a in res.answers() {
let target = if let RecordData::SRV(srv) = a.rdata() {
srv.target().to_string()
} else {
tracing::warn!(message = "unexpected record data for answer", name = host_name, answer = ?a);
continue;
};
let server_id = ServerID::from((target, port));
let server = Server::from_id(server_id, server_name.clone());
out.push(server);
}

Ok(out)
}

async fn resolve_a(&self, name: &str) -> Result<Vec<Server>, MtopError> {
let server_name = Self::server_name(name)?;

let mut out = Vec::new();
for addr in tokio::net::lookup_host(name).await? {
out.push(Server::from_addr(addr, server_name.clone()));
}

Ok(out)
}

fn host_and_port(name: &str) -> Result<(&str, u16), MtopError> {
name.rsplit_once(':')
.ok_or_else(|| {
MtopError::configuration(format!(
"invalid server name '{}', must be of the form 'host:port'",
name
))
})
// IPv6 addresses use brackets around them to disambiguate them from a port number.
// Since we're parsing the host and port, strip the brackets because they aren't
// needed or valid to include in a TLS ServerName.
.map(|(hostname, port)| (hostname.trim_start_matches('[').trim_end_matches(']'), port))
.and_then(|(hostname, port)| {
port.parse().map(|p| (hostname, p)).map_err(|e| {
MtopError::configuration_cause(format!("unable to parse port number from '{}'", name), e)
})
})
}

fn server_name(name: &str) -> Result<ServerName<'static>, MtopError> {
Self::host_and_port(name).and_then(|(host, _)| {
ServerName::try_from(host)
.map(|s| s.to_owned())
.map_err(|e| MtopError::configuration_cause(format!("invalid server name '{}'", host), e))
})
}
}

#[cfg(test)]
mod test {
use super::ServerID;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};

#[test]
fn test_server_id_from_ipv4_addr() {
let addr = SocketAddr::from((Ipv4Addr::new(127, 1, 1, 1), 11211));
let id = ServerID::from(addr);
assert_eq!("127.1.1.1:11211", id.to_string());
}

#[test]
fn test_server_id_from_ipv6_addr() {
let addr = SocketAddr::from((Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 11211));
let id = ServerID::from(addr);
assert_eq!("[::1]:11211", id.to_string());
}

#[test]
fn test_server_id_from_ipv4_pair() {
let pair = ("10.1.1.22", 11212);
let id = ServerID::from(pair);
assert_eq!("10.1.1.22:11212", id.to_string());
}

#[test]
fn test_server_id_from_ipv6_pair() {
let pair = ("::1", 11212);
let id = ServerID::from(pair);
assert_eq!("[::1]:11212", id.to_string());
}

#[test]
fn test_server_id_from_host_pair() {
let pair = ("cache.example.com", 11211);
let id = ServerID::from(pair);
assert_eq!("cache.example.com:11211", id.to_string());
}
}
62 changes: 62 additions & 0 deletions mtop-client/src/dns/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use crate::core::MtopError;
use crate::dns::core::RecordType;
use crate::dns::message::{Flags, Message, MessageId, Question};
use crate::dns::name::Name;
use std::io::Cursor;
use std::net::SocketAddr;
use std::str::FromStr;
use tokio::net::UdpSocket;

const DEFAULT_RECV_BUF: usize = 512;

#[derive(Debug, Clone)]
pub struct DnsClient {
local: SocketAddr,
server: SocketAddr,
}

impl DnsClient {
pub fn new(local: SocketAddr, server: SocketAddr) -> Self {
Self { local, server }
}

pub async fn exchange(&self, msg: &Message) -> Result<Message, MtopError> {
let id = msg.id();
let sock = self.connect_udp().await?;
self.send_udp(&sock, msg).await?;
self.recv_udp(&sock, id).await
}

pub async fn resolve_srv(&self, name: &str) -> Result<Message, MtopError> {
let n = Name::from_str(name)?;
let id = MessageId::random();
let flags = Flags::default().set_recursion_desired();
let msg = Message::new(id, flags).add_question(Question::new(n, RecordType::SRV));

self.exchange(&msg).await
}

async fn connect_udp(&self) -> Result<UdpSocket, MtopError> {
let sock = UdpSocket::bind(&self.local).await?;
sock.connect(&self.server).await?;
Ok(sock)
}

async fn send_udp(&self, socket: &UdpSocket, msg: &Message) -> Result<(), MtopError> {
let mut buf = Vec::new();
msg.write_network_bytes(&mut buf)?;
Ok(socket.send(&buf).await.map(|_| ())?)
}

async fn recv_udp(&self, socket: &UdpSocket, id: MessageId) -> Result<Message, MtopError> {
let mut buf = vec![0_u8; DEFAULT_RECV_BUF];
loop {
let n = socket.recv(&mut buf).await?;
let cur = Cursor::new(&buf[0..n]);
let msg = Message::read_network_bytes(cur)?;
if msg.id() == id {
return Ok(msg);
}
}
}
}
Loading

0 comments on commit 22ba6b7

Please sign in to comment.