From d74dd9cea1ee94a46f6b818c461732da8e58cc24 Mon Sep 17 00:00:00 2001 From: Nick Pillitteri Date: Wed, 27 Dec 2023 12:52:56 -0500 Subject: [PATCH] Introduce high-level MemcachedClient Create a high-level client that interacts with multiple Memcached servers depending on the key or operation being performed. The client uses rendezvous hashing to determine which server is responsible for any relevant keys. In addition the following improvements have been made: * `Key` type introduced to validate keys before using them * `mc` has been rewritten to clean up resources properly The following things are missing and will done in follow up work: * Tests for newly introduced client and hashing * Type-safe alternative to string hostnames --- CHANGELOG.md | 5 +- Cargo.lock | 44 +-- mtop-client/src/client.rs | 562 +++++++++++++++++++++++++++++++++++++ mtop-client/src/core.rs | 454 ++++++++++++++---------------- mtop-client/src/lib.rs | 8 +- mtop-client/src/pool.rs | 62 ++-- mtop-client/src/timeout.rs | 2 +- mtop/Cargo.toml | 2 +- mtop/src/bin/mc.rs | 328 ++++++++++++++-------- mtop/src/bin/mtop.rs | 225 ++++++++------- mtop/src/check.rs | 16 +- mtop/src/queue.rs | 24 +- mtop/src/tracing.rs | 2 +- 13 files changed, 1215 insertions(+), 519 deletions(-) create mode 100644 mtop-client/src/client.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d69aa8..79653da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,11 @@ # Changelog -## v0.7.1 - unreleased +## v0.8.0 - unreleased - Add default 5 second timeout to network operations done by `mtop`. #90 -- Add `add` and `replace` commands to `mc. #95 +- Add `incr`, `decr`, `add`, and `replace` commands to `mc`. #95 #98 - TLS related dependency updates. #93 +- Create high-level client for operating on multiple servers. #100 ## v0.7.0 - 2023-11-28 diff --git a/Cargo.lock b/Cargo.lock index 36c1fd8..30640db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91429305e9f0a25f6205c5b8e0d2db09e0708a7a6df0f42212bb56c32c8ac97a" +checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" dependencies = [ "cfg-if", "once_cell", @@ -109,9 +109,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.4.11" +version = "4.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfaff671f6b22ca62406885ece523383b9b64022e341e53e009a62ebc47a45f2" +checksum = "52bdc885e4cacc7f7c9eedc1ef6da641603180c783c41a15c264944deeaab642" dependencies = [ "clap_builder", "clap_derive", @@ -119,9 +119,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.11" +version = "4.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a216b506622bb1d316cd51328dce24e07bdff4a6128a47c7e7fad11878d5adbb" +checksum = "fb7fb5e4e979aec3be7791562fcba452f94ad85e954da024396433e0e25a79e9" dependencies = [ "anstyle", "clap_lex", @@ -249,9 +249,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.151" +version = "0.2.152" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" +checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" [[package]] name = "linux-raw-sys" @@ -286,9 +286,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.6.4" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" [[package]] name = "miniz_oxide" @@ -360,9 +360,9 @@ dependencies = [ [[package]] name = "object" -version = "0.32.1" +version = "0.32.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" dependencies = [ "memchr", ] @@ -416,18 +416,18 @@ checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "proc-macro2" -version = "1.0.71" +version = "1.0.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8" +checksum = "95fc56cda0b5c3325f5fbbd7ff9fda9e02bb00bb3dac51252d2f1bfa1cb8cc8c" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.33" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -518,15 +518,15 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.0.1" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7673e0aa20ee4937c6aacfc12bb8341cfbf054cdd21df6bec5fd0629fe9339b" +checksum = "9e9d979b3ce68192e42760c7810125eb6cf2ea10efae545a156063e61f314e2a" [[package]] name = "rustls-webpki" -version = "0.102.0" +version = "0.102.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de2635c8bc2b88d367767c5de8ea1d8db9af3f6219eba28442242d9ab81d1b89" +checksum = "ef4ca26037c909dedb327b48c3327d0ba91d3dd3c4e05dad328f210ffb68e95b" dependencies = [ "ring", "rustls-pki-types", @@ -636,9 +636,9 @@ checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" [[package]] name = "syn" -version = "2.0.42" +version = "2.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b7d0a2c048d661a1a59fcd7355baa232f7ed34e0ee4df2eef3c1c1c0d3852d8" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" dependencies = [ "proc-macro2", "quote", diff --git a/mtop-client/src/client.rs b/mtop-client/src/client.rs new file mode 100644 index 0000000..17c7d44 --- /dev/null +++ b/mtop-client/src/client.rs @@ -0,0 +1,562 @@ +use crate::core::{Key, Meta, MtopError, SlabItems, Slabs, Stats, Value}; +use crate::pool::{MemcachedPool, PooledMemcached, Server}; +use std::collections::hash_map::DefaultHasher; +use std::collections::HashMap; +use std::hash::Hasher; +use std::sync::Arc; +use tokio::runtime::Handle; +use tokio::sync::RwLock; + +// Further reading on rendezvous hashing: +// +// - https://randorithms.com/2020/12/26/rendezvous-hashing.html +// - https://www.snia.org/sites/default/files/SDC15_presentations/dist_sys/Jason_Resch_New_Consistent_Hashings_Rev.pdf +// - https://dgryski.medium.com/consistent-hashing-algorithmic-tradeoffs-ef6b8e2fcae8 +// - https://medium.com/i0exception/rendezvous-hashing-8c00e2fb58b0 +// - https://stackoverflow.com/questions/69841546/consistent-hashing-why-are-vnodes-a-thing +// - https://medium.com/@panchr/dynamic-replication-in-memcached-8939c6f81e7f + +/// Logic for picking a server to "own" a particular cache key that uses +/// rendezvous hashing. +/// +/// See https://en.wikipedia.org/wiki/Rendezvous_hashing +#[derive(Debug)] +pub struct SelectorRendezvous { + servers: RwLock>, +} + +impl SelectorRendezvous { + /// Create a new instance with the provided initial server list + pub fn new(servers: Vec) -> Self { + Self { + servers: RwLock::new(servers), + } + } + + fn score(server: &Server, key: &Key) -> f64 { + let mut hasher = DefaultHasher::new(); + hasher.write(server.name.as_bytes()); + hasher.write(key.as_ref().as_bytes()); + let h = hasher.finish(); + + // Add one to the hash value since we can't take the log of 0 and clamp to 1.0 + // since we negate the log result, which is negative for numbers less than 1.0 + let unit = ((h + 1) as f64 / u64::MAX as f64).min(1.0); + + // Log of 1.0 is 0. Dividing by floating point 0 is infinity in IEEE 754 so we + // just let that happen instead of special casing a unit value of exactly 1.0. + 100_f64 / -unit.ln() + } + + /// Get a copy of all current servers. + pub async fn servers(&self) -> Vec { + let servers = self.servers.read().await; + servers.clone() + } + + /// Get the `Server` that owns the given key, or none if there are no servers. + pub async fn server(&self, key: &Key) -> Option { + let servers = self.servers.read().await; + if servers.is_empty() { + None + } else if servers.len() == 1 { + servers.first().cloned() + } else { + servers + .iter() + .max_by(|x, y| Self::score(x, key).total_cmp(&Self::score(y, key))) + .cloned() + } + } + + /// Update the list of potential servers to pick from. + pub async fn set_servers(&self, servers: Vec) { + let mut current = self.servers.write().await; + *current = servers + } +} + +/// Response for both values and errors from multiple servers, indexed by server. +#[derive(Debug, Default)] +pub struct ServersResponse { + pub values: HashMap, + pub errors: HashMap, +} + +impl ServersResponse { + /// Return true if there are any errors, false otherwise. + pub fn has_errors(&self) -> bool { + !self.errors.is_empty() + } +} + +/// Response for values indexed by key and errors indexed by server. +#[derive(Debug, Default)] +pub struct ValuesResponse { + pub values: HashMap, + pub errors: HashMap, +} + +impl ValuesResponse { + /// Return true if there are any errors, false otherwise. + pub fn has_errors(&self) -> bool { + !self.errors.is_empty() + } +} + +#[derive(Debug)] +pub struct MemcachedClient { + handle: Handle, + selector: SelectorRendezvous, + pool: Arc, +} + +/// Run a method for a particular server in a spawned future. +macro_rules! spawn_for_host { + ($self:ident, $method:ident, $host:expr $(, $args:expr)* $(,)?) => {{ + let pool = $self.pool.clone(); + $self.handle.spawn(async move { + match pool.get($host).await { + Ok(mut conn) => { + let res = conn.$method($($args,)*).await; + pool.put(conn).await; + res + } + Err(e) => Err(e), + } + }) + }}; +} + +/// Run a method on a connection to a particular server based on the hash of a single key. +macro_rules! operation_for_key { + ($self:ident, $method:ident, $key:expr $(, $args:expr)* $(,)?) => {{ + let key = Key::one($key)?; + if let Some(s) = $self.selector.server(&key).await { + let mut conn = $self.pool.get(&s.name).await?; + let res = conn.$method(&key, $($args,)*).await; + $self.pool.put(conn).await; + res + } else { + Err(MtopError::runtime("no servers available")) + } + }}; +} + +/// Run a method on a connection to every server and bucket the results and errors by server. +macro_rules! operation_for_all { + ($self:ident, $method:ident) => {{ + let servers = $self.selector.servers().await; + let tasks = servers + .into_iter() + .map(|s| { + let host = s.name.clone(); + (s, spawn_for_host!($self, $method, &host)) + }) + .collect::>(); + + let mut values = HashMap::with_capacity(tasks.len()); + let mut errors = HashMap::new(); + + for (server, task) in tasks { + match task.await { + Ok(Ok(results)) => { + values.insert(server, results); + } + Ok(Err(e)) => { + errors.insert(server, e); + } + Err(e) => { + errors.insert(server, MtopError::runtime_cause("fetching cluster values", e)); + } + }; + } + + Ok(ServersResponse { values, errors }) + }}; +} + +impl MemcachedClient { + /// Create a new `MemcachedClient` instance. + /// + /// `handle` is used to spawn multiple async tasks to fetch data from servers in + /// parallel. `selector` is used to determine which server "owns" a particular key. + /// `pool` is used for pooling or establishing new connections to each server as + /// needed. + pub fn new(handle: Handle, selector: SelectorRendezvous, pool: MemcachedPool) -> Self { + Self { + handle, + selector, + pool: Arc::new(pool), + } + } + + /// Get a connection to a particular server from the pool if available, otherwise + /// establish a new connection. + pub async fn raw_open(&self, host: &str) -> Result { + self.pool.get(host).await + } + + /// Return a connection to a particular server to the pool if fewer than the configured + /// number of idle connections to that server are currently in the pool, otherwise close + /// it immediately. + pub async fn raw_close(&self, conn: PooledMemcached) { + self.pool.put(conn).await; + } + + /// Get a `Stats` object with the current values of the interesting stats for each server. + /// + /// A future is spawned for each server with results and any errors indexed by server. A + /// pooled connection to each server is used if available, otherwise new connections are + /// established. + pub async fn stats(&self) -> Result, MtopError> { + operation_for_all!(self, stats) + } + + /// Get a `Slabs` object with information about each set of `Slab`s maintained by each server. + /// You can think of each `Slab` as a class of objects that are stored together in memory. Note + /// that `Slab` IDs may not be contiguous based on the size of items actually stored by the server. + /// + /// A future is spawned for each server with results and any errors indexed by server. A + /// pooled connection to each server is used if available, otherwise new connections are + /// established. + pub async fn slabs(&self) -> Result, MtopError> { + operation_for_all!(self, slabs) + } + + /// Get a `SlabsItems` object with information about the `SlabItem` items stored in + /// each slab class maintained by each server. The ID of each `SlabItem` corresponds to a + /// `Slab` maintained by the server. Note that `SlabItem` IDs may not be contiguous based + /// on the size of items actually stored by the server. + /// + /// A future is spawned for each server with results and any errors indexed by server. A + /// pooled connection to each server is used if available, otherwise new connections are + /// established. + pub async fn items(&self) -> Result, MtopError> { + operation_for_all!(self, items) + } + + /// Get a `Meta` object for every item in the cache for each server which includes its key + /// and expiration time as a UNIX timestamp. Expiration time will be `-1` if the item was + /// set with an infinite TTL. + /// + /// A future is spawned for each server with results and any errors indexed by server. A + /// pooled connection to each server is used if available, otherwise new connections are + /// established. + pub async fn metas(&self) -> Result>, MtopError> { + operation_for_all!(self, metas) + } + + /// Send a simple command to verify our connection each known server. + /// + /// A future is spawned for each server with results and any errors indexed by server. A + /// pooled connection to each server is used if available, otherwise new connections are + /// established. + pub async fn ping(&self) -> Result, MtopError> { + operation_for_all!(self, ping) + } + + /// Get a map of the requested keys and their corresponding `Value` in the cache + /// including the key, flags, and data. + /// + /// This method uses a selector implementation to determine which server "owns" each of the + /// provided keys. A future is spawned for each server and the results merged together. A + /// pooled connection to each server is used if available, otherwise new connections are + /// established. + pub async fn get(&self, keys: I) -> Result + where + I: IntoIterator, + K: Into, + { + let keys = Key::many(keys)?; + if keys.is_empty() { + return Ok(ValuesResponse::default()); + } + + let num_keys = keys.len(); + let mut by_server: HashMap> = HashMap::new(); + for key in keys { + if let Some(s) = self.selector.server(&key).await { + let entry = by_server.entry(s).or_default(); + entry.push(key); + } + } + + let tasks = by_server + .into_iter() + .map(|(server, keys)| { + let host = server.name.clone(); + (server, spawn_for_host!(self, get, &host, &keys)) + }) + .collect::>(); + + let mut values = HashMap::with_capacity(num_keys); + let mut errors = HashMap::new(); + + for (server, task) in tasks { + match task.await { + Ok(Ok(results)) => { + values.extend(results); + } + Ok(Err(e)) => { + errors.insert(server, e); + } + Err(e) => { + errors.insert(server, MtopError::runtime_cause("fetching keys", e)); + } + }; + } + + Ok(ValuesResponse { values, errors }) + } + + /// Increment the value of a key by the given delta if the value is numeric returning + /// the new value. Returns an error if the value is not set or _not_ numeric. + /// + /// This method uses a selector implementation to determine which server "owns" the provided + /// key. A pooled connection to the server is used if available, otherwise a new connection + /// is established. + pub async fn incr(&self, key: K, delta: u64) -> Result + where + K: Into, + { + operation_for_key!(self, incr, key, delta) + } + + /// Decrement the value of a key by the given delta if the value is numeric returning + /// the new value with a minimum of 0. Returns an error if the value is not set or _not_ + /// numeric. + /// + /// This method uses a selector implementation to determine which server "owns" the provided + /// key. A pooled connection to the server is used if available, otherwise a new connection + /// is established. + pub async fn decr(&self, key: K, delta: u64) -> Result + where + K: Into, + { + operation_for_key!(self, decr, key, delta) + } + + /// Store the provided item in the cache, regardless of whether it already exists. + /// + /// This method uses a selector implementation to determine which server "owns" the provided + /// key. A pooled connection to the server is used if available, otherwise a new connection + /// is established. + pub async fn set(&self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError> + where + K: Into, + V: AsRef<[u8]>, + { + operation_for_key!(self, set, key, flags, ttl, data) + } + + /// Store the provided item in the cache only if it does not already exist. + /// + /// This method uses a selector implementation to determine which server "owns" the provided + /// key. A pooled connection to the server is used if available, otherwise a new connection + /// is established. + pub async fn add(&self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError> + where + K: Into, + V: AsRef<[u8]>, + { + operation_for_key!(self, add, key, flags, ttl, data) + } + + /// Store the provided item in the cache only if it already exists. + /// + /// This method uses a selector implementation to determine which server "owns" the provided + /// key. A pooled connection to the server is used if available, otherwise a new connection + /// is established. + pub async fn replace(&self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError> + where + K: Into, + V: AsRef<[u8]>, + { + operation_for_key!(self, replace, key, flags, ttl, data) + } + + /// Update the TTL of an item in the cache if it exists, return an error otherwise. + /// + /// This method uses a selector implementation to determine which server "owns" the provided + /// key. A pooled connection to the server is used if available, otherwise a new connection + /// is established. + pub async fn touch(&self, key: K, ttl: u32) -> Result<(), MtopError> + where + K: Into, + { + operation_for_key!(self, touch, key, ttl) + } + + /// Delete an item from the cache if it exists, return an error otherwise. + /// + /// This method uses a selector implementation to determine which server "owns" the provided + /// key. A pooled connection to the server is used if available, otherwise a new connection + /// is established. + pub async fn delete(&self, key: K) -> Result<(), MtopError> + where + K: Into, + { + operation_for_key!(self, delete, key) + } +} + +#[cfg(test)] +mod test { + + // TODO: Actually figure out how to test this without a bunch of boilerplate. + + /////////// + // stats // + /////////// + + #[tokio::test] + async fn test_memcached_client_stats_no_servers() {} + + #[tokio::test] + async fn test_memcached_client_stats_no_errors() {} + + #[tokio::test] + async fn test_memcached_client_stats_some_errors() {} + + /////////// + // slabs // + /////////// + + #[tokio::test] + async fn test_memcached_client_slabs_no_servers() {} + + #[tokio::test] + async fn test_memcached_client_slabs_no_errors() {} + + #[tokio::test] + async fn test_memcached_client_slabs_some_errors() {} + + /////////// + // items // + /////////// + + #[tokio::test] + async fn test_memcached_client_items_no_servers() {} + + #[tokio::test] + async fn test_memcached_client_items_no_errors() {} + + #[tokio::test] + async fn test_memcached_client_items_some_errors() {} + + /////////// + // metas // + /////////// + + #[tokio::test] + async fn test_memcached_client_metas_no_servers() {} + + #[tokio::test] + async fn test_memcached_client_metas_no_errors() {} + + #[tokio::test] + async fn test_memcached_client_metas_some_errors() {} + + ////////// + // ping // + ////////// + + #[tokio::test] + async fn test_memcached_client_ping_no_servers() {} + + #[tokio::test] + async fn test_memcached_client_ping_no_errors() {} + + #[tokio::test] + async fn test_memcached_client_ping_some_errors() {} + + ///////// + // get // + ///////// + + #[tokio::test] + async fn test_memcached_client_get_invalid_keys() {} + + #[tokio::test] + async fn test_memcached_client_get_no_keys() {} + + #[tokio::test] + async fn test_memcached_client_get_no_servers() {} + + #[tokio::test] + async fn test_memcached_client_get_no_errors() {} + + #[tokio::test] + async fn test_memcached_client_get_some_errors() {} + + ////////// + // incr // + ////////// + + #[tokio::test] + async fn test_memcached_client_incr_no_servers() {} + + #[tokio::test] + async fn test_memcached_client_incr_success() {} + + ////////// + // decr // + ////////// + + #[tokio::test] + async fn test_memcached_client_decr_no_servers() {} + + #[tokio::test] + async fn test_memcached_client_decr_success() {} + + ///////// + // set // + ///////// + + #[tokio::test] + async fn test_memcached_client_set_no_servers() {} + + #[tokio::test] + async fn test_memcached_client_set_success() {} + + ///////// + // add // + ///////// + + #[tokio::test] + async fn test_memcached_client_add_no_servers() {} + + #[tokio::test] + async fn test_memcached_client_add_success() {} + + ///////////// + // replace // + ///////////// + + #[tokio::test] + async fn test_memcached_client_replace_no_servers() {} + + #[tokio::test] + async fn test_memcached_client_replace_success() {} + + /////////// + // touch // + /////////// + + #[tokio::test] + async fn test_memcached_client_touch_no_servers() {} + + #[tokio::test] + async fn test_memcached_client_touch_success() {} + + //////////// + // delete // + //////////// + + #[tokio::test] + async fn test_memcached_client_delete_no_servers() {} + + #[tokio::test] + async fn test_memcached_client_delete_success() {} +} diff --git a/mtop-client/src/core.rs b/mtop-client/src/core.rs index 3675d15..af402e3 100644 --- a/mtop-client/src/core.rs +++ b/mtop-client/src/core.rs @@ -1,3 +1,4 @@ +use std::borrow::Borrow; use std::cmp::Ordering; use std::collections::{BTreeSet, HashMap}; use std::error; @@ -8,8 +9,6 @@ use std::str::FromStr; use std::time::Duration; use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, Lines}; -const MAX_KEY_LENGTH: usize = 250; - #[derive(Debug, Default, PartialEq, Clone)] pub struct Stats { // Server info @@ -639,18 +638,18 @@ impl error::Error for ProtocolError {} #[derive(Debug, Eq, PartialEq, Clone)] enum Command<'a> { - Add(&'a str, u64, u32, &'a [u8]), + Add(&'a Key, u64, u32, &'a [u8]), CrawlerMetadump, - Decr(&'a str, u64), - Delete(&'a str), - Gets(&'a [String]), - Incr(&'a str, u64), - Replace(&'a str, u64, u32, &'a [u8]), + Decr(&'a Key, u64), + Delete(&'a Key), + Gets(&'a [Key]), + Incr(&'a Key, u64), + Replace(&'a Key, u64, u32, &'a [u8]), Stats, StatsItems, StatsSlabs, - Set(&'a str, u64, u32, &'a [u8]), - Touch(&'a str, u32), + Set(&'a Key, u64, u32, &'a [u8]), + Touch(&'a Key, u32), Version, } @@ -674,7 +673,7 @@ impl<'a> From> for Vec { } } -fn storage_command(verb: &str, key: &str, flags: u64, ttl: u32, data: &[u8]) -> Vec { +fn storage_command(verb: &str, key: &Key, flags: u64, ttl: u32, data: &[u8]) -> Vec { let mut bytes = Vec::with_capacity(key.len() + data.len() + 32); io::Write::write_all( &mut bytes, @@ -836,22 +835,8 @@ impl Memcached { /// Get a map of the requested keys and their corresponding `Value` in the cache /// including the key, flags, and data. - pub async fn get(&mut self, keys: I) -> Result, MtopError> - where - I: IntoIterator, - K: Into, - { - let keys: Vec = keys.into_iter().map(|k| k.into()).collect(); - - if keys.is_empty() { - return Err(MtopError::runtime("missing required keys")); - } - - if !validate_keys(&keys) { - return Err(MtopError::runtime("invalid keys")); - } - - self.send(Command::Gets(&keys)).await?; + pub async fn get(&mut self, keys: &[Key]) -> Result, MtopError> { + self.send(Command::Gets(keys)).await?; let mut out = HashMap::with_capacity(keys.len()); loop { @@ -905,15 +890,8 @@ impl Memcached { /// Increment the value of a key by the given delta if the value is numeric returning /// the new value. Returns an error if the value is _not_ numeric. - pub async fn incr(&mut self, key: K, delta: u64) -> Result - where - K: AsRef, - { - if !validate_key(key.as_ref()) { - return Err(MtopError::runtime("invalid key")); - } - - self.send(Command::Incr(key.as_ref(), delta)).await?; + pub async fn incr(&mut self, key: &Key, delta: u64) -> Result { + self.send(Command::Incr(key, delta)).await?; if let Some(v) = self.read.next_line().await? { Self::parse_numeric_response(&v) } else { @@ -923,15 +901,8 @@ impl Memcached { /// Decrement the value of a key by the given delta if the value is numeric returning /// the new value with a minimum of 0. Returns an error if the value is _not_ numeric. - pub async fn decr(&mut self, key: K, delta: u64) -> Result - where - K: AsRef, - { - if !validate_key(key.as_ref()) { - return Err(MtopError::runtime("invalid key")); - } - - self.send(Command::Decr(key.as_ref(), delta)).await?; + pub async fn decr(&mut self, key: &Key, delta: u64) -> Result { + self.send(Command::Decr(key, delta)).await?; if let Some(v) = self.read.next_line().await? { Self::parse_numeric_response(&v) } else { @@ -961,16 +932,11 @@ impl Memcached { } /// Store the provided item in the cache, regardless of whether it already exists. - pub async fn set(&mut self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError> + pub async fn set(&mut self, key: &Key, flags: u64, ttl: u32, data: V) -> Result<(), MtopError> where - K: AsRef, V: AsRef<[u8]>, { - if !validate_key(key.as_ref()) { - return Err(MtopError::runtime("invalid key")); - } - - self.send(Command::Set(key.as_ref(), flags, ttl, data.as_ref())).await?; + self.send(Command::Set(key, flags, ttl, data.as_ref())).await?; if let Some(v) = self.read.next_line().await? { Self::parse_simple_response(&v, "STORED") } else { @@ -979,16 +945,11 @@ impl Memcached { } /// Store the provided item in the cache only if it does not already exist. - pub async fn add(&mut self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError> + pub async fn add(&mut self, key: &Key, flags: u64, ttl: u32, data: V) -> Result<(), MtopError> where - K: AsRef, V: AsRef<[u8]>, { - if !validate_key(key.as_ref()) { - return Err(MtopError::runtime("invalid key")); - } - - self.send(Command::Add(key.as_ref(), flags, ttl, data.as_ref())).await?; + self.send(Command::Add(key, flags, ttl, data.as_ref())).await?; if let Some(v) = self.read.next_line().await? { Self::parse_simple_response(&v, "STORED") } else { @@ -997,17 +958,11 @@ impl Memcached { } /// Store the provided item in the cache only if it already exists. - pub async fn replace(&mut self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError> + pub async fn replace(&mut self, key: &Key, flags: u64, ttl: u32, data: V) -> Result<(), MtopError> where - K: AsRef, V: AsRef<[u8]>, { - if !validate_key(key.as_ref()) { - return Err(MtopError::runtime("invalid key")); - } - - self.send(Command::Replace(key.as_ref(), flags, ttl, data.as_ref())) - .await?; + self.send(Command::Replace(key, flags, ttl, data.as_ref())).await?; if let Some(v) = self.read.next_line().await? { Self::parse_simple_response(&v, "STORED") } else { @@ -1016,15 +971,8 @@ impl Memcached { } /// Update the TTL of an item in the cache if it exists, return an error otherwise. - pub async fn touch(&mut self, key: K, ttl: u32) -> Result<(), MtopError> - where - K: AsRef, - { - if !validate_key(key.as_ref()) { - return Err(MtopError::runtime("invalid key")); - } - - self.send(Command::Touch(key.as_ref(), ttl)).await?; + pub async fn touch(&mut self, key: &Key, ttl: u32) -> Result<(), MtopError> { + self.send(Command::Touch(key, ttl)).await?; if let Some(v) = self.read.next_line().await? { Self::parse_simple_response(&v, "TOUCHED") } else { @@ -1033,15 +981,8 @@ impl Memcached { } /// Delete an item in the cache if it exists, return an error otherwise. - pub async fn delete(&mut self, key: K) -> Result<(), MtopError> - where - K: AsRef, - { - if !validate_key(key.as_ref()) { - return Err(MtopError::runtime("invalid key")); - } - - self.send(Command::Delete(key.as_ref())).await?; + pub async fn delete(&mut self, key: &Key) -> Result<(), MtopError> { + self.send(Command::Delete(key)).await?; if let Some(v) = self.read.next_line().await? { Self::parse_simple_response(&v, "DELETED") } else { @@ -1090,69 +1031,125 @@ impl fmt::Debug for Memcached { } } -/// Return true if the key is legal to use with Memcached, false otherwise -fn validate_key(key: &str) -> bool { - if key.len() > MAX_KEY_LENGTH { - return false; +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[repr(transparent)] +pub struct Key(String); + +impl Key { + const MAX_LENGTH: usize = 250; + + pub fn one(val: T) -> Result + where + T: Into, + { + let val = val.into(); + if !Self::is_legal_val(&val) { + Err(MtopError::runtime(format!("invalid key {}", val))) + } else { + Ok(Key(val)) + } + } + + pub fn many(vals: I) -> Result, MtopError> + where + I: IntoIterator, + T: Into, + { + let mut out = Vec::new(); + for val in vals { + out.push(Self::one(val)?); + } + + Ok(out) + } + + pub fn len(&self) -> usize { + self.0.len() } - for c in key.chars() { - if !c.is_ascii() || c.is_ascii_whitespace() || c.is_ascii_control() { + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + fn is_legal_val(val: &str) -> bool { + if val.len() > Self::MAX_LENGTH { return false; } + + for c in val.chars() { + if !c.is_ascii() || c.is_ascii_whitespace() || c.is_ascii_control() { + return false; + } + } + + true } +} - true +impl fmt::Display for Key { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } } -/// Return true if all keys are legal to use with Memcached, false otherwise -fn validate_keys(keys: &[String]) -> bool { - for key in keys { - if !validate_key(key) { - return false; - } +impl AsRef for Key { + fn as_ref(&self) -> &str { + &self.0 } +} - true +impl Borrow for Key { + fn borrow(&self) -> &str { + &self.0 + } } #[cfg(test)] mod test { - use super::{validate_key, ErrorKind, Memcached, Meta, Slab, SlabItem, SlabItems, MAX_KEY_LENGTH}; + use super::{ErrorKind, Key, Memcached, Meta, Slab, SlabItem, SlabItems}; use std::io::{Cursor, Error}; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::AsyncWrite; use tokio::sync::mpsc::{self, UnboundedSender}; + ///////// + // key // + ///////// + #[test] - fn test_validate_key_length() { - let key = "abc".repeat(MAX_KEY_LENGTH); - assert!(!validate_key(&key)); + fn test_key_one_length() { + let val = "abc".repeat(Key::MAX_LENGTH); + let res = Key::one(val); + assert!(res.is_err()); } #[test] - fn test_validate_key_non_ascii() { - let key = "🤦"; - assert!(!validate_key(key)); + fn test_key_one_non_ascii() { + let val = "🤦"; + let res = Key::one(val); + assert!(res.is_err()); } #[test] - fn test_validate_key_whitespace() { - let key = "some thing"; - assert!(!validate_key(key)) + fn test_key_one_whitespace() { + let val = "some thing"; + let res = Key::one(val); + assert!(res.is_err()); } #[test] - fn test_validate_key_control_char() { - let key = "\x7F"; - assert!(!validate_key(key)); + fn test_key_one_control_char() { + let val = "\x7F"; + let res = Key::one(val); + assert!(res.is_err()); } #[test] - fn test_validate_key_success() { - let key = "a-reasonable-key"; - assert!(validate_key(key)); + fn test_key_one_success() { + let val = "a-reasonable-key"; + let res = Key::one(val); + assert!(res.is_ok()); } struct WriteAdapter { @@ -1192,31 +1189,24 @@ mod test { }) } - #[tokio::test] - async fn test_get_no_key() { - let (_rx, mut client) = client!(); - let keys: Vec = vec![]; - let res = client.get(keys).await; - - assert!(res.is_err()); - let err = res.unwrap_err(); - assert_eq!(ErrorKind::Runtime, err.kind()); - } + ///////// + // get // + ///////// #[tokio::test] - async fn test_get_bad_key() { + async fn test_memcached_get_no_key() { let (_rx, mut client) = client!(); + let vals: Vec = vec![]; + let keys = Key::many(vals).unwrap(); + let res = client.get(&keys).await.unwrap(); - let res = client.get(&["bad key".repeat(MAX_KEY_LENGTH)]).await; - assert!(res.is_err()); - let err = res.unwrap_err(); - assert_eq!(ErrorKind::Runtime, err.kind()); + assert!(res.is_empty()); } #[tokio::test] - async fn test_get_error() { + async fn test_memcached_get_error() { let (_rx, mut client) = client!("SERVER_ERROR backend failure\r\n"); - let keys = vec!["foo".to_owned(), "baz".to_owned()]; + let keys = Key::many(vec!["foo", "baz"]).unwrap(); let res = client.get(&keys).await; assert!(res.is_err()); @@ -1225,16 +1215,16 @@ mod test { } #[tokio::test] - async fn test_get_miss() { + async fn test_memcached_get_miss() { let (_rx, mut client) = client!("END\r\n"); - let keys = vec!["foo".to_owned(), "baz".to_owned()]; + let keys = Key::many(vec!["foo", "baz"]).unwrap(); let res = client.get(&keys).await.unwrap(); assert!(res.is_empty()); } #[tokio::test] - async fn test_get_hit() { + async fn test_memcached_get_hit() { let (_rx, mut client) = client!( "VALUE foo 32 3 1\r\n", "bar\r\n", @@ -1242,7 +1232,7 @@ mod test { "qux\r\n", "END\r\n", ); - let keys = vec!["foo".to_owned(), "baz".to_owned()]; + let keys = Key::many(vec!["foo", "baz"]).unwrap(); let res = client.get(&keys).await.unwrap(); let val1 = res.get("foo").unwrap(); @@ -1258,20 +1248,15 @@ mod test { assert_eq!(2, val2.cas); } - #[tokio::test] - async fn test_incr_bad_key() { - let (_rx, mut client) = client!(); - let res = client.incr("bad key", 1).await; - - assert!(res.is_err()); - let err = res.unwrap_err(); - assert_eq!(ErrorKind::Runtime, err.kind()); - } + ////////// + // incr // + ////////// #[tokio::test] - async fn test_incr_bad_val() { + async fn test_memcached_incr_bad_val() { let (mut rx, mut client) = client!("CLIENT_ERROR cannot increment or decrement non-numeric value\r\n"); - let res = client.incr("test", 2).await; + let key = Key::one("test").unwrap(); + let res = client.incr(&key, 2).await; assert!(res.is_err()); let err = res.unwrap_err(); @@ -1283,9 +1268,10 @@ mod test { } #[tokio::test] - async fn test_incr_success() { + async fn test_memcached_incr_success() { let (mut rx, mut client) = client!("3\r\n"); - let res = client.incr("test", 2).await.unwrap(); + let key = Key::one("test").unwrap(); + let res = client.incr(&key, 2).await.unwrap(); assert_eq!(3, res); let bytes = rx.recv().await.unwrap(); @@ -1293,20 +1279,15 @@ mod test { assert_eq!("incr test 2\r\n", command); } - #[tokio::test] - async fn test_decr_bad_key() { - let (_rx, mut client) = client!(); - let res = client.decr("bad key", 1).await; - - assert!(res.is_err()); - let err = res.unwrap_err(); - assert_eq!(ErrorKind::Runtime, err.kind()); - } + ////////// + // decr // + ////////// #[tokio::test] - async fn test_decr_bad_val() { + async fn test_memcached_decr_bad_val() { let (mut rx, mut client) = client!("CLIENT_ERROR cannot increment or decrement non-numeric value\r\n"); - let res = client.decr("test", 1).await; + let key = Key::one("test").unwrap(); + let res = client.decr(&key, 1).await; assert!(res.is_err()); let err = res.unwrap_err(); @@ -1318,9 +1299,10 @@ mod test { } #[tokio::test] - async fn test_decr_success() { + async fn test_memcached_decr_success() { let (mut rx, mut client) = client!("3\r\n"); - let res = client.decr("test", 1).await.unwrap(); + let key = Key::one("test").unwrap(); + let res = client.decr(&key, 1).await.unwrap(); assert_eq!(3, res); let bytes = rx.recv().await.unwrap(); @@ -1331,7 +1313,9 @@ mod test { macro_rules! test_store_command_success { ($method:ident, $verb:expr) => { let (mut rx, mut client) = client!("STORED\r\n"); - let res = client.$method("test", 0, 300, "val".as_bytes()).await; + let res = client + .$method(&Key::one("test").unwrap(), 0, 300, "val".as_bytes()) + .await; assert!(res.is_ok()); let bytes = rx.recv().await.unwrap(); @@ -1340,21 +1324,12 @@ mod test { }; } - macro_rules! test_store_command_bad_key { - ($method:ident) => { - let (_rx, mut client) = client!(); - let res = client.$method("bad key", 0, 300, "val".as_bytes()).await; - - assert!(res.is_err()); - let err = res.unwrap_err(); - assert_eq!(ErrorKind::Runtime, err.kind()); - }; - } - macro_rules! test_store_command_error { ($method:ident, $verb:expr) => { let (mut rx, mut client) = client!("NOT_STORED\r\n"); - let res = client.$method("test", 0, 300, "val".as_bytes()).await; + let res = client + .$method(&Key::one("test").unwrap(), 0, 300, "val".as_bytes()) + .await; assert!(res.is_err()); let err = res.unwrap_err(); @@ -1366,55 +1341,57 @@ mod test { }; } - #[tokio::test] - async fn test_set_success() { - test_store_command_success!(set, "set"); - } + ///////// + // set // + ///////// #[tokio::test] - async fn test_set_bad_key() { - test_store_command_bad_key!(set); + async fn test_memcached_set_success() { + test_store_command_success!(set, "set"); } #[tokio::test] - async fn test_set_error() { + async fn test_memcached_set_error() { test_store_command_error!(set, "set"); } - #[tokio::test] - async fn test_add_success() { - test_store_command_success!(add, "add"); - } + ///////// + // add // + ///////// #[tokio::test] - async fn test_add_bad_key() { - test_store_command_bad_key!(add); + async fn test_memcached_add_success() { + test_store_command_success!(add, "add"); } #[tokio::test] - async fn test_add_error() { + async fn test_memcached_add_error() { test_store_command_error!(add, "add"); } - #[tokio::test] - async fn test_replace_success() { - test_store_command_success!(replace, "replace"); - } + ///////////// + // replace // + ///////////// #[tokio::test] - async fn test_replace_bad_key() { - test_store_command_bad_key!(replace); + async fn test_memcached_replace_success() { + test_store_command_success!(replace, "replace"); } #[tokio::test] - async fn test_replace_error() { + async fn test_memcached_replace_error() { test_store_command_error!(replace, "replace"); } + /////////// + // touch // + /////////// + #[tokio::test] - async fn test_touch_success() { + async fn test_memcached_touch_success() { let (mut rx, mut client) = client!("TOUCHED\r\n"); - let res = client.touch("test", 300).await; + let key = Key::one("test").unwrap(); + let res = client.touch(&key, 300).await; assert!(res.is_ok()); let bytes = rx.recv().await.unwrap(); @@ -1423,19 +1400,10 @@ mod test { } #[tokio::test] - async fn test_touch_bad_key() { - let (_rx, mut client) = client!(); - let res = client.touch("bad key", 300).await; - - assert!(res.is_err()); - let err = res.unwrap_err(); - assert_eq!(ErrorKind::Runtime, err.kind()); - } - - #[tokio::test] - async fn test_touch_error() { + async fn test_memcached_touch_error() { let (mut rx, mut client) = client!("NOT_FOUND\r\n"); - let res = client.touch("test", 300).await; + let key = Key::one("test").unwrap(); + let res = client.touch(&key, 300).await; assert!(res.is_err()); let err = res.unwrap_err(); @@ -1446,10 +1414,15 @@ mod test { assert_eq!("touch test 300\r\n", command); } + //////////// + // delete // + //////////// + #[tokio::test] - async fn test_delete_success() { + async fn test_memcached_delete_success() { let (mut rx, mut client) = client!("DELETED\r\n"); - let res = client.delete("test").await; + let key = Key::one("test").unwrap(); + let res = client.delete(&key).await; assert!(res.is_ok()); let bytes = rx.recv().await.unwrap(); @@ -1458,19 +1431,10 @@ mod test { } #[tokio::test] - async fn test_delete_bad_key() { - let (_rx, mut client) = client!(); - let res = client.delete("bad key").await; - - assert!(res.is_err()); - let err = res.unwrap_err(); - assert_eq!(ErrorKind::Runtime, err.kind()); - } - - #[tokio::test] - async fn test_delete_error() { + async fn test_memcached_delete_error() { let (mut rx, mut client) = client!("NOT_FOUND\r\n"); - let res = client.delete("test").await; + let key = Key::one("test").unwrap(); + let res = client.delete(&key).await; assert!(res.is_err()); let err = res.unwrap_err(); @@ -1481,8 +1445,12 @@ mod test { assert_eq!("delete test\r\n", command); } + /////////// + // stats // + /////////// + #[tokio::test] - async fn test_stats_empty() { + async fn test_memcached_stats_empty() { let (_rx, mut client) = client!("END\r\n"); let res = client.stats().await; @@ -1492,7 +1460,7 @@ mod test { } #[tokio::test] - async fn test_stats_error() { + async fn test_memcached_stats_error() { let (_rx, mut client) = client!("SERVER_ERROR backend failure\r\n"); let res = client.stats().await; @@ -1502,7 +1470,7 @@ mod test { } #[tokio::test] - async fn test_stats_success() { + async fn test_memcached_stats_success() { let (_rx, mut client) = client!( "STAT pid 1525\r\n", "STAT uptime 271984\r\n", @@ -1596,6 +1564,7 @@ mod test { "STAT moves_within_lru 0\r\n", "STAT direct_reclaims 0\r\n", "STAT lru_bumps_dropped 0\r\n", + "END\r\n", ); let res = client.stats().await.unwrap(); @@ -1605,8 +1574,12 @@ mod test { assert_eq!(0, res.get_hits); } + /////////// + // slabs // + /////////// + #[tokio::test] - async fn test_slabs_empty() { + async fn test_memcached_slabs_empty() { let (_rx, mut client) = client!("STAT active_slabs 0\r\n", "STAT total_malloced 0\r\n", "END\r\n"); let res = client.slabs().await.unwrap(); @@ -1614,7 +1587,7 @@ mod test { } #[tokio::test] - async fn test_slabs_error() { + async fn test_memcached_slabs_error() { let (_rx, mut client) = client!("ERROR Too many open connections\r\n"); let res = client.slabs().await; @@ -1624,7 +1597,7 @@ mod test { } #[tokio::test] - async fn test_slabs_success() { + async fn test_memcached_slabs_success() { let (_rx, mut client) = client!( "STAT 6:chunk_size 304\r\n", "STAT 6:chunks_per_page 3449\r\n", @@ -1658,6 +1631,7 @@ mod test { "STAT 7:touch_hits 0\r\n", "STAT active_slabs 2\r\n", "STAT total_malloced 30408704\r\n", + "END\r\n", ); let res = client.slabs().await.unwrap(); @@ -1701,8 +1675,12 @@ mod test { assert_eq!(expected, res.slabs); } + /////////// + // items // + /////////// + #[tokio::test] - async fn test_items_empty() { + async fn test_memcached_items_empty() { let (_rx, mut client) = client!(); let res = client.items().await.unwrap(); @@ -1710,7 +1688,7 @@ mod test { } #[tokio::test] - async fn test_items_error() { + async fn test_memcached_items_error() { let (_rx, mut client) = client!("ERROR Too many open connections\r\n"); let res = client.items().await; @@ -1720,7 +1698,7 @@ mod test { } #[tokio::test] - async fn test_items_success() { + async fn test_memcached_items_success() { let (_rx, mut client) = client!( "STAT items:39:number 3\r\n", "STAT items:39:number_hot 0\r\n", @@ -1791,8 +1769,12 @@ mod test { assert_eq!(expected, res); } + ////////// + // meta // + ////////// + #[tokio::test] - async fn test_metas_empty() { + async fn test_memcached_metas_empty() { let (_rx, mut client) = client!(); let res = client.metas().await.unwrap(); @@ -1800,7 +1782,7 @@ mod test { } #[tokio::test] - async fn test_metas_error() { + async fn test_memcached_metas_error() { let (_rx, mut client) = client!("BUSY crawler is busy\r\n",); let res = client.metas().await; @@ -1810,7 +1792,7 @@ mod test { } #[tokio::test] - async fn test_metas_success() { + async fn test_memcached_metas_success() { let (_rx, mut client) = client!( "key=memcached%2Fmurmur3_hash.c exp=1687216956 la=1687216656 cas=259502 fetch=yes cls=17 size=2912\r\n", "key=memcached%2Fmd5.h exp=1687216956 la=1687216656 cas=259731 fetch=yes cls=17 size=3593\r\n", diff --git a/mtop-client/src/lib.rs b/mtop-client/src/lib.rs index 4932635..ff957d4 100644 --- a/mtop-client/src/lib.rs +++ b/mtop-client/src/lib.rs @@ -1,10 +1,12 @@ +mod client; mod core; mod pool; mod timeout; +pub use crate::client::{MemcachedClient, SelectorRendezvous, ServersResponse, ValuesResponse}; pub use crate::core::{ - ErrorKind, Memcached, Meta, MtopError, ProtocolError, ProtocolErrorKind, Slab, SlabItem, SlabItems, Slabs, Stats, - Value, + ErrorKind, Key, Memcached, Meta, MtopError, ProtocolError, ProtocolErrorKind, Slab, SlabItem, SlabItems, Slabs, + Stats, Value, }; -pub use crate::pool::{MemcachedPool, PoolConfig, PooledMemcached, TLSConfig}; +pub use crate::pool::{MemcachedPool, PoolConfig, PooledMemcached, Server, TLSConfig}; pub use crate::timeout::{Timed, Timeout}; diff --git a/mtop-client/src/pool.rs b/mtop-client/src/pool.rs index 009e247..bd046e0 100644 --- a/mtop-client/src/pool.rs +++ b/mtop-client/src/pool.rs @@ -14,6 +14,24 @@ use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName} use tokio_rustls::rustls::{ClientConfig, RootCertStore}; use tokio_rustls::TlsConnector; +/// An individual server that is part of a Memcached cluster. +/// +/// The name of the server is expected to be a hostname or IP and port combination, +/// something that can be successfully converted to a `SocketAddr`. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct Server { + pub name: String, +} + +impl Server { + pub fn new(name: S) -> Self + where + S: Into, + { + Self { name: name.into() } + } +} + #[derive(Debug)] pub struct PooledMemcached { inner: Memcached, @@ -46,7 +64,7 @@ impl Default for PoolConfig { fn default() -> Self { Self { max_idle_per_host: 4, - check_on_get: true, + check_on_get: false, check_on_put: true, tls: TLSConfig::default(), } @@ -64,7 +82,7 @@ pub struct TLSConfig { #[derive(Debug)] pub struct MemcachedPool { - clients: Mutex>>, + connections: Mutex>>, client_config: Option>, server: Option>, config: PoolConfig, @@ -85,7 +103,7 @@ impl MemcachedPool { }; Ok(MemcachedPool { - clients: Mutex::new(HashMap::new()), + connections: Mutex::new(HashMap::new()), client_config, server, config, @@ -222,7 +240,7 @@ impl MemcachedPool { where F: Future>, { - let mut map = self.clients.lock().await; + let mut map = self.connections.lock().await; let mut inner = match map.get_mut(host).and_then(|v| v.pop()) { Some(c) => c, None => connect.await?, @@ -241,13 +259,13 @@ impl MemcachedPool { /// Return a connection to the pool if there are currently fewer than `max_idle_per_host` /// connections to the host this client is for. If there are more connections, the returned /// client is closed immediately. - pub async fn put(&self, mut client: PooledMemcached) { - if !self.config.check_on_put || client.ping().await.is_ok() { - let mut map = self.clients.lock().await; - let conns = map.entry(client.host).or_insert_with(Vec::new); + pub async fn put(&self, mut conn: PooledMemcached) { + if !self.config.check_on_put || conn.ping().await.is_ok() { + let mut map = self.connections.lock().await; + let conns = map.entry(conn.host).or_default(); if conns.len() < self.config.max_idle_per_host { - conns.push(client.inner); + conns.push(conn.inner); } } } @@ -304,9 +322,8 @@ mod test { #[tokio::test] async fn test_get_new_connection() { - let pool = MemcachedPool::new(Handle::current(), PoolConfig::default()) - .await - .unwrap(); + let cfg = PoolConfig::default(); + let pool = MemcachedPool::new(Handle::current(), cfg).await.unwrap(); let connect = async { Ok(client!( @@ -322,9 +339,8 @@ mod test { #[tokio::test] async fn test_get_existing_connection() { - let pool = MemcachedPool::new(Handle::current(), PoolConfig::default()) - .await - .unwrap(); + let cfg = PoolConfig::default(); + let pool = MemcachedPool::new(Handle::current(), cfg).await.unwrap(); pool.put(PooledMemcached { host: "localhost:11211".to_owned(), @@ -339,9 +355,14 @@ mod test { #[tokio::test] async fn test_get_dead_connection() { - let pool = MemcachedPool::new(Handle::current(), PoolConfig::default()) - .await - .unwrap(); + let cfg = PoolConfig { + max_idle_per_host: 4, + check_on_get: true, + check_on_put: true, + ..Default::default() + }; + + let pool = MemcachedPool::new(Handle::current(), cfg).await.unwrap(); pool.put(PooledMemcached { host: "localhost:11211".to_owned(), @@ -359,9 +380,8 @@ mod test { #[tokio::test] async fn test_get_error() { - let pool = MemcachedPool::new(Handle::current(), PoolConfig::default()) - .await - .unwrap(); + let cfg = PoolConfig::default(); + let pool = MemcachedPool::new(Handle::current(), cfg).await.unwrap(); let connect = async { Err(MtopError::from(io::Error::new(io::ErrorKind::TimedOut, "timeout"))) }; let res = pool.get_with_connect("localhost:11211", connect).await; diff --git a/mtop-client/src/timeout.rs b/mtop-client/src/timeout.rs index b64dd2a..974d0a1 100644 --- a/mtop-client/src/timeout.rs +++ b/mtop-client/src/timeout.rs @@ -1,4 +1,4 @@ -use crate::MtopError; +use crate::core::MtopError; use pin_project_lite::pin_project; use std::future::Future; use std::pin::Pin; diff --git a/mtop/Cargo.toml b/mtop/Cargo.toml index 34f3797..b72e3fa 100644 --- a/mtop/Cargo.toml +++ b/mtop/Cargo.toml @@ -11,7 +11,7 @@ keywords = ["top", "memcached"] edition = "2021" [dependencies] -clap = { version = "4.1.8", features = ["cargo", "derive", "help", "error-context", "std", "usage", "wrap_help"], default_features = false } +clap = { version = "4.1.8", features = ["cargo", "derive", "help", "error-context", "std", "string", "usage", "wrap_help"], default_features = false } crossterm = "0.27.0" mtop-client = { path = "../mtop-client", version = "0.7.0" } ratatui = "0.24.0" diff --git a/mtop/src/bin/mc.rs b/mtop/src/bin/mc.rs index 23589cc..7d6066d 100644 --- a/mtop/src/bin/mc.rs +++ b/mtop/src/bin/mc.rs @@ -1,15 +1,19 @@ use clap::{Args, Parser, Subcommand, ValueHint}; use mtop::check::{Checker, MeasurementBundle}; -use mtop_client::{MemcachedPool, Meta, PoolConfig, TLSConfig, Value}; +use mtop_client::{ + MemcachedClient, MemcachedPool, Meta, MtopError, PoolConfig, SelectorRendezvous, Server, TLSConfig, Timeout, Value, +}; use std::path::PathBuf; +use std::process::ExitCode; use std::time::Duration; -use std::{env, error, io, process}; +use std::{env, io}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}; use tokio::runtime::Handle; -use tracing::Level; +use tracing::{Instrument, Level}; const DEFAULT_LOG_LEVEL: Level = Level::INFO; const DEFAULT_HOST: &str = "localhost:11211"; +const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); /// mc: memcached command line utility #[derive(Debug, Parser)] @@ -203,132 +207,234 @@ struct TouchCommand { } #[tokio::main] -async fn main() -> Result<(), Box> { +async fn main() -> ExitCode { let opts = McConfig::parse(); - let console_subscriber = mtop::tracing::console_subscriber(opts.log_level)?; + let console_subscriber = + mtop::tracing::console_subscriber(opts.log_level).expect("failed to setup console logging"); tracing::subscriber::set_global_default(console_subscriber).expect("failed to initialize console logging"); - let pool = MemcachedPool::new( + let client = match new_client(&opts).await { + Ok(v) => v, + Err(e) => { + tracing::error!(message = "unable to initialize memcached client", host = opts.host, err = %e); + return ExitCode::FAILURE; + } + }; + + // Hardcoded timeout so that we can ensure the host is actually up. + if let Err(e) = connect(&client, CONNECT_TIMEOUT).await { + tracing::error!(message = "unable to connect", host = opts.host, err = %e); + return ExitCode::FAILURE; + }; + + match &opts.mode { + Action::Add(cmd) => run_add(&opts, cmd, &client).await, + Action::Check(cmd) => run_check(&opts, cmd, &client).await, + Action::Decr(cmd) => run_decr(&opts, cmd, &client).await, + Action::Delete(cmd) => run_delete(&opts, cmd, &client).await, + Action::Get(cmd) => run_get(&opts, cmd, &client).await, + Action::Incr(cmd) => run_incr(&opts, cmd, &client).await, + Action::Keys(cmd) => run_keys(&opts, cmd, &client).await, + Action::Replace(cmd) => run_replace(&opts, cmd, &client).await, + Action::Set(cmd) => run_set(&opts, cmd, &client).await, + Action::Touch(cmd) => run_touch(&opts, cmd, &client).await, + } +} + +async fn new_client(opts: &McConfig) -> Result { + let tls = TLSConfig { + enabled: opts.tls_enabled, + ca_path: opts.tls_ca.clone(), + cert_path: opts.tls_cert.clone(), + key_path: opts.tls_key.clone(), + server_name: opts.tls_server_name.clone(), + }; + + MemcachedPool::new( Handle::current(), PoolConfig { - tls: TLSConfig { - enabled: opts.tls_enabled, - ca_path: opts.tls_ca, - cert_path: opts.tls_cert, - key_path: opts.tls_key, - server_name: opts.tls_server_name, - }, + tls, ..Default::default() }, ) .await - .unwrap_or_else(|e| { - tracing::error!(message = "unable to initialize memcached client", host = opts.host, error = %e); - process::exit(1); - }); - - let mut client = pool.get(&opts.host).await.unwrap_or_else(|e| { - tracing::error!(message = "unable to connect", host = opts.host, error = %e); - process::exit(1); - }); - - match opts.mode { - Action::Add(c) => { - let buf = read_input().await.unwrap_or_else(|e| { - tracing::error!(message = "unable to read item data from stdin", error = %e); - process::exit(1); - }); - - if let Err(e) = client.add(&c.key, 0, c.ttl, &buf).await { - tracing::error!(message = "unable to add item", key = c.key, host = opts.host, error = %e); - process::exit(1); - } - } - Action::Check(c) => { - let checker = Checker::new( - &pool, - Duration::from_millis(c.delay_millis), - Duration::from_secs(c.timeout_secs), - ); - let results = checker.run(&opts.host, Duration::from_secs(c.time_secs)).await; - if let Err(e) = print_check_results(&results).await { - tracing::warn!(message = "error writing output", error = %e); - } - } - Action::Decr(c) => { - if let Err(e) = client.decr(&c.key, c.delta).await { - tracing::error!(message = "unable to decrement value", key = c.key, host = opts.host, error = %e); - process::exit(1); - } - } - Action::Delete(c) => { - if let Err(e) = client.delete(&c.key).await { - tracing::error!(message = "unable to delete item", key = c.key, host = opts.host, error = %e); - process::exit(1); - } - } - Action::Get(c) => { - let results = client.get(&[c.key.clone()]).await.unwrap_or_else(|e| { - tracing::error!(message = "unable to get item", key = c.key, host = opts.host, error = %e); - process::exit(1); - }); - - if let Some(v) = results.get(&c.key) { - if let Err(e) = print_data(v).await { - tracing::warn!(message = "error writing output", error = %e); - } - } + .map(|pool| { + let selector = SelectorRendezvous::new(vec![Server::new(&opts.host)]); + MemcachedClient::new(Handle::current(), selector, pool) + }) +} + +async fn connect(client: &MemcachedClient, timeout: Duration) -> Result<(), MtopError> { + let pings = client + .ping() + .timeout(timeout, "client.ping") + .instrument(tracing::span!(Level::INFO, "client.ping")) + .await?; + + if let Some((_server, err)) = pings.errors.into_iter().next() { + return Err(err); + } + + Ok(()) +} + +async fn run_add(opts: &McConfig, cmd: &AddCommand, client: &MemcachedClient) -> ExitCode { + let buf = match read_input().await { + Ok(v) => v, + Err(e) => { + tracing::error!(message = "unable to read item data from stdin", err = %e); + return ExitCode::FAILURE; } - Action::Incr(c) => { - if let Err(e) = client.incr(&c.key, c.delta).await { - tracing::error!(message = "unable to increment value", key = c.key, host = opts.host, error = %e); - process::exit(1); - } + }; + + if let Err(e) = client.add(&cmd.key, 0, cmd.ttl, &buf).await { + tracing::error!(message = "unable to add item", key = cmd.key, host = opts.host, err = %e); + ExitCode::FAILURE + } else { + ExitCode::SUCCESS + } +} + +async fn run_check(opts: &McConfig, cmd: &CheckCommand, client: &MemcachedClient) -> ExitCode { + let checker = Checker::new( + client, + Duration::from_millis(cmd.delay_millis), + Duration::from_secs(cmd.timeout_secs), + ); + let results = checker.run(&opts.host, Duration::from_secs(cmd.time_secs)).await; + if let Err(e) = print_check_results(&results).await { + tracing::warn!(message = "error writing output", err = %e); + } + + if results.failures.total > 0 { + ExitCode::FAILURE + } else { + ExitCode::SUCCESS + } +} + +async fn run_decr(opts: &McConfig, cmd: &DecrCommand, client: &MemcachedClient) -> ExitCode { + if let Err(e) = client.decr(&cmd.key, cmd.delta).await { + tracing::error!(message = "unable to decrement value", key = cmd.key, host = opts.host, err = %e); + ExitCode::FAILURE + } else { + ExitCode::SUCCESS + } +} + +async fn run_delete(opts: &McConfig, cmd: &DeleteCommand, client: &MemcachedClient) -> ExitCode { + if let Err(e) = client.delete(&cmd.key).await { + tracing::error!(message = "unable to delete item", key = cmd.key, host = opts.host, err = %e); + ExitCode::FAILURE + } else { + ExitCode::SUCCESS + } +} + +async fn run_get(opts: &McConfig, cmd: &GetCommand, client: &MemcachedClient) -> ExitCode { + let response = match client.get(&[cmd.key.clone()]).await { + Ok(v) => v, + Err(e) => { + tracing::error!(message = "unable to get item", key = cmd.key, host = opts.host, err = %e); + return ExitCode::FAILURE; } - Action::Keys(c) => { - let mut metas = client.metas().await.unwrap_or_else(|e| { - tracing::error!(message = "unable to list keys", host = opts.host, error = %e); - process::exit(1); - }); - - metas.sort(); - if let Err(e) = print_keys(&metas, c.details).await { - tracing::warn!(message = "error writing output", error = %e); - } + }; + + if let Some(v) = response.values.get(&cmd.key) { + if let Err(e) = print_data(v).await { + tracing::warn!(message = "error writing output", err = %e); } - Action::Replace(c) => { - let buf = read_input().await.unwrap_or_else(|e| { - tracing::error!(message = "unable to read item data from stdin", error = %e); - process::exit(1); - }); - - if let Err(e) = client.replace(&c.key, 0, c.ttl, &buf).await { - tracing::error!(message = "unable to replace item", key = c.key, host = opts.host, error = %e); - process::exit(1); - } + } + + for (server, e) in response.errors.iter() { + tracing::error!(message = "error fetching value", host = server.name, err = %e); + } + + if response.has_errors() { + ExitCode::FAILURE + } else { + ExitCode::SUCCESS + } +} + +async fn run_incr(opts: &McConfig, cmd: &IncrCommand, client: &MemcachedClient) -> ExitCode { + if let Err(e) = client.incr(&cmd.key, cmd.delta).await { + tracing::error!(message = "unable to increment value", key = cmd.key, host = opts.host, err = %e); + ExitCode::FAILURE + } else { + ExitCode::SUCCESS + } +} + +async fn run_keys(opts: &McConfig, cmd: &KeysCommand, client: &MemcachedClient) -> ExitCode { + let mut response = match client.metas().await { + Ok(v) => v, + Err(e) => { + tracing::error!(message = "unable to list keys", host = opts.host, err = %e); + return ExitCode::FAILURE; } - Action::Set(c) => { - let buf = read_input().await.unwrap_or_else(|e| { - tracing::error!(message = "unable to read item data from stdin", error = %e); - process::exit(1); - }); - - if let Err(e) = client.set(&c.key, 0, c.ttl, &buf).await { - tracing::error!(message = "unable to set item", key = c.key, host = opts.host, error = %e); - process::exit(1); - } + }; + + let mut metas = response.values.remove(&Server::new(&opts.host)).unwrap_or_default(); + metas.sort(); + + if let Err(e) = print_keys(&metas, cmd.details).await { + tracing::warn!(message = "error writing output", err = %e); + } + + for (server, e) in response.errors.iter() { + tracing::error!(message = "error fetching metas", host = server.name, err = %e); + } + + if response.has_errors() { + ExitCode::FAILURE + } else { + ExitCode::SUCCESS + } +} + +async fn run_replace(opts: &McConfig, cmd: &ReplaceCommand, client: &MemcachedClient) -> ExitCode { + let buf = match read_input().await { + Ok(v) => v, + Err(e) => { + tracing::error!(message = "unable to read item data from stdin", err = %e); + return ExitCode::FAILURE; } - Action::Touch(c) => { - if let Err(e) = client.touch(&c.key, c.ttl).await { - tracing::error!(message = "unable to touch item", key = c.key, host = opts.host, error = %e); - process::exit(1); - } + }; + + if let Err(e) = client.replace(&cmd.key, 0, cmd.ttl, &buf).await { + tracing::error!(message = "unable to replace item", key = cmd.key, host = opts.host, err = %e); + ExitCode::FAILURE + } else { + ExitCode::SUCCESS + } +} + +async fn run_set(opts: &McConfig, cmd: &SetCommand, client: &MemcachedClient) -> ExitCode { + let buf = match read_input().await { + Ok(v) => v, + Err(e) => { + tracing::error!(message = "unable to read item data from stdin", err = %e); + return ExitCode::FAILURE; } + }; + + if let Err(e) = client.set(&cmd.key, 0, cmd.ttl, &buf).await { + tracing::error!(message = "unable to set item", key = cmd.key, host = opts.host, err = %e); + ExitCode::FAILURE + } else { + ExitCode::SUCCESS } +} - pool.put(client).await; - Ok(()) +async fn run_touch(opts: &McConfig, cmd: &TouchCommand, client: &MemcachedClient) -> ExitCode { + if let Err(e) = client.touch(&cmd.key, cmd.ttl).await { + tracing::error!(message = "unable to touch item", key = cmd.key, host = opts.host, err = %e); + ExitCode::FAILURE + } else { + ExitCode::SUCCESS + } } async fn read_input() -> io::Result> { diff --git a/mtop/src/bin/mtop.rs b/mtop/src/bin/mtop.rs index 606b923..8197a58 100644 --- a/mtop/src/bin/mtop.rs +++ b/mtop/src/bin/mtop.rs @@ -1,11 +1,14 @@ use clap::{Parser, ValueHint}; use mtop::queue::{BlockingStatsQueue, StatsQueue}; -use mtop_client::{MemcachedPool, MtopError, PoolConfig, SlabItems, Slabs, Stats, TLSConfig, Timeout}; +use mtop_client::{ + MemcachedClient, MemcachedPool, MtopError, PoolConfig, SelectorRendezvous, Server, TLSConfig, Timeout, +}; +use std::env; use std::net::SocketAddr; use std::path::PathBuf; +use std::process::ExitCode; use std::sync::Arc; use std::time::Duration; -use std::{env, error, process}; use tokio::net::ToSocketAddrs; use tokio::runtime::Handle; use tokio::task; @@ -35,9 +38,8 @@ struct MtopConfig { /// File to log errors to since they cannot be logged to the console. If the path is not /// writable, mtop will not start. - /// [default: $TEMP/mtop/mtop.log] - #[arg(long, value_hint = ValueHint::FilePath)] - log_file: Option, + #[arg(long, default_value=default_log_file().into_os_string(), value_hint = ValueHint::FilePath)] + log_file: PathBuf, /// Enable TLS connections to the Memcached server. #[arg(long)] @@ -76,66 +78,61 @@ fn default_log_file() -> PathBuf { } #[tokio::main] -async fn main() -> Result<(), Box> { +async fn main() -> ExitCode { let opts = MtopConfig::parse(); - let console_subscriber = mtop::tracing::console_subscriber(opts.log_level)?; + let console_subscriber = + mtop::tracing::console_subscriber(opts.log_level).expect("failed to setup console logging"); tracing::subscriber::set_global_default(console_subscriber).expect("failed to initialize console logging"); // Create a file subscriber for log messages generated while the UI is running // since we can't log to stdout or stderr. - let log_file = opts.log_file.unwrap_or_else(default_log_file); - let file_subscriber = mtop::tracing::file_subscriber(opts.log_level, log_file) - .map(Arc::new) - .unwrap_or_else(|e| { + let file_subscriber = match mtop::tracing::file_subscriber(opts.log_level, &opts.log_file).map(Arc::new) { + Ok(v) => v, + Err(e) => { tracing::error!(message = "failed to initialize file logging", error = %e); - process::exit(1); - }); + return ExitCode::FAILURE; + } + }; let timeout = Duration::from_secs(opts.timeout_secs); let measurements = Arc::new(StatsQueue::new(NUM_MEASUREMENTS)); - let pool = MemcachedPool::new( - Handle::current(), - PoolConfig { - tls: TLSConfig { - enabled: opts.tls_enabled, - ca_path: opts.tls_ca, - cert_path: opts.tls_cert, - key_path: opts.tls_key, - server_name: opts.tls_server_name, - }, - ..Default::default() - }, - ) - .await - .unwrap_or_else(|e| { - tracing::error!(message = "unable to initialize memcached client", hosts = ?opts.hosts, error = %e); - process::exit(1); - }); // Do DNS lookups on any "dns+" hostnames to expand them to multiple IPs based on A records. - let hosts = expand_hosts(&opts.hosts, timeout).await.unwrap_or_else(|e| { - tracing::error!(message = "unable to resolve host names", hosts = ?opts.hosts, error = %e); - process::exit(1); - }); - - // Run the initial connection to each server once in the main thread to make bad hostnames - // easier to spot. - let update_task = UpdateTask::new(&hosts, pool, measurements.clone(), timeout, Handle::current()); - update_task.connect().await.unwrap_or_else(|e| { - tracing::error!(message = "unable to connect to memcached servers", hosts = ?opts.hosts, error = %e); - process::exit(1); - }); + let hosts = match expand_hosts(&opts.hosts, timeout).await { + Ok(v) => v, + Err(e) => { + tracing::error!(message = "unable to resolve host names", hosts = ?opts.hosts, error = %e); + return ExitCode::FAILURE; + } + }; + + let client = match new_client(&opts, &hosts).await { + Ok(v) => v, + Err(e) => { + tracing::error!(message = "unable to initialize memcached client", hosts = ?hosts, err = %e); + return ExitCode::FAILURE; + } + }; + + let update_task = UpdateTask::new(client, measurements.clone(), timeout); + if let Err(e) = update_task.connect().await { + tracing::error!(message = "unable to connect to memcached servers", hosts = ?hosts, err = %e); + return ExitCode::FAILURE; + } task::spawn( async move { let mut interval = tokio::time::interval(DEFAULT_STATS_INTERVAL); loop { let _ = interval.tick().await; - update_task + if let Err(e) = update_task .update() .instrument(tracing::span!(Level::INFO, "periodic.update")) - .await; + .await + { + tracing::error!(message = "unable to update server metrics", err = %e); + } } } .with_subscriber(file_subscriber.clone()), @@ -157,16 +154,14 @@ async fn main() -> Result<(), Box> { match ui_res { Err(e) => { tracing::error!(message = "unable to run UI in dedicated thread", error = %e); - process::exit(1); + ExitCode::FAILURE } Ok(Err(e)) => { tracing::error!(message = "error setting up terminal or running UI", error = %e); - process::exit(1); + ExitCode::FAILURE } - _ => {} + _ => ExitCode::SUCCESS, } - - Ok(()) } async fn expand_hosts(hosts: &[String], timeout: Duration) -> Result, MtopError> { @@ -199,91 +194,107 @@ where Ok(tokio::net::lookup_host(host).await?) } +async fn new_client(opts: &MtopConfig, hosts: &[String]) -> Result { + let tls = TLSConfig { + enabled: opts.tls_enabled, + ca_path: opts.tls_ca.clone(), + cert_path: opts.tls_cert.clone(), + key_path: opts.tls_key.clone(), + server_name: opts.tls_server_name.clone(), + }; + + MemcachedPool::new( + Handle::current(), + PoolConfig { + tls, + ..Default::default() + }, + ) + .await + .map(|pool| { + let selector = SelectorRendezvous::new(hosts.iter().map(Server::new).collect()); + MemcachedClient::new(Handle::current(), selector, pool) + }) +} + #[derive(Debug)] struct UpdateTask { - hosts: Vec, - pool: Arc, + client: MemcachedClient, queue: Arc, timeout: Duration, - handle: Handle, } impl UpdateTask { - fn new(hosts: &[String], pool: MemcachedPool, queue: Arc, timeout: Duration, handle: Handle) -> Self { - UpdateTask { - hosts: Vec::from(hosts), - pool: Arc::new(pool), - queue, - timeout, - handle, - } + fn new(client: MemcachedClient, queue: Arc, timeout: Duration) -> Self { + UpdateTask { client, queue, timeout } } async fn connect(&self) -> Result<(), MtopError> { - for host in self.hosts.iter() { - let client = self - .pool - .get(host) - .timeout(self.timeout, "client.connect") - .instrument(tracing::span!(Level::INFO, "client.connect")) - .await?; - self.pool.put(client).await; + let pings = self + .client + .ping() + .timeout(self.timeout, "client.ping") + .instrument(tracing::span!(Level::INFO, "client.ping")) + .await?; + + if let Some((_server, err)) = pings.errors.into_iter().next() { + return Err(err); } Ok(()) } - async fn update_host( - host: String, - pool: Arc, - timeout: Duration, - ) -> Result<(Stats, Slabs, SlabItems), MtopError> { - let mut client = pool - .get(&host) - .timeout(timeout, "client.connect") - .instrument(tracing::span!(Level::INFO, "client.connect")) - .await?; - let stats = client + async fn update(&self) -> Result<(), MtopError> { + let stats = self + .client .stats() - .timeout(timeout, "client.stats") + .timeout(self.timeout, "client.stats") .instrument(tracing::span!(Level::INFO, "client.stats")) .await?; - let slabs = client + + let mut slabs = self + .client .slabs() - .timeout(timeout, "client.slabs") + .timeout(self.timeout, "client.slabs") .instrument(tracing::span!(Level::INFO, "client.slabs")) .await?; - let items = client + + let mut items = self + .client .items() - .timeout(timeout, "client.items") + .timeout(self.timeout, "client.items") .instrument(tracing::span!(Level::INFO, "client.items")) .await?; - pool.put(client).await; - Ok((stats, slabs, items)) - } + for (server, stats) in stats.values { + let slabs = match slabs.values.remove(&server) { + Some(v) => v, + None => continue, + }; + + let items = match items.values.remove(&server) { + Some(v) => v, + None => continue, + }; + + self.queue + .insert(server.name, stats, slabs, items) + .instrument(tracing::span!(Level::INFO, "queue.insert")) + .await; + } - async fn update(&self) { - let mut tasks = Vec::with_capacity(self.hosts.len()); - for host in self.hosts.iter() { - tasks.push(( - host, - self.handle - .spawn(Self::update_host(host.clone(), self.pool.clone(), self.timeout)), - )); + for (server, e) in stats.errors { + tracing::warn!(message = "error fetching stats", host = server.name, err = %e); } - for (host, task) in tasks { - match task.await { - Err(e) => tracing::error!(message = "failed to run server update task", host = host, err = %e), - Ok(Err(e)) => tracing::warn!(message = "failed to update stats for server", host = host, err = %e), - Ok(Ok((stats, slabs, items))) => { - self.queue - .insert(host.clone(), stats, slabs, items) - .instrument(tracing::span!(Level::INFO, "queue.insert")) - .await; - } - } + for (server, e) in slabs.errors { + tracing::warn!(message = "error fetching slabs", host = server.name, err = %e); + } + + for (server, e) in items.errors { + tracing::warn!(message = "error fetching items", host = server.name, err = %e); } + + Ok(()) } } diff --git a/mtop/src/check.rs b/mtop/src/check.rs index d65c446..6ff9ba6 100644 --- a/mtop/src/check.rs +++ b/mtop/src/check.rs @@ -1,4 +1,4 @@ -use mtop_client::{MemcachedPool, MtopError}; +use mtop_client::{Key, MemcachedClient, MtopError}; use std::time::{Duration, Instant}; use std::{cmp, fmt}; use tokio::net::ToSocketAddrs; @@ -10,7 +10,7 @@ const VALUE: &[u8] = "test".as_bytes(); /// Repeatedly make connections to a Memcached server to verify connectivity. #[derive(Debug)] pub struct Checker<'a> { - pool: &'a MemcachedPool, + client: &'a MemcachedClient, delay: Duration, timeout: Duration, } @@ -20,8 +20,8 @@ impl<'a> Checker<'a> { /// is the amount of time to wait between each test. `timeout` is how long each individual /// part of the test may take (DNS resolution, connecting, setting a value, and fetching /// a value). - pub fn new(pool: &'a MemcachedPool, delay: Duration, timeout: Duration) -> Self { - Self { pool, delay, timeout } + pub fn new(client: &'a MemcachedClient, delay: Duration, timeout: Duration) -> Self { + Self { client, delay, timeout } } /// Perform connection tests for a particular hosts in a loop and return information @@ -38,7 +38,7 @@ impl<'a> Checker<'a> { let mut failures = Failures::default(); let start = Instant::now(); - // Note that we don't return the connection to the pool each iteration. This ensures + // Note that we don't return the connection to the client each iteration. This ensures // we're creating a new connection each time and thus actually testing the network // when doing the check. loop { @@ -67,7 +67,7 @@ impl<'a> Checker<'a> { let dns_time = dns_start.elapsed(); let conn_start = Instant::now(); - let mut conn = match time::timeout(self.timeout, self.pool.get(&ip_addr)).await { + let mut conn = match time::timeout(self.timeout, self.client.raw_open(&ip_addr)).await { Ok(Ok(v)) => v, Ok(Err(e)) => { tracing::warn!(message = "failed to connect to host", host = host, addr = ip_addr, err = %e); @@ -85,7 +85,7 @@ impl<'a> Checker<'a> { let conn_time = conn_start.elapsed(); let set_start = Instant::now(); - match time::timeout(self.timeout, conn.set(KEY.to_owned(), 0, 60, VALUE.to_vec())).await { + match time::timeout(self.timeout, conn.set(&Key::one(KEY).unwrap(), 0, 60, VALUE.to_vec())).await { Ok(Ok(_)) => {} Ok(Err(e)) => { tracing::warn!(message = "failed to set key", host = host, addr = ip_addr, err = %e); @@ -103,7 +103,7 @@ impl<'a> Checker<'a> { let set_time = set_start.elapsed(); let get_start = Instant::now(); - match time::timeout(self.timeout, conn.get(&[KEY.to_owned()])).await { + match time::timeout(self.timeout, conn.get(&Key::many(vec![KEY]).unwrap())).await { Ok(Ok(_)) => {} Ok(Err(e)) => { tracing::warn!(message = "failed to get key", host = host, addr = ip_addr, err = %e); diff --git a/mtop/src/queue.rs b/mtop/src/queue.rs index c240018..46ca696 100644 --- a/mtop/src/queue.rs +++ b/mtop/src/queue.rs @@ -32,9 +32,12 @@ impl StatsQueue { } } - pub async fn insert(&self, host: String, stats: Stats, slabs: Slabs, items: SlabItems) { + pub async fn insert(&self, host: H, stats: Stats, slabs: Slabs, items: SlabItems) + where + H: Into, + { let mut map = self.queues.lock().await; - let q = map.entry(host).or_insert_with(VecDeque::new); + let q = map.entry(host.into()).or_insert_with(VecDeque::new); if let Some(prev) = q.back() { if stats.uptime == prev.stats.uptime { @@ -81,9 +84,12 @@ impl StatsQueue { } } - pub async fn read_delta(&self, host: &str) -> Option { + pub async fn read_delta(&self, host: H) -> Option + where + H: AsRef, + { let map = self.queues.lock().await; - map.get(host).and_then(|q| match (q.front(), q.back()) { + map.get(host.as_ref()).and_then(|q| match (q.front(), q.back()) { // The delta is only valid if there are more than two entries in the queue. This // avoids division by zero errors (since the time for the entries would be the same). (Some(previous), Some(current)) if q.len() >= 2 => { @@ -110,11 +116,17 @@ impl BlockingStatsQueue { Self { queue, handle } } - pub fn insert(&self, host: String, stats: Stats, slabs: Slabs, items: SlabItems) { + pub fn insert(&self, host: H, stats: Stats, slabs: Slabs, items: SlabItems) + where + H: Into, + { self.handle.block_on(self.queue.insert(host, stats, slabs, items)) } - pub fn read_delta(&self, host: &str) -> Option { + pub fn read_delta(&self, host: H) -> Option + where + H: AsRef, + { self.handle.block_on(self.queue.read_delta(host)) } } diff --git a/mtop/src/tracing.rs b/mtop/src/tracing.rs index fa1cb46..2df0595 100644 --- a/mtop/src/tracing.rs +++ b/mtop/src/tracing.rs @@ -19,7 +19,7 @@ pub fn console_subscriber( #[allow(clippy::type_complexity)] pub fn file_subscriber( level: tracing::Level, - path: PathBuf, + path: &PathBuf, ) -> Result, io::Error> { if let Some(d) = path.parent() { fs::create_dir_all(d)?;