diff --git a/Cargo.lock b/Cargo.lock index d628d5a2..1ebca561 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,31 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "actix" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cba56612922b907719d4a01cf11c8d5b458e7d3dba946d0435f20f58d6795ed2" -dependencies = [ - "actix-macros", - "actix-rt", - "actix_derive", - "bitflags 2.4.1", - "bytes", - "crossbeam-channel", - "futures-core", - "futures-sink", - "futures-task", - "futures-util", - "log", - "once_cell", - "parking_lot", - "pin-project-lite", - "smallvec", - "tokio", - "tokio-util", -] - [[package]] name = "actix-codec" version = "0.5.1" @@ -309,17 +284,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "actix_derive" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c7db3d5a9718568e4cf4a537cfd7070e6e6ff7481510d0237fb529ac850f6d3" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.48", -] - [[package]] name = "addr2line" version = "0.21.0" @@ -1021,15 +985,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "crossbeam-channel" -version = "0.5.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "176dc175b78f56c0f321911d9c8eb2b77a78a4860b9c19db83835fea1a46649b" -dependencies = [ - "crossbeam-utils", -] - [[package]] name = "crossbeam-deque" version = "0.8.5" @@ -1803,6 +1758,26 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "governor" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b" +dependencies = [ + "cfg-if", + "dashmap", + "futures", + "futures-timer", + "no-std-compat", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand", + "smallvec", + "spinning_top", +] + [[package]] name = "h2" version = "0.3.23" @@ -2304,7 +2279,6 @@ dependencies = [ name = "labrinth" version = "2.7.0" dependencies = [ - "actix", "actix-cors", "actix-files", "actix-http", @@ -2330,6 +2304,8 @@ dependencies = [ "flate2", "futures", "futures-timer", + "futures-util", + "governor", "hex", "hmac 0.11.0", "hyper", @@ -2731,6 +2707,12 @@ dependencies = [ "libc", ] +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" + [[package]] name = "nom" version = "7.1.3" @@ -2741,6 +2723,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "num-bigint-dig" version = "0.8.4" @@ -3132,6 +3120,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + [[package]] name = "powerfmt" version = "0.2.0" @@ -3252,6 +3246,21 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "quanta" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ca0b7bac0b97248c40bb77288fc52029cf1459c0461ea1b05ee32ccf011de2c" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + [[package]] name = "quick-error" version = "2.0.1" @@ -3339,6 +3348,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "raw-cpuid" +version = "11.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d86a7c4638d42c44551f4791a20e687dbb4c3de1f33c43dd71e355cd429def1" +dependencies = [ + "bitflags 2.4.1", +] + [[package]] name = "rayon" version = "1.8.0" @@ -4204,6 +4222,15 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" diff --git a/Cargo.toml b/Cargo.toml index d2deedac..092a15e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,6 @@ name = "labrinth" path = "src/main.rs" [dependencies] -actix = "0.13.1" actix-web = "4.4.1" actix-rt = "2.9.0" actix-multipart = "0.6.1" @@ -19,12 +18,14 @@ actix-cors = "0.7.0" actix-ws = "0.2.5" actix-files = "0.6.5" actix-web-prom = "0.7.0" +governor = "0.6.3" tokio = { version = "1.35.1", features = ["sync"] } tokio-stream = "0.1.14" futures = "0.3.30" futures-timer = "3.0.2" +futures-util = "0.3.30" async-trait = "0.1.70" dashmap = "5.4.0" lazy_static = "1.4.0" diff --git a/src/lib.rs b/src/lib.rs index 05827e7e..a1d7a5b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,6 @@ +use std::num::NonZeroU32; use std::sync::Arc; +use std::time::Duration; use actix_web::web; use database::redis::RedisPool; @@ -6,12 +8,13 @@ use log::{info, warn}; use queue::{ analytics::AnalyticsQueue, payouts::PayoutsQueue, session::AuthQueue, socket::ActiveSockets, }; -use scheduler::Scheduler; use sqlx::Postgres; use tokio::sync::RwLock; extern crate clickhouse as clickhouse_crate; use clickhouse_crate::Client; +use governor::{Quota, RateLimiter}; +use governor::middleware::StateInformationMiddleware; use util::cors::default_cors; use crate::queue::moderation::AutomatedModerationQueue; @@ -20,6 +23,7 @@ use crate::{ search::indexing::index_projects, util::env::{parse_strings_from_var, parse_var}, }; +use crate::util::ratelimit::KeyedRateLimiter; pub mod auth; pub mod clickhouse; @@ -27,7 +31,6 @@ pub mod database; pub mod file_hosting; pub mod models; pub mod queue; -pub mod ratelimit; pub mod routes; pub mod scheduler; pub mod search; @@ -46,7 +49,7 @@ pub struct LabrinthConfig { pub clickhouse: Client, pub file_host: Arc, pub maxmind: Arc, - pub scheduler: Arc, + pub scheduler: Arc, pub ip_salt: Pepper, pub search_config: search::SearchConfig, pub session_queue: web::Data, @@ -54,6 +57,7 @@ pub struct LabrinthConfig { pub analytics_queue: Arc, pub active_sockets: web::Data>, pub automated_moderation_queue: web::Data, + pub rate_limiter: KeyedRateLimiter, } pub fn app_setup( @@ -82,6 +86,25 @@ pub fn app_setup( let mut scheduler = scheduler::Scheduler::new(); + let limiter: KeyedRateLimiter = Arc::new( + RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(300).unwrap())) + .with_middleware::(), + ); + let limiter_clone = Arc::clone(&limiter); + scheduler.run(Duration::from_secs(60), move || { + info!( + "Clearing ratelimiter, storage size: {}", + limiter_clone.len() + ); + limiter_clone.retain_recent(); + info!( + "Done clearing ratelimiter, storage size: {}", + limiter_clone.len() + ); + + async move {} + }); + // The interval in seconds at which the local database is indexed // for searching. Defaults to 1 hour if unset. let local_index_interval = @@ -255,6 +278,7 @@ pub fn app_setup( analytics_queue, active_sockets, automated_moderation_queue, + rate_limiter: limiter, } } diff --git a/src/main.rs b/src/main.rs index bbdf6b7f..5cd49379 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,17 @@ use actix_web::{App, HttpServer}; use actix_web_prom::PrometheusMetricsBuilder; use env_logger::Env; +use governor::middleware::StateInformationMiddleware; +use governor::{Quota, RateLimiter}; use labrinth::database::redis::RedisPool; use labrinth::file_hosting::S3Host; -use labrinth::ratelimit::errors::ARError; -use labrinth::ratelimit::memory::{MemoryStore, MemoryStoreActor}; -use labrinth::ratelimit::middleware::RateLimiter; use labrinth::search; -use labrinth::util::env::parse_var; +use labrinth::util::ratelimit::{KeyedRateLimiter, RateLimit}; use labrinth::{check_env_vars, clickhouse, database, file_hosting, queue}; use log::{error, info}; +use std::num::NonZeroU32; use std::sync::Arc; +use std::time::Duration; #[cfg(feature = "jemalloc")] #[global_allocator] @@ -90,17 +91,14 @@ async fn main() -> std::io::Result<()> { let maxmind_reader = Arc::new(queue::maxmind::MaxMindIndexer::new().await.unwrap()); - let store = MemoryStore::new(); - let prometheus = PrometheusMetricsBuilder::new("labrinth") .endpoint("/metrics") .build() .expect("Failed to create prometheus metrics middleware"); let search_config = search::SearchConfig::new(None); - info!("Starting Actix HTTP server!"); - let labrinth_config = labrinth::app_setup( + let mut labrinth_config = labrinth::app_setup( pool.clone(), redis_pool.clone(), search_config.clone(), @@ -109,32 +107,14 @@ async fn main() -> std::io::Result<()> { maxmind_reader.clone(), ); + info!("Starting Actix HTTP server!"); + // Init App HttpServer::new(move || { App::new() .wrap(prometheus.clone()) + .wrap(RateLimit(Arc::clone(&labrinth_config.rate_limiter))) .wrap(actix_web::middleware::Compress::default()) - .wrap( - RateLimiter::new(MemoryStoreActor::from(store.clone()).start()) - .with_identifier(|req| { - let connection_info = req.connection_info(); - let ip = - String::from(if parse_var("CLOUDFLARE_INTEGRATION").unwrap_or(false) { - if let Some(header) = req.headers().get("CF-Connecting-IP") { - header.to_str().map_err(|_| ARError::Identification)? - } else { - connection_info.peer_addr().ok_or(ARError::Identification)? - } - } else { - connection_info.peer_addr().ok_or(ARError::Identification)? - }); - - Ok(ip) - }) - .with_interval(std::time::Duration::from_secs(60)) - .with_max_requests(300) - .with_ignore_key(dotenvy::var("RATE_LIMIT_IGNORE_KEY").ok()), - ) .wrap(sentry_actix::Sentry::new()) .configure(|cfg| labrinth::app_config(cfg, labrinth_config.clone())) }) diff --git a/src/ratelimit/errors.rs b/src/ratelimit/errors.rs deleted file mode 100644 index f06ba48c..00000000 --- a/src/ratelimit/errors.rs +++ /dev/null @@ -1,52 +0,0 @@ -//! Errors that can occur during middleware processing stage -use crate::models::error::ApiError; -use actix_web::ResponseError; -use log::*; -use thiserror::Error; - -/// Custom error type. Useful for logging and debugging different kinds of errors. -/// This type can be converted to Actix Error, which defaults to -/// InternalServerError -/// -#[derive(Debug, Error)] -pub enum ARError { - /// Read/Write error on store - #[error("read/write operation failed: {0}")] - ReadWrite(String), - - /// Identifier error - #[error("client identification failed")] - Identification, - /// Limited Error - #[error("You are being rate-limited. Please wait {reset} seconds. {remaining}/{max_requests} remaining.")] - Limited { - max_requests: usize, - remaining: usize, - reset: u64, - }, -} - -impl ResponseError for ARError { - fn error_response(&self) -> actix_web::HttpResponse { - match self { - Self::Limited { - max_requests, - remaining, - reset, - } => { - let mut response = actix_web::HttpResponse::TooManyRequests(); - response.insert_header(("x-ratelimit-limit", max_requests.to_string())); - response.insert_header(("x-ratelimit-remaining", remaining.to_string())); - response.insert_header(("x-ratelimit-reset", reset.to_string())); - response.json(ApiError { - error: "ratelimit_error", - description: self.to_string(), - }) - } - _ => actix_web::HttpResponse::build(self.status_code()).json(ApiError { - error: "ratelimit_error", - description: self.to_string(), - }), - } - } -} diff --git a/src/ratelimit/memory.rs b/src/ratelimit/memory.rs deleted file mode 100644 index 14280167..00000000 --- a/src/ratelimit/memory.rs +++ /dev/null @@ -1,143 +0,0 @@ -//! In memory store for rate limiting -use actix::prelude::*; -use dashmap::DashMap; -use futures::future::{self}; -use log::*; -use std::sync::Arc; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; - -use crate::ratelimit::errors::ARError; -use crate::ratelimit::{ActorMessage, ActorResponse}; - -/// Type used to create a concurrent hashmap store -#[derive(Clone)] -pub struct MemoryStore { - inner: Arc>, -} - -impl Default for MemoryStore { - fn default() -> Self { - Self::new() - } -} - -impl MemoryStore { - /// Create a new hashmap - /// - /// # Example - /// ```rust - /// use labrinth::ratelimit::memory::MemoryStore; - /// - /// let store = MemoryStore::new(); - /// ``` - pub fn new() -> Self { - debug!("Creating new MemoryStore"); - MemoryStore { - inner: Arc::new(DashMap::::new()), - } - } -} - -/// Actor for memory store -pub struct MemoryStoreActor { - inner: Arc>, -} - -impl From for MemoryStoreActor { - fn from(store: MemoryStore) -> Self { - MemoryStoreActor { inner: store.inner } - } -} - -impl MemoryStoreActor { - /// Starts the memory actor and returns it's address - pub fn start(self) -> Addr { - debug!("Started memory store"); - Supervisor::start(|_| self) - } -} - -impl Actor for MemoryStoreActor { - type Context = Context; -} - -impl Supervised for MemoryStoreActor { - fn restarting(&mut self, _: &mut Self::Context) { - debug!("Restarting memory store"); - } -} - -impl Handler for MemoryStoreActor { - type Result = ActorResponse; - fn handle(&mut self, msg: ActorMessage, ctx: &mut Self::Context) -> Self::Result { - match msg { - ActorMessage::Set { key, value, expiry } => { - debug!("Inserting key {} with expiry {}", &key, &expiry.as_secs()); - let future_key = String::from(&key); - let now = SystemTime::now(); - let now = now.duration_since(UNIX_EPOCH).unwrap(); - self.inner.insert(key, (value, now + expiry)); - ctx.notify_later(ActorMessage::Remove(future_key), expiry); - ActorResponse::Set(Box::pin(future::ready(Ok(())))) - } - ActorMessage::Update { key, value } => match self.inner.get_mut(&key) { - Some(mut c) => { - let val_mut: &mut (usize, Duration) = c.value_mut(); - if val_mut.0 > value { - val_mut.0 -= value; - } else { - val_mut.0 = 0; - } - let new_val = val_mut.0; - ActorResponse::Update(Box::pin(future::ready(Ok(new_val)))) - } - None => ActorResponse::Update(Box::pin(future::ready(Err(ARError::ReadWrite( - "memory store: read failed!".to_string(), - ))))), - }, - ActorMessage::Get(key) => { - if self.inner.contains_key(&key) { - let val = match self.inner.get(&key) { - Some(c) => c, - None => { - return ActorResponse::Get(Box::pin(future::ready(Err( - ARError::ReadWrite("memory store: read failed!".to_string()), - )))) - } - }; - let val = val.value().0; - ActorResponse::Get(Box::pin(future::ready(Ok(Some(val))))) - } else { - ActorResponse::Get(Box::pin(future::ready(Ok(None)))) - } - } - ActorMessage::Expire(key) => { - let c = match self.inner.get(&key) { - Some(d) => d, - None => { - return ActorResponse::Expire(Box::pin(future::ready(Err( - ARError::ReadWrite("memory store: read failed!".to_string()), - )))) - } - }; - let dur = c.value().1; - let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); - let res = dur.checked_sub(now).unwrap_or_else(|| Duration::new(0, 0)); - ActorResponse::Expire(Box::pin(future::ready(Ok(res)))) - } - ActorMessage::Remove(key) => { - debug!("Removing key: {}", &key); - let val = match self.inner.remove::(&key) { - Some(c) => c, - None => { - return ActorResponse::Remove(Box::pin(future::ready(Err( - ARError::ReadWrite("memory store: remove failed!".to_string()), - )))) - } - }; - let val = val.1; - ActorResponse::Remove(Box::pin(future::ready(Ok(val.0)))) - } - } - } -} diff --git a/src/ratelimit/middleware.rs b/src/ratelimit/middleware.rs deleted file mode 100644 index 495dcad5..00000000 --- a/src/ratelimit/middleware.rs +++ /dev/null @@ -1,260 +0,0 @@ -use crate::ratelimit::errors::ARError; -use crate::ratelimit::{ActorMessage, ActorResponse}; -use actix::dev::*; -use actix_web::{ - dev::{Service, ServiceRequest, ServiceResponse, Transform}, - error::Error as AWError, - http::header::{HeaderName, HeaderValue}, -}; -use futures::future::{ok, Ready}; -use log::*; -use std::{ - cell::RefCell, - future::Future, - ops::Fn, - pin::Pin, - rc::Rc, - task::{Context, Poll}, - time::Duration, -}; - -type RateLimiterIdentifier = Rc Result + 'static>>; - -pub struct RateLimiter -where - T: Handler + Send + Sync + 'static, - T::Context: ToEnvelope, -{ - interval: Duration, - max_requests: usize, - store: Addr, - identifier: RateLimiterIdentifier, - ignore_key: Option, -} - -impl RateLimiter -where - T: Handler + Send + Sync + 'static, - ::Context: ToEnvelope, -{ - /// Creates a new instance of `RateLimiter` with the provided address of `StoreActor`. - pub fn new(store: Addr) -> Self { - let identifier = |req: &ServiceRequest| { - let connection_info = req.connection_info(); - let ip = connection_info.peer_addr().ok_or(ARError::Identification)?; - Ok(String::from(ip)) - }; - RateLimiter { - interval: Duration::from_secs(0), - max_requests: 0, - store, - identifier: Rc::new(Box::new(identifier)), - ignore_key: None, - } - } - - /// Specify the interval. The counter for a client is reset after this interval - pub fn with_interval(mut self, interval: Duration) -> Self { - self.interval = interval; - self - } - - /// Specify the maximum number of requests allowed in the given interval. - pub fn with_max_requests(mut self, max_requests: usize) -> Self { - self.max_requests = max_requests; - self - } - - /// Sets key which can be used to bypass rate-limiter - pub fn with_ignore_key(mut self, ignore_key: Option) -> Self { - self.ignore_key = ignore_key; - self - } - - /// Function to get the identifier for the client request - pub fn with_identifier Result + 'static>( - mut self, - identifier: F, - ) -> Self { - self.identifier = Rc::new(Box::new(identifier)); - self - } -} - -impl Transform for RateLimiter -where - T: Handler + Send + Sync + 'static, - T::Context: ToEnvelope, - S: Service, Error = AWError> + 'static, - S::Future: 'static, - B: 'static, -{ - type Response = ServiceResponse; - type Error = S::Error; - type Transform = RateLimitMiddleware; - type InitError = (); - type Future = Ready>; - - fn new_transform(&self, service: S) -> Self::Future { - ok(RateLimitMiddleware { - service: Rc::new(RefCell::new(service)), - store: self.store.clone(), - max_requests: self.max_requests, - interval: self.interval.as_secs(), - identifier: self.identifier.clone(), - ignore_key: self.ignore_key.clone(), - }) - } -} - -/// Service factory for RateLimiter -pub struct RateLimitMiddleware -where - S: 'static, - T: Handler + 'static, -{ - service: Rc>, - store: Addr, - // Exists here for the sole purpose of knowing the max_requests and interval from RateLimiter - max_requests: usize, - interval: u64, - identifier: RateLimiterIdentifier, - ignore_key: Option, -} - -impl Service for RateLimitMiddleware -where - T: Handler + 'static, - S: Service, Error = AWError> + 'static, - S::Future: 'static, - B: 'static, - T::Context: ToEnvelope, -{ - type Response = ServiceResponse; - type Error = S::Error; - type Future = Pin>>>; - - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.service.borrow_mut().poll_ready(cx) - } - - fn call(&self, req: ServiceRequest) -> Self::Future { - let store = self.store.clone(); - let srv = self.service.clone(); - let max_requests = self.max_requests; - let interval = Duration::from_secs(self.interval); - let identifier = self.identifier.clone(); - let ignore_key = self.ignore_key.clone(); - Box::pin(async move { - let identifier: String = (identifier)(&req)?; - - if let Some(ignore_key) = ignore_key { - if let Some(key) = req.headers().get("x-ratelimit-key") { - if key.to_str().ok().unwrap_or_default() == &*ignore_key { - let fut = srv.call(req); - let res = fut.await?; - return Ok(res); - } - } - } - - let remaining: ActorResponse = store - .send(ActorMessage::Get(String::from(&identifier))) - .await - .map_err(|_| ARError::Identification)?; - match remaining { - ActorResponse::Get(opt) => { - let opt = opt.await?; - if let Some(c) = opt { - // Existing entry in store - let expiry = store - .send(ActorMessage::Expire(String::from(&identifier))) - .await - .map_err(|_| ARError::ReadWrite("Setting timeout".to_string()))?; - let reset: Duration = match expiry { - ActorResponse::Expire(dur) => dur.await?, - _ => unreachable!(), - }; - if c == 0 { - info!("Limit exceeded for client: {}", &identifier); - Err(ARError::Limited { - max_requests, - remaining: c, - reset: reset.as_secs(), - } - .into()) - } else { - // Decrement value - let res: ActorResponse = store - .send(ActorMessage::Update { - key: identifier, - value: 1, - }) - .await - .map_err(|_| { - ARError::ReadWrite("Decrementing ratelimit".to_string()) - })?; - let updated_value: usize = match res { - ActorResponse::Update(c) => c.await?, - _ => unreachable!(), - }; - // Execute the request - let fut = srv.call(req); - let mut res = fut.await?; - let headers = res.headers_mut(); - // Safe unwraps, since usize is always convertible to string - headers.insert( - HeaderName::from_static("x-ratelimit-limit"), - HeaderValue::from_str(max_requests.to_string().as_str())?, - ); - headers.insert( - HeaderName::from_static("x-ratelimit-remaining"), - HeaderValue::from_str(updated_value.to_string().as_str())?, - ); - headers.insert( - HeaderName::from_static("x-ratelimit-reset"), - HeaderValue::from_str(reset.as_secs().to_string().as_str())?, - ); - Ok(res) - } - } else { - // New client, create entry in store - let current_value = max_requests - 1; - let res = store - .send(ActorMessage::Set { - key: String::from(&identifier), - value: current_value, - expiry: interval, - }) - .await - .map_err(|_| ARError::ReadWrite("Creating store entry".to_string()))?; - match res { - ActorResponse::Set(c) => c.await?, - _ => unreachable!(), - } - let fut = srv.call(req); - let mut res = fut.await?; - let headers = res.headers_mut(); - // Safe unwraps, since usize is always convertible to string - headers.insert( - HeaderName::from_static("x-ratelimit-limit"), - HeaderValue::from_str(max_requests.to_string().as_str()).unwrap(), - ); - headers.insert( - HeaderName::from_static("x-ratelimit-remaining"), - HeaderValue::from_str(current_value.to_string().as_str()).unwrap(), - ); - headers.insert( - HeaderName::from_static("x-ratelimit-reset"), - HeaderValue::from_str(interval.as_secs().to_string().as_str()).unwrap(), - ); - Ok(res) - } - } - _ => { - unreachable!(); - } - } - }) - } -} diff --git a/src/ratelimit/mod.rs b/src/ratelimit/mod.rs deleted file mode 100644 index 2d659c87..00000000 --- a/src/ratelimit/mod.rs +++ /dev/null @@ -1,64 +0,0 @@ -use std::future::Future; -use std::marker::Send; -use std::pin::Pin; -use std::time::Duration; - -use crate::ratelimit::errors::ARError; -use actix::dev::*; - -pub mod errors; -pub mod memory; -/// The code for this module was directly taken from https://github.com/TerminalWitchcraft/actix-ratelimit -/// with some modifications including upgrading it to Actix 4! -pub mod middleware; - -/// Represents message that can be handled by a `StoreActor` -pub enum ActorMessage { - /// Get the remaining count based on the provided identifier - Get(String), - /// Set the count of the client identified by `key` to `value` valid for `expiry` - Set { - key: String, - value: usize, - expiry: Duration, - }, - /// Change the value of count for the client identified by `key` by `value` - Update { key: String, value: usize }, - /// Get the expiration time for the client. - Expire(String), - /// Remove the client from the store - Remove(String), -} - -impl Message for ActorMessage { - type Result = ActorResponse; -} - -/// Wrapper type for `Pin>` type -pub type Output = Pin> + Send>>; - -/// Represents data returned in response to `Messages` by a `StoreActor` -pub enum ActorResponse { - /// Returned in response to [Messages::Get](enum.Messages.html) - Get(Output>), - /// Returned in response to [Messages::Set](enum.Messages.html) - Set(Output<()>), - /// Returned in response to [Messages::Update](enum.Messages.html) - Update(Output), - /// Returned in response to [Messages::Expire](enum.Messages.html) - Expire(Output), - /// Returned in response to [Messages::Remove](enum.Messages.html) - Remove(Output), -} - -impl MessageResponse for ActorResponse -where - A: Actor, - M: actix::Message, -{ - fn handle(self, _: &mut A::Context, tx: Option>) { - if let Some(tx) = tx { - let _ = tx.send(self); - } - } -} diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 1b77062f..be706988 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -129,6 +129,8 @@ pub enum ApiError { Io(#[from] std::io::Error), #[error("Resource not found")] NotFound, + #[error("You are being rate-limited. Please wait {0} milliseconds. 0/{1} remaining.")] + RateLimitError(u128, u32), } impl ApiError { @@ -160,6 +162,7 @@ impl ApiError { ApiError::NotFound => "not_found", ApiError::Zip(..) => "zip_error", ApiError::Io(..) => "io_error", + ApiError::RateLimitError(..) => "ratelimit_error", }, description: self.to_string(), } @@ -194,6 +197,7 @@ impl actix_web::ResponseError for ApiError { ApiError::NotFound => StatusCode::NOT_FOUND, ApiError::Zip(..) => StatusCode::BAD_REQUEST, ApiError::Io(..) => StatusCode::BAD_REQUEST, + ApiError::RateLimitError(..) => StatusCode::TOO_MANY_REQUESTS, } } diff --git a/src/scheduler.rs b/src/scheduler.rs index 68cb593b..63487882 100644 --- a/src/scheduler.rs +++ b/src/scheduler.rs @@ -19,9 +19,9 @@ impl Scheduler { } pub fn run(&mut self, interval: std::time::Duration, mut task: F) - where - F: FnMut() -> R + Send + 'static, - R: std::future::Future + Send + 'static, + where + F: FnMut() -> R + Send + 'static, + R: std::future::Future + Send + 'static, { let future = IntervalStream::new(actix_rt::time::interval(interval)) .for_each_concurrent(2, move |_| task()); @@ -207,4 +207,4 @@ async fn update_versions( } Ok(()) -} +} \ No newline at end of file diff --git a/src/util/mod.rs b/src/util/mod.rs index 03512d3e..b7271c70 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -7,6 +7,7 @@ pub mod env; pub mod ext; pub mod guards; pub mod img; +pub mod ratelimit; pub mod redis; pub mod routes; pub mod validate; diff --git a/src/util/ratelimit.rs b/src/util/ratelimit.rs new file mode 100644 index 00000000..74c7cf5b --- /dev/null +++ b/src/util/ratelimit.rs @@ -0,0 +1,167 @@ +use governor::clock::{Clock, DefaultClock}; +use governor::{middleware, state, RateLimiter}; +use std::str::FromStr; +use std::sync::Arc; + +use crate::routes::ApiError; +use crate::util::env::parse_var; +use actix_web::{ + body::EitherBody, + dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, + Error, ResponseError, +}; +use futures_util::future::LocalBoxFuture; +use futures_util::future::{ready, Ready}; + +pub type KeyedRateLimiter = + Arc, DefaultClock, MW>>; + +pub struct RateLimit(pub KeyedRateLimiter); + +impl Transform for RateLimit +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Response = ServiceResponse>; + type Error = Error; + type Transform = RateLimitService; + type InitError = (); + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ready(Ok(RateLimitService { + service, + rate_limiter: Arc::clone(&self.0), + })) + } +} + +#[doc(hidden)] +pub struct RateLimitService { + service: S, + rate_limiter: KeyedRateLimiter, +} + +impl Service for RateLimitService +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Response = ServiceResponse>; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + forward_ready!(service); + + fn call(&self, req: ServiceRequest) -> Self::Future { + if let Some(key) = req.headers().get("x-ratelimit-key") { + if key.to_str().ok() == dotenvy::var("RATE_LIMIT_IGNORE_KEY").ok().as_deref() { + let res = self.service.call(req); + + return Box::pin(async move { + let service_response = res.await?; + Ok(service_response.map_into_left_body()) + }); + } + } + + let conn_info = req.connection_info().clone(); + let ip = if parse_var("CLOUDFLARE_INTEGRATION").unwrap_or(false) { + if let Some(header) = req.headers().get("CF-Connecting-IP") { + header.to_str().ok() + } else { + conn_info.peer_addr() + } + } else { + conn_info.peer_addr() + }; + + if let Some(ip) = ip { + let ip = ip.to_string(); + + match self.rate_limiter.check_key(&ip) { + Ok(snapshot) => { + let fut = self.service.call(req); + + Box::pin(async move { + match fut.await { + Ok(mut service_response) => { + // Now you have a mutable reference to the ServiceResponse, so you can modify its headers. + let headers = service_response.headers_mut(); + headers.insert( + actix_web::http::header::HeaderName::from_str( + "x-ratelimit-limit", + ) + .unwrap(), + snapshot.quota().burst_size().get().into(), + ); + headers.insert( + actix_web::http::header::HeaderName::from_str( + "x-ratelimit-remaining", + ) + .unwrap(), + snapshot.remaining_burst_capacity().into(), + ); + + headers.insert( + actix_web::http::header::HeaderName::from_str( + "x-ratelimit-reset", + ) + .unwrap(), + snapshot + .quota() + .burst_size_replenished_in() + .as_secs() + .into(), + ); + + // Return the modified response as Ok. + Ok(service_response.map_into_left_body()) + } + Err(e) => { + // Handle error case + Err(e) + } + } + }) + } + Err(negative) => { + let wait_time = negative.wait_time_from(DefaultClock::default().now()); + + let mut response = ApiError::RateLimitError( + wait_time.as_millis(), + negative.quota().burst_size().get(), + ) + .error_response(); + + let headers = response.headers_mut(); + + headers.insert( + actix_web::http::header::HeaderName::from_str("x-ratelimit-limit").unwrap(), + negative.quota().burst_size().get().into(), + ); + headers.insert( + actix_web::http::header::HeaderName::from_str("x-ratelimit-remaining") + .unwrap(), + 0.into(), + ); + headers.insert( + actix_web::http::header::HeaderName::from_str("x-ratelimit-reset").unwrap(), + wait_time.as_secs().into(), + ); + + Box::pin(async { Ok(req.into_response(response.map_into_right_body())) }) + } + } + } else { + let response = + ApiError::CustomAuthentication("Unable to obtain user IP address!".to_string()) + .error_response(); + + Box::pin(async { Ok(req.into_response(response.map_into_right_body())) }) + } + } +}