diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3757ded0..93e6a72d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -4,6 +4,8 @@ on: push: branches: [master] pull_request: + branches: + - master env: CARGO_TERM_COLOR: always @@ -21,23 +23,34 @@ jobs: steps: - uses: actions/checkout@v2 + + # Start Docker Compose + - name: Start Docker Compose + run: docker-compose up -d + - uses: actions-rs/toolchain@v1 name: Install toolchain with: profile: minimal toolchain: ${{ matrix.rust }} override: true - - name: Cache build artifacts - id: cache-build + + # Cache dependencies and build artifacts + - name: Cache build artifacts and dependencies uses: actions/cache@v2 with: - path: target/** - key: ${{ runner.os }}-build-cache-${{ matrix.rust }} + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + - uses: actions-rs/cargo@v1 with: command: build env: SQLX_OFFLINE: true + - uses: actions-rs/cargo@v1 with: command: test @@ -50,4 +63,5 @@ jobs: S3_URL: ${{ secrets.S3_URL }} S3_REGION: ${{ secrets.S3_REGION }} S3_BUCKET_NAME: ${{ secrets.S3_BUCKET_NAME }} - SQLX_OFFLINE: true \ No newline at end of file + SQLX_OFFLINE: true + DATABASE_URL: postgresql://labrinth:labrinth@localhost/postgres diff --git a/Cargo.lock b/Cargo.lock index cb780278..030e2fc7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,9 +83,9 @@ dependencies = [ [[package]] name = "actix-http" -version = "3.3.1" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2079246596c18b4a33e274ae10c0e50613f4d32a4198e09c7b93771013fed74" +checksum = "a92ef85799cba03f76e4f7c10f533e66d87c9a7e7055f3391f09000ad8351bc9" dependencies = [ "actix-codec", "actix-rt", @@ -93,7 +93,7 @@ dependencies = [ "actix-utils", "ahash 0.8.3", "base64 0.21.2", - "bitflags 1.3.2", + "bitflags 2.4.0", "brotli", "bytes", "bytestring", @@ -597,9 +597,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.3.3" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" +checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" [[package]] name = "bitvec" @@ -2230,6 +2230,7 @@ dependencies = [ "actix", "actix-cors", "actix-files", + "actix-http", "actix-multipart", "actix-rt", "actix-web", @@ -3596,7 +3597,7 @@ version = "0.38.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac5ffa1efe7548069688cd7028f32591853cd7b5b756d41bcffd2353e4fc75b4" dependencies = [ - "bitflags 2.3.3", + "bitflags 2.4.0", "errno", "libc", "linux-raw-sys 0.4.3", diff --git a/Cargo.toml b/Cargo.toml index 755adc38..8e464ad5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -91,4 +91,7 @@ color-thief = "0.2.2" woothee = "0.13.0" -lettre = "0.10.4" \ No newline at end of file +lettre = "0.10.4" + +[dev-dependencies] +actix-http = "3.4.0" diff --git a/sqlx-data.json b/sqlx-data.json index 46c3d051..eabec5b7 100644 --- a/sqlx-data.json +++ b/sqlx-data.json @@ -1,5 +1,26 @@ { "db": "PostgreSQL", + "009bce5eee6ed65d9dc0899a4e24da528507a3f00b7ec997fa9ccdd7599655b1": { + "describe": { + "columns": [ + { + "name": "id", + "ordinal": 0, + "type_info": "Int8" + } + ], + "nullable": [ + false + ], + "parameters": { + "Left": [ + "Int8", + "Text" + ] + } + }, + "query": "\n SELECT m.id FROM organizations o\n INNER JOIN mods m ON m.organization_id = o.id\n WHERE (o.id = $1 AND $1 IS NOT NULL) OR (o.title = $2 AND $2 IS NOT NULL)\n " + }, "010cafcafb6adc25b00e3c81d844736b0245e752a90334c58209d8a02536c800": { "describe": { "columns": [], @@ -3669,6 +3690,19 @@ }, "query": "\n SELECT n.id FROM notifications n\n WHERE n.user_id = $1\n " }, + "7b6b76f383adcbe2afbd2a2e87e66fd2a0d9d05b68b27823c1395e7cc3b8c0a2": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Varchar", + "Int8" + ] + } + }, + "query": "\n UPDATE collections\n SET status = $1\n WHERE (id = $2)\n " + }, "7c0cdacf0898155c94008a96a0b918550df4475b9e3362a926d4d00e001880c1": { "describe": { "columns": [ @@ -3821,19 +3855,6 @@ }, "query": "\n SELECT name FROM side_types\n " }, - "86049f204c9eda5241403d22b5f8ffe13b258ddfffb81a1a9ee8602e21c64723": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Varchar", - "Int8" - ] - } - }, - "query": "\n UPDATE collections\n SET status = $1\n WHERE (id = $2)\n " - }, "868ee76d507cc9e94cd3c2e44770faff127e2b3c5f49b8100a9a37ac4d7b1f1d": { "describe": { "columns": [], @@ -6131,27 +6152,6 @@ }, "query": "\n UPDATE versions\n SET featured = $1\n WHERE (id = $2)\n " }, - "e60561aeefbc2bed1f77ff4bbca763b5be84bd6bc3eff75ca57e3590be286d45": { - "describe": { - "columns": [ - { - "name": "id", - "ordinal": 0, - "type_info": "Int8" - } - ], - "nullable": [ - false - ], - "parameters": { - "Left": [ - "Int8", - "Text" - ] - } - }, - "query": "\n SELECT m.id FROM organizations o\n LEFT JOIN mods m ON m.id = o.id\n WHERE (o.id = $1 AND $1 IS NOT NULL) OR (o.title = $2 AND $2 IS NOT NULL)\n " - }, "e60ea75112db37d3e73812e21b1907716e4762e06aa883af878e3be82e3f87d3": { "describe": { "columns": [ diff --git a/src/auth/flows.rs b/src/auth/flows.rs index 8b13524b..03771312 100644 --- a/src/auth/flows.rs +++ b/src/auth/flows.rs @@ -3,16 +3,17 @@ use crate::auth::session::issue_session; use crate::auth::validate::get_user_record_from_bearer_token; use crate::auth::{get_user_from_headers, AuthenticationError}; use crate::database::models::flow_item::Flow; +use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models::ids::base62_impl::{parse_base62, to_base62}; use crate::models::ids::random_base62_rng; use crate::models::pats::Scopes; use crate::models::users::{Badges, Role}; -use crate::parse_strings_from_var; use crate::queue::session::AuthQueue; use crate::queue::socket::ActiveSockets; use crate::routes::ApiError; use crate::util::captcha::check_turnstile_captcha; +use crate::util::env::parse_strings_from_var; use crate::util::ext::{get_image_content_type, get_image_ext}; use crate::util::validate::{validation_errors_to_string, RE_URL_SAFE}; use actix_web::web::{scope, Data, Payload, Query, ServiceConfig}; @@ -54,7 +55,7 @@ pub fn config(cfg: &mut ServiceConfig) { ); } -#[derive(Serialize, Deserialize, Default, Eq, PartialEq, Clone, Copy)] +#[derive(Serialize, Deserialize, Default, Eq, PartialEq, Clone, Copy, Debug)] #[serde(rename_all = "lowercase")] pub enum AuthProvider { #[default] @@ -84,7 +85,7 @@ impl TempUser { transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, client: &PgPool, file_host: &Arc, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result { if let Some(email) = &self.email { if crate::database::models::User::get_email(email, client) @@ -907,7 +908,7 @@ pub async fn init( req: HttpRequest, Query(info): Query, // callback url client: Data, - redis: Data, + redis: Data, session_queue: Data, ) -> Result { let url = url::Url::parse(&info.url).map_err(|_| AuthenticationError::Url)?; @@ -959,7 +960,7 @@ pub async fn ws_init( Query(info): Query, body: Payload, db: Data>, - redis: Data, + redis: Data, ) -> Result { let (res, session, _msg_stream) = actix_ws::handle(&req, body)?; @@ -967,7 +968,7 @@ pub async fn ws_init( mut ws_stream: actix_ws::Session, info: WsInit, db: Data>, - redis: Data, + redis: Data, ) -> Result<(), Closed> { let flow = Flow::OAuth { user_id: None, @@ -1003,7 +1004,7 @@ pub async fn auth_callback( active_sockets: Data>, client: Data, file_host: Data>, - redis: Data, + redis: Data, ) -> Result { let state_string = query .get("state") @@ -1210,7 +1211,7 @@ pub struct DeleteAuthProvider { pub async fn delete_auth_provider( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, delete_provider: web::Json, session_queue: Data, ) -> Result { @@ -1297,7 +1298,7 @@ pub struct NewAccount { pub async fn create_account_with_password( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, new_account: web::Json, ) -> Result { new_account @@ -1414,7 +1415,7 @@ pub struct Login { pub async fn login_password( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, login: web::Json, ) -> Result { if !check_turnstile_captcha(&req, &login.challenge).await? { @@ -1478,7 +1479,7 @@ async fn validate_2fa_code( secret: String, allow_backup: bool, user_id: crate::database::models::UserId, - redis: &deadpool_redis::Pool, + redis: &RedisPool, pool: &PgPool, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, ) -> Result { @@ -1530,7 +1531,7 @@ async fn validate_2fa_code( pub async fn login_2fa( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, login: web::Json, ) -> Result { let flow = Flow::get(&login.flow, &redis) @@ -1577,7 +1578,7 @@ pub async fn login_2fa( pub async fn begin_2fa_flow( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, session_queue: Data, ) -> Result { let user = get_user_from_headers( @@ -1616,7 +1617,7 @@ pub async fn begin_2fa_flow( pub async fn finish_2fa_flow( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, login: web::Json, session_queue: Data, ) -> Result { @@ -1739,7 +1740,7 @@ pub struct Remove2FA { pub async fn remove_2fa( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, login: web::Json, session_queue: Data, ) -> Result { @@ -1821,7 +1822,7 @@ pub struct ResetPassword { pub async fn reset_password_begin( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, reset_password: web::Json, ) -> Result { if !check_turnstile_captcha(&req, &reset_password.challenge).await? { @@ -1866,7 +1867,7 @@ pub struct ChangePassword { pub async fn change_password( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, change_password: web::Json, session_queue: Data, ) -> Result { @@ -2007,7 +2008,7 @@ pub struct SetEmail { pub async fn set_email( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, email: web::Json, session_queue: Data, ) -> Result { @@ -2073,7 +2074,7 @@ pub async fn set_email( pub async fn resend_verify_email( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, session_queue: Data, ) -> Result { let user = get_user_from_headers( @@ -2118,7 +2119,7 @@ pub struct VerifyEmail { #[post("email/verify")] pub async fn verify_email( pool: Data, - redis: Data, + redis: Data, email: web::Json, ) -> Result { let flow = Flow::get(&email.flow, &redis).await?; @@ -2168,7 +2169,7 @@ pub async fn verify_email( pub async fn subscribe_newsletter( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, session_queue: Data, ) -> Result { let user = get_user_from_headers( diff --git a/src/auth/pats.rs b/src/auth/pats.rs index c38f428b..b8b2d918 100644 --- a/src/auth/pats.rs +++ b/src/auth/pats.rs @@ -4,6 +4,7 @@ use crate::database::models::generate_pat_id; use crate::auth::get_user_from_headers; use crate::routes::ApiError; +use crate::database::redis::RedisPool; use actix_web::web::{self, Data}; use actix_web::{delete, get, patch, post, HttpRequest, HttpResponse}; use chrono::{DateTime, Utc}; @@ -30,7 +31,7 @@ pub fn config(cfg: &mut web::ServiceConfig) { pub async fn get_pats( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, session_queue: Data, ) -> Result { let user = get_user_from_headers( @@ -73,14 +74,14 @@ pub async fn create_pat( req: HttpRequest, info: web::Json, pool: Data, - redis: Data, + redis: Data, session_queue: Data, ) -> Result { info.0 .validate() .map_err(|err| ApiError::InvalidInput(validation_errors_to_string(err, None)))?; - if info.scopes.restricted() { + if info.scopes.is_restricted() { return Err(ApiError::InvalidInput( "Invalid scopes requested!".to_string(), )); @@ -159,7 +160,7 @@ pub async fn edit_pat( id: web::Path<(String,)>, info: web::Json, pool: Data, - redis: Data, + redis: Data, session_queue: Data, ) -> Result { let user = get_user_from_headers( @@ -180,7 +181,7 @@ pub async fn edit_pat( let mut transaction = pool.begin().await?; if let Some(scopes) = &info.scopes { - if scopes.restricted() { + if scopes.is_restricted() { return Err(ApiError::InvalidInput( "Invalid scopes requested!".to_string(), )); @@ -248,7 +249,7 @@ pub async fn delete_pat( req: HttpRequest, id: web::Path<(String,)>, pool: Data, - redis: Data, + redis: Data, session_queue: Data, ) -> Result { let user = get_user_from_headers( diff --git a/src/auth/session.rs b/src/auth/session.rs index 43931aa9..7d1b7d85 100644 --- a/src/auth/session.rs +++ b/src/auth/session.rs @@ -2,6 +2,7 @@ use crate::auth::{get_user_from_headers, AuthenticationError}; use crate::database::models::session_item::Session as DBSession; use crate::database::models::session_item::SessionBuilder; use crate::database::models::UserId; +use crate::database::redis::RedisPool; use crate::models::pats::Scopes; use crate::models::sessions::Session; use crate::queue::session::AuthQueue; @@ -86,7 +87,7 @@ pub async fn issue_session( req: HttpRequest, user_id: UserId, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result { let metadata = get_session_metadata(&req).await?; @@ -132,7 +133,7 @@ pub async fn issue_session( pub async fn list( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, session_queue: Data, ) -> Result { let current_user = get_user_from_headers( @@ -167,7 +168,7 @@ pub async fn delete( info: web::Path<(String,)>, req: HttpRequest, pool: Data, - redis: Data, + redis: Data, session_queue: Data, ) -> Result { let current_user = get_user_from_headers( @@ -206,7 +207,7 @@ pub async fn delete( pub async fn refresh( req: HttpRequest, pool: Data, - redis: Data, + redis: Data, session_queue: Data, ) -> Result { let current_user = get_user_from_headers(&req, &**pool, &redis, &session_queue, None) diff --git a/src/auth/validate.rs b/src/auth/validate.rs index 8589e176..34a0d128 100644 --- a/src/auth/validate.rs +++ b/src/auth/validate.rs @@ -2,6 +2,7 @@ use crate::auth::flows::AuthProvider; use crate::auth::session::get_session_metadata; use crate::auth::AuthenticationError; use crate::database::models::user_item; +use crate::database::redis::RedisPool; use crate::models::pats::Scopes; use crate::models::users::{Role, User, UserId, UserPayoutData}; use crate::queue::session::AuthQueue; @@ -12,7 +13,7 @@ use reqwest::header::{HeaderValue, AUTHORIZATION}; pub async fn get_user_from_headers<'a, E>( req: &HttpRequest, executor: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, session_queue: &AuthQueue, required_scopes: Option<&[Scopes]>, ) -> Result<(Scopes, User), AuthenticationError> @@ -82,7 +83,7 @@ pub async fn get_user_record_from_bearer_token<'a, 'b, E>( req: &HttpRequest, token: Option<&str>, executor: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, session_queue: &AuthQueue, ) -> Result, AuthenticationError> where @@ -140,7 +141,7 @@ where session_queue.add_session(session.id, metadata).await; } - user.map(|x| (Scopes::ALL, x)) + user.map(|x| (Scopes::all(), x)) } Some(("github", _)) | Some(("gho", _)) | Some(("ghp", _)) => { let user = AuthProvider::GitHub.get_user(token).await?; @@ -153,7 +154,7 @@ where ) .await?; - user.map(|x| (Scopes::NOT_RESTRICTED, x)) + user.map(|x| ((Scopes::all() ^ Scopes::restricted()), x)) } _ => return Err(AuthenticationError::InvalidAuthMethod), }; @@ -163,13 +164,14 @@ where pub async fn check_is_moderator_from_headers<'a, 'b, E>( req: &HttpRequest, executor: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, session_queue: &AuthQueue, + required_scopes: Option<&[Scopes]>, ) -> Result where E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy, { - let user = get_user_from_headers(req, executor, redis, session_queue, None) + let user = get_user_from_headers(req, executor, redis, session_queue, required_scopes) .await? .1; diff --git a/src/database/mod.rs b/src/database/mod.rs index 9c51cd17..2bba7dca 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,5 +1,6 @@ pub mod models; mod postgres_database; +pub mod redis; pub use models::Image; pub use models::Project; pub use models::Version; diff --git a/src/database/models/categories.rs b/src/database/models/categories.rs index 4a99a750..6bca5379 100644 --- a/src/database/models/categories.rs +++ b/src/database/models/categories.rs @@ -1,13 +1,13 @@ +use crate::database::redis::RedisPool; + use super::ids::*; use super::DatabaseError; use chrono::DateTime; use chrono::Utc; use futures::TryStreamExt; -use redis::cmd; use serde::{Deserialize, Serialize}; const TAGS_NAMESPACE: &str = "tags"; -const DEFAULT_EXPIRY: i64 = 1800; // 30 minutes pub struct ProjectType { pub id: ProjectTypeId, @@ -98,17 +98,12 @@ impl Category { Ok(result.map(|r| CategoryId(r.id))) } - pub async fn list<'a, E>( - exec: E, - redis: &deadpool_redis::Pool, - ) -> Result, DatabaseError> + pub async fn list<'a, E>(exec: E, redis: &RedisPool) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { - let mut redis = redis.get().await?; - let res = cmd("GET") - .arg(format!("{}:category", TAGS_NAMESPACE)) - .query_async::<_, Option>(&mut redis) + let res = redis + .get::(TAGS_NAMESPACE, "category") .await? .and_then(|x| serde_json::from_str::>(&x).ok()); @@ -137,12 +132,13 @@ impl Category { .try_collect::>() .await?; - cmd("SET") - .arg(format!("{}:category", TAGS_NAMESPACE)) - .arg(serde_json::to_string(&result)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + TAGS_NAMESPACE, + "category", + serde_json::to_string(&result)?, + None, + ) .await?; Ok(result) @@ -167,17 +163,12 @@ impl Loader { Ok(result.map(|r| LoaderId(r.id))) } - pub async fn list<'a, E>( - exec: E, - redis: &deadpool_redis::Pool, - ) -> Result, DatabaseError> + pub async fn list<'a, E>(exec: E, redis: &RedisPool) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { - let mut redis = redis.get().await?; - let res = cmd("GET") - .arg(format!("{}:loader", TAGS_NAMESPACE)) - .query_async::<_, Option>(&mut redis) + let res = redis + .get::(TAGS_NAMESPACE, "loader") .await? .and_then(|x| serde_json::from_str::>(&x).ok()); @@ -212,12 +203,13 @@ impl Loader { .try_collect::>() .await?; - cmd("SET") - .arg(format!("{}:loader", TAGS_NAMESPACE)) - .arg(serde_json::to_string(&result)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + TAGS_NAMESPACE, + "loader", + serde_json::to_string(&result)?, + None, + ) .await?; Ok(result) @@ -256,17 +248,12 @@ impl GameVersion { Ok(result.map(|r| GameVersionId(r.id))) } - pub async fn list<'a, E>( - exec: E, - redis: &deadpool_redis::Pool, - ) -> Result, DatabaseError> + pub async fn list<'a, E>(exec: E, redis: &RedisPool) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { - let mut redis = redis.get().await?; - let res = cmd("GET") - .arg(format!("{}:game_version", TAGS_NAMESPACE)) - .query_async::<_, Option>(&mut redis) + let res = redis + .get::(TAGS_NAMESPACE, "game_version") .await? .and_then(|x| serde_json::from_str::>(&x).ok()); @@ -291,14 +278,14 @@ impl GameVersion { .try_collect::>() .await?; - cmd("SET") - .arg(format!("{}:game_version", TAGS_NAMESPACE)) - .arg(serde_json::to_string(&result)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + TAGS_NAMESPACE, + "game_version", + serde_json::to_string(&result)?, + None, + ) .await?; - Ok(result) } @@ -306,7 +293,7 @@ impl GameVersion { version_type_option: Option<&str>, major_option: Option, exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -408,15 +395,13 @@ impl DonationPlatform { pub async fn list<'a, E>( exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { - let mut redis = redis.get().await?; - let res = cmd("GET") - .arg(format!("{}:donation_platform", TAGS_NAMESPACE)) - .query_async::<_, Option>(&mut redis) + let res = redis + .get::(TAGS_NAMESPACE, "donation_platform") .await? .and_then(|x| serde_json::from_str::>(&x).ok()); @@ -440,12 +425,13 @@ impl DonationPlatform { .try_collect::>() .await?; - cmd("SET") - .arg(format!("{}:donation_platform", TAGS_NAMESPACE)) - .arg(serde_json::to_string(&result)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + TAGS_NAMESPACE, + "donation_platform", + serde_json::to_string(&result)?, + None, + ) .await?; Ok(result) @@ -470,17 +456,12 @@ impl ReportType { Ok(result.map(|r| ReportTypeId(r.id))) } - pub async fn list<'a, E>( - exec: E, - redis: &deadpool_redis::Pool, - ) -> Result, DatabaseError> + pub async fn list<'a, E>(exec: E, redis: &RedisPool) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { - let mut redis = redis.get().await?; - let res = cmd("GET") - .arg(format!("{}:report_type", TAGS_NAMESPACE)) - .query_async::<_, Option>(&mut redis) + let res = redis + .get::(TAGS_NAMESPACE, "report_type") .await? .and_then(|x| serde_json::from_str::>(&x).ok()); @@ -498,12 +479,13 @@ impl ReportType { .try_collect::>() .await?; - cmd("SET") - .arg(format!("{}:report_type", TAGS_NAMESPACE)) - .arg(serde_json::to_string(&result)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + TAGS_NAMESPACE, + "report_type", + serde_json::to_string(&result)?, + None, + ) .await?; Ok(result) @@ -528,17 +510,12 @@ impl ProjectType { Ok(result.map(|r| ProjectTypeId(r.id))) } - pub async fn list<'a, E>( - exec: E, - redis: &deadpool_redis::Pool, - ) -> Result, DatabaseError> + pub async fn list<'a, E>(exec: E, redis: &RedisPool) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { - let mut redis = redis.get().await?; - let res = cmd("GET") - .arg(format!("{}:project_type", TAGS_NAMESPACE)) - .query_async::<_, Option>(&mut redis) + let res = redis + .get::(TAGS_NAMESPACE, "project_type") .await? .and_then(|x| serde_json::from_str::>(&x).ok()); @@ -556,12 +533,13 @@ impl ProjectType { .try_collect::>() .await?; - cmd("SET") - .arg(format!("{}:project_type", TAGS_NAMESPACE)) - .arg(serde_json::to_string(&result)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + TAGS_NAMESPACE, + "project_type", + serde_json::to_string(&result)?, + None, + ) .await?; Ok(result) @@ -586,17 +564,12 @@ impl SideType { Ok(result.map(|r| SideTypeId(r.id))) } - pub async fn list<'a, E>( - exec: E, - redis: &deadpool_redis::Pool, - ) -> Result, DatabaseError> + pub async fn list<'a, E>(exec: E, redis: &RedisPool) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { - let mut redis = redis.get().await?; - let res = cmd("GET") - .arg(format!("{}:side_type", TAGS_NAMESPACE)) - .query_async::<_, Option>(&mut redis) + let res = redis + .get::(TAGS_NAMESPACE, "side_type") .await? .and_then(|x| serde_json::from_str::>(&x).ok()); @@ -614,12 +587,13 @@ impl SideType { .try_collect::>() .await?; - cmd("SET") - .arg(format!("{}:side_type", TAGS_NAMESPACE)) - .arg(serde_json::to_string(&result)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + TAGS_NAMESPACE, + "side_type", + serde_json::to_string(&result)?, + None, + ) .await?; Ok(result) diff --git a/src/database/models/collection_item.rs b/src/database/models/collection_item.rs index 0500ee81..12ff7838 100644 --- a/src/database/models/collection_item.rs +++ b/src/database/models/collection_item.rs @@ -1,13 +1,12 @@ use super::ids::*; use crate::database::models; use crate::database::models::DatabaseError; +use crate::database::redis::RedisPool; use crate::models::collections::CollectionStatus; use chrono::{DateTime, Utc}; -use redis::cmd; use serde::{Deserialize, Serialize}; const COLLECTIONS_NAMESPACE: &str = "collections"; -const DEFAULT_EXPIRY: i64 = 1800; // 30 minutes #[derive(Clone)] pub struct CollectionBuilder { @@ -102,7 +101,7 @@ impl Collection { pub async fn remove( id: CollectionId, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> { let collection = Self::get(id, &mut *transaction, redis).await?; @@ -138,7 +137,7 @@ impl Collection { pub async fn get<'a, 'b, E>( id: CollectionId, executor: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -151,7 +150,7 @@ impl Collection { pub async fn get_many<'a, E>( collection_ids: &[CollectionId], exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -162,20 +161,12 @@ impl Collection { return Ok(Vec::new()); } - let mut redis = redis.get().await?; - let mut found_collections = Vec::new(); let mut remaining_collections: Vec = collection_ids.to_vec(); if !collection_ids.is_empty() { - let collections = cmd("MGET") - .arg( - collection_ids - .iter() - .map(|x| format!("{}:{}", COLLECTIONS_NAMESPACE, x.0)) - .collect::>(), - ) - .query_async::<_, Vec>>(&mut redis) + let collections = redis + .multi_get::(COLLECTIONS_NAMESPACE, collection_ids.iter().map(|x| x.0)) .await?; for collection in collections { @@ -233,14 +224,14 @@ impl Collection { .await?; for collection in db_collections { - cmd("SET") - .arg(format!("{}:{}", COLLECTIONS_NAMESPACE, collection.id.0)) - .arg(serde_json::to_string(&collection)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + COLLECTIONS_NAMESPACE, + collection.id.0, + serde_json::to_string(&collection)?, + None, + ) .await?; - found_collections.push(collection); } } @@ -248,16 +239,8 @@ impl Collection { Ok(found_collections) } - pub async fn clear_cache( - id: CollectionId, - redis: &deadpool_redis::Pool, - ) -> Result<(), DatabaseError> { - let mut redis = redis.get().await?; - let mut cmd = cmd("DEL"); - - cmd.arg(format!("{}:{}", COLLECTIONS_NAMESPACE, id.0)); - cmd.query_async::<_, ()>(&mut redis).await?; - + pub async fn clear_cache(id: CollectionId, redis: &RedisPool) -> Result<(), DatabaseError> { + redis.delete(COLLECTIONS_NAMESPACE, id.0).await?; Ok(()) } } diff --git a/src/database/models/flow_item.rs b/src/database/models/flow_item.rs index f55fa9b0..d9e8cfa3 100644 --- a/src/database/models/flow_item.rs +++ b/src/database/models/flow_item.rs @@ -1,12 +1,12 @@ use super::ids::*; use crate::auth::flows::AuthProvider; use crate::database::models::DatabaseError; +use crate::database::redis::RedisPool; use chrono::Duration; use rand::distributions::Alphanumeric; use rand::Rng; use rand_chacha::rand_core::SeedableRng; use rand_chacha::ChaCha20Rng; -use redis::cmd; use serde::{Deserialize, Serialize}; const FLOWS_NAMESPACE: &str = "flows"; @@ -40,50 +40,32 @@ impl Flow { pub async fn insert( &self, expires: Duration, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result { - let mut redis = redis.get().await?; - let flow = ChaCha20Rng::from_entropy() .sample_iter(&Alphanumeric) .take(32) .map(char::from) .collect::(); - cmd("SET") - .arg(format!("{}:{}", FLOWS_NAMESPACE, flow)) - .arg(serde_json::to_string(&self)?) - .arg("EX") - .arg(expires.num_seconds()) - .query_async::<_, ()>(&mut redis) + redis + .set( + FLOWS_NAMESPACE, + &flow, + serde_json::to_string(&self)?, + Some(expires.num_seconds()), + ) .await?; - Ok(flow) } - pub async fn get( - id: &str, - redis: &deadpool_redis::Pool, - ) -> Result, DatabaseError> { - let mut redis = redis.get().await?; - - let res = cmd("GET") - .arg(format!("{}:{}", FLOWS_NAMESPACE, id)) - .query_async::<_, Option>(&mut redis) - .await?; - + pub async fn get(id: &str, redis: &RedisPool) -> Result, DatabaseError> { + let res = redis.get::(FLOWS_NAMESPACE, id).await?; Ok(res.and_then(|x| serde_json::from_str(&x).ok())) } - pub async fn remove( - id: &str, - redis: &deadpool_redis::Pool, - ) -> Result, DatabaseError> { - let mut redis = redis.get().await?; - let mut cmd = cmd("DEL"); - cmd.arg(format!("{}:{}", FLOWS_NAMESPACE, id)); - cmd.query_async::<_, ()>(&mut redis).await?; - + pub async fn remove(id: &str, redis: &RedisPool) -> Result, DatabaseError> { + redis.delete(FLOWS_NAMESPACE, id).await?; Ok(Some(())) } } diff --git a/src/database/models/image_item.rs b/src/database/models/image_item.rs index fd6d0abb..45f42583 100644 --- a/src/database/models/image_item.rs +++ b/src/database/models/image_item.rs @@ -1,11 +1,10 @@ use super::ids::*; +use crate::database::redis::RedisPool; use crate::{database::models::DatabaseError, models::images::ImageContext}; use chrono::{DateTime, Utc}; -use redis::cmd; use serde::{Deserialize, Serialize}; const IMAGES_NAMESPACE: &str = "images"; -const DEFAULT_EXPIRY: i64 = 1800; // 30 minutes #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Image { @@ -58,7 +57,7 @@ impl Image { pub async fn remove( id: ImageId, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> { let image = Self::get(id, &mut *transaction, redis).await?; @@ -161,7 +160,7 @@ impl Image { pub async fn get<'a, 'b, E>( id: ImageId, executor: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -174,7 +173,7 @@ impl Image { pub async fn get_many<'a, E>( image_ids: &[ImageId], exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -185,24 +184,15 @@ impl Image { return Ok(Vec::new()); } - let mut redis = redis.get().await?; - let mut found_images = Vec::new(); let mut remaining_ids = image_ids.to_vec(); let image_ids = image_ids.iter().map(|x| x.0).collect::>(); if !image_ids.is_empty() { - let images = cmd("MGET") - .arg( - image_ids - .iter() - .map(|x| format!("{}:{}", IMAGES_NAMESPACE, x)) - .collect::>(), - ) - .query_async::<_, Vec>>(&mut redis) + let images = redis + .multi_get::(IMAGES_NAMESPACE, image_ids) .await?; - for image in images { if let Some(image) = image.and_then(|x| serde_json::from_str::(&x).ok()) { remaining_ids.retain(|x| image.id.0 != x.0); @@ -245,14 +235,14 @@ impl Image { .await?; for image in db_images { - cmd("SET") - .arg(format!("{}:{}", IMAGES_NAMESPACE, image.id.0)) - .arg(serde_json::to_string(&image)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + IMAGES_NAMESPACE, + image.id.0, + serde_json::to_string(&image)?, + None, + ) .await?; - found_images.push(image); } } @@ -260,16 +250,8 @@ impl Image { Ok(found_images) } - pub async fn clear_cache( - id: ImageId, - redis: &deadpool_redis::Pool, - ) -> Result<(), DatabaseError> { - let mut redis = redis.get().await?; - let mut cmd = cmd("DEL"); - - cmd.arg(format!("{}:{}", IMAGES_NAMESPACE, id.0)); - cmd.query_async::<_, ()>(&mut redis).await?; - + pub async fn clear_cache(id: ImageId, redis: &RedisPool) -> Result<(), DatabaseError> { + redis.delete(IMAGES_NAMESPACE, id.0).await?; Ok(()) } } diff --git a/src/database/models/organization_item.rs b/src/database/models/organization_item.rs index 5a52558a..15e880cd 100644 --- a/src/database/models/organization_item.rs +++ b/src/database/models/organization_item.rs @@ -1,14 +1,14 @@ -use crate::models::ids::base62_impl::{parse_base62, to_base62}; +use crate::{ + database::redis::RedisPool, + models::ids::base62_impl::{parse_base62, to_base62}, +}; use super::{ids::*, TeamMember}; -use redis::cmd; use serde::{Deserialize, Serialize}; const ORGANIZATIONS_NAMESPACE: &str = "organizations"; const ORGANIZATIONS_TITLES_NAMESPACE: &str = "organizations_titles"; -const DEFAULT_EXPIRY: i64 = 1800; - #[derive(Deserialize, Serialize, Clone, Debug)] /// An organization of users who together control one or more projects and organizations. pub struct Organization { @@ -55,7 +55,7 @@ impl Organization { pub async fn get<'a, E>( string: &str, exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, super::DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -68,7 +68,7 @@ impl Organization { pub async fn get_id<'a, 'b, E>( id: OrganizationId, exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, super::DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -81,7 +81,7 @@ impl Organization { pub async fn get_many_ids<'a, 'b, E>( organization_ids: &[OrganizationId], exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, super::DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -96,7 +96,7 @@ impl Organization { pub async fn get_many<'a, E, T: ToString>( organization_strings: &[T], exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, super::DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -107,8 +107,6 @@ impl Organization { return Ok(Vec::new()); } - let mut redis = redis.get().await?; - let mut found_organizations = Vec::new(); let mut remaining_strings = organization_strings .iter() @@ -121,20 +119,13 @@ impl Organization { .collect::>(); organization_ids.append( - &mut cmd("MGET") - .arg( + &mut redis + .multi_get::( + ORGANIZATIONS_TITLES_NAMESPACE, organization_strings .iter() - .map(|x| { - format!( - "{}:{}", - ORGANIZATIONS_TITLES_NAMESPACE, - x.to_string().to_lowercase() - ) - }) - .collect::>(), + .map(|x| x.to_string().to_lowercase()), ) - .query_async::<_, Vec>>(&mut redis) .await? .into_iter() .flatten() @@ -142,14 +133,8 @@ impl Organization { ); if !organization_ids.is_empty() { - let organizations = cmd("MGET") - .arg( - organization_ids - .iter() - .map(|x| format!("{}:{}", ORGANIZATIONS_NAMESPACE, x)) - .collect::>(), - ) - .query_async::<_, Vec>>(&mut redis) + let organizations = redis + .multi_get::(ORGANIZATIONS_NAMESPACE, organization_ids) .await?; for organization in organizations { @@ -201,25 +186,23 @@ impl Organization { .await?; for organization in organizations { - cmd("SET") - .arg(format!("{}:{}", ORGANIZATIONS_NAMESPACE, organization.id.0)) - .arg(serde_json::to_string(&organization)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + ORGANIZATIONS_NAMESPACE, + organization.id.0, + serde_json::to_string(&organization)?, + None, + ) .await?; - - cmd("SET") - .arg(format!( - "{}:{}", + redis + .set( ORGANIZATIONS_TITLES_NAMESPACE, - organization.title.to_lowercase() - )) - .arg(organization.id.0) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + organization.title.to_lowercase(), + organization.id.0, + None, + ) .await?; + found_organizations.push(organization); } } @@ -265,7 +248,7 @@ impl Organization { pub async fn remove( id: OrganizationId, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, super::DatabaseError> { use futures::TryStreamExt; @@ -333,20 +316,17 @@ impl Organization { pub async fn clear_cache( id: OrganizationId, title: Option, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result<(), super::DatabaseError> { - let mut redis = redis.get().await?; - let mut cmd = cmd("DEL"); - cmd.arg(format!("{}:{}", ORGANIZATIONS_NAMESPACE, id.0)); - if let Some(title) = title { - cmd.arg(format!( - "{}:{}", - ORGANIZATIONS_TITLES_NAMESPACE, - title.to_lowercase() - )); - } - cmd.query_async::<_, ()>(&mut redis).await?; - + redis + .delete_many([ + (ORGANIZATIONS_NAMESPACE, Some(id.0.to_string())), + ( + ORGANIZATIONS_TITLES_NAMESPACE, + title.map(|x| x.to_lowercase()), + ), + ]) + .await?; Ok(()) } } diff --git a/src/database/models/pat_item.rs b/src/database/models/pat_item.rs index f8ff23d1..ac1a17e9 100644 --- a/src/database/models/pat_item.rs +++ b/src/database/models/pat_item.rs @@ -1,17 +1,16 @@ use super::ids::*; use crate::database::models::DatabaseError; +use crate::database::redis::RedisPool; use crate::models::ids::base62_impl::{parse_base62, to_base62}; use crate::models::pats::Scopes; use chrono::{DateTime, Utc}; -use redis::cmd; use serde::{Deserialize, Serialize}; const PATS_NAMESPACE: &str = "pats"; const PATS_TOKENS_NAMESPACE: &str = "pats_tokens"; const PATS_USERS_NAMESPACE: &str = "pats_users"; -const DEFAULT_EXPIRY: i64 = 1800; // 30 minutes -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Clone, Debug)] pub struct PersonalAccessToken { pub id: PatId, pub name: String, @@ -55,7 +54,7 @@ impl PersonalAccessToken { pub async fn get<'a, E, T: ToString>( id: T, exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -68,7 +67,7 @@ impl PersonalAccessToken { pub async fn get_many_ids<'a, E>( pat_ids: &[PatId], exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -83,7 +82,7 @@ impl PersonalAccessToken { pub async fn get_many<'a, E, T: ToString>( pat_strings: &[T], exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -94,8 +93,6 @@ impl PersonalAccessToken { return Ok(Vec::new()); } - let mut redis = redis.get().await?; - let mut found_pats = Vec::new(); let mut remaining_strings = pat_strings .iter() @@ -108,14 +105,11 @@ impl PersonalAccessToken { .collect::>(); pat_ids.append( - &mut cmd("MGET") - .arg( - pat_strings - .iter() - .map(|x| format!("{}:{}", PATS_TOKENS_NAMESPACE, x.to_string())) - .collect::>(), + &mut redis + .multi_get::( + PATS_TOKENS_NAMESPACE, + pat_strings.iter().map(|x| x.to_string()), ) - .query_async::<_, Vec>>(&mut redis) .await? .into_iter() .flatten() @@ -123,16 +117,9 @@ impl PersonalAccessToken { ); if !pat_ids.is_empty() { - let pats = cmd("MGET") - .arg( - pat_ids - .iter() - .map(|x| format!("{}:{}", PATS_NAMESPACE, x)) - .collect::>(), - ) - .query_async::<_, Vec>>(&mut redis) + let pats = redis + .multi_get::(PATS_NAMESPACE, pat_ids) .await?; - for pat in pats { if let Some(pat) = pat.and_then(|x| serde_json::from_str::(&x).ok()) @@ -181,20 +168,16 @@ impl PersonalAccessToken { .await?; for pat in db_pats { - cmd("SET") - .arg(format!("{}:{}", PATS_NAMESPACE, pat.id.0)) - .arg(serde_json::to_string(&pat)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set(PATS_NAMESPACE, pat.id.0, serde_json::to_string(&pat)?, None) .await?; - - cmd("SET") - .arg(format!("{}:{}", PATS_TOKENS_NAMESPACE, pat.access_token)) - .arg(pat.id.0) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + PATS_TOKENS_NAMESPACE, + pat.access_token.clone(), + pat.id.0, + None, + ) .await?; found_pats.push(pat); } @@ -206,15 +189,13 @@ impl PersonalAccessToken { pub async fn get_user_pats<'a, E>( user_id: UserId, exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { - let mut redis = redis.get().await?; - let res = cmd("GET") - .arg(format!("{}:{}", PATS_USERS_NAMESPACE, user_id.0)) - .query_async::<_, Option>(&mut redis) + let res = redis + .get::(PATS_USERS_NAMESPACE, user_id.0) .await? .and_then(|x| serde_json::from_str::>(&x).ok()); @@ -237,41 +218,34 @@ impl PersonalAccessToken { .try_collect::>() .await?; - cmd("SET") - .arg(format!("{}:{}", PATS_USERS_NAMESPACE, user_id.0)) - .arg(serde_json::to_string(&db_pats)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + PATS_USERS_NAMESPACE, + user_id.0, + serde_json::to_string(&db_pats)?, + None, + ) .await?; - Ok(db_pats) } pub async fn clear_cache( clear_pats: Vec<(Option, Option, Option)>, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result<(), DatabaseError> { if clear_pats.is_empty() { return Ok(()); } - let mut redis = redis.get().await?; - let mut cmd = cmd("DEL"); - - for (id, token, user_id) in clear_pats { - if let Some(id) = id { - cmd.arg(format!("{}:{}", PATS_NAMESPACE, id.0)); - } - if let Some(token) = token { - cmd.arg(format!("{}:{}", PATS_TOKENS_NAMESPACE, token)); - } - if let Some(user_id) = user_id { - cmd.arg(format!("{}:{}", PATS_USERS_NAMESPACE, user_id.0)); - } - } - - cmd.query_async::<_, ()>(&mut redis).await?; + redis + .delete_many(clear_pats.into_iter().flat_map(|(id, token, user_id)| { + [ + (PATS_NAMESPACE, id.map(|i| i.0.to_string())), + (PATS_TOKENS_NAMESPACE, token), + (PATS_USERS_NAMESPACE, user_id.map(|i| i.0.to_string())), + ] + })) + .await?; Ok(()) } diff --git a/src/database/models/project_item.rs b/src/database/models/project_item.rs index a7d85679..f841f934 100644 --- a/src/database/models/project_item.rs +++ b/src/database/models/project_item.rs @@ -1,16 +1,15 @@ use super::ids::*; use crate::database::models; use crate::database::models::DatabaseError; +use crate::database::redis::RedisPool; use crate::models::ids::base62_impl::{parse_base62, to_base62}; use crate::models::projects::{MonetizationStatus, ProjectStatus}; use chrono::{DateTime, Utc}; -use redis::cmd; use serde::{Deserialize, Serialize}; -const PROJECTS_NAMESPACE: &str = "projects"; -const PROJECTS_SLUGS_NAMESPACE: &str = "projects_slugs"; +pub const PROJECTS_NAMESPACE: &str = "projects"; +pub const PROJECTS_SLUGS_NAMESPACE: &str = "projects_slugs"; const PROJECTS_DEPENDENCIES_NAMESPACE: &str = "projects_dependencies"; -const DEFAULT_EXPIRY: i64 = 1800; // 30 minutes #[derive(Clone, Debug, Serialize, Deserialize)] pub struct DonationUrl { @@ -299,7 +298,7 @@ impl Project { pub async fn remove( id: ProjectId, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> { let project = Self::get_id(id, &mut *transaction, redis).await?; @@ -433,7 +432,7 @@ impl Project { pub async fn get<'a, 'b, E>( string: &str, executor: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -446,7 +445,7 @@ impl Project { pub async fn get_id<'a, 'b, E>( id: ProjectId, executor: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -459,7 +458,7 @@ impl Project { pub async fn get_many_ids<'a, E>( project_ids: &[ProjectId], exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -474,7 +473,7 @@ impl Project { pub async fn get_many<'a, E, T: ToString>( project_strings: &[T], exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -485,8 +484,6 @@ impl Project { return Ok(Vec::new()); } - let mut redis = redis.get().await?; - let mut found_projects = Vec::new(); let mut remaining_strings = project_strings .iter() @@ -499,20 +496,11 @@ impl Project { .collect::>(); project_ids.append( - &mut cmd("MGET") - .arg( - project_strings - .iter() - .map(|x| { - format!( - "{}:{}", - PROJECTS_SLUGS_NAMESPACE, - x.to_string().to_lowercase() - ) - }) - .collect::>(), + &mut redis + .multi_get::( + PROJECTS_SLUGS_NAMESPACE, + project_strings.iter().map(|x| x.to_string().to_lowercase()), ) - .query_async::<_, Vec>>(&mut redis) .await? .into_iter() .flatten() @@ -520,16 +508,9 @@ impl Project { ); if !project_ids.is_empty() { - let projects = cmd("MGET") - .arg( - project_ids - .iter() - .map(|x| format!("{}:{}", PROJECTS_NAMESPACE, x)) - .collect::>(), - ) - .query_async::<_, Vec>>(&mut redis) + let projects = redis + .multi_get::(PROJECTS_NAMESPACE, project_ids) .await?; - for project in projects { if let Some(project) = project.and_then(|x| serde_json::from_str::(&x).ok()) @@ -551,7 +532,6 @@ impl Project { .flat_map(|x| parse_base62(&x.to_string()).ok()) .map(|x| x as i64) .collect(); - let db_projects: Vec = sqlx::query!( " SELECT m.id id, m.project_type project_type, m.title title, m.description description, m.downloads downloads, m.follows follows, @@ -672,25 +652,22 @@ impl Project { .await?; for project in db_projects { - cmd("SET") - .arg(format!("{}:{}", PROJECTS_NAMESPACE, project.inner.id.0)) - .arg(serde_json::to_string(&project)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + PROJECTS_NAMESPACE, + project.inner.id.0, + serde_json::to_string(&project)?, + None, + ) .await?; - if let Some(slug) = &project.inner.slug { - cmd("SET") - .arg(format!( - "{}:{}", + redis + .set( PROJECTS_SLUGS_NAMESPACE, - slug.to_lowercase() - )) - .arg(project.inner.id.0) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + slug.to_lowercase(), + project.inner.id.0, + None, + ) .await?; } found_projects.push(project); @@ -703,7 +680,7 @@ impl Project { pub async fn get_dependencies<'a, E>( id: ProjectId, exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, Option, Option)>, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -712,13 +689,9 @@ impl Project { use futures::stream::TryStreamExt; - let mut redis = redis.get().await?; - - let dependencies = cmd("GET") - .arg(format!("{}:{}", PROJECTS_DEPENDENCIES_NAMESPACE, id.0)) - .query_async::<_, Option>(&mut redis) + let dependencies = redis + .get::(PROJECTS_DEPENDENCIES_NAMESPACE, id.0) .await?; - if let Some(dependencies) = dependencies.and_then(|x| serde_json::from_str::(&x).ok()) { @@ -752,14 +725,14 @@ impl Project { .try_collect::() .await?; - cmd("SET") - .arg(format!("{}:{}", PROJECTS_DEPENDENCIES_NAMESPACE, id.0)) - .arg(serde_json::to_string(&dependencies)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + PROJECTS_DEPENDENCIES_NAMESPACE, + id.0, + serde_json::to_string(&dependencies)?, + None, + ) .await?; - Ok(dependencies) } @@ -817,25 +790,22 @@ impl Project { id: ProjectId, slug: Option, clear_dependencies: Option, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result<(), DatabaseError> { - let mut redis = redis.get().await?; - let mut cmd = cmd("DEL"); - - cmd.arg(format!("{}:{}", PROJECTS_NAMESPACE, id.0)); - if let Some(slug) = slug { - cmd.arg(format!( - "{}:{}", - PROJECTS_SLUGS_NAMESPACE, - slug.to_lowercase() - )); - } - if clear_dependencies.unwrap_or(false) { - cmd.arg(format!("{}:{}", PROJECTS_DEPENDENCIES_NAMESPACE, id.0)); - } - - cmd.query_async::<_, ()>(&mut redis).await?; - + redis + .delete_many([ + (PROJECTS_NAMESPACE, Some(id.0.to_string())), + (PROJECTS_SLUGS_NAMESPACE, slug.map(|x| x.to_lowercase())), + ( + PROJECTS_DEPENDENCIES_NAMESPACE, + if clear_dependencies.unwrap_or(false) { + Some(id.0.to_string()) + } else { + None + }, + ), + ]) + .await?; Ok(()) } } diff --git a/src/database/models/session_item.rs b/src/database/models/session_item.rs index e1f1843c..3cf7d2b8 100644 --- a/src/database/models/session_item.rs +++ b/src/database/models/session_item.rs @@ -1,14 +1,13 @@ use super::ids::*; use crate::database::models::DatabaseError; +use crate::database::redis::RedisPool; use crate::models::ids::base62_impl::{parse_base62, to_base62}; use chrono::{DateTime, Utc}; -use redis::cmd; use serde::{Deserialize, Serialize}; const SESSIONS_NAMESPACE: &str = "sessions"; const SESSIONS_IDS_NAMESPACE: &str = "sessions_ids"; const SESSIONS_USERS_NAMESPACE: &str = "sessions_users"; -const DEFAULT_EXPIRY: i64 = 1800; // 30 minutes pub struct SessionBuilder { pub session: String, @@ -83,7 +82,7 @@ impl Session { pub async fn get<'a, E, T: ToString>( id: T, exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -96,7 +95,7 @@ impl Session { pub async fn get_id<'a, 'b, E>( id: SessionId, executor: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -109,7 +108,7 @@ impl Session { pub async fn get_many_ids<'a, E>( session_ids: &[SessionId], exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -124,7 +123,7 @@ impl Session { pub async fn get_many<'a, E, T: ToString>( session_strings: &[T], exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -135,8 +134,6 @@ impl Session { return Ok(Vec::new()); } - let mut redis = redis.get().await?; - let mut found_sessions = Vec::new(); let mut remaining_strings = session_strings .iter() @@ -149,14 +146,11 @@ impl Session { .collect::>(); session_ids.append( - &mut cmd("MGET") - .arg( - session_strings - .iter() - .map(|x| format!("{}:{}", SESSIONS_IDS_NAMESPACE, x.to_string())) - .collect::>(), + &mut redis + .multi_get::( + SESSIONS_IDS_NAMESPACE, + session_strings.iter().map(|x| x.to_string()), ) - .query_async::<_, Vec>>(&mut redis) .await? .into_iter() .flatten() @@ -164,16 +158,9 @@ impl Session { ); if !session_ids.is_empty() { - let sessions = cmd("MGET") - .arg( - session_ids - .iter() - .map(|x| format!("{}:{}", SESSIONS_NAMESPACE, x)) - .collect::>(), - ) - .query_async::<_, Vec>>(&mut redis) + let sessions = redis + .multi_get::(SESSIONS_NAMESPACE, session_ids) .await?; - for session in sessions { if let Some(session) = session.and_then(|x| serde_json::from_str::(&x).ok()) @@ -225,20 +212,21 @@ impl Session { .await?; for session in db_sessions { - cmd("SET") - .arg(format!("{}:{}", SESSIONS_NAMESPACE, session.id.0)) - .arg(serde_json::to_string(&session)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + SESSIONS_NAMESPACE, + session.id.0, + serde_json::to_string(&session)?, + None, + ) .await?; - - cmd("SET") - .arg(format!("{}:{}", SESSIONS_IDS_NAMESPACE, session.session)) - .arg(session.id.0) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + SESSIONS_IDS_NAMESPACE, + session.session.clone(), + session.id.0, + None, + ) .await?; found_sessions.push(session); } @@ -250,15 +238,13 @@ impl Session { pub async fn get_user_sessions<'a, E>( user_id: UserId, exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { - let mut redis = redis.get().await?; - let res = cmd("GET") - .arg(format!("{}:{}", SESSIONS_USERS_NAMESPACE, user_id.0)) - .query_async::<_, Option>(&mut redis) + let res = redis + .get::(SESSIONS_USERS_NAMESPACE, user_id.0) .await? .and_then(|x| serde_json::from_str::>(&x).ok()); @@ -281,12 +267,13 @@ impl Session { .try_collect::>() .await?; - cmd("SET") - .arg(format!("{}:{}", SESSIONS_USERS_NAMESPACE, user_id.0)) - .arg(serde_json::to_string(&db_sessions)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + SESSIONS_USERS_NAMESPACE, + user_id.0, + serde_json::to_string(&db_sessions)?, + None, + ) .await?; Ok(db_sessions) @@ -294,29 +281,25 @@ impl Session { pub async fn clear_cache( clear_sessions: Vec<(Option, Option, Option)>, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result<(), DatabaseError> { if clear_sessions.is_empty() { return Ok(()); } - let mut redis = redis.get().await?; - let mut cmd = cmd("DEL"); - - for (id, session, user_id) in clear_sessions { - if let Some(id) = id { - cmd.arg(format!("{}:{}", SESSIONS_NAMESPACE, id.0)); - } - if let Some(session) = session { - cmd.arg(format!("{}:{}", SESSIONS_IDS_NAMESPACE, session)); - } - if let Some(user_id) = user_id { - cmd.arg(format!("{}:{}", SESSIONS_USERS_NAMESPACE, user_id.0)); - } - } - - cmd.query_async::<_, ()>(&mut redis).await?; - + redis + .delete_many( + clear_sessions + .into_iter() + .flat_map(|(id, session, user_id)| { + [ + (SESSIONS_NAMESPACE, id.map(|i| i.0.to_string())), + (SESSIONS_IDS_NAMESPACE, session), + (SESSIONS_USERS_NAMESPACE, user_id.map(|i| i.0.to_string())), + ] + }), + ) + .await?; Ok(()) } diff --git a/src/database/models/team_item.rs b/src/database/models/team_item.rs index f092f8dc..31d60b20 100644 --- a/src/database/models/team_item.rs +++ b/src/database/models/team_item.rs @@ -1,12 +1,13 @@ use super::{ids::*, Organization, Project}; -use crate::models::teams::{OrganizationPermissions, ProjectPermissions}; +use crate::{ + database::redis::RedisPool, + models::teams::{OrganizationPermissions, ProjectPermissions}, +}; use itertools::Itertools; -use redis::cmd; use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; const TEAMS_NAMESPACE: &str = "teams"; -const DEFAULT_EXPIRY: i64 = 1800; pub struct TeamBuilder { pub members: Vec, @@ -145,7 +146,7 @@ impl TeamMember { pub async fn get_from_team_full<'a, 'b, E>( id: TeamId, executor: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, super::DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy, @@ -156,7 +157,7 @@ impl TeamMember { pub async fn get_from_team_full_many<'a, E>( team_ids: &[TeamId], exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, super::DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy, @@ -169,18 +170,10 @@ impl TeamMember { let mut team_ids_parsed: Vec = team_ids.iter().map(|x| x.0).collect(); - let mut redis = redis.get().await?; - let mut found_teams = Vec::new(); - let teams = cmd("MGET") - .arg( - team_ids_parsed - .iter() - .map(|x| format!("{}:{}", TEAMS_NAMESPACE, x)) - .collect::>(), - ) - .query_async::<_, Vec>>(&mut redis) + let teams = redis + .multi_get::(TEAMS_NAMESPACE, team_ids_parsed.clone()) .await?; for team_raw in teams { @@ -232,14 +225,14 @@ impl TeamMember { for (id, members) in &teams.into_iter().group_by(|x| x.team_id) { let mut members = members.collect::>(); - cmd("SET") - .arg(format!("{}:{}", TEAMS_NAMESPACE, id.0)) - .arg(serde_json::to_string(&members)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + TEAMS_NAMESPACE, + id.0, + serde_json::to_string(&members)?, + None, + ) .await?; - found_teams.append(&mut members); } } @@ -247,16 +240,8 @@ impl TeamMember { Ok(found_teams) } - pub async fn clear_cache( - id: TeamId, - redis: &deadpool_redis::Pool, - ) -> Result<(), super::DatabaseError> { - let mut redis = redis.get().await?; - cmd("DEL") - .arg(format!("{}:{}", TEAMS_NAMESPACE, id.0)) - .query_async::<_, ()>(&mut redis) - .await?; - + pub async fn clear_cache(id: TeamId, redis: &RedisPool) -> Result<(), super::DatabaseError> { + redis.delete(TEAMS_NAMESPACE, id.0).await?; Ok(()) } diff --git a/src/database/models/thread_item.rs b/src/database/models/thread_item.rs index 091eece3..c81b2db4 100644 --- a/src/database/models/thread_item.rs +++ b/src/database/models/thread_item.rs @@ -2,7 +2,7 @@ use super::ids::*; use crate::database::models::DatabaseError; use crate::models::threads::{MessageBody, ThreadType}; use chrono::{DateTime, Utc}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; pub struct ThreadBuilder { pub type_: ThreadType, @@ -11,7 +11,7 @@ pub struct ThreadBuilder { pub report_id: Option, } -#[derive(Clone)] +#[derive(Clone, Serialize)] pub struct Thread { pub id: ThreadId, @@ -30,7 +30,7 @@ pub struct ThreadMessageBuilder { pub thread_id: ThreadId, } -#[derive(Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone)] pub struct ThreadMessage { pub id: ThreadMessageId, pub thread_id: ThreadId, diff --git a/src/database/models/user_item.rs b/src/database/models/user_item.rs index a46456e8..5f732e2c 100644 --- a/src/database/models/user_item.rs +++ b/src/database/models/user_item.rs @@ -1,17 +1,16 @@ use super::ids::{ProjectId, UserId}; use super::CollectionId; use crate::database::models::DatabaseError; +use crate::database::redis::RedisPool; use crate::models::ids::base62_impl::{parse_base62, to_base62}; use crate::models::users::{Badges, RecipientType, RecipientWallet}; use chrono::{DateTime, Utc}; -use redis::cmd; use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; const USERS_NAMESPACE: &str = "users"; const USER_USERNAMES_NAMESPACE: &str = "users_usernames"; // const USERS_PROJECTS_NAMESPACE: &str = "users_projects"; -const DEFAULT_EXPIRY: i64 = 1800; // 30 minutes #[derive(Deserialize, Serialize, Clone, Debug)] pub struct User { @@ -87,7 +86,7 @@ impl User { pub async fn get<'a, 'b, E>( string: &str, executor: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -100,7 +99,7 @@ impl User { pub async fn get_id<'a, 'b, E>( id: UserId, executor: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -113,7 +112,7 @@ impl User { pub async fn get_many_ids<'a, E>( user_ids: &[UserId], exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -128,7 +127,7 @@ impl User { pub async fn get_many<'a, E, T: ToString>( users_strings: &[T], exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -139,8 +138,6 @@ impl User { return Ok(Vec::new()); } - let mut redis = redis.get().await?; - let mut found_users = Vec::new(); let mut remaining_strings = users_strings .iter() @@ -153,20 +150,11 @@ impl User { .collect::>(); user_ids.append( - &mut cmd("MGET") - .arg( - users_strings - .iter() - .map(|x| { - format!( - "{}:{}", - USER_USERNAMES_NAMESPACE, - x.to_string().to_lowercase() - ) - }) - .collect::>(), + &mut redis + .multi_get::( + USER_USERNAMES_NAMESPACE, + users_strings.iter().map(|x| x.to_string().to_lowercase()), ) - .query_async::<_, Vec>>(&mut redis) .await? .into_iter() .flatten() @@ -174,16 +162,9 @@ impl User { ); if !user_ids.is_empty() { - let users = cmd("MGET") - .arg( - user_ids - .iter() - .map(|x| format!("{}:{}", USERS_NAMESPACE, x)) - .collect::>(), - ) - .query_async::<_, Vec>>(&mut redis) + let users = redis + .multi_get::(USERS_NAMESPACE, user_ids) .await?; - for user in users { if let Some(user) = user.and_then(|x| serde_json::from_str::(&x).ok()) { remaining_strings.retain(|x| { @@ -252,24 +233,21 @@ impl User { .await?; for user in db_users { - cmd("SET") - .arg(format!("{}:{}", USERS_NAMESPACE, user.id.0)) - .arg(serde_json::to_string(&user)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + USERS_NAMESPACE, + user.id.0, + serde_json::to_string(&user)?, + None, + ) .await?; - - cmd("SET") - .arg(format!( - "{}:{}", + redis + .set( USER_USERNAMES_NAMESPACE, - user.username.to_lowercase() - )) - .arg(user.id.0) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + user.username.to_lowercase(), + user.id.0, + None, + ) .await?; found_users.push(user); } @@ -371,24 +349,19 @@ impl User { pub async fn clear_caches( user_ids: &[(UserId, Option)], - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result<(), DatabaseError> { - let mut redis = redis.get().await?; - let mut cmd = cmd("DEL"); - - for (id, username) in user_ids { - cmd.arg(format!("{}:{}", USERS_NAMESPACE, id.0)); - if let Some(username) = username { - cmd.arg(format!( - "{}:{}", - USER_USERNAMES_NAMESPACE, - username.to_lowercase() - )); - } - } - - cmd.query_async::<_, ()>(&mut redis).await?; - + redis + .delete_many(user_ids.into_iter().flat_map(|(id, username)| { + [ + (USERS_NAMESPACE, Some(id.0.to_string())), + ( + USER_USERNAMES_NAMESPACE, + username.clone().map(|i| i.to_lowercase()), + ), + ] + })) + .await?; Ok(()) } @@ -396,7 +369,7 @@ impl User { id: UserId, full: bool, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> { let user = Self::get_id(id, &mut *transaction, redis).await?; diff --git a/src/database/models/version_item.rs b/src/database/models/version_item.rs index 9dcdeb3d..f917b20d 100644 --- a/src/database/models/version_item.rs +++ b/src/database/models/version_item.rs @@ -1,16 +1,16 @@ use super::ids::*; use super::DatabaseError; +use crate::database::redis::RedisPool; use crate::models::projects::{FileType, VersionStatus}; use chrono::{DateTime, Utc}; use itertools::Itertools; -use redis::cmd; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::collections::HashMap; +use std::iter; const VERSIONS_NAMESPACE: &str = "versions"; const VERSION_FILES_NAMESPACE: &str = "versions_files"; -const DEFAULT_EXPIRY: i64 = 1800; // 30 minutes #[derive(Clone)] pub struct VersionBuilder { @@ -78,7 +78,7 @@ impl DependencyBuilder { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct VersionFileBuilder { pub url: String, pub filename: String, @@ -130,7 +130,7 @@ impl VersionFileBuilder { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct HashBuilder { pub algorithm: String, pub hash: Vec, @@ -263,7 +263,7 @@ impl Version { pub async fn remove_full( id: VersionId, - redis: &deadpool_redis::Pool, + redis: &RedisPool, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, ) -> Result, DatabaseError> { let result = Self::get(id, &mut *transaction, redis).await?; @@ -398,7 +398,7 @@ impl Version { pub async fn get<'a, 'b, E>( id: VersionId, executor: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -411,7 +411,7 @@ impl Version { pub async fn get_many<'a, E>( version_ids: &[VersionId], exec: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, @@ -424,18 +424,10 @@ impl Version { let mut version_ids_parsed: Vec = version_ids.iter().map(|x| x.0).collect(); - let mut redis = redis.get().await?; - let mut found_versions = Vec::new(); - let versions = cmd("MGET") - .arg( - version_ids_parsed - .iter() - .map(|x| format!("{}:{}", VERSIONS_NAMESPACE, x)) - .collect::>(), - ) - .query_async::<_, Vec>>(&mut redis) + let versions = redis + .multi_get::(VERSIONS_NAMESPACE, version_ids_parsed.clone()) .await?; for version in versions { @@ -588,12 +580,13 @@ impl Version { .await?; for version in db_versions { - cmd("SET") - .arg(format!("{}:{}", VERSIONS_NAMESPACE, version.inner.id.0)) - .arg(serde_json::to_string(&version)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + VERSIONS_NAMESPACE, + version.inner.id.0, + serde_json::to_string(&version)?, + None, + ) .await?; found_versions.push(version); @@ -608,7 +601,7 @@ impl Version { hash: String, version_id: Option, executor: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy, @@ -625,7 +618,7 @@ impl Version { algorithm: String, hashes: &[String], executor: E, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy, @@ -638,18 +631,16 @@ impl Version { let mut file_ids_parsed = hashes.to_vec(); - let mut redis = redis.get().await?; - let mut found_files = Vec::new(); - let files = cmd("MGET") - .arg( + let files = redis + .multi_get::( + VERSION_FILES_NAMESPACE, file_ids_parsed .iter() - .map(|hash| format!("{}:{}_{}", VERSION_FILES_NAMESPACE, algorithm, hash)) + .map(|hash| format!("{}_{}", algorithm, hash)) .collect::>(), ) - .query_async::<_, Vec>>(&mut redis) .await?; for file in files { @@ -726,12 +717,13 @@ impl Version { } for (key, mut files) in save_files { - cmd("SET") - .arg(format!("{}:{}", VERSION_FILES_NAMESPACE, key)) - .arg(serde_json::to_string(&files)?) - .arg("EX") - .arg(DEFAULT_EXPIRY) - .query_async::<_, ()>(&mut redis) + redis + .set( + VERSION_FILES_NAMESPACE, + key, + serde_json::to_string(&files)?, + None, + ) .await?; found_files.append(&mut files); @@ -743,22 +735,19 @@ impl Version { pub async fn clear_cache( version: &QueryVersion, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result<(), DatabaseError> { - let mut redis = redis.get().await?; - - let mut cmd = cmd("DEL"); - - cmd.arg(format!("{}:{}", VERSIONS_NAMESPACE, version.inner.id.0)); - - for file in &version.files { - for (algo, hash) in &file.hashes { - cmd.arg(format!("{}:{}_{}", VERSION_FILES_NAMESPACE, algo, hash)); - } - } - - cmd.query_async::<_, ()>(&mut redis).await?; - + redis + .delete_many( + iter::once((VERSIONS_NAMESPACE, Some(version.inner.id.0.to_string()))).chain( + version.files.iter().flat_map(|file| { + file.hashes.iter().map(|(algo, hash)| { + (VERSION_FILES_NAMESPACE, Some(format!("{}_{}", algo, hash))) + }) + }), + ), + ) + .await?; Ok(()) } } diff --git a/src/database/redis.rs b/src/database/redis.rs new file mode 100644 index 00000000..35a17c5f --- /dev/null +++ b/src/database/redis.rs @@ -0,0 +1,128 @@ +use super::models::DatabaseError; +use deadpool_redis::{Config, Runtime}; +use redis::{cmd, FromRedisValue, ToRedisArgs}; +use std::fmt::Display; + +const DEFAULT_EXPIRY: i64 = 1800; // 30 minutes + +#[derive(Clone)] +pub struct RedisPool { + pool: deadpool_redis::Pool, + meta_namespace: String, +} + +impl RedisPool { + // initiate a new redis pool + // testing pool uses a hashmap to mimic redis behaviour for very small data sizes (ie: tests) + // PANICS: production pool will panic if redis url is not set + pub fn new(meta_namespace: Option) -> Self { + let redis_pool = Config::from_url(dotenvy::var("REDIS_URL").expect("Redis URL not set")) + .builder() + .expect("Error building Redis pool") + .max_size( + dotenvy::var("DATABASE_MAX_CONNECTIONS") + .ok() + .and_then(|x| x.parse().ok()) + .unwrap_or(10000), + ) + .runtime(Runtime::Tokio1) + .build() + .expect("Redis connection failed"); + + RedisPool { + pool: redis_pool, + meta_namespace: meta_namespace.unwrap_or("".to_string()), + } + } + + pub async fn set( + &self, + namespace: &str, + id: T1, + data: T2, + expiry: Option, + ) -> Result<(), DatabaseError> + where + T1: Display, + T2: ToRedisArgs, + { + let mut redis_connection = self.pool.get().await?; + + cmd("SET") + .arg(format!("{}_{}:{}", self.meta_namespace, namespace, id)) + .arg(data) + .arg("EX") + .arg(expiry.unwrap_or(DEFAULT_EXPIRY)) + .query_async::<_, ()>(&mut redis_connection) + .await?; + + Ok(()) + } + + pub async fn get(&self, namespace: &str, id: T1) -> Result, DatabaseError> + where + T1: Display, + R: FromRedisValue, + { + let mut redis_connection = self.pool.get().await?; + + let res = cmd("GET") + .arg(format!("{}_{}:{}", self.meta_namespace, namespace, id)) + .query_async::<_, Option>(&mut redis_connection) + .await?; + Ok(res) + } + + pub async fn multi_get( + &self, + namespace: &str, + ids: impl IntoIterator, + ) -> Result>, DatabaseError> + where + T1: Display, + R: FromRedisValue, + { + let mut redis_connection = self.pool.get().await?; + let res = cmd("MGET") + .arg( + ids.into_iter() + .map(|x| format!("{}_{}:{}", self.meta_namespace, namespace, x)) + .collect::>(), + ) + .query_async::<_, Vec>>(&mut redis_connection) + .await?; + Ok(res) + } + + pub async fn delete(&self, namespace: &str, id: T1) -> Result<(), DatabaseError> + where + T1: Display, + { + let mut redis_connection = self.pool.get().await?; + + cmd("DEL") + .arg(format!("{}_{}:{}", self.meta_namespace, namespace, id)) + .query_async::<_, ()>(&mut redis_connection) + .await?; + + Ok(()) + } + + pub async fn delete_many( + &self, + iter: impl IntoIterator)>, + ) -> Result<(), DatabaseError> +where { + let mut redis_connection = self.pool.get().await?; + + let mut cmd = cmd("DEL"); + for (namespace, id) in iter { + if let Some(id) = id { + cmd.arg(format!("{}_{}:{}", self.meta_namespace, namespace, id)); + } + } + cmd.query_async::<_, ()>(&mut redis_connection).await?; + + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 00000000..01ff0bcd --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,413 @@ +use std::sync::Arc; + +use actix_web::web; +use database::redis::RedisPool; +use log::{info, warn}; +use queue::{ + analytics::AnalyticsQueue, download::DownloadQueue, payouts::PayoutsQueue, session::AuthQueue, + socket::ActiveSockets, +}; +use scheduler::Scheduler; +use sqlx::Postgres; +use tokio::sync::{Mutex, RwLock}; + +extern crate clickhouse as clickhouse_crate; +use clickhouse_crate::Client; +use util::cors::default_cors; + +use crate::{ + queue::payouts::process_payout, + search::indexing::index_projects, + util::env::{parse_strings_from_var, parse_var}, +}; + +pub mod auth; +pub mod clickhouse; +pub mod database; +pub mod file_hosting; +pub mod models; +pub mod queue; +pub mod ratelimit; +pub mod routes; +pub mod scheduler; +pub mod search; +pub mod util; +pub mod validate; + +#[derive(Clone)] +pub struct Pepper { + pub pepper: String, +} + +#[derive(Clone)] +pub struct LabrinthConfig { + pub pool: sqlx::Pool, + pub redis_pool: RedisPool, + pub clickhouse: Client, + pub file_host: Arc, + pub maxmind: Arc, + pub scheduler: Arc, + pub ip_salt: Pepper, + pub search_config: search::SearchConfig, + pub download_queue: web::Data, + pub session_queue: web::Data, + pub payouts_queue: web::Data>, + pub analytics_queue: Arc, + pub active_sockets: web::Data>, +} + +pub fn app_setup( + pool: sqlx::Pool, + redis_pool: RedisPool, + clickhouse: &mut Client, + file_host: Arc, + maxmind: Arc, +) -> LabrinthConfig { + info!( + "Starting Labrinth on {}", + dotenvy::var("BIND_ADDR").unwrap() + ); + + let search_config = search::SearchConfig { + address: dotenvy::var("MEILISEARCH_ADDR").unwrap(), + key: dotenvy::var("MEILISEARCH_KEY").unwrap(), + }; + + let mut scheduler = scheduler::Scheduler::new(); + + // The interval in seconds at which the local database is indexed + // for searching. Defaults to 1 hour if unset. + let local_index_interval = + std::time::Duration::from_secs(parse_var("LOCAL_INDEX_INTERVAL").unwrap_or(3600)); + + let pool_ref = pool.clone(); + let search_config_ref = search_config.clone(); + scheduler.run(local_index_interval, move || { + let pool_ref = pool_ref.clone(); + let search_config_ref = search_config_ref.clone(); + async move { + info!("Indexing local database"); + let result = index_projects(pool_ref, &search_config_ref).await; + if let Err(e) = result { + warn!("Local project indexing failed: {:?}", e); + } + info!("Done indexing local database"); + } + }); + + // Changes statuses of scheduled projects/versions + let pool_ref = pool.clone(); + // TODO: Clear cache when these are run + scheduler.run(std::time::Duration::from_secs(60 * 5), move || { + let pool_ref = pool_ref.clone(); + info!("Releasing scheduled versions/projects!"); + + async move { + let projects_results = sqlx::query!( + " + UPDATE mods + SET status = requested_status + WHERE status = $1 AND approved < CURRENT_DATE AND requested_status IS NOT NULL + ", + crate::models::projects::ProjectStatus::Scheduled.as_str(), + ) + .execute(&pool_ref) + .await; + + if let Err(e) = projects_results { + warn!("Syncing scheduled releases for projects failed: {:?}", e); + } + + let versions_results = sqlx::query!( + " + UPDATE versions + SET status = requested_status + WHERE status = $1 AND date_published < CURRENT_DATE AND requested_status IS NOT NULL + ", + crate::models::projects::VersionStatus::Scheduled.as_str(), + ) + .execute(&pool_ref) + .await; + + if let Err(e) = versions_results { + warn!("Syncing scheduled releases for versions failed: {:?}", e); + } + + info!("Finished releasing scheduled versions/projects"); + } + }); + + scheduler::schedule_versions(&mut scheduler, pool.clone()); + + let download_queue = web::Data::new(DownloadQueue::new()); + + let pool_ref = pool.clone(); + let download_queue_ref = download_queue.clone(); + scheduler.run(std::time::Duration::from_secs(60 * 5), move || { + let pool_ref = pool_ref.clone(); + let download_queue_ref = download_queue_ref.clone(); + + async move { + info!("Indexing download queue"); + let result = download_queue_ref.index(&pool_ref).await; + if let Err(e) = result { + warn!("Indexing download queue failed: {:?}", e); + } + info!("Done indexing download queue"); + } + }); + + let session_queue = web::Data::new(AuthQueue::new()); + + let pool_ref = pool.clone(); + let redis_ref = redis_pool.clone(); + let session_queue_ref = session_queue.clone(); + scheduler.run(std::time::Duration::from_secs(60 * 30), move || { + let pool_ref = pool_ref.clone(); + let redis_ref = redis_ref.clone(); + let session_queue_ref = session_queue_ref.clone(); + + async move { + info!("Indexing sessions queue"); + let result = session_queue_ref.index(&pool_ref, &redis_ref).await; + if let Err(e) = result { + warn!("Indexing sessions queue failed: {:?}", e); + } + info!("Done indexing sessions queue"); + } + }); + + let reader = maxmind.clone(); + { + let reader_ref = reader.clone(); + scheduler.run(std::time::Duration::from_secs(60 * 60 * 24), move || { + let reader_ref = reader_ref.clone(); + + async move { + info!("Downloading MaxMind GeoLite2 country database"); + let result = reader_ref.index().await; + if let Err(e) = result { + warn!( + "Downloading MaxMind GeoLite2 country database failed: {:?}", + e + ); + } + info!("Done downloading MaxMind GeoLite2 country database"); + } + }); + } + info!("Downloading MaxMind GeoLite2 country database"); + + let analytics_queue = Arc::new(AnalyticsQueue::new()); + { + let client_ref = clickhouse.clone(); + let analytics_queue_ref = analytics_queue.clone(); + scheduler.run(std::time::Duration::from_secs(60 * 5), move || { + let client_ref = client_ref.clone(); + let analytics_queue_ref = analytics_queue_ref.clone(); + + async move { + info!("Indexing analytics queue"); + let result = analytics_queue_ref.index(client_ref).await; + if let Err(e) = result { + warn!("Indexing analytics queue failed: {:?}", e); + } + info!("Done indexing analytics queue"); + } + }); + } + + { + let pool_ref = pool.clone(); + let redis_ref = redis_pool.clone(); + let client_ref = clickhouse.clone(); + scheduler.run(std::time::Duration::from_secs(60 * 60 * 6), move || { + let pool_ref = pool_ref.clone(); + let redis_ref = redis_ref.clone(); + let client_ref = client_ref.clone(); + + async move { + info!("Started running payouts"); + let result = process_payout(&pool_ref, &redis_ref, &client_ref).await; + if let Err(e) = result { + warn!("Payouts run failed: {:?}", e); + } + info!("Done running payouts"); + } + }); + } + + let ip_salt = Pepper { + pepper: models::ids::Base62Id(models::ids::random_base62(11)).to_string(), + }; + + let payouts_queue = web::Data::new(Mutex::new(PayoutsQueue::new())); + let active_sockets = web::Data::new(RwLock::new(ActiveSockets::default())); + + LabrinthConfig { + pool, + redis_pool, + clickhouse: clickhouse.clone(), + file_host, + maxmind, + scheduler: Arc::new(scheduler), + ip_salt, + download_queue, + search_config, + session_queue, + payouts_queue, + analytics_queue, + active_sockets, + } +} + +pub fn app_config(cfg: &mut web::ServiceConfig, labrinth_config: LabrinthConfig) { + cfg.app_data( + web::FormConfig::default() + .error_handler(|err, _req| routes::ApiError::Validation(err.to_string()).into()), + ) + .app_data( + web::PathConfig::default() + .error_handler(|err, _req| routes::ApiError::Validation(err.to_string()).into()), + ) + .app_data( + web::QueryConfig::default() + .error_handler(|err, _req| routes::ApiError::Validation(err.to_string()).into()), + ) + .app_data( + web::JsonConfig::default() + .error_handler(|err, _req| routes::ApiError::Validation(err.to_string()).into()), + ) + .app_data(web::Data::new(labrinth_config.redis_pool.clone())) + .app_data(web::Data::new(labrinth_config.pool.clone())) + .app_data(web::Data::new(labrinth_config.file_host.clone())) + .app_data(web::Data::new(labrinth_config.search_config.clone())) + .app_data(labrinth_config.download_queue.clone()) + .app_data(labrinth_config.session_queue.clone()) + .app_data(labrinth_config.payouts_queue.clone()) + .app_data(web::Data::new(labrinth_config.ip_salt.clone())) + .app_data(web::Data::new(labrinth_config.analytics_queue.clone())) + .app_data(web::Data::new(labrinth_config.clickhouse.clone())) + .app_data(web::Data::new(labrinth_config.maxmind.clone())) + .app_data(labrinth_config.active_sockets.clone()) + .configure(routes::v2::config) + .configure(routes::v3::config) + .configure(routes::root_config) + .default_service(web::get().wrap(default_cors()).to(routes::not_found)); +} + +// This is so that env vars not used immediately don't panic at runtime +pub fn check_env_vars() -> bool { + let mut failed = false; + + fn check_var(var: &'static str) -> bool { + let check = parse_var::(var).is_none(); + if check { + warn!( + "Variable `{}` missing in dotenv or not of type `{}`", + var, + std::any::type_name::() + ); + } + check + } + + failed |= check_var::("SITE_URL"); + failed |= check_var::("CDN_URL"); + failed |= check_var::("LABRINTH_ADMIN_KEY"); + failed |= check_var::("RATE_LIMIT_IGNORE_KEY"); + failed |= check_var::("DATABASE_URL"); + failed |= check_var::("MEILISEARCH_ADDR"); + failed |= check_var::("MEILISEARCH_KEY"); + failed |= check_var::("REDIS_URL"); + failed |= check_var::("BIND_ADDR"); + failed |= check_var::("SELF_ADDR"); + + failed |= check_var::("STORAGE_BACKEND"); + + let storage_backend = dotenvy::var("STORAGE_BACKEND").ok(); + match storage_backend.as_deref() { + Some("backblaze") => { + failed |= check_var::("BACKBLAZE_KEY_ID"); + failed |= check_var::("BACKBLAZE_KEY"); + failed |= check_var::("BACKBLAZE_BUCKET_ID"); + } + Some("s3") => { + failed |= check_var::("S3_ACCESS_TOKEN"); + failed |= check_var::("S3_SECRET"); + failed |= check_var::("S3_URL"); + failed |= check_var::("S3_REGION"); + failed |= check_var::("S3_BUCKET_NAME"); + } + Some("local") => { + failed |= check_var::("MOCK_FILE_PATH"); + } + Some(backend) => { + warn!("Variable `STORAGE_BACKEND` contains an invalid value: {}. Expected \"backblaze\", \"s3\", or \"local\".", backend); + failed |= true; + } + _ => { + warn!("Variable `STORAGE_BACKEND` is not set!"); + failed |= true; + } + } + + failed |= check_var::("LOCAL_INDEX_INTERVAL"); + failed |= check_var::("VERSION_INDEX_INTERVAL"); + + if parse_strings_from_var("WHITELISTED_MODPACK_DOMAINS").is_none() { + warn!("Variable `WHITELISTED_MODPACK_DOMAINS` missing in dotenv or not a json array of strings"); + failed |= true; + } + + if parse_strings_from_var("ALLOWED_CALLBACK_URLS").is_none() { + warn!("Variable `ALLOWED_CALLBACK_URLS` missing in dotenv or not a json array of strings"); + failed |= true; + } + + failed |= check_var::("PAYPAL_API_URL"); + failed |= check_var::("PAYPAL_CLIENT_ID"); + failed |= check_var::("PAYPAL_CLIENT_SECRET"); + + failed |= check_var::("GITHUB_CLIENT_ID"); + failed |= check_var::("GITHUB_CLIENT_SECRET"); + failed |= check_var::("GITLAB_CLIENT_ID"); + failed |= check_var::("GITLAB_CLIENT_SECRET"); + failed |= check_var::("DISCORD_CLIENT_ID"); + failed |= check_var::("DISCORD_CLIENT_SECRET"); + failed |= check_var::("MICROSOFT_CLIENT_ID"); + failed |= check_var::("MICROSOFT_CLIENT_SECRET"); + failed |= check_var::("GOOGLE_CLIENT_ID"); + failed |= check_var::("GOOGLE_CLIENT_SECRET"); + failed |= check_var::("STEAM_API_KEY"); + + failed |= check_var::("TURNSTILE_SECRET"); + + failed |= check_var::("SMTP_USERNAME"); + failed |= check_var::("SMTP_PASSWORD"); + failed |= check_var::("SMTP_HOST"); + + failed |= check_var::("SITE_VERIFY_EMAIL_PATH"); + failed |= check_var::("SITE_RESET_PASSWORD_PATH"); + + failed |= check_var::("BEEHIIV_PUBLICATION_ID"); + failed |= check_var::("BEEHIIV_API_KEY"); + + if parse_strings_from_var("ANALYTICS_ALLOWED_ORIGINS").is_none() { + warn!( + "Variable `ANALYTICS_ALLOWED_ORIGINS` missing in dotenv or not a json array of strings" + ); + failed |= true; + } + + failed |= check_var::("CLICKHOUSE_URL"); + failed |= check_var::("CLICKHOUSE_USER"); + failed |= check_var::("CLICKHOUSE_PASSWORD"); + failed |= check_var::("CLICKHOUSE_DATABASE"); + + failed |= check_var::("MAXMIND_LICENSE_KEY"); + + failed |= check_var::("PAYOUTS_BUDGET"); + + failed +} diff --git a/src/main.rs b/src/main.rs index 5a6aed60..e0d0e0ff 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,34 +1,15 @@ -use crate::file_hosting::S3Host; -use crate::queue::analytics::AnalyticsQueue; -use crate::queue::download::DownloadQueue; -use crate::queue::payouts::{process_payout, PayoutsQueue}; -use crate::queue::session::AuthQueue; -use crate::queue::socket::ActiveSockets; -use crate::ratelimit::errors::ARError; -use crate::ratelimit::memory::{MemoryStore, MemoryStoreActor}; -use crate::ratelimit::middleware::RateLimiter; -use crate::util::cors::default_cors; -use crate::util::env::{parse_strings_from_var, parse_var}; -use actix_web::{web, App, HttpServer}; -use deadpool_redis::{Config, Runtime}; +use actix_web::{App, HttpServer}; use env_logger::Env; -use log::{error, info, warn}; -use search::indexing::index_projects; -use std::sync::Arc; -use tokio::sync::{Mutex, RwLock}; +use labrinth::database::redis::RedisPool; +use labrinth::file_hosting::S3Host; +use labrinth::ratelimit::errors::ARError; +use labrinth::ratelimit::memory::{MemoryStore, MemoryStoreActor}; +use labrinth::ratelimit::middleware::RateLimiter; +use labrinth::util::env::parse_var; +use labrinth::{check_env_vars, clickhouse, database, file_hosting, queue}; +use log::{error, info}; -mod auth; -mod clickhouse; -mod database; -mod file_hosting; -mod models; -mod queue; -mod ratelimit; -mod routes; -mod scheduler; -mod search; -mod util; -mod validate; +use std::sync::Arc; #[derive(Clone)] pub struct Pepper { @@ -63,11 +44,6 @@ async fn main() -> std::io::Result<()> { dotenvy::var("BIND_ADDR").unwrap() ); - let search_config = search::SearchConfig { - address: dotenvy::var("MEILISEARCH_ADDR").unwrap(), - key: dotenvy::var("MEILISEARCH_KEY").unwrap(), - }; - database::check_for_migrations() .await .expect("An error occurred while running migrations."); @@ -78,18 +54,7 @@ async fn main() -> std::io::Result<()> { .expect("Database connection failed"); // Redis connector - let redis_pool = Config::from_url(dotenvy::var("REDIS_URL").expect("Redis URL not set")) - .builder() - .expect("Error building Redis pool") - .max_size( - dotenvy::var("DATABASE_MAX_CONNECTIONS") - .ok() - .and_then(|x| x.parse().ok()) - .unwrap_or(10000), - ) - .runtime(Runtime::Tokio1) - .build() - .expect("Redis connection failed"); + let redis_pool = RedisPool::new(None); let storage_backend = dotenvy::var("STORAGE_BACKEND").unwrap_or_else(|_| "local".to_string()); @@ -116,184 +81,23 @@ async fn main() -> std::io::Result<()> { _ => panic!("Invalid storage backend specified. Aborting startup!"), }; - let mut scheduler = scheduler::Scheduler::new(); - - // The interval in seconds at which the local database is indexed - // for searching. Defaults to 1 hour if unset. - let local_index_interval = - std::time::Duration::from_secs(parse_var("LOCAL_INDEX_INTERVAL").unwrap_or(3600)); - - let pool_ref = pool.clone(); - let search_config_ref = search_config.clone(); - scheduler.run(local_index_interval, move || { - let pool_ref = pool_ref.clone(); - let search_config_ref = search_config_ref.clone(); - async move { - info!("Indexing local database"); - let result = index_projects(pool_ref, &search_config_ref).await; - if let Err(e) = result { - warn!("Local project indexing failed: {:?}", e); - } - info!("Done indexing local database"); - } - }); - - // Changes statuses of scheduled projects/versions - let pool_ref = pool.clone(); - // TODO: Clear cache when these are run - scheduler.run(std::time::Duration::from_secs(60 * 5), move || { - let pool_ref = pool_ref.clone(); - info!("Releasing scheduled versions/projects!"); - - async move { - let projects_results = sqlx::query!( - " - UPDATE mods - SET status = requested_status - WHERE status = $1 AND approved < CURRENT_DATE AND requested_status IS NOT NULL - ", - crate::models::projects::ProjectStatus::Scheduled.as_str(), - ) - .execute(&pool_ref) - .await; - - if let Err(e) = projects_results { - warn!("Syncing scheduled releases for projects failed: {:?}", e); - } - - let versions_results = sqlx::query!( - " - UPDATE versions - SET status = requested_status - WHERE status = $1 AND date_published < CURRENT_DATE AND requested_status IS NOT NULL - ", - crate::models::projects::VersionStatus::Scheduled.as_str(), - ) - .execute(&pool_ref) - .await; - - if let Err(e) = versions_results { - warn!("Syncing scheduled releases for versions failed: {:?}", e); - } - - info!("Finished releasing scheduled versions/projects"); - } - }); - - scheduler::schedule_versions(&mut scheduler, pool.clone()); - - let download_queue = web::Data::new(DownloadQueue::new()); - - let pool_ref = pool.clone(); - let download_queue_ref = download_queue.clone(); - scheduler.run(std::time::Duration::from_secs(60 * 5), move || { - let pool_ref = pool_ref.clone(); - let download_queue_ref = download_queue_ref.clone(); - - async move { - info!("Indexing download queue"); - let result = download_queue_ref.index(&pool_ref).await; - if let Err(e) = result { - warn!("Indexing download queue failed: {:?}", e); - } - info!("Done indexing download queue"); - } - }); - - let session_queue = web::Data::new(AuthQueue::new()); - - let pool_ref = pool.clone(); - let redis_ref = redis_pool.clone(); - let session_queue_ref = session_queue.clone(); - scheduler.run(std::time::Duration::from_secs(60 * 30), move || { - let pool_ref = pool_ref.clone(); - let redis_ref = redis_ref.clone(); - let session_queue_ref = session_queue_ref.clone(); - - async move { - info!("Indexing sessions queue"); - let result = session_queue_ref.index(&pool_ref, &redis_ref).await; - if let Err(e) = result { - warn!("Indexing sessions queue failed: {:?}", e); - } - info!("Done indexing sessions queue"); - } - }); - info!("Initializing clickhouse connection"); - let clickhouse = clickhouse::init_client().await.unwrap(); - - let reader = Arc::new(queue::maxmind::MaxMindIndexer::new().await.unwrap()); - { - let reader_ref = reader.clone(); - scheduler.run(std::time::Duration::from_secs(60 * 60 * 24), move || { - let reader_ref = reader_ref.clone(); - - async move { - info!("Downloading MaxMind GeoLite2 country database"); - let result = reader_ref.index().await; - if let Err(e) = result { - warn!( - "Downloading MaxMind GeoLite2 country database failed: {:?}", - e - ); - } - info!("Done downloading MaxMind GeoLite2 country database"); - } - }); - } - info!("Downloading MaxMind GeoLite2 country database"); - - let analytics_queue = Arc::new(AnalyticsQueue::new()); - { - let client_ref = clickhouse.clone(); - let analytics_queue_ref = analytics_queue.clone(); - scheduler.run(std::time::Duration::from_secs(60 * 5), move || { - let client_ref = client_ref.clone(); - let analytics_queue_ref = analytics_queue_ref.clone(); - - async move { - info!("Indexing analytics queue"); - let result = analytics_queue_ref.index(client_ref).await; - if let Err(e) = result { - warn!("Indexing analytics queue failed: {:?}", e); - } - info!("Done indexing analytics queue"); - } - }); - } - - { - let pool_ref = pool.clone(); - let redis_ref = redis_pool.clone(); - let client_ref = clickhouse.clone(); - scheduler.run(std::time::Duration::from_secs(60 * 60 * 6), move || { - let pool_ref = pool_ref.clone(); - let redis_ref = redis_ref.clone(); - let client_ref = client_ref.clone(); - - async move { - info!("Started running payouts"); - let result = process_payout(&pool_ref, &redis_ref, &client_ref).await; - if let Err(e) = result { - warn!("Payouts run failed: {:?}", e); - } - info!("Done running payouts"); - } - }); - } - - let ip_salt = Pepper { - pepper: models::ids::Base62Id(models::ids::random_base62(11)).to_string(), - }; + let mut clickhouse = clickhouse::init_client().await.unwrap(); - let payouts_queue = web::Data::new(Mutex::new(PayoutsQueue::new())); - let active_sockets = web::Data::new(RwLock::new(ActiveSockets::default())); + let maxmind_reader = Arc::new(queue::maxmind::MaxMindIndexer::new().await.unwrap()); let store = MemoryStore::new(); info!("Starting Actix HTTP server!"); + let labrinth_config = labrinth::app_setup( + pool.clone(), + redis_pool.clone(), + &mut clickhouse, + file_host.clone(), + maxmind_reader.clone(), + ); + // Init App HttpServer::new(move || { App::new() @@ -320,160 +124,9 @@ async fn main() -> std::io::Result<()> { .with_ignore_key(dotenvy::var("RATE_LIMIT_IGNORE_KEY").ok()), ) .wrap(sentry_actix::Sentry::new()) - .app_data( - web::FormConfig::default().error_handler(|err, _req| { - routes::ApiError::Validation(err.to_string()).into() - }), - ) - .app_data( - web::PathConfig::default().error_handler(|err, _req| { - routes::ApiError::Validation(err.to_string()).into() - }), - ) - .app_data( - web::QueryConfig::default().error_handler(|err, _req| { - routes::ApiError::Validation(err.to_string()).into() - }), - ) - .app_data( - web::JsonConfig::default().error_handler(|err, _req| { - routes::ApiError::Validation(err.to_string()).into() - }), - ) - .app_data(web::Data::new(redis_pool.clone())) - .app_data(web::Data::new(pool.clone())) - .app_data(web::Data::new(file_host.clone())) - .app_data(web::Data::new(search_config.clone())) - .app_data(download_queue.clone()) - .app_data(session_queue.clone()) - .app_data(payouts_queue.clone()) - .app_data(web::Data::new(ip_salt.clone())) - .app_data(web::Data::new(analytics_queue.clone())) - .app_data(web::Data::new(clickhouse.clone())) - .app_data(web::Data::new(reader.clone())) - .app_data(active_sockets.clone()) - .configure(routes::v2::config) - .configure(routes::v3::config) - .configure(routes::root_config) - .default_service(web::get().wrap(default_cors()).to(routes::not_found)) + .configure(|cfg| labrinth::app_config(cfg, labrinth_config.clone())) }) .bind(dotenvy::var("BIND_ADDR").unwrap())? .run() .await } - -// This is so that env vars not used immediately don't panic at runtime -fn check_env_vars() -> bool { - let mut failed = false; - - fn check_var(var: &'static str) -> bool { - let check = parse_var::(var).is_none(); - if check { - warn!( - "Variable `{}` missing in dotenv or not of type `{}`", - var, - std::any::type_name::() - ); - } - check - } - - failed |= check_var::("SITE_URL"); - failed |= check_var::("CDN_URL"); - failed |= check_var::("LABRINTH_ADMIN_KEY"); - failed |= check_var::("RATE_LIMIT_IGNORE_KEY"); - failed |= check_var::("DATABASE_URL"); - failed |= check_var::("MEILISEARCH_ADDR"); - failed |= check_var::("MEILISEARCH_KEY"); - failed |= check_var::("REDIS_URL"); - failed |= check_var::("BIND_ADDR"); - failed |= check_var::("SELF_ADDR"); - - failed |= check_var::("STORAGE_BACKEND"); - - let storage_backend = dotenvy::var("STORAGE_BACKEND").ok(); - match storage_backend.as_deref() { - Some("backblaze") => { - failed |= check_var::("BACKBLAZE_KEY_ID"); - failed |= check_var::("BACKBLAZE_KEY"); - failed |= check_var::("BACKBLAZE_BUCKET_ID"); - } - Some("s3") => { - failed |= check_var::("S3_ACCESS_TOKEN"); - failed |= check_var::("S3_SECRET"); - failed |= check_var::("S3_URL"); - failed |= check_var::("S3_REGION"); - failed |= check_var::("S3_BUCKET_NAME"); - } - Some("local") => { - failed |= check_var::("MOCK_FILE_PATH"); - } - Some(backend) => { - warn!("Variable `STORAGE_BACKEND` contains an invalid value: {}. Expected \"backblaze\", \"s3\", or \"local\".", backend); - failed |= true; - } - _ => { - warn!("Variable `STORAGE_BACKEND` is not set!"); - failed |= true; - } - } - - failed |= check_var::("LOCAL_INDEX_INTERVAL"); - failed |= check_var::("VERSION_INDEX_INTERVAL"); - - if parse_strings_from_var("WHITELISTED_MODPACK_DOMAINS").is_none() { - warn!("Variable `WHITELISTED_MODPACK_DOMAINS` missing in dotenv or not a json array of strings"); - failed |= true; - } - - if parse_strings_from_var("ALLOWED_CALLBACK_URLS").is_none() { - warn!("Variable `ALLOWED_CALLBACK_URLS` missing in dotenv or not a json array of strings"); - failed |= true; - } - - failed |= check_var::("PAYPAL_API_URL"); - failed |= check_var::("PAYPAL_CLIENT_ID"); - failed |= check_var::("PAYPAL_CLIENT_SECRET"); - - failed |= check_var::("GITHUB_CLIENT_ID"); - failed |= check_var::("GITHUB_CLIENT_SECRET"); - failed |= check_var::("GITLAB_CLIENT_ID"); - failed |= check_var::("GITLAB_CLIENT_SECRET"); - failed |= check_var::("DISCORD_CLIENT_ID"); - failed |= check_var::("DISCORD_CLIENT_SECRET"); - failed |= check_var::("MICROSOFT_CLIENT_ID"); - failed |= check_var::("MICROSOFT_CLIENT_SECRET"); - failed |= check_var::("GOOGLE_CLIENT_ID"); - failed |= check_var::("GOOGLE_CLIENT_SECRET"); - failed |= check_var::("STEAM_API_KEY"); - - failed |= check_var::("TURNSTILE_SECRET"); - - failed |= check_var::("SMTP_USERNAME"); - failed |= check_var::("SMTP_PASSWORD"); - failed |= check_var::("SMTP_HOST"); - - failed |= check_var::("SITE_VERIFY_EMAIL_PATH"); - failed |= check_var::("SITE_RESET_PASSWORD_PATH"); - - failed |= check_var::("BEEHIIV_PUBLICATION_ID"); - failed |= check_var::("BEEHIIV_API_KEY"); - - if parse_strings_from_var("ANALYTICS_ALLOWED_ORIGINS").is_none() { - warn!( - "Variable `ANALYTICS_ALLOWED_ORIGINS` missing in dotenv or not a json array of strings" - ); - failed |= true; - } - - failed |= check_var::("CLICKHOUSE_URL"); - failed |= check_var::("CLICKHOUSE_USER"); - failed |= check_var::("CLICKHOUSE_PASSWORD"); - failed |= check_var::("CLICKHOUSE_DATABASE"); - - failed |= check_var::("MAXMIND_LICENSE_KEY"); - - failed |= check_var::("PAYOUTS_BUDGET"); - - failed -} diff --git a/src/models/pack.rs b/src/models/pack.rs index 67bb7c26..682d40dc 100644 --- a/src/models/pack.rs +++ b/src/models/pack.rs @@ -1,5 +1,4 @@ -use crate::models::projects::SideType; -use crate::parse_strings_from_var; +use crate::{models::projects::SideType, util::env::parse_strings_from_var}; use serde::{Deserialize, Serialize}; use validator::Validate; diff --git a/src/models/pats.rs b/src/models/pats.rs index 313a7614..5d3f65ca 100644 --- a/src/models/pats.rs +++ b/src/models/pats.rs @@ -51,7 +51,7 @@ bitflags::bitflags! { const VERSION_READ = 1 << 15; // write to a version's data (metadata, files, etc) const VERSION_WRITE = 1 << 16; - // delete a project + // delete a version const VERSION_DELETE = 1 << 17; // create a report @@ -103,26 +103,26 @@ bitflags::bitflags! { // delete an organization const ORGANIZATION_DELETE = 1 << 38; - const ALL = 0b111111111111111111111111111111111111111; - const NOT_RESTRICTED = 0b1111111100000011111111111111100111; const NONE = 0b0; } } impl Scopes { // these scopes cannot be specified in a personal access token - pub fn restricted(&self) -> bool { - self.contains( - Scopes::PAT_CREATE - | Scopes::PAT_READ - | Scopes::PAT_WRITE - | Scopes::PAT_DELETE - | Scopes::SESSION_READ - | Scopes::SESSION_DELETE - | Scopes::USER_AUTH_WRITE - | Scopes::USER_DELETE - | Scopes::PERFORM_ANALYTICS, - ) + pub fn restricted() -> Scopes { + Scopes::PAT_CREATE + | Scopes::PAT_READ + | Scopes::PAT_WRITE + | Scopes::PAT_DELETE + | Scopes::SESSION_READ + | Scopes::SESSION_DELETE + | Scopes::USER_AUTH_WRITE + | Scopes::USER_DELETE + | Scopes::PERFORM_ANALYTICS + } + + pub fn is_restricted(&self) -> bool { + self.intersects(Self::restricted()) } } diff --git a/src/models/users.rs b/src/models/users.rs index 7b1a2a98..4b2a0e90 100644 --- a/src/models/users.rs +++ b/src/models/users.rs @@ -4,7 +4,7 @@ use chrono::{DateTime, Utc}; use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; -#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] #[serde(from = "Base62Id")] #[serde(into = "Base62Id")] pub struct UserId(pub u64); @@ -35,7 +35,7 @@ impl Default for Badges { } } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct User { pub id: UserId, pub username: String, @@ -57,7 +57,7 @@ pub struct User { pub github_id: Option, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct UserPayoutData { pub balance: Decimal, pub payout_wallet: Option, @@ -156,7 +156,7 @@ impl From for User { } } -#[derive(Serialize, Deserialize, PartialEq, Eq, Clone)] +#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)] #[serde(rename_all = "lowercase")] pub enum Role { Developer, diff --git a/src/queue/payouts.rs b/src/queue/payouts.rs index 57df7054..73924e9d 100644 --- a/src/queue/payouts.rs +++ b/src/queue/payouts.rs @@ -1,6 +1,6 @@ -use crate::models::projects::MonetizationStatus; use crate::routes::ApiError; use crate::util::env::parse_var; +use crate::{database::redis::RedisPool, models::projects::MonetizationStatus}; use base64::Engine; use chrono::{DateTime, Datelike, Duration, Utc, Weekday}; use rust_decimal::Decimal; @@ -203,7 +203,7 @@ impl PayoutsQueue { pub async fn process_payout( pool: &PgPool, - redis: &deadpool_redis::Pool, + redis: &RedisPool, client: &clickhouse::Client, ) -> Result<(), ApiError> { let start: DateTime = DateTime::from_utc( diff --git a/src/queue/session.rs b/src/queue/session.rs index eb76ec39..bbc2896e 100644 --- a/src/queue/session.rs +++ b/src/queue/session.rs @@ -2,6 +2,7 @@ use crate::auth::session::SessionMetadata; use crate::database::models::pat_item::PersonalAccessToken; use crate::database::models::session_item::Session; use crate::database::models::{DatabaseError, PatId, SessionId, UserId}; +use crate::database::redis::RedisPool; use chrono::Utc; use sqlx::PgPool; use std::collections::{HashMap, HashSet}; @@ -42,11 +43,7 @@ impl AuthQueue { std::mem::replace(&mut queue, HashSet::with_capacity(len)) } - pub async fn index( - &self, - pool: &PgPool, - redis: &deadpool_redis::Pool, - ) -> Result<(), DatabaseError> { + pub async fn index(&self, pool: &PgPool, redis: &RedisPool) -> Result<(), DatabaseError> { let session_queue = self.take_sessions().await; let pat_queue = self.take_pats().await; diff --git a/src/ratelimit/memory.rs b/src/ratelimit/memory.rs index 60c4abf0..2e786835 100644 --- a/src/ratelimit/memory.rs +++ b/src/ratelimit/memory.rs @@ -20,7 +20,7 @@ impl MemoryStore { /// /// # Example /// ```rust - /// use actix_ratelimit::MemoryStore; + /// use labrinth::ratelimit::memory::MemoryStore; /// /// let store = MemoryStore::new(); /// ``` diff --git a/src/routes/analytics.rs b/src/routes/analytics.rs index 04203291..5e06b4c5 100644 --- a/src/routes/analytics.rs +++ b/src/routes/analytics.rs @@ -1,11 +1,12 @@ use crate::auth::get_user_from_headers; +use crate::database::redis::RedisPool; use crate::models::analytics::{PageView, Playtime}; use crate::models::pats::Scopes; +use crate::queue::analytics::AnalyticsQueue; use crate::queue::maxmind::MaxMindIndexer; use crate::queue::session::AuthQueue; use crate::routes::ApiError; use crate::util::env::parse_strings_from_var; -use crate::AnalyticsQueue; use actix_web::{post, web}; use actix_web::{HttpRequest, HttpResponse}; use chrono::Utc; @@ -63,7 +64,7 @@ pub async fn page_view_ingest( session_queue: web::Data, url_input: web::Json, pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let user = get_user_from_headers(&req, &**pool, &redis, &session_queue, None) .await @@ -169,7 +170,7 @@ pub async fn playtime_ingest( session_queue: web::Data, playtime_input: web::Json>, pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let (_, user) = get_user_from_headers( &req, diff --git a/src/routes/maven.rs b/src/routes/maven.rs index f8d0927e..e5641106 100644 --- a/src/routes/maven.rs +++ b/src/routes/maven.rs @@ -1,6 +1,7 @@ use crate::database::models::categories::Loader; use crate::database::models::project_item::QueryProject; use crate::database::models::version_item::{QueryFile, QueryVersion}; +use crate::database::redis::RedisPool; use crate::models::pats::Scopes; use crate::models::projects::{ProjectId, VersionId}; use crate::queue::session::AuthQueue; @@ -71,7 +72,7 @@ pub async fn maven_metadata( req: HttpRequest, params: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let project_id = params.into_inner().0; @@ -156,7 +157,7 @@ async fn find_version( project: &QueryProject, vcoords: &String, pool: &PgPool, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, ApiError> { let id_option = crate::models::ids::base62_impl::parse_base62(vcoords) .ok() @@ -245,7 +246,7 @@ pub async fn version_file( req: HttpRequest, params: web::Path<(String, String, String)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let (project_id, vnum, file) = params.into_inner(); @@ -306,7 +307,7 @@ pub async fn version_file_sha1( req: HttpRequest, params: web::Path<(String, String, String)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let (project_id, vnum, file) = params.into_inner(); @@ -348,7 +349,7 @@ pub async fn version_file_sha512( req: HttpRequest, params: web::Path<(String, String, String)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let (project_id, vnum, file) = params.into_inner(); diff --git a/src/routes/updates.rs b/src/routes/updates.rs index d3674f35..004621a9 100644 --- a/src/routes/updates.rs +++ b/src/routes/updates.rs @@ -6,6 +6,7 @@ use sqlx::PgPool; use crate::auth::{filter_authorized_versions, get_user_from_headers, is_authorized}; use crate::database; +use crate::database::redis::RedisPool; use crate::models::pats::Scopes; use crate::models::projects::VersionType; use crate::queue::session::AuthQueue; @@ -32,7 +33,7 @@ pub async fn forge_updates( web::Query(neo): web::Query, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { const ERROR: &str = "The specified project does not exist!"; diff --git a/src/routes/v2/admin.rs b/src/routes/v2/admin.rs index 4b0f193f..be4db052 100644 --- a/src/routes/v2/admin.rs +++ b/src/routes/v2/admin.rs @@ -1,13 +1,14 @@ use crate::auth::validate::get_user_record_from_bearer_token; +use crate::database::redis::RedisPool; use crate::models::analytics::Download; use crate::models::ids::ProjectId; use crate::models::pats::Scopes; use crate::queue::analytics::AnalyticsQueue; +use crate::queue::download::DownloadQueue; use crate::queue::maxmind::MaxMindIndexer; use crate::queue::session::AuthQueue; use crate::routes::ApiError; use crate::util::guards::admin_key_guard; -use crate::DownloadQueue; use actix_web::{patch, web, HttpRequest, HttpResponse}; use chrono::Utc; use serde::Deserialize; @@ -37,7 +38,7 @@ pub struct DownloadBody { pub async fn count_download( req: HttpRequest, pool: web::Data, - redis: web::Data, + redis: web::Data, maxmind: web::Data>, analytics_queue: web::Data>, session_queue: web::Data, diff --git a/src/routes/v2/analytics_get.rs b/src/routes/v2/analytics_get.rs index d09932a9..11d0f293 100644 --- a/src/routes/v2/analytics_get.rs +++ b/src/routes/v2/analytics_get.rs @@ -1,10 +1,5 @@ use super::ApiError; -use actix_web::{get, web, HttpRequest, HttpResponse}; -use chrono::{Duration, NaiveDate, Utc}; -use serde::{Deserialize, Serialize}; -use sqlx::PgPool; -use std::collections::HashMap; - +use crate::database::redis::RedisPool; use crate::{ auth::{filter_authorized_projects, filter_authorized_versions, get_user_from_headers}, database::models::{project_item, user_item, version_item}, @@ -17,6 +12,11 @@ use crate::{ }, queue::session::AuthQueue, }; +use actix_web::{get, web, HttpRequest, HttpResponse}; +use chrono::{Duration, NaiveDate, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::PgPool; +use std::collections::HashMap; pub fn config(cfg: &mut web::ServiceConfig) { cfg.service( @@ -70,7 +70,7 @@ pub async fn playtimes_get( data: web::Query, session_queue: web::Data, pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let user_option = get_user_from_headers( &req, @@ -153,7 +153,7 @@ pub async fn views_get( data: web::Query, session_queue: web::Data, pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let user_option = get_user_from_headers( &req, @@ -236,7 +236,7 @@ pub async fn downloads_get( data: web::Query, session_queue: web::Data, pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let user_option = get_user_from_headers( &req, @@ -322,7 +322,7 @@ pub async fn countries_downloads_get( data: web::Query, session_queue: web::Data, pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let user_option = get_user_from_headers( &req, @@ -406,7 +406,7 @@ pub async fn countries_views_get( data: web::Query, session_queue: web::Data, pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let user_option = get_user_from_headers( &req, @@ -476,7 +476,7 @@ async fn filter_allowed_ids( version_ids: Option>, user_option: Option, pool: &web::Data, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result<(Option>, Option>), ApiError> { if project_ids.is_some() && version_ids.is_some() { return Err(ApiError::InvalidInput( diff --git a/src/routes/v2/collections.rs b/src/routes/v2/collections.rs index b4920588..01372b0e 100644 --- a/src/routes/v2/collections.rs +++ b/src/routes/v2/collections.rs @@ -1,7 +1,7 @@ use crate::auth::checks::{filter_authorized_collections, is_authorized_collection}; use crate::auth::get_user_from_headers; -use crate::database; use crate::database::models::{collection_item, generate_collection_id, project_item}; +use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models::collections::{Collection, CollectionStatus}; use crate::models::ids::base62_impl::parse_base62; @@ -11,6 +11,7 @@ use crate::queue::session::AuthQueue; use crate::routes::ApiError; use crate::util::routes::read_from_payload; use crate::util::validate::validation_errors_to_string; +use crate::{database, models}; use actix_web::web::Data; use actix_web::{delete, get, patch, post, web, HttpRequest, HttpResponse}; use chrono::Utc; @@ -56,7 +57,7 @@ pub async fn collection_create( req: HttpRequest, collection_create_data: web::Json, client: Data, - redis: Data, + redis: Data, session_queue: Data, ) -> Result { let collection_create_data = collection_create_data.into_inner(); @@ -130,7 +131,7 @@ pub async fn collections_get( req: HttpRequest, web::Query(ids): web::Query, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let ids = serde_json::from_str::>(&ids.ids)?; @@ -162,7 +163,7 @@ pub async fn collection_get( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let string = info.into_inner().0; @@ -208,19 +209,18 @@ pub async fn collection_edit( info: web::Path<(String,)>, pool: web::Data, new_collection: web::Json, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { - let user_option = get_user_from_headers( + let user = get_user_from_headers( &req, &**pool, &redis, &session_queue, Some(&[Scopes::COLLECTION_WRITE]), ) - .await - .map(|x| x.1) - .ok(); + .await? + .1; new_collection .validate() @@ -231,7 +231,7 @@ pub async fn collection_edit( let result = database::models::Collection::get(id, &**pool, &redis).await?; if let Some(collection_item) = result { - if !is_authorized_collection(&collection_item, &user_option).await? { + if !can_modify_collection(&collection_item, &user) { return Ok(HttpResponse::Unauthorized().body("")); } @@ -268,27 +268,25 @@ pub async fn collection_edit( } if let Some(status) = &new_collection.status { - if let Some(user) = user_option { - if !(user.role.is_mod() - || collection_item.status.is_approved() && status.can_be_requested()) - { - return Err(ApiError::CustomAuthentication( - "You don't have permission to set this status!".to_string(), - )); - } - - sqlx::query!( - " - UPDATE collections - SET status = $1 - WHERE (id = $2) - ", - status.to_string(), - id as database::models::ids::CollectionId, - ) - .execute(&mut *transaction) - .await?; + if !(user.role.is_mod() + || collection_item.status.is_approved() && status.can_be_requested()) + { + return Err(ApiError::CustomAuthentication( + "You don't have permission to set this status!".to_string(), + )); } + + sqlx::query!( + " + UPDATE collections + SET status = $1 + WHERE (id = $2) + ", + status.to_string(), + id as database::models::ids::CollectionId, + ) + .execute(&mut *transaction) + .await?; } if let Some(new_project_ids) = &new_collection.new_projects { @@ -348,23 +346,22 @@ pub async fn collection_icon_edit( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, file_host: web::Data>, mut payload: web::Payload, session_queue: web::Data, ) -> Result { if let Some(content_type) = crate::util::ext::get_image_content_type(&ext.ext) { let cdn_url = dotenvy::var("CDN_URL")?; - let user_option = get_user_from_headers( + let user = get_user_from_headers( &req, &**pool, &redis, &session_queue, Some(&[Scopes::COLLECTION_WRITE]), ) - .await - .map(|x| x.1) - .ok(); + .await? + .1; let string = info.into_inner().0; let id = database::models::CollectionId(parse_base62(&string)? as i64); @@ -374,7 +371,7 @@ pub async fn collection_icon_edit( ApiError::InvalidInput("The specified collection does not exist!".to_string()) })?; - if !is_authorized_collection(&collection_item, &user_option).await? { + if !can_modify_collection(&collection_item, &user) { return Ok(HttpResponse::Unauthorized().body("")); } @@ -434,20 +431,20 @@ pub async fn delete_collection_icon( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, file_host: web::Data>, session_queue: web::Data, ) -> Result { - let user_option = get_user_from_headers( + let user = get_user_from_headers( &req, &**pool, &redis, &session_queue, Some(&[Scopes::COLLECTION_WRITE]), ) - .await - .map(|x| x.1) - .ok(); + .await? + .1; + let string = info.into_inner().0; let id = database::models::CollectionId(parse_base62(&string)? as i64); let collection_item = database::models::Collection::get(id, &**pool, &redis) @@ -455,7 +452,7 @@ pub async fn delete_collection_icon( .ok_or_else(|| { ApiError::InvalidInput("The specified collection does not exist!".to_string()) })?; - if !is_authorized_collection(&collection_item, &user_option).await? { + if !can_modify_collection(&collection_item, &user) { return Ok(HttpResponse::Unauthorized().body("")); } @@ -493,19 +490,18 @@ pub async fn collection_delete( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { - let user_option = get_user_from_headers( + let user = get_user_from_headers( &req, &**pool, &redis, &session_queue, Some(&[Scopes::COLLECTION_DELETE]), ) - .await - .map(|x| x.1) - .ok(); + .await? + .1; let string = info.into_inner().0; let id = database::models::CollectionId(parse_base62(&string)? as i64); @@ -514,7 +510,7 @@ pub async fn collection_delete( .ok_or_else(|| { ApiError::InvalidInput("The specified collection does not exist!".to_string()) })?; - if !is_authorized_collection(&collection, &user_option).await? { + if !can_modify_collection(&collection, &user) { return Ok(HttpResponse::Unauthorized().body("")); } let mut transaction = pool.begin().await?; @@ -531,3 +527,10 @@ pub async fn collection_delete( Ok(HttpResponse::NotFound().body("")) } } + +fn can_modify_collection( + collection: &database::models::Collection, + user: &models::users::User, +) -> bool { + collection.user_id == user.id.into() || user.role.is_mod() +} diff --git a/src/routes/v2/images.rs b/src/routes/v2/images.rs index a945d1e7..0d1eecbb 100644 --- a/src/routes/v2/images.rs +++ b/src/routes/v2/images.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use crate::auth::{get_user_from_headers, is_authorized, is_authorized_version}; use crate::database; use crate::database::models::{project_item, report_item, thread_item, version_item}; +use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models::ids::{ThreadMessageId, VersionId}; use crate::models::images::{Image, ImageContext}; @@ -41,7 +42,7 @@ pub async fn images_add( file_host: web::Data>, mut payload: web::Payload, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { if let Some(content_type) = crate::util::ext::get_image_content_type(&data.ext) { diff --git a/src/routes/v2/moderation.rs b/src/routes/v2/moderation.rs index e1d6e995..ebebf654 100644 --- a/src/routes/v2/moderation.rs +++ b/src/routes/v2/moderation.rs @@ -1,8 +1,9 @@ use super::ApiError; -use crate::auth::check_is_moderator_from_headers; use crate::database; +use crate::database::redis::RedisPool; use crate::models::projects::ProjectStatus; use crate::queue::session::AuthQueue; +use crate::{auth::check_is_moderator_from_headers, models::pats::Scopes}; use actix_web::{get, web, HttpRequest, HttpResponse}; use serde::Deserialize; use sqlx::PgPool; @@ -25,11 +26,18 @@ fn default_count() -> i16 { pub async fn get_projects( req: HttpRequest, pool: web::Data, - redis: web::Data, + redis: web::Data, count: web::Query, session_queue: web::Data, ) -> Result { - check_is_moderator_from_headers(&req, &**pool, &redis, &session_queue).await?; + check_is_moderator_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::PROJECT_READ]), + ) + .await?; use futures::stream::TryStreamExt; diff --git a/src/routes/v2/notifications.rs b/src/routes/v2/notifications.rs index b0a0940d..8923de57 100644 --- a/src/routes/v2/notifications.rs +++ b/src/routes/v2/notifications.rs @@ -1,5 +1,6 @@ use crate::auth::get_user_from_headers; use crate::database; +use crate::database::redis::RedisPool; use crate::models::ids::NotificationId; use crate::models::notifications::Notification; use crate::models::pats::Scopes; @@ -17,7 +18,7 @@ pub fn config(cfg: &mut web::ServiceConfig) { cfg.service( web::scope("notification") .service(notification_get) - .service(notifications_read) + .service(notification_read) .service(notification_delete), ); } @@ -32,7 +33,7 @@ pub async fn notifications_get( req: HttpRequest, web::Query(ids): web::Query, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -72,7 +73,7 @@ pub async fn notification_get( req: HttpRequest, info: web::Path<(NotificationId,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -106,7 +107,7 @@ pub async fn notification_read( req: HttpRequest, info: web::Path<(NotificationId,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -149,7 +150,7 @@ pub async fn notification_delete( req: HttpRequest, info: web::Path<(NotificationId,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -192,7 +193,7 @@ pub async fn notifications_read( req: HttpRequest, web::Query(ids): web::Query, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -237,7 +238,7 @@ pub async fn notifications_delete( req: HttpRequest, web::Query(ids): web::Query, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( diff --git a/src/routes/v2/organizations.rs b/src/routes/v2/organizations.rs index 427fe5ee..754d1a1e 100644 --- a/src/routes/v2/organizations.rs +++ b/src/routes/v2/organizations.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use crate::auth::{filter_authorized_projects, get_user_from_headers}; use crate::database::models::team_item::TeamMember; use crate::database::models::{generate_organization_id, team_item, Organization}; +use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models::ids::base62_impl::parse_base62; use crate::models::organizations::OrganizationId; @@ -39,16 +40,14 @@ pub fn config(cfg: &mut web::ServiceConfig) { #[derive(Deserialize, Validate)] pub struct NewOrganization { - #[validate(length(min = 3, max = 256))] - pub description: String, #[validate( length(min = 3, max = 64), regex = "crate::util::validate::RE_URL_SAFE" )] // Title of the organization, also used as slug pub title: String, - #[serde(default = "crate::models::teams::ProjectPermissions::default")] - pub default_project_permissions: ProjectPermissions, + #[validate(length(min = 3, max = 256))] + pub description: String, } #[post("organization")] @@ -56,7 +55,7 @@ pub async fn organization_create( req: HttpRequest, new_organization: web::Json, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let current_user = get_user_from_headers( @@ -143,7 +142,7 @@ pub async fn organization_get( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let id = info.into_inner().0; @@ -208,7 +207,7 @@ pub async fn organizations_get( req: HttpRequest, web::Query(ids): web::Query, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let ids = serde_json::from_str::>(&ids.ids)?; @@ -289,7 +288,6 @@ pub struct OrganizationEdit { )] // Title of the organization, also used as slug pub title: Option, - pub default_project_permissions: Option, } #[patch("{id}")] @@ -298,7 +296,7 @@ pub async fn organizations_edit( info: web::Path<(String,)>, new_organization: web::Json, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -434,7 +432,7 @@ pub async fn organization_delete( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -498,7 +496,7 @@ pub async fn organization_projects_get( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let info = info.into_inner().0; @@ -507,7 +505,7 @@ pub async fn organization_projects_get( &**pool, &redis, &session_queue, - Some(&[Scopes::ORGANIZATION_READ]), + Some(&[Scopes::ORGANIZATION_READ, Scopes::PROJECT_READ]), ) .await .map(|x| x.1) @@ -519,7 +517,7 @@ pub async fn organization_projects_get( let project_ids = sqlx::query!( " SELECT m.id FROM organizations o - LEFT JOIN mods m ON m.id = o.id + INNER JOIN mods m ON m.organization_id = o.id WHERE (o.id = $1 AND $1 IS NOT NULL) OR (o.title = $2 AND $2 IS NOT NULL) ", possible_organization_id.map(|x| x as i64), @@ -547,7 +545,7 @@ pub async fn organization_projects_add( info: web::Path<(String,)>, project_info: web::Json, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let info = info.into_inner().0; @@ -649,7 +647,7 @@ pub async fn organization_projects_remove( req: HttpRequest, info: web::Path<(String, String)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let (organization_id, project_id) = info.into_inner(); @@ -743,7 +741,7 @@ pub async fn organization_icon_edit( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, file_host: web::Data>, mut payload: web::Payload, session_queue: web::Data, @@ -848,7 +846,7 @@ pub async fn delete_organization_icon( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, file_host: web::Data>, session_queue: web::Data, ) -> Result { diff --git a/src/routes/v2/project_creation.rs b/src/routes/v2/project_creation.rs index a0e057d8..0eed7c25 100644 --- a/src/routes/v2/project_creation.rs +++ b/src/routes/v2/project_creation.rs @@ -2,6 +2,7 @@ use super::version_creation::InitialVersionData; use crate::auth::{get_user_from_headers, AuthenticationError}; use crate::database::models::thread_item::ThreadBuilder; use crate::database::models::{self, image_item}; +use crate::database::redis::RedisPool; use crate::file_hosting::{FileHost, FileHostingError}; use crate::models::error::ApiError; use crate::models::ids::ImageId; @@ -283,7 +284,7 @@ pub async fn project_create( req: HttpRequest, mut payload: Multipart, client: Data, - redis: Data, + redis: Data, file_host: Data>, session_queue: Data, ) -> Result { @@ -354,7 +355,7 @@ async fn project_create_inner( file_host: &dyn FileHost, uploaded_files: &mut Vec, pool: &PgPool, - redis: &deadpool_redis::Pool, + redis: &RedisPool, session_queue: &AuthQueue, ) -> Result { // The base URL for files uploaded to backblaze @@ -405,7 +406,6 @@ async fn project_create_inner( "`data` field must come before file fields", ))); } - let mut data = Vec::new(); while let Some(chunk) = field.next().await { data.extend_from_slice(&chunk.map_err(CreateError::MultipartError)?); diff --git a/src/routes/v2/projects.rs b/src/routes/v2/projects.rs index 4d30b207..50967487 100644 --- a/src/routes/v2/projects.rs +++ b/src/routes/v2/projects.rs @@ -3,6 +3,7 @@ use crate::database; use crate::database::models::image_item; use crate::database::models::notification_item::NotificationBuilder; use crate::database::models::thread_item::ThreadMessageBuilder; +use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models; use crate::models::ids::base62_impl::parse_base62; @@ -79,7 +80,7 @@ pub struct RandomProjects { pub async fn random_projects_get( web::Query(count): web::Query, pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { count .validate() @@ -119,7 +120,7 @@ pub async fn projects_get( req: HttpRequest, web::Query(ids): web::Query, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let ids = serde_json::from_str::>(&ids.ids)?; @@ -146,13 +147,12 @@ pub async fn project_get( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let string = info.into_inner().0; let project_data = database::models::Project::get(&string, &**pool, &redis).await?; - let user_option = get_user_from_headers( &req, &**pool, @@ -177,7 +177,7 @@ pub async fn project_get( pub async fn project_get_check( info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let slug = info.into_inner().0; @@ -203,7 +203,7 @@ pub async fn dependency_list( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let string = info.into_inner().0; @@ -275,7 +275,7 @@ pub async fn dependency_list( } } -#[derive(Deserialize, Validate)] +#[derive(Serialize, Deserialize, Validate)] pub struct EditProject { #[validate( length(min = 3, max = 64), @@ -381,7 +381,7 @@ pub async fn project_edit( pool: web::Data, config: web::Data, new_project: web::Json, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -997,7 +997,6 @@ pub async fn project_edit( .execute(&mut *transaction) .await?; } - if let Some(donations) = &new_project.donation_urls { if !perms.contains(ProjectPermissions::EDIT_DETAILS) { return Err(ApiError::CustomAuthentication( @@ -1244,7 +1243,7 @@ pub async fn projects_edit( web::Query(ids): web::Query, pool: web::Data, bulk_edit_project: web::Json, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -1622,7 +1621,7 @@ pub async fn project_schedule( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, scheduling_data: web::Json, ) -> Result { @@ -1724,7 +1723,7 @@ pub async fn project_icon_edit( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, file_host: web::Data>, mut payload: web::Payload, session_queue: web::Data, @@ -1840,7 +1839,7 @@ pub async fn delete_project_icon( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, file_host: web::Data>, session_queue: web::Data, ) -> Result { @@ -1943,7 +1942,7 @@ pub async fn add_gallery_item( web::Query(item): web::Query, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, file_host: web::Data>, mut payload: web::Payload, session_queue: web::Data, @@ -2106,7 +2105,7 @@ pub async fn edit_gallery_item( web::Query(item): web::Query, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -2269,7 +2268,7 @@ pub async fn delete_gallery_item( web::Query(item): web::Query, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, file_host: web::Data>, session_queue: web::Data, ) -> Result { @@ -2375,7 +2374,7 @@ pub async fn project_delete( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, config: web::Data, session_queue: web::Data, ) -> Result { @@ -2465,7 +2464,7 @@ pub async fn project_follow( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -2544,7 +2543,7 @@ pub async fn project_unfollow( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( diff --git a/src/routes/v2/reports.rs b/src/routes/v2/reports.rs index 90960e30..c0eba9c3 100644 --- a/src/routes/v2/reports.rs +++ b/src/routes/v2/reports.rs @@ -2,6 +2,7 @@ use crate::auth::{check_is_moderator_from_headers, get_user_from_headers}; use crate::database; use crate::database::models::image_item; use crate::database::models::thread_item::{ThreadBuilder, ThreadMessageBuilder}; +use crate::database::redis::RedisPool; use crate::models::ids::ImageId; use crate::models::ids::{base62_impl::parse_base62, ProjectId, UserId, VersionId}; use crate::models::images::{Image, ImageContext}; @@ -44,7 +45,7 @@ pub async fn report_create( req: HttpRequest, pool: web::Data, mut body: web::Payload, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let mut transaction = pool.begin().await?; @@ -235,7 +236,7 @@ fn default_all() -> bool { pub async fn reports( req: HttpRequest, pool: web::Data, - redis: web::Data, + redis: web::Data, count: web::Query, session_queue: web::Data, ) -> Result { @@ -310,7 +311,7 @@ pub async fn reports_get( req: HttpRequest, web::Query(ids): web::Query, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let report_ids: Vec = @@ -345,7 +346,7 @@ pub async fn reports_get( pub async fn report_get( req: HttpRequest, pool: web::Data, - redis: web::Data, + redis: web::Data, info: web::Path<(crate::models::reports::ReportId,)>, session_queue: web::Data, ) -> Result { @@ -385,7 +386,7 @@ pub struct EditReport { pub async fn report_edit( req: HttpRequest, pool: web::Data, - redis: web::Data, + redis: web::Data, info: web::Path<(crate::models::reports::ReportId,)>, session_queue: web::Data, edit_report: web::Json, @@ -404,7 +405,7 @@ pub async fn report_edit( let report = crate::database::models::report_item::Report::get(id, &**pool).await?; if let Some(report) = report { - if !user.role.is_mod() && report.user_id != Some(user.id.into()) { + if !user.role.is_mod() && report.reporter != user.id.into() { return Ok(HttpResponse::NotFound().body("")); } @@ -492,10 +493,17 @@ pub async fn report_delete( req: HttpRequest, pool: web::Data, info: web::Path<(crate::models::reports::ReportId,)>, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { - check_is_moderator_from_headers(&req, &**pool, &redis, &session_queue).await?; + check_is_moderator_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::REPORT_DELETE]), + ) + .await?; let mut transaction = pool.begin().await?; diff --git a/src/routes/v2/tags.rs b/src/routes/v2/tags.rs index 9307ae3e..56ffaac5 100644 --- a/src/routes/v2/tags.rs +++ b/src/routes/v2/tags.rs @@ -1,6 +1,7 @@ use super::ApiError; use crate::database::models; use crate::database::models::categories::{DonationPlatform, ProjectType, ReportType, SideType}; +use crate::database::redis::RedisPool; use actix_web::{get, web, HttpResponse}; use chrono::{DateTime, Utc}; use models::categories::{Category, GameVersion, Loader}; @@ -32,7 +33,7 @@ pub struct CategoryData { #[get("category")] pub async fn category_list( pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let results = Category::list(&**pool, &redis) .await? @@ -58,7 +59,7 @@ pub struct LoaderData { #[get("loader")] pub async fn loader_list( pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let mut results = Loader::list(&**pool, &redis) .await? @@ -94,7 +95,7 @@ pub struct GameVersionQuery { pub async fn game_version_list( pool: web::Data, query: web::Query, - redis: web::Data, + redis: web::Data, ) -> Result { let results: Vec = if query.type_.is_some() || query.major.is_some() { GameVersion::list_filter(query.type_.as_deref(), query.major, &**pool, &redis).await? @@ -172,7 +173,7 @@ pub struct DonationPlatformQueryData { #[get("donation_platform")] pub async fn donation_platform_list( pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let results: Vec = DonationPlatform::list(&**pool, &redis) .await? @@ -188,7 +189,7 @@ pub async fn donation_platform_list( #[get("report_type")] pub async fn report_type_list( pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let results = ReportType::list(&**pool, &redis).await?; Ok(HttpResponse::Ok().json(results)) @@ -197,7 +198,7 @@ pub async fn report_type_list( #[get("project_type")] pub async fn project_type_list( pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let results = ProjectType::list(&**pool, &redis).await?; Ok(HttpResponse::Ok().json(results)) @@ -206,7 +207,7 @@ pub async fn project_type_list( #[get("side_type")] pub async fn side_type_list( pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let results = SideType::list(&**pool, &redis).await?; Ok(HttpResponse::Ok().json(results)) diff --git a/src/routes/v2/teams.rs b/src/routes/v2/teams.rs index 34985ae3..866ee436 100644 --- a/src/routes/v2/teams.rs +++ b/src/routes/v2/teams.rs @@ -2,6 +2,7 @@ use crate::auth::{get_user_from_headers, is_authorized}; use crate::database::models::notification_item::NotificationBuilder; use crate::database::models::team_item::TeamAssociationId; use crate::database::models::{Organization, Team, TeamMember}; +use crate::database::redis::RedisPool; use crate::database::Project; use crate::models::notifications::NotificationBody; use crate::models::pats::Scopes; @@ -37,7 +38,7 @@ pub async fn team_members_get_project( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let string = info.into_inner().0; @@ -116,7 +117,7 @@ pub async fn team_members_get_organization( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let string = info.into_inner().0; @@ -182,7 +183,7 @@ pub async fn team_members_get( req: HttpRequest, info: web::Path<(TeamId,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let id = info.into_inner().0; @@ -244,7 +245,7 @@ pub async fn teams_get( req: HttpRequest, web::Query(ids): web::Query, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { use itertools::Itertools; @@ -309,7 +310,7 @@ pub async fn join_team( req: HttpRequest, info: web::Path<(TeamId,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let team_id = info.into_inner().0.into(); @@ -389,7 +390,7 @@ pub async fn add_team_member( info: web::Path<(TeamId,)>, pool: web::Data, new_member: web::Json, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let team_id = info.into_inner().0.into(); @@ -452,7 +453,6 @@ pub async fn add_team_member( let organization_permissions = OrganizationPermissions::get_permissions_by_role(¤t_user.role, &member) .unwrap_or_default(); - println!("{:?}", organization_permissions); if !organization_permissions.contains(OrganizationPermissions::MANAGE_INVITES) { return Err(ApiError::CustomAuthentication( "You don't have permission to invite users to this organization".to_string(), @@ -571,7 +571,7 @@ pub async fn edit_team_member( info: web::Path<(TeamId, UserId)>, pool: web::Data, edit_member: web::Json, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let ids = info.into_inner(); @@ -724,7 +724,7 @@ pub async fn transfer_ownership( info: web::Path<(TeamId,)>, pool: web::Data, new_owner: web::Json, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let id = info.into_inner().0; @@ -822,7 +822,7 @@ pub async fn remove_team_member( req: HttpRequest, info: web::Path<(TeamId, UserId)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let ids = info.into_inner(); diff --git a/src/routes/v2/threads.rs b/src/routes/v2/threads.rs index c2e6c096..af2a5782 100644 --- a/src/routes/v2/threads.rs +++ b/src/routes/v2/threads.rs @@ -5,6 +5,7 @@ use crate::database; use crate::database::models::image_item; use crate::database::models::notification_item::NotificationBuilder; use crate::database::models::thread_item::ThreadMessageBuilder; +use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models::ids::ThreadMessageId; use crate::models::images::{Image, ImageContext}; @@ -83,7 +84,7 @@ pub async fn filter_authorized_threads( threads: Vec, user: &User, pool: &web::Data, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result, ApiError> { let user_id: database::models::UserId = user.id.into(); @@ -225,7 +226,7 @@ pub async fn thread_get( req: HttpRequest, info: web::Path<(ThreadId,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let string = info.into_inner().0.into(); @@ -276,7 +277,7 @@ pub async fn threads_get( req: HttpRequest, web::Query(ids): web::Query, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -313,7 +314,7 @@ pub async fn thread_send_message( info: web::Path<(ThreadId,)>, pool: web::Data, new_message: web::Json, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -508,10 +509,17 @@ pub async fn thread_send_message( pub async fn moderation_inbox( req: HttpRequest, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { - let user = check_is_moderator_from_headers(&req, &**pool, &redis, &session_queue).await?; + let user = check_is_moderator_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::THREAD_READ]), + ) + .await?; let ids = sqlx::query!( " @@ -527,7 +535,6 @@ pub async fn moderation_inbox( let threads_data = database::models::Thread::get_many(&ids, &**pool).await?; let threads = filter_authorized_threads(threads_data, &user, &pool, &redis).await?; - Ok(HttpResponse::Ok().json(threads)) } @@ -536,10 +543,17 @@ pub async fn thread_read( req: HttpRequest, info: web::Path<(ThreadId,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { - check_is_moderator_from_headers(&req, &**pool, &redis, &session_queue).await?; + check_is_moderator_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::THREAD_READ]), + ) + .await?; let id = info.into_inner().0; let mut transaction = pool.begin().await?; @@ -565,7 +579,7 @@ pub async fn message_delete( req: HttpRequest, info: web::Path<(ThreadMessageId,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, file_host: web::Data>, ) -> Result { diff --git a/src/routes/v2/users.rs b/src/routes/v2/users.rs index 6adfe6a8..bda564a7 100644 --- a/src/routes/v2/users.rs +++ b/src/routes/v2/users.rs @@ -1,5 +1,6 @@ use crate::auth::{get_user_from_headers, AuthenticationError}; use crate::database::models::User; +use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models::collections::{Collection, CollectionStatus}; use crate::models::notifications::Notification; @@ -46,7 +47,7 @@ pub fn config(cfg: &mut web::ServiceConfig) { pub async fn user_auth_get( req: HttpRequest, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let (scopes, mut user) = get_user_from_headers( @@ -66,17 +67,7 @@ pub async fn user_auth_get( user.payout_data = None; } - Ok(HttpResponse::Ok().json( - get_user_from_headers( - &req, - &**pool, - &redis, - &session_queue, - Some(&[Scopes::USER_READ]), - ) - .await? - .1, - )) + Ok(HttpResponse::Ok().json(user)) } #[derive(Serialize, Deserialize)] @@ -88,7 +79,7 @@ pub struct UserIds { pub async fn users_get( web::Query(ids): web::Query, pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let user_ids = serde_json::from_str::>(&ids.ids)?; @@ -103,7 +94,7 @@ pub async fn users_get( pub async fn user_get( info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, ) -> Result { let user_data = User::get(&info.into_inner().0, &**pool, &redis).await?; @@ -120,7 +111,7 @@ pub async fn projects_list( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -164,7 +155,7 @@ pub async fn collections_list( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -250,7 +241,7 @@ pub async fn user_edit( info: web::Path<(String,)>, new_user: web::Json, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let (scopes, user) = get_user_from_headers( @@ -471,7 +462,7 @@ pub async fn user_icon_edit( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, file_host: web::Data>, mut payload: web::Payload, session_queue: web::Data, @@ -560,7 +551,7 @@ pub async fn user_delete( info: web::Path<(String,)>, pool: web::Data, removal_type: web::Query, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -608,7 +599,7 @@ pub async fn user_follows( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -664,7 +655,7 @@ pub async fn user_notifications( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -712,7 +703,7 @@ pub async fn user_payouts( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( @@ -797,7 +788,7 @@ pub async fn user_payouts_request( pool: web::Data, data: web::Json, payouts_queue: web::Data>, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let mut payouts_queue = payouts_queue.lock().await; diff --git a/src/routes/v2/version_creation.rs b/src/routes/v2/version_creation.rs index af3375de..80fc895d 100644 --- a/src/routes/v2/version_creation.rs +++ b/src/routes/v2/version_creation.rs @@ -5,6 +5,7 @@ use crate::database::models::version_item::{ DependencyBuilder, VersionBuilder, VersionFileBuilder, }; use crate::database::models::{self, image_item, Organization}; +use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models::images::{Image, ImageContext, ImageId}; use crate::models::notifications::NotificationBody; @@ -89,7 +90,7 @@ pub async fn version_create( req: HttpRequest, mut payload: Multipart, client: Data, - redis: Data, + redis: Data, file_host: Data>, session_queue: Data, ) -> Result { @@ -129,7 +130,7 @@ async fn version_create_inner( req: HttpRequest, payload: &mut Multipart, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, - redis: &deadpool_redis::Pool, + redis: &RedisPool, file_host: &dyn FileHost, uploaded_files: &mut Vec, pool: &PgPool, @@ -507,7 +508,7 @@ pub async fn upload_file_to_version( url_data: web::Path<(VersionId,)>, mut payload: Multipart, client: Data, - redis: Data, + redis: Data, file_host: Data>, session_queue: web::Data, ) -> Result { @@ -551,7 +552,7 @@ async fn upload_file_to_version_inner( payload: &mut Multipart, client: Data, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, - redis: Data, + redis: Data, file_host: &dyn FileHost, uploaded_files: &mut Vec, version_id: models::VersionId, @@ -729,6 +730,9 @@ async fn upload_file_to_version_inner( } } + // Clear version cache + models::Version::clear_cache(&version, &redis).await?; + Ok(HttpResponse::NoContent().body("")) } diff --git a/src/routes/v2/version_file.rs b/src/routes/v2/version_file.rs index a4612e1e..171788b1 100644 --- a/src/routes/v2/version_file.rs +++ b/src/routes/v2/version_file.rs @@ -3,6 +3,7 @@ use crate::auth::{ filter_authorized_projects, filter_authorized_versions, get_user_from_headers, is_authorized_version, }; +use crate::database::redis::RedisPool; use crate::models::ids::VersionId; use crate::models::pats::Scopes; use crate::models::projects::VersionType; @@ -21,7 +22,8 @@ pub fn config(cfg: &mut web::ServiceConfig) { .service(delete_file) .service(get_version_from_hash) .service(download_version) - .service(get_update_from_hash), + .service(get_update_from_hash) + .service(get_projects_from_hashes), ); cfg.service( @@ -32,7 +34,7 @@ pub fn config(cfg: &mut web::ServiceConfig) { ); } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] pub struct HashQuery { #[serde(default = "default_algorithm")] pub algorithm: String, @@ -49,7 +51,7 @@ pub async fn get_version_from_hash( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, hash_query: web::Query, session_queue: web::Data, ) -> Result { @@ -63,7 +65,6 @@ pub async fn get_version_from_hash( .await .map(|x| x.1) .ok(); - let hash = info.into_inner().0.to_lowercase(); let file = database::models::Version::get_file_from_hash( hash_query.algorithm.clone(), @@ -73,10 +74,8 @@ pub async fn get_version_from_hash( &redis, ) .await?; - if let Some(file) = file { let version = database::models::Version::get(file.version_id, &**pool, &redis).await?; - if let Some(version) = version { if !is_authorized_version(&version.inner, &user_option, &pool).await? { return Ok(HttpResponse::NotFound().body("")); @@ -102,7 +101,7 @@ pub async fn download_version( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, hash_query: web::Query, session_queue: web::Data, ) -> Result { @@ -152,7 +151,7 @@ pub async fn delete_file( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, hash_query: web::Query, session_queue: web::Data, ) -> Result { @@ -274,7 +273,7 @@ pub async fn get_update_from_hash( req: HttpRequest, info: web::Path<(String,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, hash_query: web::Query, update_data: web::Json, session_queue: web::Data, @@ -343,6 +342,7 @@ pub async fn get_update_from_hash( // Requests above with multiple versions below #[derive(Deserialize)] pub struct FileHashes { + #[serde(default = "default_algorithm")] pub algorithm: String, pub hashes: Vec, } @@ -352,7 +352,7 @@ pub struct FileHashes { pub async fn get_versions_from_hashes( req: HttpRequest, pool: web::Data, - redis: web::Data, + redis: web::Data, file_data: web::Json, session_queue: web::Data, ) -> Result { @@ -400,7 +400,7 @@ pub async fn get_versions_from_hashes( pub async fn get_projects_from_hashes( req: HttpRequest, pool: web::Data, - redis: web::Data, + redis: web::Data, file_data: web::Json, session_queue: web::Data, ) -> Result { @@ -409,7 +409,7 @@ pub async fn get_projects_from_hashes( &**pool, &redis, &session_queue, - Some(&[Scopes::VERSION_READ]), + Some(&[Scopes::PROJECT_READ, Scopes::VERSION_READ]), ) .await .map(|x| x.1) @@ -447,6 +447,7 @@ pub async fn get_projects_from_hashes( #[derive(Deserialize)] pub struct ManyUpdateData { + #[serde(default = "default_algorithm")] pub algorithm: String, pub hashes: Vec, pub loaders: Option>, @@ -458,7 +459,7 @@ pub struct ManyUpdateData { pub async fn update_files( req: HttpRequest, pool: web::Data, - redis: web::Data, + redis: web::Data, update_data: web::Json, session_queue: web::Data, ) -> Result { @@ -550,6 +551,7 @@ pub struct FileUpdateData { #[derive(Deserialize)] pub struct ManyFileUpdateData { + #[serde(default = "default_algorithm")] pub algorithm: String, pub hashes: Vec, } @@ -558,7 +560,7 @@ pub struct ManyFileUpdateData { pub async fn update_individual_files( req: HttpRequest, pool: web::Data, - redis: web::Data, + redis: web::Data, update_data: web::Json, session_queue: web::Data, ) -> Result { diff --git a/src/routes/v2/versions.rs b/src/routes/v2/versions.rs index e7dce53b..cfaa9da4 100644 --- a/src/routes/v2/versions.rs +++ b/src/routes/v2/versions.rs @@ -4,6 +4,7 @@ use crate::auth::{ }; use crate::database; use crate::database::models::{image_item, Organization}; +use crate::database::redis::RedisPool; use crate::models; use crate::models::ids::base62_impl::parse_base62; use crate::models::images::ImageContext; @@ -49,7 +50,7 @@ pub async fn version_list( info: web::Path<(String,)>, web::Query(filters): web::Query, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let string = info.into_inner().0; @@ -170,7 +171,7 @@ pub async fn version_project_get( req: HttpRequest, info: web::Path<(String, String)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let id = info.into_inner(); @@ -221,7 +222,7 @@ pub async fn versions_get( req: HttpRequest, web::Query(ids): web::Query, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let version_ids = serde_json::from_str::>(&ids.ids)? @@ -251,7 +252,7 @@ pub async fn version_get( req: HttpRequest, info: web::Path<(models::ids::VersionId,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let id = info.into_inner().0; @@ -318,7 +319,7 @@ pub async fn version_edit( req: HttpRequest, info: web::Path<(models::ids::VersionId,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, new_version: web::Json, session_queue: web::Data, ) -> Result { @@ -738,7 +739,7 @@ pub async fn version_schedule( req: HttpRequest, info: web::Path<(models::ids::VersionId,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, scheduling_data: web::Json, session_queue: web::Data, ) -> Result { @@ -835,7 +836,7 @@ pub async fn version_delete( req: HttpRequest, info: web::Path<(models::ids::VersionId,)>, pool: web::Data, - redis: web::Data, + redis: web::Data, session_queue: web::Data, ) -> Result { let user = get_user_from_headers( diff --git a/src/util/img.rs b/src/util/img.rs index 99574e22..54fe3604 100644 --- a/src/util/img.rs +++ b/src/util/img.rs @@ -1,11 +1,11 @@ -use color_thief::ColorFormat; -use image::imageops::FilterType; -use image::{EncodableLayout, ImageError}; - use crate::database; use crate::database::models::image_item; +use crate::database::redis::RedisPool; use crate::models::images::ImageContext; use crate::routes::ApiError; +use color_thief::ColorFormat; +use image::imageops::FilterType; +use image::{EncodableLayout, ImageError}; pub fn get_color_from_img(data: &[u8]) -> Result, ImageError> { let image = image::load_from_memory(data)? @@ -26,7 +26,7 @@ pub async fn delete_unused_images( context: ImageContext, reference_strings: Vec<&str>, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, - redis: &deadpool_redis::Pool, + redis: &RedisPool, ) -> Result<(), ApiError> { let uploaded_images = database::models::Image::get_many_contexted(context, transaction).await?; diff --git a/src/util/webhook.rs b/src/util/webhook.rs index 040b2eb0..8b5b5a65 100644 --- a/src/util/webhook.rs +++ b/src/util/webhook.rs @@ -1,4 +1,5 @@ use crate::database::models::categories::GameVersion; +use crate::database::redis::RedisPool; use crate::models::projects::ProjectId; use crate::routes::ApiError; use chrono::{DateTime, Utc}; @@ -72,7 +73,7 @@ const PLUGIN_LOADERS: &[&str] = &[ pub async fn send_discord_webhook( project_id: ProjectId, pool: &PgPool, - redis: &deadpool_redis::Pool, + redis: &RedisPool, webhook_url: String, message: Option, ) -> Result<(), ApiError> { diff --git a/tests/common/actix.rs b/tests/common/actix.rs new file mode 100644 index 00000000..03935e50 --- /dev/null +++ b/tests/common/actix.rs @@ -0,0 +1,82 @@ +use actix_web::test::TestRequest; +use bytes::{Bytes, BytesMut}; + +// Multipart functionality (actix-test does not innately support multipart) +#[derive(Debug, Clone)] +pub struct MultipartSegment { + pub name: String, + pub filename: Option, + pub content_type: Option, + pub data: MultipartSegmentData, +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub enum MultipartSegmentData { + Text(String), + Binary(Vec), +} + +pub trait AppendsMultipart { + fn set_multipart(self, data: Vec) -> Self; +} + +impl AppendsMultipart for TestRequest { + fn set_multipart(self, data: Vec) -> Self { + let (boundary, payload) = generate_multipart(data); + self.append_header(( + "Content-Type", + format!("multipart/form-data; boundary={}", boundary), + )) + .set_payload(payload) + } +} + +fn generate_multipart(data: Vec) -> (String, Bytes) { + let mut boundary = String::from("----WebKitFormBoundary"); + boundary.push_str(&rand::random::().to_string()); + boundary.push_str(&rand::random::().to_string()); + boundary.push_str(&rand::random::().to_string()); + + let mut payload = BytesMut::new(); + + for segment in data { + payload.extend_from_slice( + format!( + "--{boundary}\r\nContent-Disposition: form-data; name=\"{name}\"", + boundary = boundary, + name = segment.name + ) + .as_bytes(), + ); + + if let Some(filename) = &segment.filename { + payload.extend_from_slice( + format!("; filename=\"{filename}\"", filename = filename).as_bytes(), + ); + } + if let Some(content_type) = &segment.content_type { + payload.extend_from_slice( + format!( + "\r\nContent-Type: {content_type}", + content_type = content_type + ) + .as_bytes(), + ); + } + payload.extend_from_slice(b"\r\n\r\n"); + + match &segment.data { + MultipartSegmentData::Text(text) => { + payload.extend_from_slice(text.as_bytes()); + } + MultipartSegmentData::Binary(binary) => { + payload.extend_from_slice(binary); + } + } + payload.extend_from_slice(b"\r\n"); + } + payload.extend_from_slice(format!("--{boundary}--\r\n", boundary = boundary).as_bytes()); + + (boundary, Bytes::from(payload)) +} diff --git a/tests/common/database.rs b/tests/common/database.rs new file mode 100644 index 00000000..63535125 --- /dev/null +++ b/tests/common/database.rs @@ -0,0 +1,134 @@ +#![allow(dead_code)] + +use labrinth::database::redis::RedisPool; +use sqlx::{postgres::PgPoolOptions, PgPool}; +use std::time::Duration; +use url::Url; + +// The dummy test database adds a fair bit of 'dummy' data to test with. +// Some constants are used to refer to that data, and are described here. +// The rest can be accessed in the TestEnvironment 'dummy' field. + +// The user IDs are as follows: +pub const ADMIN_USER_ID: &str = "1"; +pub const MOD_USER_ID: &str = "2"; +pub const USER_USER_ID: &str = "3"; // This is the 'main' user ID, and is used for most tests. +pub const FRIEND_USER_ID: &str = "4"; // This is exactly the same as USER_USER_ID, but could be used for testing friend-only endpoints (ie: teams, etc) +pub const ENEMY_USER_ID: &str = "5"; // This is exactly the same as USER_USER_ID, but could be used for testing friend-only endpoints (ie: teams, etc) + +pub const ADMIN_USER_ID_PARSED: i64 = 1; +pub const MOD_USER_ID_PARSED: i64 = 2; +pub const USER_USER_ID_PARSED: i64 = 3; +pub const FRIEND_USER_ID_PARSED: i64 = 4; +pub const ENEMY_USER_ID_PARSED: i64 = 5; + +// These are full-scoped PATs- as if the user was logged in (including illegal scopes). +pub const ADMIN_USER_PAT: &str = "mrp_patadmin"; +pub const MOD_USER_PAT: &str = "mrp_patmoderator"; +pub const USER_USER_PAT: &str = "mrp_patuser"; +pub const FRIEND_USER_PAT: &str = "mrp_patfriend"; +pub const ENEMY_USER_PAT: &str = "mrp_patenemy"; + +pub struct TemporaryDatabase { + pub pool: PgPool, + pub redis_pool: RedisPool, + pub database_name: String, +} + +impl TemporaryDatabase { + // Creates a temporary database like sqlx::test does + // 1. Logs into the main database + // 2. Creates a new randomly generated database + // 3. Runs migrations on the new database + // 4. (Optionally, by using create_with_dummy) adds dummy data to the database + // If a db is created with create_with_dummy, it must be cleaned up with cleanup. + // This means that dbs will only 'remain' if a test fails (for examination of the db), and will be cleaned up otherwise. + pub async fn create() -> Self { + let temp_database_name = generate_random_database_name(); + println!("Creating temporary database: {}", &temp_database_name); + + let database_url = dotenvy::var("DATABASE_URL").expect("No database URL"); + let mut url = Url::parse(&database_url).expect("Invalid database URL"); + let pool = PgPool::connect(&database_url) + .await + .expect("Connection to database failed"); + + // Create the temporary database + let create_db_query = format!("CREATE DATABASE {}", &temp_database_name); + + sqlx::query(&create_db_query) + .execute(&pool) + .await + .expect("Database creation failed"); + + pool.close().await; + + // Modify the URL to switch to the temporary database + url.set_path(&format!("/{}", &temp_database_name)); + let temp_db_url = url.to_string(); + + let pool = PgPoolOptions::new() + .min_connections(0) + .max_connections(4) + .max_lifetime(Some(Duration::from_secs(60 * 60))) + .connect(&temp_db_url) + .await + .expect("Connection to temporary database failed"); + + // Performs migrations + let migrations = sqlx::migrate!("./migrations"); + migrations.run(&pool).await.expect("Migrations failed"); + + // Gets new Redis pool + let redis_pool = RedisPool::new(Some(temp_database_name.clone())); + + Self { + pool, + database_name: temp_database_name, + redis_pool, + } + } + + // Deletes the temporary database + // If a temporary db is created, it must be cleaned up with cleanup. + // This means that dbs will only 'remain' if a test fails (for examination of the db), and will be cleaned up otherwise. + pub async fn cleanup(mut self) { + let database_url = dotenvy::var("DATABASE_URL").expect("No database URL"); + self.pool.close().await; + + self.pool = PgPool::connect(&database_url) + .await + .expect("Connection to main database failed"); + + // Forcibly terminate all existing connections to this version of the temporary database + // We are done and deleting it, so we don't need them anymore + let terminate_query = format!( + "SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE datname = '{}' AND pid <> pg_backend_pid()", + &self.database_name + ); + sqlx::query(&terminate_query) + .execute(&self.pool) + .await + .unwrap(); + + // Execute the deletion query asynchronously + let drop_db_query = format!("DROP DATABASE IF EXISTS {}", &self.database_name); + sqlx::query(&drop_db_query) + .execute(&self.pool) + .await + .expect("Database deletion failed"); + } +} + +fn generate_random_database_name() -> String { + // Generate a random database name here + // You can use your logic to create a unique name + // For example, you can use a random string as you did before + // or append a timestamp, etc. + + // We will use a random string starting with "labrinth_tests_db_" + // and append a 6-digit number to it. + let mut database_name = String::from("labrinth_tests_db_"); + database_name.push_str(&rand::random::().to_string()[..6]); + database_name +} diff --git a/tests/common/dummy_data.rs b/tests/common/dummy_data.rs new file mode 100644 index 00000000..d3cd9667 --- /dev/null +++ b/tests/common/dummy_data.rs @@ -0,0 +1,229 @@ +use actix_web::test::{self, TestRequest}; +use labrinth::{models::projects::Project, models::projects::Version}; +use serde_json::json; +use sqlx::Executor; + +use crate::common::{ + actix::AppendsMultipart, + database::{MOD_USER_PAT, USER_USER_PAT}, +}; + +use super::{ + actix::{MultipartSegment, MultipartSegmentData}, + environment::TestEnvironment, +}; + +pub struct DummyData { + pub alpha_team_id: String, + pub beta_team_id: String, + + pub alpha_project_id: String, + pub beta_project_id: String, + + pub alpha_project_slug: String, + pub beta_project_slug: String, + + pub alpha_version_id: String, + pub beta_version_id: String, + + pub alpha_thread_id: String, + pub beta_thread_id: String, + + pub alpha_file_hash: String, + pub beta_file_hash: String, +} + +pub async fn add_dummy_data(test_env: &TestEnvironment) -> DummyData { + // Adds basic dummy data to the database directly with sql (user, pats) + let pool = &test_env.db.pool.clone(); + pool.execute(include_str!("../files/dummy_data.sql")) + .await + .unwrap(); + + let (alpha_project, alpha_version) = add_project_alpha(test_env).await; + let (beta_project, beta_version) = add_project_beta(test_env).await; + + DummyData { + alpha_team_id: alpha_project.team.to_string(), + beta_team_id: beta_project.team.to_string(), + + alpha_project_id: alpha_project.id.to_string(), + beta_project_id: beta_project.id.to_string(), + + alpha_project_slug: alpha_project.slug.unwrap(), + beta_project_slug: beta_project.slug.unwrap(), + + alpha_version_id: alpha_version.id.to_string(), + beta_version_id: beta_version.id.to_string(), + + alpha_thread_id: alpha_project.thread_id.to_string(), + beta_thread_id: beta_project.thread_id.to_string(), + + alpha_file_hash: alpha_version.files[0].hashes["sha1"].clone(), + beta_file_hash: beta_version.files[0].hashes["sha1"].clone(), + } +} + +pub async fn add_project_alpha(test_env: &TestEnvironment) -> (Project, Version) { + // Adds dummy data to the database with sqlx (projects, versions, threads) + // Generate test project data. + let json_data = json!( + { + "title": "Test Project Alpha", + "slug": "alpha", + "description": "A dummy project for testing with.", + "body": "This project is approved, and versions are listed.", + "client_side": "required", + "server_side": "optional", + "initial_versions": [{ + "file_parts": ["dummy-project-alpha.jar"], + "version_number": "1.2.3", + "version_title": "start", + "dependencies": [], + "game_versions": ["1.20.1"] , + "release_channel": "release", + "loaders": ["fabric"], + "featured": true + }], + "categories": [], + "license_id": "MIT" + } + ); + + // Basic json + let json_segment = MultipartSegment { + name: "data".to_string(), + filename: None, + content_type: Some("application/json".to_string()), + data: MultipartSegmentData::Text(serde_json::to_string(&json_data).unwrap()), + }; + + // Basic file + let file_segment = MultipartSegment { + name: "dummy-project-alpha.jar".to_string(), + filename: Some("dummy-project-alpha.jar".to_string()), + content_type: Some("application/java-archive".to_string()), + data: MultipartSegmentData::Binary( + include_bytes!("../../tests/files/dummy-project-alpha.jar").to_vec(), + ), + }; + + // Add a project. + let req = TestRequest::post() + .uri("/v2/project") + .append_header(("Authorization", USER_USER_PAT)) + .set_multipart(vec![json_segment.clone(), file_segment.clone()]) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 200); + + // Approve as a moderator. + let req = TestRequest::patch() + .uri("/v2/project/alpha") + .append_header(("Authorization", MOD_USER_PAT)) + .set_json(json!( + { + "status": "approved" + } + )) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 204); + + // Get project + let req = TestRequest::get() + .uri("/v2/project/alpha") + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + let resp = test_env.call(req).await; + let project: Project = test::read_body_json(resp).await; + + // Get project's versions + let req = TestRequest::get() + .uri("/v2/project/alpha/version") + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + let resp = test_env.call(req).await; + let versions: Vec = test::read_body_json(resp).await; + let version = versions.into_iter().next().unwrap(); + + (project, version) +} + +pub async fn add_project_beta(test_env: &TestEnvironment) -> (Project, Version) { + // Adds dummy data to the database with sqlx (projects, versions, threads) + // Generate test project data. + let json_data = json!( + { + "title": "Test Project Beta", + "slug": "beta", + "description": "A dummy project for testing with.", + "body": "This project is not-yet-approved, and versions are draft.", + "client_side": "required", + "server_side": "optional", + "initial_versions": [{ + "file_parts": ["dummy-project-beta.jar"], + "version_number": "1.2.3", + "version_title": "start", + "status": "unlisted", + "requested_status": "unlisted", + "dependencies": [], + "game_versions": ["1.20.1"] , + "release_channel": "release", + "loaders": ["fabric"], + "featured": true + }], + "status": "private", + "requested_status": "private", + "categories": [], + "license_id": "MIT" + } + ); + + // Basic json + let json_segment = MultipartSegment { + name: "data".to_string(), + filename: None, + content_type: Some("application/json".to_string()), + data: MultipartSegmentData::Text(serde_json::to_string(&json_data).unwrap()), + }; + + // Basic file + let file_segment = MultipartSegment { + name: "dummy-project-beta.jar".to_string(), + filename: Some("dummy-project-beta.jar".to_string()), + content_type: Some("application/java-archive".to_string()), + data: MultipartSegmentData::Binary( + include_bytes!("../../tests/files/dummy-project-beta.jar").to_vec(), + ), + }; + + // Add a project. + let req = TestRequest::post() + .uri("/v2/project") + .append_header(("Authorization", USER_USER_PAT)) + .set_multipart(vec![json_segment.clone(), file_segment.clone()]) + .to_request(); + let resp = test_env.call(req).await; + + assert_eq!(resp.status(), 200); + + // Get project + let req = TestRequest::get() + .uri("/v2/project/beta") + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + let resp = test_env.call(req).await; + let project: Project = test::read_body_json(resp).await; + + // Get project's versions + let req = TestRequest::get() + .uri("/v2/project/beta/version") + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + let resp = test_env.call(req).await; + let versions: Vec = test::read_body_json(resp).await; + let version = versions.into_iter().next().unwrap(); + + (project, version) +} diff --git a/tests/common/environment.rs b/tests/common/environment.rs new file mode 100644 index 00000000..bcf5c686 --- /dev/null +++ b/tests/common/environment.rs @@ -0,0 +1,71 @@ +#![allow(dead_code)] + +use super::{database::TemporaryDatabase, dummy_data}; +use crate::common::setup; +use actix_web::{dev::ServiceResponse, test, App}; + +// A complete test environment, with a test actix app and a database. +// Must be called in an #[actix_rt::test] context. It also simulates a +// temporary sqlx db like #[sqlx::test] would. +// Use .call(req) on it directly to make a test call as if test::call_service(req) were being used. +pub struct TestEnvironment { + test_app: Box, + pub db: TemporaryDatabase, + + pub dummy: Option, +} + +impl TestEnvironment { + pub async fn build_with_dummy() -> Self { + let mut test_env = Self::build().await; + let dummy = dummy_data::add_dummy_data(&test_env).await; + test_env.dummy = Some(dummy); + test_env + } + + pub async fn build() -> Self { + let db = TemporaryDatabase::create().await; + let labrinth_config = setup(&db).await; + let app = App::new().configure(|cfg| labrinth::app_config(cfg, labrinth_config.clone())); + let test_app = test::init_service(app).await; + Self { + test_app: Box::new(test_app), + db, + dummy: None, + } + } + pub async fn cleanup(self) { + self.db.cleanup().await; + } + + pub async fn call(&self, req: actix_http::Request) -> ServiceResponse { + self.test_app.call(req).await.unwrap() + } +} + +trait LocalService { + fn call( + &self, + req: actix_http::Request, + ) -> std::pin::Pin< + Box>>, + >; +} +impl LocalService for S +where + S: actix_web::dev::Service< + actix_http::Request, + Response = ServiceResponse, + Error = actix_web::Error, + >, + S::Future: 'static, +{ + fn call( + &self, + req: actix_http::Request, + ) -> std::pin::Pin< + Box>>, + > { + Box::pin(self.call(req)) + } +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs new file mode 100644 index 00000000..cde6fc8d --- /dev/null +++ b/tests/common/mod.rs @@ -0,0 +1,40 @@ +use labrinth::{check_env_vars, clickhouse}; +use labrinth::{file_hosting, queue, LabrinthConfig}; +use std::sync::Arc; + +use self::database::TemporaryDatabase; + +pub mod actix; +pub mod database; +pub mod dummy_data; +pub mod environment; +pub mod pats; +pub mod scopes; + +// Testing equivalent to 'setup' function, producing a LabrinthConfig +// If making a test, you should probably use environment::TestEnvironment::build_with_dummy() (which calls this) +pub async fn setup(db: &TemporaryDatabase) -> LabrinthConfig { + println!("Setting up labrinth config"); + + dotenvy::dotenv().ok(); + + if check_env_vars() { + println!("Some environment variables are missing!"); + } + + let pool = db.pool.clone(); + let redis_pool = db.redis_pool.clone(); + let file_host: Arc = + Arc::new(file_hosting::MockHost::new()); + let mut clickhouse = clickhouse::init_client().await.unwrap(); + + let maxmind_reader = Arc::new(queue::maxmind::MaxMindIndexer::new().await.unwrap()); + + labrinth::app_setup( + pool.clone(), + redis_pool.clone(), + &mut clickhouse, + file_host.clone(), + maxmind_reader.clone(), + ) +} diff --git a/tests/common/pats.rs b/tests/common/pats.rs new file mode 100644 index 00000000..d63517cf --- /dev/null +++ b/tests/common/pats.rs @@ -0,0 +1,30 @@ +#![allow(dead_code)] + +use chrono::Utc; +use labrinth::{ + database::{self, models::generate_pat_id}, + models::pats::Scopes, +}; + +use super::database::TemporaryDatabase; + +// Creates a PAT with the given scopes, and returns the access token +// Interfacing with the db directly, rather than using a ourte, +// allows us to test with scopes that are not allowed to be created by PATs +pub async fn create_test_pat(scopes: Scopes, user_id: i64, db: &TemporaryDatabase) -> String { + let mut transaction = db.pool.begin().await.unwrap(); + let id = generate_pat_id(&mut transaction).await.unwrap(); + let pat = database::models::pat_item::PersonalAccessToken { + id, + name: format!("test_pat_{}", scopes.bits()), + access_token: format!("mrp_{}", id.0), + scopes, + user_id: database::models::ids::UserId(user_id), + created: Utc::now(), + expires: Utc::now() + chrono::Duration::days(1), + last_used: None, + }; + pat.insert(&mut transaction).await.unwrap(); + transaction.commit().await.unwrap(); + pat.access_token +} diff --git a/tests/common/scopes.rs b/tests/common/scopes.rs new file mode 100644 index 00000000..44a4b7df --- /dev/null +++ b/tests/common/scopes.rs @@ -0,0 +1,124 @@ +#![allow(dead_code)] +use actix_web::test::{self, TestRequest}; +use labrinth::models::pats::Scopes; + +use super::{database::USER_USER_ID_PARSED, environment::TestEnvironment, pats::create_test_pat}; + +// A reusable test type that works for any scope test testing an endpoint that: +// - returns a known 'expected_failure_code' if the scope is not present (defaults to 401) +// - returns a 200-299 if the scope is present +// - returns failure and success JSON bodies for requests that are 200 (for performing non-simple follow-up tests on) +// This uses a builder format, so you can chain methods to set the parameters to non-defaults (most will probably be not need to be set). +pub struct ScopeTest<'a> { + test_env: &'a TestEnvironment, + // Scopes expected to fail on this test. By default, this is all scopes except the success scopes. + // (To ensure we have isolated the scope we are testing) + failure_scopes: Option, + // User ID to use for the PATs. By default, this is the USER_USER_ID_PARSED constant. + user_id: i64, + // The code that is expected to be returned if the scope is not present. By default, this is 401 (Unauthorized) + expected_failure_code: u16, +} + +impl<'a> ScopeTest<'a> { + pub fn new(test_env: &'a TestEnvironment) -> Self { + Self { + test_env, + failure_scopes: None, + user_id: USER_USER_ID_PARSED, + expected_failure_code: 401, + } + } + + // Set non-standard failure scopes + // If not set, it will be set to all scopes except the success scopes + // (eg: if a combination of scopes is needed, but you want to make sure that the endpoint does not work with all-but-one of them) + pub fn with_failure_scopes(mut self, scopes: Scopes) -> Self { + self.failure_scopes = Some(scopes); + self + } + + // Set the user ID to use + // (eg: a moderator, or friend) + pub fn with_user_id(mut self, user_id: i64) -> Self { + self.user_id = user_id; + self + } + + // If a non-401 code is expected. + // (eg: a 404 for a hidden resource, or 200 for a resource with hidden values deeper in) + pub fn with_failure_code(mut self, code: u16) -> Self { + self.expected_failure_code = code; + self + } + + // Call the endpoint generated by req_gen twice, once with a PAT with the failure scopes, and once with the success scopes. + // success_scopes : the scopes that we are testing that should succeed + // returns a tuple of (failure_body, success_body) + // Should return a String error if on unexpected status code, allowing unwrapping in tests. + pub async fn test( + &self, + req_gen: T, + success_scopes: Scopes, + ) -> Result<(serde_json::Value, serde_json::Value), String> + where + T: Fn() -> TestRequest, + { + // First, create a PAT with failure scopes + let failure_scopes = self + .failure_scopes + .unwrap_or(Scopes::all() ^ success_scopes); + let access_token_all_others = + create_test_pat(failure_scopes, self.user_id, &self.test_env.db).await; + + // Create a PAT with the success scopes + let access_token = create_test_pat(success_scopes, self.user_id, &self.test_env.db).await; + + // Perform test twice, once with each PAT + // the first time, we expect a 401 (or known failure code) + let req = req_gen() + .append_header(("Authorization", access_token_all_others.as_str())) + .to_request(); + let resp = self.test_env.call(req).await; + + if resp.status().as_u16() != self.expected_failure_code { + return Err(format!( + "Expected failure code {}, got {}", + self.expected_failure_code, + resp.status().as_u16() + )); + } + + let failure_body = if resp.status() == 200 + && resp.headers().contains_key("Content-Type") + && resp.headers().get("Content-Type").unwrap() == "application/json" + { + test::read_body_json(resp).await + } else { + serde_json::Value::Null + }; + + // The second time, we expect a success code + let req = req_gen() + .append_header(("Authorization", access_token.as_str())) + .to_request(); + let resp = self.test_env.call(req).await; + + if !(resp.status().is_success() || resp.status().is_redirection()) { + return Err(format!( + "Expected success code, got {}", + resp.status().as_u16() + )); + } + + let success_body = if resp.status() == 200 + && resp.headers().contains_key("Content-Type") + && resp.headers().get("Content-Type").unwrap() == "application/json" + { + test::read_body_json(resp).await + } else { + serde_json::Value::Null + }; + Ok((failure_body, success_body)) + } +} diff --git a/tests/files/200x200.png b/tests/files/200x200.png new file mode 100644 index 00000000..bb923179 Binary files /dev/null and b/tests/files/200x200.png differ diff --git a/tests/files/basic-mod-different.jar b/tests/files/basic-mod-different.jar new file mode 100644 index 00000000..616131ae Binary files /dev/null and b/tests/files/basic-mod-different.jar differ diff --git a/tests/files/basic-mod.jar b/tests/files/basic-mod.jar new file mode 100644 index 00000000..0987832e Binary files /dev/null and b/tests/files/basic-mod.jar differ diff --git a/tests/files/dummy-project-alpha.jar b/tests/files/dummy-project-alpha.jar new file mode 100644 index 00000000..61f82078 Binary files /dev/null and b/tests/files/dummy-project-alpha.jar differ diff --git a/tests/files/dummy-project-beta.jar b/tests/files/dummy-project-beta.jar new file mode 100644 index 00000000..1b072b20 Binary files /dev/null and b/tests/files/dummy-project-beta.jar differ diff --git a/tests/files/dummy_data.sql b/tests/files/dummy_data.sql new file mode 100644 index 00000000..59391f48 --- /dev/null +++ b/tests/files/dummy_data.sql @@ -0,0 +1,36 @@ +-- Dummy test data for use in tests. +-- IDs are listed as integers, followed by their equivalent base 62 representation. + +-- Inserts 5 dummy users for testing, with slight differences +-- 'Friend' and 'enemy' function like 'user', but we can use them to simulate 'other' users that may or may not be able to access certain things +-- IDs 1-5, 1-5 +INSERT INTO users (id, username, name, email, role) VALUES (1, 'admin', 'Administrator Test', 'admin@modrinth.com', 'admin'); +INSERT INTO users (id, username, name, email, role) VALUES (2, 'moderator', 'Moderator Test', 'moderator@modrinth.com', 'moderator'); +INSERT INTO users (id, username, name, email, role) VALUES (3, 'user', 'User Test', 'user@modrinth.com', 'developer'); +INSERT INTO users (id, username, name, email, role) VALUES (4, 'friend', 'Friend Test', 'friend@modrinth.com', 'developer'); +INSERT INTO users (id, username, name, email, role) VALUES (5, 'enemy', 'Enemy Test', 'enemy@modrinth.com', 'developer'); + +-- Full PATs for each user, with different scopes +-- These are not legal PATs, as they contain all scopes- they mimic permissions of a logged in user +-- IDs: 50-54, o p q r s +INSERT INTO pats (id, user_id, name, access_token, scopes, expires) VALUES (50, 1, 'admin-pat', 'mrp_patadmin', B'11111111111111111111111111111111111'::BIGINT, '2030-08-18 15:48:58.435729+00'); +INSERT INTO pats (id, user_id, name, access_token, scopes, expires) VALUES (51, 2, 'moderator-pat', 'mrp_patmoderator', B'11111111111111111111111111111111111'::BIGINT, '2030-08-18 15:48:58.435729+00'); +INSERT INTO pats (id, user_id, name, access_token, scopes, expires) VALUES (52, 3, 'user-pat', 'mrp_patuser', B'11111111111111111111111111111111111'::BIGINT, '2030-08-18 15:48:58.435729+00'); +INSERT INTO pats (id, user_id, name, access_token, scopes, expires) VALUES (53, 4, 'friend-pat', 'mrp_patfriend', B'11111111111111111111111111111111111'::BIGINT, '2030-08-18 15:48:58.435729+00'); +INSERT INTO pats (id, user_id, name, access_token, scopes, expires) VALUES (54, 5, 'enemy-pat', 'mrp_patenemy', B'11111111111111111111111111111111111'::BIGINT, '2030-08-18 15:48:58.435729+00'); + +-- -- Sample game versions, loaders, categories +INSERT INTO game_versions (id, version, type, created) +VALUES (20000, '1.20.1', 'release', timezone('utc', now())); + +INSERT INTO loaders (id, loader) VALUES (1, 'fabric'); +INSERT INTO loaders_project_types (joining_loader_id, joining_project_type_id) VALUES (1,1); +INSERT INTO loaders_project_types (joining_loader_id, joining_project_type_id) VALUES (1,2); + +INSERT INTO categories (id, category, project_type) VALUES (1, 'combat', 1); +INSERT INTO categories (id, category, project_type) VALUES (2, 'decoration', 1); +INSERT INTO categories (id, category, project_type) VALUES (3, 'economy', 1); + +INSERT INTO categories (id, category, project_type) VALUES (4, 'combat', 2); +INSERT INTO categories (id, category, project_type) VALUES (5, 'decoration', 2); +INSERT INTO categories (id, category, project_type) VALUES (6, 'economy', 2); \ No newline at end of file diff --git a/tests/files/simple-zip.zip b/tests/files/simple-zip.zip new file mode 100644 index 00000000..20bf64b8 Binary files /dev/null and b/tests/files/simple-zip.zip differ diff --git a/tests/pats.rs b/tests/pats.rs new file mode 100644 index 00000000..98da30ec --- /dev/null +++ b/tests/pats.rs @@ -0,0 +1,292 @@ +use actix_web::test; +use chrono::{Duration, Utc}; +use common::database::*; +use labrinth::models::pats::Scopes; +use serde_json::json; + +use crate::common::environment::TestEnvironment; + +// importing common module. +mod common; + +// Full pat test: +// - create a PAT and ensure it can be used for the scope +// - ensure access token is not returned for any PAT in GET +// - ensure PAT can be patched to change scopes +// - ensure PAT can be patched to change expiry +// - ensure expired PATs cannot be used +// - ensure PATs can be deleted +#[actix_rt::test] +pub async fn pat_full_test() { + let test_env = TestEnvironment::build_with_dummy().await; + + // Create a PAT for a full test + let req = test::TestRequest::post() + .uri("/v2/pat") + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "scopes": Scopes::COLLECTION_CREATE, // Collection create as an easily tested example + "name": "test_pat_scopes Test", + "expires": Utc::now() + Duration::days(1), + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status().as_u16(), 200); + let success: serde_json::Value = test::read_body_json(resp).await; + let id = success["id"].as_str().unwrap(); + + // Has access token and correct scopes + assert!(success["access_token"].as_str().is_some()); + assert_eq!( + success["scopes"].as_u64().unwrap(), + Scopes::COLLECTION_CREATE.bits() + ); + let access_token = success["access_token"].as_str().unwrap(); + + // Get PAT again + let req = test::TestRequest::get() + .append_header(("Authorization", USER_USER_PAT)) + .uri("/v2/pat") + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status().as_u16(), 200); + let success: serde_json::Value = test::read_body_json(resp).await; + + // Ensure access token is NOT returned for any PATs + for pat in success.as_array().unwrap() { + assert!(pat["access_token"].as_str().is_none()); + } + + // Create mock test for using PAT + let mock_pat_test = |token: &str| { + let token = token.to_string(); + async { + let req = test::TestRequest::post() + .uri("/v2/collection") + .append_header(("Authorization", token)) + .set_json(json!({ + "title": "Test Collection 1", + "description": "Test Collection Description" + })) + .to_request(); + let resp = test_env.call(req).await; + resp.status().as_u16() + } + }; + + assert_eq!(mock_pat_test(access_token).await, 200); + + // Change scopes and test again + let req = test::TestRequest::patch() + .uri(&format!("/v2/pat/{}", id)) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "scopes": 0, + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status().as_u16(), 204); + assert_eq!(mock_pat_test(access_token).await, 401); // No longer works + + // Change scopes back, and set expiry to the past, and test again + let req = test::TestRequest::patch() + .uri(&format!("/v2/pat/{}", id)) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "scopes": Scopes::COLLECTION_CREATE, + "expires": Utc::now() + Duration::seconds(1), // expires in 1 second + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status().as_u16(), 204); + + // Wait 1 second before testing again for expiry + tokio::time::sleep(Duration::seconds(1).to_std().unwrap()).await; + assert_eq!(mock_pat_test(access_token).await, 401); // No longer works + + // Change everything back to normal and test again + let req = test::TestRequest::patch() + .uri(&format!("/v2/pat/{}", id)) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "expires": Utc::now() + Duration::days(1), // no longer expired! + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status().as_u16(), 204); + assert_eq!(mock_pat_test(access_token).await, 200); // Works again + + // Patching to a bad expiry should fail + let req = test::TestRequest::patch() + .uri(&format!("/v2/pat/{}", id)) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "expires": Utc::now() - Duration::days(1), // Past + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status().as_u16(), 400); + + // Similar to above with PAT creation, patching to a bad scope should fail + for i in 0..64 { + let scope = Scopes::from_bits_truncate(1 << i); + if !Scopes::all().contains(scope) { + continue; + } + + let req = test::TestRequest::patch() + .uri(&format!("/v2/pat/{}", id)) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "scopes": scope.bits(), + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!( + resp.status().as_u16(), + if scope.is_restricted() { 400 } else { 204 } + ); + } + + // Delete PAT + let req = test::TestRequest::delete() + .append_header(("Authorization", USER_USER_PAT)) + .uri(&format!("/v2/pat/{}", id)) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status().as_u16(), 204); + + // Cleanup test db + test_env.cleanup().await; +} + +// Test illegal PAT setting, both in POST and PATCH +#[actix_rt::test] +pub async fn bad_pats() { + let test_env = TestEnvironment::build_with_dummy().await; + + // Creating a PAT with no name should fail + let req = test::TestRequest::post() + .uri("/v2/pat") + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "scopes": Scopes::COLLECTION_CREATE, // Collection create as an easily tested example + "expires": Utc::now() + Duration::days(1), + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status().as_u16(), 400); + + // Name too short or too long should fail + for name in ["n", "this_name_is_too_long".repeat(16).as_str()] { + let req = test::TestRequest::post() + .uri("/v2/pat") + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "name": name, + "scopes": Scopes::COLLECTION_CREATE, // Collection create as an easily tested example + "expires": Utc::now() + Duration::days(1), + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status().as_u16(), 400); + } + + // Creating a PAT with an expiry in the past should fail + let req = test::TestRequest::post() + .uri("/v2/pat") + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "scopes": Scopes::COLLECTION_CREATE, // Collection create as an easily tested example + "name": "test_pat_scopes Test", + "expires": Utc::now() - Duration::days(1), + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status().as_u16(), 400); + + // Make a PAT with each scope, with the result varying by whether that scope is restricted + for i in 0..64 { + let scope = Scopes::from_bits_truncate(1 << i); + if !Scopes::all().contains(scope) { + continue; + } + let req = test::TestRequest::post() + .uri("/v2/pat") + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "scopes": scope.bits(), + "name": format!("test_pat_scopes Name {}", i), + "expires": Utc::now() + Duration::days(1), + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!( + resp.status().as_u16(), + if scope.is_restricted() { 400 } else { 200 } + ); + } + + // Create a 'good' PAT for patching + let req = test::TestRequest::post() + .uri("/v2/pat") + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "scopes": Scopes::COLLECTION_CREATE, + "name": "test_pat_scopes Test", + "expires": Utc::now() + Duration::days(1), + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status().as_u16(), 200); + let success: serde_json::Value = test::read_body_json(resp).await; + let id = success["id"].as_str().unwrap(); + + // Patching to a bad name should fail + for name in ["n", "this_name_is_too_long".repeat(16).as_str()] { + let req = test::TestRequest::post() + .uri("/v2/pat") + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "name": name, + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status().as_u16(), 400); + } + + // Patching to a bad expiry should fail + let req = test::TestRequest::patch() + .uri(&format!("/v2/pat/{}", id)) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "expires": Utc::now() - Duration::days(1), // Past + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status().as_u16(), 400); + + // Similar to above with PAT creation, patching to a bad scope should fail + for i in 0..64 { + let scope = Scopes::from_bits_truncate(1 << i); + if !Scopes::all().contains(scope) { + continue; + } + + let req = test::TestRequest::patch() + .uri(&format!("/v2/pat/{}", id)) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "scopes": scope.bits(), + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!( + resp.status().as_u16(), + if scope.is_restricted() { 400 } else { 204 } + ); + } + + // Cleanup test db + test_env.cleanup().await; +} diff --git a/tests/project.rs b/tests/project.rs new file mode 100644 index 00000000..215bcb66 --- /dev/null +++ b/tests/project.rs @@ -0,0 +1,461 @@ +use actix_web::test; +use labrinth::database::models::project_item::{PROJECTS_NAMESPACE, PROJECTS_SLUGS_NAMESPACE}; +use labrinth::models::ids::base62_impl::parse_base62; +use serde_json::json; + +use crate::common::database::*; + +use crate::common::{actix::AppendsMultipart, environment::TestEnvironment}; + +// importing common module. +mod common; + +#[actix_rt::test] +async fn test_get_project() { + // Test setup and dummy data + let test_env = TestEnvironment::build_with_dummy().await; + let alpha_project_id = &test_env.dummy.as_ref().unwrap().alpha_project_id; + let beta_project_id = &test_env.dummy.as_ref().unwrap().beta_project_id; + let alpha_project_slug = &test_env.dummy.as_ref().unwrap().alpha_project_slug; + let alpha_version_id = &test_env.dummy.as_ref().unwrap().alpha_version_id; + + // Perform request on dummy data + let req = test::TestRequest::get() + .uri(&format!("/v2/project/{alpha_project_id}")) + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + let resp = test_env.call(req).await; + let status = resp.status(); + let body: serde_json::Value = test::read_body_json(resp).await; + + assert_eq!(status, 200); + assert_eq!(body["id"], json!(alpha_project_id)); + assert_eq!(body["slug"], json!(alpha_project_slug)); + let versions = body["versions"].as_array().unwrap(); + assert!(!versions.is_empty()); + assert_eq!(versions[0], json!(alpha_version_id)); + + // Confirm that the request was cached + assert_eq!( + test_env + .db + .redis_pool + .get::(PROJECTS_SLUGS_NAMESPACE, alpha_project_slug) + .await + .unwrap(), + Some(parse_base62(alpha_project_id).unwrap() as i64) + ); + + let cached_project = test_env + .db + .redis_pool + .get::(PROJECTS_NAMESPACE, parse_base62(alpha_project_id).unwrap()) + .await + .unwrap() + .unwrap(); + let cached_project: serde_json::Value = serde_json::from_str(&cached_project).unwrap(); + assert_eq!(cached_project["inner"]["slug"], json!(alpha_project_slug)); + + // Make the request again, this time it should be cached + let req = test::TestRequest::get() + .uri(&format!("/v2/project/{alpha_project_id}")) + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + let resp = test_env.call(req).await; + let status = resp.status(); + assert_eq!(status, 200); + + let body: serde_json::Value = test::read_body_json(resp).await; + assert_eq!(body["id"], json!(alpha_project_id)); + assert_eq!(body["slug"], json!(alpha_project_slug)); + + // Request should fail on non-existent project + let req = test::TestRequest::get() + .uri("/v2/project/nonexistent") + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 404); + + // Similarly, request should fail on non-authorized user, on a yet-to-be-approved or hidden project, with a 404 (hiding the existence of the project) + let req = test::TestRequest::get() + .uri(&format!("/v2/project/{beta_project_id}")) + .append_header(("Authorization", ENEMY_USER_PAT)) + .to_request(); + + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 404); + + // Cleanup test db + test_env.cleanup().await; +} + +#[actix_rt::test] +async fn test_add_remove_project() { + // Test setup and dummy data + let test_env = TestEnvironment::build_with_dummy().await; + + // Generate test project data. + let mut json_data = json!( + { + "title": "Test_Add_Project project", + "slug": "demo", + "description": "Example description.", + "body": "Example body.", + "client_side": "required", + "server_side": "optional", + "initial_versions": [{ + "file_parts": ["basic-mod.jar"], + "version_number": "1.2.3", + "version_title": "start", + "dependencies": [], + "game_versions": ["1.20.1"] , + "release_channel": "release", + "loaders": ["fabric"], + "featured": true + }], + "categories": [], + "license_id": "MIT" + } + ); + + // Basic json + let json_segment = common::actix::MultipartSegment { + name: "data".to_string(), + filename: None, + content_type: Some("application/json".to_string()), + data: common::actix::MultipartSegmentData::Text(serde_json::to_string(&json_data).unwrap()), + }; + + // Basic json, with a different file + json_data["initial_versions"][0]["file_parts"][0] = json!("basic-mod-different.jar"); + let json_diff_file_segment = common::actix::MultipartSegment { + data: common::actix::MultipartSegmentData::Text(serde_json::to_string(&json_data).unwrap()), + ..json_segment.clone() + }; + + // Basic json, with a different file, and a different slug + json_data["slug"] = json!("new_demo"); + json_data["initial_versions"][0]["file_parts"][0] = json!("basic-mod-different.jar"); + let json_diff_slug_file_segment = common::actix::MultipartSegment { + data: common::actix::MultipartSegmentData::Text(serde_json::to_string(&json_data).unwrap()), + ..json_segment.clone() + }; + + // Basic file + let file_segment = common::actix::MultipartSegment { + name: "basic-mod.jar".to_string(), + filename: Some("basic-mod.jar".to_string()), + content_type: Some("application/java-archive".to_string()), + data: common::actix::MultipartSegmentData::Binary( + include_bytes!("../tests/files/basic-mod.jar").to_vec(), + ), + }; + + // Differently named file, with the same content (for hash testing) + let file_diff_name_segment = common::actix::MultipartSegment { + name: "basic-mod-different.jar".to_string(), + filename: Some("basic-mod-different.jar".to_string()), + content_type: Some("application/java-archive".to_string()), + data: common::actix::MultipartSegmentData::Binary( + include_bytes!("../tests/files/basic-mod.jar").to_vec(), + ), + }; + + // Differently named file, with different content + let file_diff_name_content_segment = common::actix::MultipartSegment { + name: "basic-mod-different.jar".to_string(), + filename: Some("basic-mod-different.jar".to_string()), + content_type: Some("application/java-archive".to_string()), + data: common::actix::MultipartSegmentData::Binary( + include_bytes!("../tests/files/basic-mod-different.jar").to_vec(), + ), + }; + + // Add a project- simple, should work. + let req = test::TestRequest::post() + .uri("/v2/project") + .append_header(("Authorization", USER_USER_PAT)) + .set_multipart(vec![json_segment.clone(), file_segment.clone()]) + .to_request(); + let resp = test_env.call(req).await; + + let status = resp.status(); + assert_eq!(status, 200); + + // Get the project we just made, and confirm that it's correct + let req = test::TestRequest::get() + .uri("/v2/project/demo") + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 200); + + let body: serde_json::Value = test::read_body_json(resp).await; + let versions = body["versions"].as_array().unwrap(); + assert!(versions.len() == 1); + let uploaded_version_id = &versions[0]; + + // Checks files to ensure they were uploaded and correctly identify the file + let hash = sha1::Sha1::from(include_bytes!("../tests/files/basic-mod.jar")) + .digest() + .to_string(); + let req = test::TestRequest::get() + .uri(&format!("/v2/version_file/{hash}?algorithm=sha1")) + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 200); + + let body: serde_json::Value = test::read_body_json(resp).await; + let file_version_id = &body["id"]; + assert_eq!(&file_version_id, &uploaded_version_id); + + // Reusing with a different slug and the same file should fail + // Even if that file is named differently + let req = test::TestRequest::post() + .uri("/v2/project") + .append_header(("Authorization", USER_USER_PAT)) + .set_multipart(vec![ + json_diff_slug_file_segment.clone(), // Different slug, different file name + file_diff_name_segment.clone(), // Different file name, same content + ]) + .to_request(); + + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 400); + + // Reusing with the same slug and a different file should fail + let req = test::TestRequest::post() + .uri("/v2/project") + .append_header(("Authorization", USER_USER_PAT)) + .set_multipart(vec![ + json_diff_file_segment.clone(), // Same slug, different file name + file_diff_name_content_segment.clone(), // Different file name, different content + ]) + .to_request(); + + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 400); + + // Different slug, different file should succeed + let req = test::TestRequest::post() + .uri("/v2/project") + .append_header(("Authorization", USER_USER_PAT)) + .set_multipart(vec![ + json_diff_slug_file_segment.clone(), // Different slug, different file name + file_diff_name_content_segment.clone(), // Different file name, same content + ]) + .to_request(); + + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 200); + + // Get + let req = test::TestRequest::get() + .uri("/v2/project/demo") + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 200); + let body: serde_json::Value = test::read_body_json(resp).await; + let id = body["id"].to_string(); + + // Remove the project + let req = test::TestRequest::delete() + .uri("/v2/project/demo") + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 204); + + // Confirm that the project is gone from the cache + assert_eq!( + test_env + .db + .redis_pool + .get::(PROJECTS_SLUGS_NAMESPACE, "demo") + .await + .unwrap(), + None + ); + assert_eq!( + test_env + .db + .redis_pool + .get::(PROJECTS_SLUGS_NAMESPACE, id) + .await + .unwrap(), + None + ); + + // Old slug no longer works + let req = test::TestRequest::get() + .uri("/v2/project/demo") + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 404); + + // Cleanup test db + test_env.cleanup().await; +} + +#[actix_rt::test] +pub async fn test_patch_project() { + let test_env = TestEnvironment::build_with_dummy().await; + let alpha_project_slug = &test_env.dummy.as_ref().unwrap().alpha_project_slug; + let beta_project_slug = &test_env.dummy.as_ref().unwrap().beta_project_slug; + + // First, we do some patch requests that should fail. + // Failure because the user is not authorized. + let req = test::TestRequest::patch() + .uri(&format!("/v2/project/{alpha_project_slug}")) + .append_header(("Authorization", ENEMY_USER_PAT)) + .set_json(json!({ + "title": "Test_Add_Project project - test 1", + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 401); + + // Failure because we are setting URL fields to invalid urls. + for url_type in ["issues_url", "source_url", "wiki_url", "discord_url"] { + let req = test::TestRequest::patch() + .uri(&format!("/v2/project/{alpha_project_slug}")) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + url_type: "w.fake.url", + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 400); + } + + // Failure because these are illegal requested statuses for a normal user. + for req in ["unknown", "processing", "withheld", "scheduled"] { + let req = test::TestRequest::patch() + .uri(&format!("/v2/project/{alpha_project_slug}")) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "requested_status": req, + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 400); + } + + // Failure because these should not be able to be set by a non-mod + for key in ["moderation_message", "moderation_message_body"] { + let req = test::TestRequest::patch() + .uri(&format!("/v2/project/{alpha_project_slug}")) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + key: "test", + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 401); + + // (should work for a mod, though) + let req = test::TestRequest::patch() + .uri(&format!("/v2/project/{alpha_project_slug}")) + .append_header(("Authorization", MOD_USER_PAT)) + .set_json(json!({ + key: "test", + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 204); + } + + // Failure because the slug is already taken. + let req = test::TestRequest::patch() + .uri(&format!("/v2/project/{alpha_project_slug}")) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "slug": beta_project_slug, // the other dummy project has this slug + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 400); + + // Not allowed to directly set status, as 'beta_project_slug' (the other project) is "processing" and cannot have its status changed like this. + let req = test::TestRequest::patch() + .uri(&format!("/v2/project/{beta_project_slug}")) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "status": "private" + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 401); + + // Sucessful request to patch many fields. + let req = test::TestRequest::patch() + .uri(&format!("/v2/project/{alpha_project_slug}")) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "slug": "newslug", + "title": "New successful title", + "description": "New successful description", + "body": "New successful body", + "categories": ["combat"], + "license_id": "MIT", + "issues_url": "https://github.com", + "discord_url": "https://discord.gg", + "wiki_url": "https://wiki.com", + "client_side": "optional", + "server_side": "required", + "donation_urls": [{ + "id": "patreon", + "platform": "Patreon", + "url": "https://patreon.com" + }] + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 204); + + // Old slug no longer works + let req = test::TestRequest::get() + .uri(&format!("/v2/project/{alpha_project_slug}")) + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 404); + + // Old slug no longer works + let req = test::TestRequest::get() + .uri("/v2/project/newslug") + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 200); + + let body: serde_json::Value = test::read_body_json(resp).await; + assert_eq!(body["slug"], json!("newslug")); + assert_eq!(body["title"], json!("New successful title")); + assert_eq!(body["description"], json!("New successful description")); + assert_eq!(body["body"], json!("New successful body")); + assert_eq!(body["categories"], json!(["combat"])); + assert_eq!(body["license"]["id"], json!("MIT")); + assert_eq!(body["issues_url"], json!("https://github.com")); + assert_eq!(body["discord_url"], json!("https://discord.gg")); + assert_eq!(body["wiki_url"], json!("https://wiki.com")); + assert_eq!(body["client_side"], json!("optional")); + assert_eq!(body["server_side"], json!("required")); + assert_eq!( + body["donation_urls"][0]["url"], + json!("https://patreon.com") + ); + + // Cleanup test db + test_env.cleanup().await; +} + +// TODO: Missing routes on projects +// TODO: using permissions/scopes, can we SEE projects existence that we are not allowed to? (ie 401 instead of 404) diff --git a/tests/scopes.rs b/tests/scopes.rs new file mode 100644 index 00000000..806905ab --- /dev/null +++ b/tests/scopes.rs @@ -0,0 +1,1331 @@ +use actix_web::test::{self, TestRequest}; +use bytes::Bytes; +use chrono::{Duration, Utc}; +use common::actix::AppendsMultipart; +use labrinth::models::pats::Scopes; +use serde_json::json; + +use crate::common::{database::*, environment::TestEnvironment, scopes::ScopeTest}; + +// importing common module. +mod common; + +// For each scope, we (using test_scope): +// - create a PAT with a given set of scopes for a function +// - create a PAT with all other scopes for a function +// - test the function with the PAT with the given scopes +// - test the function with the PAT with all other scopes + +// Test for users, emails, and payout scopes (not user auth scope or notifs) +#[actix_rt::test] +async fn user_scopes() { + // Test setup and dummy data + let test_env = TestEnvironment::build_with_dummy().await; + + // User reading + let read_user = Scopes::USER_READ; + let req_gen = || TestRequest::get().uri("/v2/user"); + let (_, success) = ScopeTest::new(&test_env) + .test(req_gen, read_user) + .await + .unwrap(); + assert!(success["email"].as_str().is_none()); // email should not be present + assert!(success["payout_data"].as_object().is_none()); // payout should not be present + + // Email reading + let read_email = Scopes::USER_READ | Scopes::USER_READ_EMAIL; + let req_gen = || TestRequest::get().uri("/v2/user"); + let (_, success) = ScopeTest::new(&test_env) + .test(req_gen, read_email) + .await + .unwrap(); + assert_eq!(success["email"], json!("user@modrinth.com")); // email should be present + + // Payout reading + let read_payout = Scopes::USER_READ | Scopes::PAYOUTS_READ; + let req_gen = || TestRequest::get().uri("/v2/user"); + let (_, success) = ScopeTest::new(&test_env) + .test(req_gen, read_payout) + .await + .unwrap(); + assert!(success["payout_data"].as_object().is_some()); // payout should be present + + // User writing + // We use the Admin PAT for this test, on the 'user' user + let write_user = Scopes::USER_WRITE; + let req_gen = || { + TestRequest::patch().uri("/v2/user/user").set_json(json!( { + // Do not include 'username', as to not change the rest of the tests + "name": "NewName", + "bio": "New bio", + "location": "New location", + "role": "admin", + "badges": 5, + // Do not include payout info, different scope + })) + }; + ScopeTest::new(&test_env) + .with_user_id(ADMIN_USER_ID_PARSED) + .test(req_gen, write_user) + .await + .unwrap(); + + // User payout info writing + let failure_write_user_payout = Scopes::all() ^ Scopes::PAYOUTS_WRITE; // Failure case should include USER_WRITE + let write_user_payout = Scopes::USER_WRITE | Scopes::PAYOUTS_WRITE; + let req_gen = || { + TestRequest::patch().uri("/v2/user/user").set_json(json!( { + "payout_data": { + "payout_wallet": "paypal", + "payout_wallet_type": "email", + "payout_address": "test@modrinth.com" + } + })) + }; + ScopeTest::new(&test_env) + .with_failure_scopes(failure_write_user_payout) + .test(req_gen, write_user_payout) + .await + .unwrap(); + + // User deletion + // (The failure is first, and this is the last test for this test function, we can delete it and use the same PAT for both tests) + let delete_user = Scopes::USER_DELETE; + let req_gen = || TestRequest::delete().uri("/v2/user/enemy"); + ScopeTest::new(&test_env) + .with_user_id(ENEMY_USER_ID_PARSED) + .test(req_gen, delete_user) + .await + .unwrap(); + + // Cleanup test db + test_env.cleanup().await; +} + +// Notifications +#[actix_rt::test] +pub async fn notifications_scopes() { + let test_env = TestEnvironment::build_with_dummy().await; + let alpha_team_id = &test_env.dummy.as_ref().unwrap().alpha_team_id.clone(); + + // We will invite user 'friend' to project team, and use that as a notification + // Get notifications + let req = TestRequest::post() + .uri(&format!("/v2/team/{alpha_team_id}/members")) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!( { + "user_id": FRIEND_USER_ID // friend + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 204); + + // Notification get + let read_notifications = Scopes::NOTIFICATION_READ; + let req_gen = + || test::TestRequest::get().uri(&format!("/v2/user/{FRIEND_USER_ID}/notifications")); + let (_, success) = ScopeTest::new(&test_env) + .with_user_id(FRIEND_USER_ID_PARSED) + .test(req_gen, read_notifications) + .await + .unwrap(); + let notification_id = success.as_array().unwrap()[0]["id"].as_str().unwrap(); + + let req_gen = || { + test::TestRequest::get().uri(&format!( + "/v2/notifications?ids=[{uri}]", + uri = urlencoding::encode(&format!("\"{notification_id}\"")) + )) + }; + ScopeTest::new(&test_env) + .with_user_id(FRIEND_USER_ID_PARSED) + .test(req_gen, read_notifications) + .await + .unwrap(); + + let req_gen = || test::TestRequest::get().uri(&format!("/v2/notification/{notification_id}")); + ScopeTest::new(&test_env) + .with_user_id(FRIEND_USER_ID_PARSED) + .test(req_gen, read_notifications) + .await + .unwrap(); + + // Notification mark as read + let write_notifications = Scopes::NOTIFICATION_WRITE; + let req_gen = || { + test::TestRequest::patch().uri(&format!( + "/v2/notifications?ids=[{uri}]", + uri = urlencoding::encode(&format!("\"{notification_id}\"")) + )) + }; + ScopeTest::new(&test_env) + .with_user_id(FRIEND_USER_ID_PARSED) + .test(req_gen, write_notifications) + .await + .unwrap(); + + let req_gen = || test::TestRequest::patch().uri(&format!("/v2/notification/{notification_id}")); + ScopeTest::new(&test_env) + .with_user_id(FRIEND_USER_ID_PARSED) + .test(req_gen, write_notifications) + .await + .unwrap(); + + // Notification delete + let req_gen = + || test::TestRequest::delete().uri(&format!("/v2/notification/{notification_id}")); + ScopeTest::new(&test_env) + .with_user_id(FRIEND_USER_ID_PARSED) + .test(req_gen, write_notifications) + .await + .unwrap(); + + // Mass notification delete + // We invite mod, get the notification ID, and do mass delete using that + let req = test::TestRequest::post() + .uri(&format!("/v2/team/{alpha_team_id}/members")) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!( { + "user_id": MOD_USER_ID // mod + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 204); + let read_notifications = Scopes::NOTIFICATION_READ; + let req_gen = || test::TestRequest::get().uri(&format!("/v2/user/{MOD_USER_ID}/notifications")); + let (_, success) = ScopeTest::new(&test_env) + .with_user_id(MOD_USER_ID_PARSED) + .test(req_gen, read_notifications) + .await + .unwrap(); + let notification_id = success.as_array().unwrap()[0]["id"].as_str().unwrap(); + + let req_gen = || { + test::TestRequest::delete().uri(&format!( + "/v2/notifications?ids=[{uri}]", + uri = urlencoding::encode(&format!("\"{notification_id}\"")) + )) + }; + ScopeTest::new(&test_env) + .with_user_id(MOD_USER_ID_PARSED) + .test(req_gen, write_notifications) + .await + .unwrap(); + + // Cleanup test db + test_env.cleanup().await; +} + +// Project version creation scopes +#[actix_rt::test] +pub async fn project_version_create_scopes() { + let test_env = TestEnvironment::build_with_dummy().await; + + // Create project + let create_project = Scopes::PROJECT_CREATE; + let json_data = json!( + { + "title": "Test_Add_Project project", + "slug": "demo", + "description": "Example description.", + "body": "Example body.", + "client_side": "required", + "server_side": "optional", + "initial_versions": [{ + "file_parts": ["basic-mod.jar"], + "version_number": "1.2.3", + "version_title": "start", + "dependencies": [], + "game_versions": ["1.20.1"] , + "release_channel": "release", + "loaders": ["fabric"], + "featured": true + }], + "categories": [], + "license_id": "MIT" + } + ); + let json_segment = common::actix::MultipartSegment { + name: "data".to_string(), + filename: None, + content_type: Some("application/json".to_string()), + data: common::actix::MultipartSegmentData::Text(serde_json::to_string(&json_data).unwrap()), + }; + let file_segment = common::actix::MultipartSegment { + name: "basic-mod.jar".to_string(), + filename: Some("basic-mod.jar".to_string()), + content_type: Some("application/java-archive".to_string()), + data: common::actix::MultipartSegmentData::Binary( + include_bytes!("../tests/files/basic-mod.jar").to_vec(), + ), + }; + + let req_gen = || { + test::TestRequest::post() + .uri("/v2/project") + .set_multipart(vec![json_segment.clone(), file_segment.clone()]) + }; + let (_, success) = ScopeTest::new(&test_env) + .test(req_gen, create_project) + .await + .unwrap(); + let project_id = success["id"].as_str().unwrap(); + + // Add version to project + let create_version = Scopes::VERSION_CREATE; + let json_data = json!( + { + "project_id": project_id, + "file_parts": ["basic-mod-different.jar"], + "version_number": "1.2.3.4", + "version_title": "start", + "dependencies": [], + "game_versions": ["1.20.1"] , + "release_channel": "release", + "loaders": ["fabric"], + "featured": true + } + ); + let json_segment = common::actix::MultipartSegment { + name: "data".to_string(), + filename: None, + content_type: Some("application/json".to_string()), + data: common::actix::MultipartSegmentData::Text(serde_json::to_string(&json_data).unwrap()), + }; + let file_segment = common::actix::MultipartSegment { + name: "basic-mod-different.jar".to_string(), + filename: Some("basic-mod.jar".to_string()), + content_type: Some("application/java-archive".to_string()), + data: common::actix::MultipartSegmentData::Binary( + include_bytes!("../tests/files/basic-mod-different.jar").to_vec(), + ), + }; + + let req_gen = || { + test::TestRequest::post() + .uri("/v2/version") + .set_multipart(vec![json_segment.clone(), file_segment.clone()]) + }; + ScopeTest::new(&test_env) + .test(req_gen, create_version) + .await + .unwrap(); + + // Cleanup test db + test_env.cleanup().await; +} + +// Project management scopes +#[actix_rt::test] +pub async fn project_version_reads_scopes() { + let test_env = TestEnvironment::build_with_dummy().await; + let beta_project_id = &test_env.dummy.as_ref().unwrap().beta_project_id.clone(); + let beta_version_id = &test_env.dummy.as_ref().unwrap().beta_version_id.clone(); + let alpha_team_id = &test_env.dummy.as_ref().unwrap().alpha_team_id.clone(); + let beta_file_hash = &test_env.dummy.as_ref().unwrap().beta_file_hash.clone(); + + // Project reading + // Uses 404 as the expected failure code (or 200 and an empty list for mass reads) + let read_project = Scopes::PROJECT_READ; + let req_gen = || test::TestRequest::get().uri(&format!("/v2/project/{beta_project_id}")); + ScopeTest::new(&test_env) + .with_failure_code(404) + .test(req_gen, read_project) + .await + .unwrap(); + + let req_gen = + || test::TestRequest::get().uri(&format!("/v2/project/{beta_project_id}/dependencies")); + ScopeTest::new(&test_env) + .with_failure_code(404) + .test(req_gen, read_project) + .await + .unwrap(); + + let req_gen = || { + test::TestRequest::get().uri(&format!( + "/v2/projects?ids=[{uri}]", + uri = urlencoding::encode(&format!("\"{beta_project_id}\"")) + )) + }; + let (failure, success) = ScopeTest::new(&test_env) + .with_failure_code(200) + .test(req_gen, read_project) + .await + .unwrap(); + assert!(failure.as_array().unwrap().is_empty()); + assert!(!success.as_array().unwrap().is_empty()); + + // Team project reading + let req_gen = + || test::TestRequest::get().uri(&format!("/v2/project/{beta_project_id}/members")); + ScopeTest::new(&test_env) + .with_failure_code(404) + .test(req_gen, read_project) + .await + .unwrap(); + + // Get team members + // In this case, as these are public endpoints, logging in only is relevant to showing permissions + // So for our test project (with 1 user, 'user') we will check the permissions before and after having the scope. + let req_gen = || test::TestRequest::get().uri(&format!("/v2/team/{alpha_team_id}/members")); + let (failure, success) = ScopeTest::new(&test_env) + .with_failure_code(200) + .test(req_gen, read_project) + .await + .unwrap(); + assert!(!failure.as_array().unwrap()[0].as_object().unwrap()["permissions"].is_number()); + assert!(success.as_array().unwrap()[0].as_object().unwrap()["permissions"].is_number()); + + let req_gen = || { + test::TestRequest::get().uri(&format!( + "/v2/teams?ids=[{uri}]", + uri = urlencoding::encode(&format!("\"{alpha_team_id}\"")) + )) + }; + let (failure, success) = ScopeTest::new(&test_env) + .with_failure_code(200) + .test(req_gen, read_project) + .await + .unwrap(); + assert!(!failure.as_array().unwrap()[0].as_array().unwrap()[0] + .as_object() + .unwrap()["permissions"] + .is_number()); + assert!(success.as_array().unwrap()[0].as_array().unwrap()[0] + .as_object() + .unwrap()["permissions"] + .is_number()); + + // User project reading + // Test user has two projects, one public and one private + let req_gen = || test::TestRequest::get().uri(&format!("/v2/user/{USER_USER_ID}/projects")); + let (failure, success) = ScopeTest::new(&test_env) + .with_failure_code(200) + .test(req_gen, read_project) + .await + .unwrap(); + assert!(!failure + .as_array() + .unwrap() + .iter() + .any(|x| x["status"] == "processing")); + assert!(success + .as_array() + .unwrap() + .iter() + .any(|x| x["status"] == "processing")); + + // Project metadata reading + let req_gen = || { + test::TestRequest::get().uri(&format!( + "/maven/maven/modrinth/{beta_project_id}/maven-metadata.xml" + )) + }; + ScopeTest::new(&test_env) + .with_failure_code(404) + .test(req_gen, read_project) + .await + .unwrap(); + + // Version reading + // First, set version to hidden (which is when the scope is required to read it) + let read_version = Scopes::VERSION_READ; + let req = test::TestRequest::patch() + .uri(&format!("/v2/version/{beta_version_id}")) + .append_header(("Authorization", USER_USER_PAT)) + .set_json(json!({ + "status": "draft" + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 204); + + let req_gen = || test::TestRequest::get().uri(&format!("/v2/version_file/{beta_file_hash}")); + ScopeTest::new(&test_env) + .with_failure_code(404) + .test(req_gen, read_version) + .await + .unwrap(); + + let req_gen = + || test::TestRequest::get().uri(&format!("/v2/version_file/{beta_file_hash}/download")); + ScopeTest::new(&test_env) + .with_failure_code(404) + .test(req_gen, read_version) + .await + .unwrap(); + + // TODO: Should this be /POST? Looks like /GET + // TODO: this scope doesn't actually affect anything, because the Project::get_id contained within disallows hidden versions, which is the point of this scope + // let req_gen = || { + // test::TestRequest::post() + // .uri(&format!("/v2/version_file/{beta_file_hash}/update")) + // .set_json(json!({})) + // }; + // ScopeTest::new(&test_env).with_failure_code(404).test(req_gen, read_version).await.unwrap(); + + // TODO: Should this be /POST? Looks like /GET + let req_gen = || { + test::TestRequest::post() + .uri("/v2/version_files") + .set_json(json!({ + "hashes": [beta_file_hash] + })) + }; + let (failure, success) = ScopeTest::new(&test_env) + .with_failure_code(200) + .test(req_gen, read_version) + .await + .unwrap(); + assert!(!failure.as_object().unwrap().contains_key(beta_file_hash)); + assert!(success.as_object().unwrap().contains_key(beta_file_hash)); + + // Update version file + // TODO: Should this be /POST? Looks like /GET + // TODO: this scope doesn't actually affect anything, because the Project::get_id contained within disallows hidden versions, which is the point of this scope + + // let req_gen = || { + // test::TestRequest::post() + // .uri(&format!("/v2/version_files/update_individual")) + // .set_json(json!({ + // "hashes": [{ + // "hash": beta_file_hash, + // }] + // })) + // }; + // let (failure, success) = ScopeTest::new(&test_env).with_failure_code(200).test(req_gen, read_version).await.unwrap(); + // assert!(!failure.as_object().unwrap().contains_key(beta_file_hash)); + // assert!(success.as_object().unwrap().contains_key(beta_file_hash)); + + // Update version file + // TODO: this scope doesn't actually affect anything, because the Project::get_id contained within disallows hidden versions, which is the point of this scope + // let req_gen = || { + // test::TestRequest::post() + // .uri(&format!("/v2/version_files/update")) + // .set_json(json!({ + // "hashes": [beta_file_hash] + // })) + // }; + // let (failure, success) = ScopeTest::new(&test_env).with_failure_code(200).test(req_gen, read_version).await.unwrap(); + // assert!(!failure.as_object().unwrap().contains_key(beta_file_hash)); + // assert!(success.as_object().unwrap().contains_key(beta_file_hash)); + + // Both project and version reading + let read_project_and_version = Scopes::PROJECT_READ | Scopes::VERSION_READ; + let req_gen = + || test::TestRequest::get().uri(&format!("/v2/project/{beta_project_id}/version")); + ScopeTest::new(&test_env) + .with_failure_code(404) + .test(req_gen, read_project_and_version) + .await + .unwrap(); + + // TODO: fails for the same reason as above + // let req_gen = || { + // test::TestRequest::get() + // .uri(&format!("/v2/project/{beta_project_id}/version/{beta_version_id}")) + // }; + // ScopeTest::new(&test_env).with_failure_code(404).test(req_gen, read_project_and_version).await.unwrap(); + + // Cleanup test db + test_env.cleanup().await; +} + +// Project writing +#[actix_rt::test] +pub async fn project_write_scopes() { + // Test setup and dummy data + let test_env = TestEnvironment::build_with_dummy().await; + let beta_project_id = &test_env.dummy.as_ref().unwrap().beta_project_id.clone(); + let alpha_team_id = &test_env.dummy.as_ref().unwrap().alpha_team_id.clone(); + + // Projects writing + let write_project = Scopes::PROJECT_WRITE; + let req_gen = || { + test::TestRequest::patch() + .uri(&format!("/v2/project/{beta_project_id}")) + .set_json(json!( + { + "title": "test_project_version_write_scopes Title", + } + )) + }; + ScopeTest::new(&test_env) + .test(req_gen, write_project) + .await + .unwrap(); + + let req_gen = || { + test::TestRequest::patch() + .uri(&format!( + "/v2/projects?ids=[{uri}]", + uri = urlencoding::encode(&format!("\"{beta_project_id}\"")) + )) + .set_json(json!( + { + "description": "test_project_version_write_scopes Description", + } + )) + }; + ScopeTest::new(&test_env) + .test(req_gen, write_project) + .await + .unwrap(); + + // Approve beta as private so we can schedule it + let req = test::TestRequest::patch() + .uri(&format!("/v2/project/{beta_project_id}")) + .append_header(("Authorization", MOD_USER_PAT)) + .set_json(json!({ + "status": "private" + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 204); + + let req_gen = || { + test::TestRequest::post() + .uri(&format!("/v2/project/{beta_project_id}/schedule")) // beta_project_id is an unpublished can schedule it + .set_json(json!( + { + "requested_status": "private", + "time": Utc::now() + Duration::days(1), + } + )) + }; + ScopeTest::new(&test_env) + .test(req_gen, write_project) + .await + .unwrap(); + + // Icons and gallery images + let req_gen = || { + test::TestRequest::patch() + .uri(&format!("/v2/project/{beta_project_id}/icon?ext=png")) + .set_payload(Bytes::from( + include_bytes!("../tests/files/200x200.png") as &[u8] + )) + }; + ScopeTest::new(&test_env) + .test(req_gen, write_project) + .await + .unwrap(); + + let req_gen = + || test::TestRequest::delete().uri(&format!("/v2/project/{beta_project_id}/icon")); + ScopeTest::new(&test_env) + .test(req_gen, write_project) + .await + .unwrap(); + + let req_gen = || { + test::TestRequest::post() + .uri(&format!( + "/v2/project/{beta_project_id}/gallery?ext=png&featured=true" + )) + .set_payload(Bytes::from( + include_bytes!("../tests/files/200x200.png") as &[u8] + )) + }; + ScopeTest::new(&test_env) + .test(req_gen, write_project) + .await + .unwrap(); + + // Get project, as we need the gallery image url + let req_gen = test::TestRequest::get() + .uri(&format!("/v2/project/{beta_project_id}")) + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + let resp = test_env.call(req_gen).await; + let project: serde_json::Value = test::read_body_json(resp).await; + let gallery_url = project["gallery"][0]["url"].as_str().unwrap(); + + let req_gen = || { + test::TestRequest::patch().uri(&format!( + "/v2/project/{beta_project_id}/gallery?url={gallery_url}" + )) + }; + ScopeTest::new(&test_env) + .test(req_gen, write_project) + .await + .unwrap(); + + let req_gen = || { + test::TestRequest::delete().uri(&format!( + "/v2/project/{beta_project_id}/gallery?url={gallery_url}" + )) + }; + ScopeTest::new(&test_env) + .test(req_gen, write_project) + .await + .unwrap(); + + // Team scopes - add user 'friend' + let req_gen = || { + test::TestRequest::post() + .uri(&format!("/v2/team/{alpha_team_id}/members")) + .set_json(json!({ + "user_id": FRIEND_USER_ID + })) + }; + ScopeTest::new(&test_env) + .test(req_gen, write_project) + .await + .unwrap(); + + // Accept team invite as 'friend' + let req_gen = || test::TestRequest::post().uri(&format!("/v2/team/{alpha_team_id}/join")); + ScopeTest::new(&test_env) + .with_user_id(FRIEND_USER_ID_PARSED) + .test(req_gen, write_project) + .await + .unwrap(); + + // Patch 'friend' user + let req_gen = || { + test::TestRequest::patch() + .uri(&format!( + "/v2/team/{alpha_team_id}/members/{FRIEND_USER_ID}" + )) + .set_json(json!({ + "permissions": 1 + })) + }; + ScopeTest::new(&test_env) + .test(req_gen, write_project) + .await + .unwrap(); + + // Transfer ownership to 'friend' + let req_gen = || { + test::TestRequest::patch() + .uri(&format!("/v2/team/{alpha_team_id}/owner")) + .set_json(json!({ + "user_id": FRIEND_USER_ID + })) + }; + ScopeTest::new(&test_env) + .test(req_gen, write_project) + .await + .unwrap(); + + // Now as 'friend', delete 'user' + let req_gen = || { + test::TestRequest::delete().uri(&format!("/v2/team/{alpha_team_id}/members/{USER_USER_ID}")) + }; + ScopeTest::new(&test_env) + .with_user_id(FRIEND_USER_ID_PARSED) + .test(req_gen, write_project) + .await + .unwrap(); + + // Delete project + // TODO: this route is currently broken, + // because the Project::get_id contained within Project::remove doesnt include hidden versions, meaning that if there + // is a hidden version, it will fail to delete the project (with a 500 error, as the versions of a project are not all deleted) + // let delete_version = Scopes::PROJECT_DELETE; + // let req_gen = || { + // test::TestRequest::delete() + // .uri(&format!("/v2/project/{beta_project_id}")) + // }; + // ScopeTest::new(&test_env).test(req_gen, delete_version).await.unwrap(); + + // Cleanup test db + test_env.cleanup().await; +} + +// Version write +#[actix_rt::test] +pub async fn version_write_scopes() { + // Test setup and dummy data + let test_env = TestEnvironment::build_with_dummy().await; + let alpha_version_id = &test_env.dummy.as_ref().unwrap().beta_version_id.clone(); + let beta_version_id = &test_env.dummy.as_ref().unwrap().beta_version_id.clone(); + let alpha_file_hash = &test_env.dummy.as_ref().unwrap().beta_file_hash.clone(); + + let write_version = Scopes::VERSION_WRITE; + + // Approve beta version as private so we can schedule it + let req = test::TestRequest::patch() + .uri(&format!("/v2/version/{beta_version_id}")) + .append_header(("Authorization", MOD_USER_PAT)) + .set_json(json!({ + "status": "unlisted" + })) + .to_request(); + let resp = test_env.call(req).await; + assert_eq!(resp.status(), 204); + + // Schedule version + let req_gen = || { + test::TestRequest::post() + .uri(&format!("/v2/version/{beta_version_id}/schedule")) // beta_version_id is an *approved* version, so we can schedule it + .set_json(json!( + { + "requested_status": "archived", + "time": Utc::now() + Duration::days(1), + } + )) + }; + ScopeTest::new(&test_env) + .test(req_gen, write_version) + .await + .unwrap(); + + // Patch version + let req_gen = || { + test::TestRequest::patch() + .uri(&format!("/v2/version/{alpha_version_id}")) + .set_json(json!( + { + "version_title": "test_version_write_scopes Title", + } + )) + }; + ScopeTest::new(&test_env) + .test(req_gen, write_version) + .await + .unwrap(); + + // Generate test project data. + // Basic json + let json_segment = common::actix::MultipartSegment { + name: "data".to_string(), + filename: None, + content_type: Some("application/json".to_string()), + data: common::actix::MultipartSegmentData::Text( + serde_json::to_string(&json!( + { + "file_types": { + "simple-zip.zip": "required-resource-pack" + }, + } + )) + .unwrap(), + ), + }; + + // Differently named file, with different content + let content_segment = common::actix::MultipartSegment { + name: "simple-zip.zip".to_string(), + filename: Some("simple-zip.zip".to_string()), + content_type: Some("application/zip".to_string()), + data: common::actix::MultipartSegmentData::Binary( + include_bytes!("../tests/files/simple-zip.zip").to_vec(), + ), + }; + + // Upload version file + let req_gen = || { + test::TestRequest::post() + .uri(&format!("/v2/version/{alpha_version_id}/file")) + .set_multipart(vec![json_segment.clone(), content_segment.clone()]) + }; + ScopeTest::new(&test_env) + .test(req_gen, write_version) + .await + .unwrap(); + + // Delete version file + // TODO: Should this scope be VERSION_DELETE? + let req_gen = || { + test::TestRequest::delete().uri(&format!("/v2/version_file/{alpha_file_hash}")) + // Delete from alpha_version_id, as we uploaded to alpha_version_id and it needs another file + }; + ScopeTest::new(&test_env) + .test(req_gen, write_version) + .await + .unwrap(); + + // Delete version + let delete_version = Scopes::VERSION_DELETE; + let req_gen = || test::TestRequest::delete().uri(&format!("/v2/version/{alpha_version_id}")); + ScopeTest::new(&test_env) + .test(req_gen, delete_version) + .await + .unwrap(); + + // Cleanup test db + test_env.cleanup().await; +} + +// Report scopes +#[actix_rt::test] +pub async fn report_scopes() { + // Test setup and dummy data + let test_env = TestEnvironment::build_with_dummy().await; + let beta_project_id = &test_env.dummy.as_ref().unwrap().beta_project_id.clone(); + + // Create report + let report_create = Scopes::REPORT_CREATE; + let req_gen = || { + test::TestRequest::post().uri("/v2/report").set_json(json!({ + "report_type": "copyright", + "item_id": beta_project_id, + "item_type": "project", + "body": "This is a reupload of my mod, ", + })) + }; + ScopeTest::new(&test_env) + .test(req_gen, report_create) + .await + .unwrap(); + + // Get reports + let report_read = Scopes::REPORT_READ; + let req_gen = || test::TestRequest::get().uri("/v2/report"); + let (_, success) = ScopeTest::new(&test_env) + .test(req_gen, report_read) + .await + .unwrap(); + let report_id = success.as_array().unwrap()[0]["id"].as_str().unwrap(); + + let req_gen = || test::TestRequest::get().uri(&format!("/v2/report/{}", report_id)); + ScopeTest::new(&test_env) + .test(req_gen, report_read) + .await + .unwrap(); + + let req_gen = || { + test::TestRequest::get().uri(&format!( + "/v2/reports?ids=[{}]", + urlencoding::encode(&format!("\"{}\"", report_id)) + )) + }; + ScopeTest::new(&test_env) + .test(req_gen, report_read) + .await + .unwrap(); + + // Edit report + let report_edit = Scopes::REPORT_WRITE; + let req_gen = || { + test::TestRequest::patch() + .uri(&format!("/v2/report/{}", report_id)) + .set_json(json!({ + "body": "This is a reupload of my mod, G8!", + })) + }; + ScopeTest::new(&test_env) + .test(req_gen, report_edit) + .await + .unwrap(); + + // Delete report + // We use a moderator PAT here, as only moderators can delete reports + let report_delete = Scopes::REPORT_DELETE; + let req_gen = || test::TestRequest::delete().uri(&format!("/v2/report/{}", report_id)); + ScopeTest::new(&test_env) + .with_user_id(MOD_USER_ID_PARSED) + .test(req_gen, report_delete) + .await + .unwrap(); + + // Cleanup test db + test_env.cleanup().await; +} + +// Thread scopes +#[actix_rt::test] +pub async fn thread_scopes() { + // Test setup and dummy data + let test_env = TestEnvironment::build_with_dummy().await; + let alpha_thread_id = &test_env.dummy.as_ref().unwrap().alpha_thread_id.clone(); + let beta_thread_id = &test_env.dummy.as_ref().unwrap().beta_thread_id.clone(); + + // Thread read + let thread_read = Scopes::THREAD_READ; + let req_gen = || test::TestRequest::get().uri(&format!("/v2/thread/{alpha_thread_id}")); + ScopeTest::new(&test_env) + .test(req_gen, thread_read) + .await + .unwrap(); + + let req_gen = || { + test::TestRequest::get().uri(&format!( + "/v2/threads?ids=[{}]", + urlencoding::encode(&format!("\"{}\"", "U")) + )) + }; + ScopeTest::new(&test_env) + .test(req_gen, thread_read) + .await + .unwrap(); + + // Thread write (to also push to moderator inbox) + let thread_write = Scopes::THREAD_WRITE; + let req_gen = || { + test::TestRequest::post() + .uri(&format!("/v2/thread/{beta_thread_id}")) + .set_json(json!({ + "body": { + "type": "text", + "body": "test_thread_scopes Body" + } + })) + }; + ScopeTest::new(&test_env) + .with_user_id(USER_USER_ID_PARSED) + .test(req_gen, thread_write) + .await + .unwrap(); + + // Check moderation inbox + // Uses moderator PAT, as only moderators can see the moderation inbox + let req_gen = || test::TestRequest::get().uri("/v2/thread/inbox"); + let (_, success) = ScopeTest::new(&test_env) + .with_user_id(MOD_USER_ID_PARSED) + .test(req_gen, thread_read) + .await + .unwrap(); + let thread = success.as_array().unwrap()[0].as_object().unwrap(); + let thread_id = thread["id"].as_str().unwrap(); + + // Moderator 'read' thread + // Uses moderator PAT, as only moderators can see the moderation inbox + let req_gen = || test::TestRequest::post().uri(&format!("/v2/thread/{thread_id}/read")); + ScopeTest::new(&test_env) + .with_user_id(MOD_USER_ID_PARSED) + .test(req_gen, thread_read) + .await + .unwrap(); + + // Delete that message + // First, get message id + let req_gen = test::TestRequest::get() + .uri(&format!("/v2/thread/{thread_id}")) + .append_header(("Authorization", USER_USER_PAT)) + .to_request(); + let resp = test_env.call(req_gen).await; + let success: serde_json::Value = test::read_body_json(resp).await; + let thread_messages = success.as_object().unwrap()["messages"].as_array().unwrap(); + let thread_message_id = thread_messages[0].as_object().unwrap()["id"] + .as_str() + .unwrap(); + let req_gen = || test::TestRequest::delete().uri(&format!("/v2/message/{thread_message_id}")); + ScopeTest::new(&test_env) + .with_user_id(MOD_USER_ID_PARSED) + .test(req_gen, thread_write) + .await + .unwrap(); + + // Cleanup test db + test_env.cleanup().await; +} + +// Pat scopes +#[actix_rt::test] +pub async fn pat_scopes() { + let test_env = TestEnvironment::build_with_dummy().await; + + // Pat create + let pat_create = Scopes::PAT_CREATE; + let req_gen = || { + test::TestRequest::post().uri("/v2/pat").set_json(json!({ + "scopes": 1, + "name": "test_pat_scopes Name", + "expires": Utc::now() + Duration::days(1), + })) + }; + let (_, success) = ScopeTest::new(&test_env) + .test(req_gen, pat_create) + .await + .unwrap(); + let pat_id = success["id"].as_str().unwrap(); + + // Pat write + let pat_write = Scopes::PAT_WRITE; + let req_gen = || { + test::TestRequest::patch() + .uri(&format!("/v2/pat/{pat_id}")) + .set_json(json!({})) + }; + ScopeTest::new(&test_env) + .test(req_gen, pat_write) + .await + .unwrap(); + + // Pat read + let pat_read = Scopes::PAT_READ; + let req_gen = || test::TestRequest::get().uri("/v2/pat"); + ScopeTest::new(&test_env) + .test(req_gen, pat_read) + .await + .unwrap(); + + // Pat delete + let pat_delete = Scopes::PAT_DELETE; + let req_gen = || test::TestRequest::delete().uri(&format!("/v2/pat/{pat_id}")); + ScopeTest::new(&test_env) + .test(req_gen, pat_delete) + .await + .unwrap(); + + // Cleanup test db + test_env.cleanup().await; +} + +// Collection scopes +#[actix_rt::test] +pub async fn collections_scopes() { + // Test setup and dummy data + let test_env = TestEnvironment::build_with_dummy().await; + let alpha_project_id = &test_env.dummy.as_ref().unwrap().alpha_project_id.clone(); + + // Create collection + let collection_create = Scopes::COLLECTION_CREATE; + let req_gen = || { + test::TestRequest::post() + .uri("/v2/collection") + .set_json(json!({ + "title": "Test Collection", + "description": "Test Collection Description", + "projects": [alpha_project_id] + })) + }; + let (_, success) = ScopeTest::new(&test_env) + .test(req_gen, collection_create) + .await + .unwrap(); + let collection_id = success["id"].as_str().unwrap(); + + // Patch collection + // Collections always initialize to public, so we do patch before Get testing + let collection_write = Scopes::COLLECTION_WRITE; + let req_gen = || { + test::TestRequest::patch() + .uri(&format!("/v2/collection/{collection_id}")) + .set_json(json!({ + "title": "Test Collection patch", + "status": "private", + })) + }; + ScopeTest::new(&test_env) + .test(req_gen, collection_write) + .await + .unwrap(); + + // Read collection + let collection_read = Scopes::COLLECTION_READ; + let req_gen = || test::TestRequest::get().uri(&format!("/v2/collection/{}", collection_id)); + ScopeTest::new(&test_env) + .with_failure_code(404) + .test(req_gen, collection_read) + .await + .unwrap(); + + let req_gen = || { + test::TestRequest::get().uri(&format!( + "/v2/collections?ids=[{}]", + urlencoding::encode(&format!("\"{}\"", collection_id)) + )) + }; + let (failure, success) = ScopeTest::new(&test_env) + .with_failure_code(200) + .test(req_gen, collection_read) + .await + .unwrap(); + assert_eq!(failure.as_array().unwrap().len(), 0); + assert_eq!(success.as_array().unwrap().len(), 1); + + let req_gen = || test::TestRequest::get().uri(&format!("/v2/user/{USER_USER_ID}/collections")); + let (failure, success) = ScopeTest::new(&test_env) + .with_failure_code(200) + .test(req_gen, collection_read) + .await + .unwrap(); + assert_eq!(failure.as_array().unwrap().len(), 0); + assert_eq!(success.as_array().unwrap().len(), 1); + + let req_gen = || { + test::TestRequest::patch() + .uri(&format!("/v2/collection/{collection_id}/icon?ext=png")) + .set_payload(Bytes::from( + include_bytes!("../tests/files/200x200.png") as &[u8] + )) + }; + ScopeTest::new(&test_env) + .test(req_gen, collection_write) + .await + .unwrap(); + + let req_gen = + || test::TestRequest::delete().uri(&format!("/v2/collection/{collection_id}/icon")); + ScopeTest::new(&test_env) + .test(req_gen, collection_write) + .await + .unwrap(); + + // Cleanup test db + test_env.cleanup().await; +} + +// Organization scopes (and a couple PROJECT_WRITE scopes that are only allowed for orgs) +#[actix_rt::test] +pub async fn organization_scopes() { + // Test setup and dummy data + let test_env = TestEnvironment::build_with_dummy().await; + let beta_project_id = &test_env.dummy.as_ref().unwrap().beta_project_id.clone(); + + // Create organization + let organization_create = Scopes::ORGANIZATION_CREATE; + let req_gen = || { + test::TestRequest::post() + .uri("/v2/organization") + .set_json(json!({ + "title": "TestOrg", + "description": "TestOrg Description", + })) + }; + let (_, success) = ScopeTest::new(&test_env) + .test(req_gen, organization_create) + .await + .unwrap(); + let organization_id = success["id"].as_str().unwrap(); + + // Patch organization + let organization_edit = Scopes::ORGANIZATION_WRITE; + let req_gen = || { + test::TestRequest::patch() + .uri(&format!("/v2/organization/{organization_id}")) + .set_json(json!({ + "description": "TestOrg Patch Description", + })) + }; + ScopeTest::new(&test_env) + .test(req_gen, organization_edit) + .await + .unwrap(); + + let req_gen = || { + test::TestRequest::patch() + .uri(&format!("/v2/organization/{organization_id}/icon?ext=png")) + .set_payload(Bytes::from( + include_bytes!("../tests/files/200x200.png") as &[u8] + )) + }; + ScopeTest::new(&test_env) + .test(req_gen, organization_edit) + .await + .unwrap(); + + let req_gen = + || test::TestRequest::delete().uri(&format!("/v2/organization/{organization_id}/icon")); + ScopeTest::new(&test_env) + .test(req_gen, organization_edit) + .await + .unwrap(); + + // add project + let organization_project_edit = Scopes::PROJECT_WRITE | Scopes::ORGANIZATION_WRITE; + let req_gen = || { + test::TestRequest::post() + .uri(&format!("/v2/organization/{organization_id}/projects")) + .set_json(json!({ + "project_id": beta_project_id + })) + }; + ScopeTest::new(&test_env) + .with_failure_scopes(Scopes::all() ^ Scopes::ORGANIZATION_WRITE) + .test(req_gen, organization_project_edit) + .await + .unwrap(); + + // Organization reads + let organization_read = Scopes::ORGANIZATION_READ; + let req_gen = || test::TestRequest::get().uri(&format!("/v2/organization/{organization_id}")); + let (failure, success) = ScopeTest::new(&test_env) + .with_failure_code(200) + .test(req_gen, organization_read) + .await + .unwrap(); + assert!( + failure.as_object().unwrap()["members"].as_array().unwrap()[0] + .as_object() + .unwrap()["permissions"] + .is_null() + ); + assert!( + !success.as_object().unwrap()["members"].as_array().unwrap()[0] + .as_object() + .unwrap()["permissions"] + .is_null() + ); + + let req_gen = || { + test::TestRequest::get().uri(&format!( + "/v2/organizations?ids=[{}]", + urlencoding::encode(&format!("\"{}\"", organization_id)) + )) + }; + + let (failure, success) = ScopeTest::new(&test_env) + .with_failure_code(200) + .test(req_gen, organization_read) + .await + .unwrap(); + assert!( + failure.as_array().unwrap()[0].as_object().unwrap()["members"] + .as_array() + .unwrap()[0] + .as_object() + .unwrap()["permissions"] + .is_null() + ); + assert!( + !success.as_array().unwrap()[0].as_object().unwrap()["members"] + .as_array() + .unwrap()[0] + .as_object() + .unwrap()["permissions"] + .is_null() + ); + + let organization_project_read = Scopes::PROJECT_READ | Scopes::ORGANIZATION_READ; + let req_gen = + || test::TestRequest::get().uri(&format!("/v2/organization/{organization_id}/projects")); + let (failure, success) = ScopeTest::new(&test_env) + .with_failure_code(200) + .with_failure_scopes(Scopes::all() ^ Scopes::ORGANIZATION_READ) + .test(req_gen, organization_project_read) + .await + .unwrap(); + assert!(failure.as_array().unwrap().is_empty()); + assert!(!success.as_array().unwrap().is_empty()); + + // remove project (now that we've checked) + let req_gen = || { + test::TestRequest::delete().uri(&format!( + "/v2/organization/{organization_id}/projects/{beta_project_id}" + )) + }; + ScopeTest::new(&test_env) + .with_failure_scopes(Scopes::all() ^ Scopes::ORGANIZATION_WRITE) + .test(req_gen, organization_project_edit) + .await + .unwrap(); + + // Delete organization + let organization_delete = Scopes::ORGANIZATION_DELETE; + let req_gen = + || test::TestRequest::delete().uri(&format!("/v2/organization/{organization_id}")); + ScopeTest::new(&test_env) + .test(req_gen, organization_delete) + .await + .unwrap(); + + // Cleanup test db + test_env.cleanup().await; +} + +// TODO: Analytics scopes + +// TODO: User authentication, and Session scopes + +// TODO: Some hash/version files functions + +// TODO: Meta pat stuff + +// TODO: Image scopes