diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index df4eeda..ed4520f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -14,6 +14,9 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + # Temporary while https://github.com/rust-lang/rust-clippy/issues/12014 is affecting stable + - name: Rustup + run: rustup update nightly && rustup default nightly && rustup component add clippy rustfmt - name: Versions run: cargo --version && rustc --version - name: Build diff --git a/CHANGELOG.md b/CHANGELOG.md index 151bf02..4d69aa8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## v0.7.1 - unreleased - Add default 5 second timeout to network operations done by `mtop`. #90 +- Add `add` and `replace` commands to `mc. #95 - TLS related dependency updates. #93 ## v0.7.0 - 2023-11-28 diff --git a/mtop-client/src/core.rs b/mtop-client/src/core.rs index 0d05573..5adad2a 100644 --- a/mtop-client/src/core.rs +++ b/mtop-client/src/core.rs @@ -8,6 +8,8 @@ use std::str::FromStr; use std::time::Duration; use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, Lines}; +const MAX_KEY_LENGTH: usize = 250; + #[derive(Debug, Default, PartialEq, Clone)] pub struct Stats { // Server info @@ -598,6 +600,7 @@ pub enum ProtocolErrorKind { Busy, Client, NotFound, + NotStored, Server, Syntax, } @@ -609,6 +612,7 @@ impl fmt::Display for ProtocolErrorKind { Self::Busy => "BUSY".fmt(f), Self::Client => "CLIENT_ERROR".fmt(f), Self::NotFound => "NOT_FOUND".fmt(f), + Self::NotStored => "NOT_STORED".fmt(f), Self::Server => "SERVER_ERROR".fmt(f), Self::Syntax => "ERROR".fmt(f), } @@ -635,9 +639,11 @@ impl error::Error for ProtocolError {} #[derive(Debug, Eq, PartialEq, Clone)] enum Command<'a> { + Add(&'a str, u64, u32, &'a [u8]), CrawlerMetadump, Delete(&'a str), Gets(&'a [String]), + Replace(&'a str, u64, u32, &'a [u8]), Stats, StatsItems, StatsSlabs, @@ -648,32 +654,34 @@ enum Command<'a> { impl<'a> From> for Vec { fn from(value: Command<'a>) -> Self { - let buf = match value { + match value { + Command::Add(key, flags, ttl, data) => storage_command("add", key, flags, ttl, data), Command::CrawlerMetadump => "lru_crawler metadump hash\r\n".to_owned().into_bytes(), Command::Delete(key) => format!("delete {}\r\n", key).into_bytes(), Command::Gets(keys) => format!("gets {}\r\n", keys.join(" ")).into_bytes(), + Command::Replace(key, flags, ttl, data) => storage_command("replace", key, flags, ttl, data), Command::Stats => "stats\r\n".to_owned().into_bytes(), Command::StatsItems => "stats items\r\n".to_owned().into_bytes(), Command::StatsSlabs => "stats slabs\r\n".to_owned().into_bytes(), - Command::Set(key, flags, ttl, data) => { - let mut set = Vec::with_capacity(key.len() + data.len() + 32); - io::Write::write_all( - &mut set, - format!("set {} {} {} {}\r\n", key, flags, ttl, data.len()).as_bytes(), - ) - .unwrap(); - io::Write::write_all(&mut set, data).unwrap(); - io::Write::write_all(&mut set, "\r\n".as_bytes()).unwrap(); - set - } + Command::Set(key, flags, ttl, data) => storage_command("set", key, flags, ttl, data), Command::Touch(key, ttl) => format!("touch {} {}\r\n", key, ttl).into_bytes(), Command::Version => "version\r\n".to_owned().into_bytes(), - }; - - buf + } } } +fn storage_command(verb: &str, key: &str, flags: u64, ttl: u32, data: &[u8]) -> Vec { + let mut bytes = Vec::with_capacity(key.len() + data.len() + 32); + io::Write::write_all( + &mut bytes, + format!("{} {} {} {} {}\r\n", verb, key, flags, ttl, data.len()).as_bytes(), + ) + .unwrap(); + io::Write::write_all(&mut bytes, data).unwrap(); + io::Write::write_all(&mut bytes, "\r\n".as_bytes()).unwrap(); + bytes +} + pub struct Memcached { read: Lines>>, write: BufWriter>, @@ -829,6 +837,10 @@ impl Memcached { return Err(MtopError::internal("missing required keys")); } + if !validate_keys(keys) { + return Err(MtopError::internal("invalid keys")); + } + self.send(Command::Gets(keys)).await?; let mut out = HashMap::with_capacity(keys.len()); @@ -894,8 +906,53 @@ impl Memcached { } /// Store the provided item in the cache, regardless of whether it already exists. - pub async fn set(&mut self, key: String, flags: u64, ttl: u32, data: Vec) -> Result<(), MtopError> { - self.send(Command::Set(&key, flags, ttl, &data)).await?; + pub async fn set(&mut self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError> + where + K: AsRef, + V: AsRef<[u8]>, + { + if !validate_key(key.as_ref()) { + return Err(MtopError::internal("invalid key")); + } + + self.send(Command::Set(key.as_ref(), flags, ttl, data.as_ref())).await?; + if let Some(v) = self.read.next_line().await? { + Self::parse_simple_response(&v, "STORED") + } else { + Err(MtopError::internal("unexpected empty response")) + } + } + + /// Store the provided item in the cache only if it does not already exist. + pub async fn add(&mut self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError> + where + K: AsRef, + V: AsRef<[u8]>, + { + if !validate_key(key.as_ref()) { + return Err(MtopError::internal("invalid key")); + } + + self.send(Command::Add(key.as_ref(), flags, ttl, data.as_ref())).await?; + if let Some(v) = self.read.next_line().await? { + Self::parse_simple_response(&v, "STORED") + } else { + Err(MtopError::internal("unexpected empty response")) + } + } + + /// Store the provided item in the cache only if it already exists. + pub async fn replace(&mut self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError> + where + K: AsRef, + V: AsRef<[u8]>, + { + if !validate_key(key.as_ref()) { + return Err(MtopError::internal("invalid key")); + } + + self.send(Command::Replace(key.as_ref(), flags, ttl, data.as_ref())) + .await?; if let Some(v) = self.read.next_line().await? { Self::parse_simple_response(&v, "STORED") } else { @@ -904,8 +961,15 @@ impl Memcached { } /// Update the TTL of an item in the cache if it exists, return an error otherwise. - pub async fn touch(&mut self, key: String, ttl: u32) -> Result<(), MtopError> { - self.send(Command::Touch(&key, ttl)).await?; + pub async fn touch(&mut self, key: K, ttl: u32) -> Result<(), MtopError> + where + K: AsRef, + { + if !validate_key(key.as_ref()) { + return Err(MtopError::internal("invalid key")); + } + + self.send(Command::Touch(key.as_ref(), ttl)).await?; if let Some(v) = self.read.next_line().await? { Self::parse_simple_response(&v, "TOUCHED") } else { @@ -914,8 +978,15 @@ impl Memcached { } /// Delete an item in the cache if it exists, return an error otherwise. - pub async fn delete(&mut self, key: String) -> Result<(), MtopError> { - self.send(Command::Delete(&key)).await?; + pub async fn delete(&mut self, key: K) -> Result<(), MtopError> + where + K: AsRef, + { + if !validate_key(key.as_ref()) { + return Err(MtopError::internal("invalid key")); + } + + self.send(Command::Delete(key.as_ref())).await?; if let Some(v) = self.read.next_line().await? { Self::parse_simple_response(&v, "DELETED") } else { @@ -942,6 +1013,7 @@ impl Memcached { (Some("ERROR"), None) => (ProtocolErrorKind::Syntax, None), (Some("ERROR"), Some(msg)) => (ProtocolErrorKind::Syntax, Some(msg.to_owned())), (Some("NOT_FOUND"), None) => (ProtocolErrorKind::NotFound, None), + (Some("NOT_STORED"), None) => (ProtocolErrorKind::NotStored, None), (Some("SERVER_ERROR"), Some(msg)) => (ProtocolErrorKind::Server, Some(msg.to_owned())), _ => return None, @@ -963,10 +1035,89 @@ impl fmt::Debug for Memcached { } } +/// Return true if the key is legal to use with Memcached, false otherwise +fn validate_key(key: &str) -> bool { + if key.len() > MAX_KEY_LENGTH { + return false; + } + + for c in key.chars() { + if !c.is_ascii() || c.is_ascii_whitespace() || c.is_ascii_control() { + return false; + } + } + + true +} + +/// Return true if all keys are legal to use with Memcached, false otherwise +fn validate_keys(keys: &[String]) -> bool { + for key in keys { + if !validate_key(key) { + return false; + } + } + + true +} + #[cfg(test)] mod test { - use super::{ErrorKind, Memcached, Meta, Slab, SlabItem, SlabItems}; - use std::io::Cursor; + use super::{validate_key, ErrorKind, Memcached, Meta, Slab, SlabItem, SlabItems, MAX_KEY_LENGTH}; + use std::io::{Cursor, Error}; + use std::pin::Pin; + use std::task::{Context, Poll}; + use tokio::io::AsyncWrite; + use tokio::sync::mpsc::{self, UnboundedSender}; + + #[test] + fn test_validate_key_length() { + let key = "abc".repeat(MAX_KEY_LENGTH); + assert!(!validate_key(&key)); + } + + #[test] + fn test_validate_key_non_ascii() { + let key = "🤦"; + assert!(!validate_key(key)); + } + + #[test] + fn test_validate_key_whitespace() { + let key = "some thing"; + assert!(!validate_key(key)) + } + + #[test] + fn test_validate_key_control_char() { + let key = "\x7F"; + assert!(!validate_key(key)); + } + + #[test] + fn test_validate_key_success() { + let key = "a-long-but-reasonable-key"; + assert!(validate_key(key)); + } + + struct WriteAdapter { + tx: UnboundedSender>, + } + + impl AsyncWrite for WriteAdapter { + fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + self.tx.send(buf.to_owned()).unwrap(); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } /// Create a new `Memcached` instance to read the provided server response. macro_rules! client { @@ -981,8 +1132,19 @@ mod test { }) } + /// Create a new receiver channel and `Memcached` instance to read the provided server + /// response. Anything written by the client is able to be read from the receiver channel. + macro_rules! client_rw { + ($($line:expr),+ $(,)?) => ({ + let (tx, rx) = mpsc::unbounded_channel(); + let mut reads = Vec::new(); + $(reads.extend_from_slice($line.as_bytes());)+ + (rx, Memcached::new(Cursor::new(reads), WriteAdapter { tx })) + }) + } + #[tokio::test] - async fn test_get_no_keys() { + async fn test_get_no_key() { let mut client = client!(); let res = client.get(&[]).await; @@ -991,6 +1153,16 @@ mod test { assert_eq!(ErrorKind::Internal, err.kind()); } + #[tokio::test] + async fn test_get_bad_key() { + let mut client = client!(); + + let res = client.get(&["bad key".repeat(MAX_KEY_LENGTH)]).await; + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(ErrorKind::Internal, err.kind()); + } + #[tokio::test] async fn test_get_error() { let mut client = client!("SERVER_ERROR backend failure\r\n"); @@ -1036,6 +1208,159 @@ mod test { assert_eq!(2, val2.cas); } + macro_rules! test_store_command_success { + ($method:ident, $verb:expr) => { + let (mut rx, mut client) = client_rw!("STORED\r\n"); + let res = client.$method("test", 0, 300, "val".as_bytes()).await; + + assert!(res.is_ok()); + let bytes = rx.recv().await.unwrap(); + let command = String::from_utf8(bytes).unwrap(); + assert_eq!(concat!($verb, " test 0 300 3\r\nval\r\n"), command); + }; + } + + macro_rules! test_store_command_bad_key { + ($method:ident) => { + let mut client = client!(); + let res = client.$method("bad key", 0, 300, "val".as_bytes()).await; + + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(ErrorKind::Internal, err.kind()); + }; + } + + macro_rules! test_store_command_error { + ($method:ident, $verb:expr) => { + let (mut rx, mut client) = client_rw!("NOT_STORED\r\n"); + let res = client.$method("test", 0, 300, "val".as_bytes()).await; + + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(ErrorKind::Protocol, err.kind()); + + let bytes = rx.recv().await.unwrap(); + let command = String::from_utf8(bytes).unwrap(); + assert_eq!(concat!($verb, " test 0 300 3\r\nval\r\n"), command); + }; + } + + #[tokio::test] + async fn test_set_success() { + test_store_command_success!(set, "set"); + } + + #[tokio::test] + async fn test_set_bad_key() { + test_store_command_bad_key!(set); + } + + #[tokio::test] + async fn test_set_error() { + test_store_command_error!(set, "set"); + } + + #[tokio::test] + async fn test_add_success() { + test_store_command_success!(add, "add"); + } + + #[tokio::test] + async fn test_add_bad_key() { + test_store_command_bad_key!(add); + } + + #[tokio::test] + async fn test_add_error() { + test_store_command_error!(add, "add"); + } + + #[tokio::test] + async fn test_replace_success() { + test_store_command_success!(replace, "replace"); + } + + #[tokio::test] + async fn test_replace_bad_key() { + test_store_command_bad_key!(replace); + } + + #[tokio::test] + async fn test_replace_error() { + test_store_command_error!(replace, "replace"); + } + + #[tokio::test] + async fn test_touch_success() { + let (mut rx, mut client) = client_rw!("TOUCHED\r\n"); + let res = client.touch("test", 300).await; + + assert!(res.is_ok()); + let bytes = rx.recv().await.unwrap(); + let command = String::from_utf8(bytes).unwrap(); + assert_eq!("touch test 300\r\n", command); + } + + #[tokio::test] + async fn test_touch_bad_key() { + let mut client = client!(); + let res = client.touch("bad key", 300).await; + + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(ErrorKind::Internal, err.kind()); + } + + #[tokio::test] + async fn test_touch_error() { + let (mut rx, mut client) = client_rw!("NOT_FOUND\r\n"); + let res = client.touch("test", 300).await; + + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(ErrorKind::Protocol, err.kind()); + + let bytes = rx.recv().await.unwrap(); + let command = String::from_utf8(bytes).unwrap(); + assert_eq!("touch test 300\r\n", command); + } + + #[tokio::test] + async fn test_delete_success() { + let (mut rx, mut client) = client_rw!("DELETED\r\n"); + let res = client.delete("test").await; + + assert!(res.is_ok()); + let bytes = rx.recv().await.unwrap(); + let command = String::from_utf8(bytes).unwrap(); + assert_eq!("delete test\r\n", command); + } + + #[tokio::test] + async fn test_delete_bad_key() { + let mut client = client!(); + let res = client.delete("bad key").await; + + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(ErrorKind::Internal, err.kind()); + } + + #[tokio::test] + async fn test_delete_error() { + let (mut rx, mut client) = client_rw!("NOT_FOUND\r\n"); + let res = client.delete("test").await; + + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(ErrorKind::Protocol, err.kind()); + + let bytes = rx.recv().await.unwrap(); + let command = String::from_utf8(bytes).unwrap(); + assert_eq!("delete test\r\n", command); + } + #[tokio::test] async fn test_stats_empty() { let mut client = client!("END\r\n"); diff --git a/mtop/src/bin/mc.rs b/mtop/src/bin/mc.rs index cdc4ca9..e4bcc39 100644 --- a/mtop/src/bin/mc.rs +++ b/mtop/src/bin/mc.rs @@ -54,14 +54,33 @@ struct McConfig { #[derive(Debug, Subcommand)] enum Action { + Add(AddCommand), Delete(DeleteCommand), Get(GetCommand), Keys(KeysCommand), + Replace(ReplaceCommand), Set(SetCommand), Touch(TouchCommand), Check(CheckCommand), } +/// Add a value to the cache only if it does not already exist. +/// +/// The value will be read from standard input. You can use shell pipes or redirects to set +/// the contents of files as values. +#[derive(Debug, Args)] +struct AddCommand { + /// Key of the item to add the value for. + #[arg(required = true)] + key: String, + + /// TTL to set for the item, in seconds. If the TTL is longer than the number of seconds + /// in 30 days, it will be treated as a UNIX timestamp, setting the item to expire at a + /// particular date/time. + #[arg(required = true)] + ttl: u32, +} + /// Run health checks against the cache. #[derive(Debug, Args)] struct CheckCommand { @@ -106,6 +125,23 @@ struct KeysCommand { details: bool, } +/// Replace a value in the cache only if it already exists. +/// +/// The value will be read from standard input. You can use shell pipes or redirects to set +/// the contents of files as values. +#[derive(Debug, Args)] +struct ReplaceCommand { + /// Key of the item to replace the value for. + #[arg(required = true)] + key: String, + + /// TTL to set for the item, in seconds. If the TTL is longer than the number of seconds + /// in 30 days, it will be treated as a UNIX timestamp, setting the item to expire at a + /// particular date/time. + #[arg(required = true)] + ttl: u32, +} + /// Set a value in the cache. /// /// The value will be read from standard input. You can use shell pipes or redirects to set @@ -170,6 +206,17 @@ async fn main() -> Result<(), Box> { }); match opts.mode { + Action::Add(c) => { + let buf = read_input().await.unwrap_or_else(|e| { + tracing::error!(message = "unable to read item data from stdin", error = %e); + process::exit(1); + }); + + if let Err(e) = client.add(&c.key, 0, c.ttl, &buf).await { + tracing::error!(message = "unable to add item", key = c.key, host = opts.host, error = %e); + process::exit(1); + } + } Action::Check(c) => { let checker = Checker::new( &pool, @@ -182,7 +229,7 @@ async fn main() -> Result<(), Box> { } } Action::Delete(c) => { - if let Err(e) = client.delete(c.key.clone()).await { + if let Err(e) = client.delete(&c.key).await { tracing::error!(message = "unable to delete item", key = c.key, host = opts.host, error = %e); process::exit(1); } @@ -210,19 +257,30 @@ async fn main() -> Result<(), Box> { tracing::warn!(message = "error writing output", error = %e); } } + Action::Replace(c) => { + let buf = read_input().await.unwrap_or_else(|e| { + tracing::error!(message = "unable to read item data from stdin", error = %e); + process::exit(1); + }); + + if let Err(e) = client.replace(&c.key, 0, c.ttl, &buf).await { + tracing::error!(message = "unable to replace item", key = c.key, host = opts.host, error = %e); + process::exit(1); + } + } Action::Set(c) => { let buf = read_input().await.unwrap_or_else(|e| { tracing::error!(message = "unable to read item data from stdin", error = %e); process::exit(1); }); - if let Err(e) = client.set(c.key.clone(), 0, c.ttl, buf).await { + if let Err(e) = client.set(&c.key, 0, c.ttl, &buf).await { tracing::error!(message = "unable to set item", key = c.key, host = opts.host, error = %e); process::exit(1); } } Action::Touch(c) => { - if let Err(e) = client.touch(c.key.clone(), c.ttl).await { + if let Err(e) = client.touch(&c.key, c.ttl).await { tracing::error!(message = "unable to touch item", key = c.key, host = opts.host, error = %e); process::exit(1); }