diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 00000000..3b187dd0 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,13 @@ +comment: false + +coverage: + status: + project: + default: + threshold: 60% # make CI green + patch: + default: + threshold: 60% # make CI green + +ignore: # ignore code coverage on following paths + - "**/tests" \ No newline at end of file diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml new file mode 100644 index 00000000..6e67a3ac --- /dev/null +++ b/.github/workflows/coverage.yml @@ -0,0 +1,44 @@ +name: Coverage-Tarpaulin + +env: + CARGO_TERM_COLOR: always + SQLX_OFFLINE: true + +on: + push: + branches: [ master ] + # Uncomment to allow PRs to trigger the workflow + # pull_request: + # branches: [ master ] +jobs: + citarp: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + # Start Docker Compose + - name: Start Docker Compose + run: docker-compose up -d + + - name: Install cargo tarpaulin + uses: taiki-e/install-action@cargo-tarpaulin + - name: Generate code coverage + run: | + cargo tarpaulin --verbose --all-features --timeout 120 --out xml + env: + BACKBLAZE_BUCKET_ID: ${{ secrets.BACKBLAZE_BUCKET_ID }} + BACKBLAZE_KEY: ${{ secrets.BACKBLAZE_KEY }} + BACKBLAZE_KEY_ID: ${{ secrets.BACKBLAZE_KEY_ID }} + S3_ACCESS_TOKEN: ${{ secrets.S3_ACCESS_TOKEN }} + S3_SECRET: ${{ secrets.S3_SECRET }} + S3_URL: ${{ secrets.S3_URL }} + S3_REGION: ${{ secrets.S3_REGION }} + S3_BUCKET_NAME: ${{ secrets.S3_BUCKET_NAME }} + SQLX_OFFLINE: true + DATABASE_URL: postgresql://labrinth:labrinth@localhost/postgres + + - name: Upload to codecov.io + uses: codecov/codecov-action@v2 + with: + # token: ${{secrets.CODECOV_TOKEN}} # not required for public repos + fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index 0fcb7d0a..516893c6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +codecov.json + # Created by https://www.gitignore.io/api/rust,clion # Edit at https://www.gitignore.io/?templates=rust,clion diff --git a/Cargo.toml b/Cargo.toml index bb38733c..040ed0c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -109,3 +109,9 @@ derive-new = "0.5.9" [dev-dependencies] actix-http = "3.4.0" + +[profile.dev] +opt-level = 0 # Minimal optimization, speeds up compilation +lto = false # Disables Link Time Optimization +incremental = true # Enables incremental compilation +codegen-units = 16 # Higher number can improve compile times but reduce runtime performance diff --git a/migrations/20231122111700_adds_missing_loader_field_loaders.sql b/migrations/20231122111700_adds_missing_loader_field_loaders.sql new file mode 100644 index 00000000..8747a2cb --- /dev/null +++ b/migrations/20231122111700_adds_missing_loader_field_loaders.sql @@ -0,0 +1,45 @@ + +-- Adds missing fields to loader_fields_loaders +INSERT INTO loader_fields_loaders (loader_id, loader_field_id) +SELECT l.id, lf.id FROM loaders l CROSS JOIN loader_fields lf WHERE lf.field = 'game_versions' +AND l.loader = ANY( ARRAY['forge', 'fabric', 'quilt', 'modloader','rift','liteloader', 'neoforge']) +ON CONFLICT (loader_id, loader_field_id) DO NOTHING; + +-- Fixes mrpack variants being added to the wrong enum +-- Luckily, mrpack variants are the only ones set to 2 without metadata +UPDATE loader_field_enum_values SET enum_id = 3 WHERE enum_id = 2 AND metadata IS NULL; + +-- Because it was mislabeled, version_fields for mrpack_loaders were set to null. +-- 1) Update version_fields corresponding to mrpack_loaders to the correct enum_value +UPDATE version_fields vf +SET enum_value = subquery.lfev_id +FROM ( + SELECT vf.version_id, vf.field_id, lfev.id AS lfev_id + FROM version_fields vf + LEFT JOIN versions v ON v.id = vf.version_id + LEFT JOIN loaders_versions lv ON v.id = lv.version_id + LEFT JOIN loaders l ON l.id = lv.loader_id + LEFT JOIN loader_fields lf ON lf.id = vf.field_id + LEFT JOIN loader_field_enum_values lfev ON lfev.value = l.loader AND lf.enum_type = lfev.enum_id + WHERE lf.field = 'mrpack_loaders' AND vf.enum_value IS NULL +) AS subquery +WHERE vf.version_id = subquery.version_id AND vf.field_id = subquery.field_id; + +-- 2) Set those versions to mrpack as their version +INSERT INTO loaders_versions (version_id, loader_id) +SELECT DISTINCT vf.version_id, l.id +FROM version_fields vf +LEFT JOIN loader_fields lf ON lf.id = vf.field_id +CROSS JOIN loaders l +WHERE lf.field = 'mrpack_loaders' +AND l.loader = 'mrpack' +ON CONFLICT DO NOTHING; + +-- 3) Delete the old versions that had mrpack added to them +DELETE FROM loaders_versions lv +WHERE lv.loader_id != (SELECT id FROM loaders WHERE loader = 'mrpack') +AND lv.version_id IN ( + SELECT version_id + FROM loaders_versions + WHERE loader_id = (SELECT id FROM loaders WHERE loader = 'mrpack') +); diff --git a/src/database/models/categories.rs b/src/database/models/categories.rs index 95d054f2..6205fab8 100644 --- a/src/database/models/categories.rs +++ b/src/database/models/categories.rs @@ -90,6 +90,8 @@ impl Category { where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { + let mut redis = redis.connect().await?; + let res: Option> = redis .get_deserialized_from_json(TAGS_NAMESPACE, "category") .await?; @@ -155,6 +157,8 @@ impl DonationPlatform { where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { + let mut redis = redis.connect().await?; + let res: Option> = redis .get_deserialized_from_json(TAGS_NAMESPACE, "donation_platform") .await?; @@ -209,6 +213,8 @@ impl ReportType { where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { + let mut redis = redis.connect().await?; + let res: Option> = redis .get_deserialized_from_json(TAGS_NAMESPACE, "report_type") .await?; @@ -257,6 +263,8 @@ impl ProjectType { where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { + let mut redis = redis.connect().await?; + let res: Option> = redis .get_deserialized_from_json(TAGS_NAMESPACE, "project_type") .await?; diff --git a/src/database/models/collection_item.rs b/src/database/models/collection_item.rs index d000e2ce..4a4f7424 100644 --- a/src/database/models/collection_item.rs +++ b/src/database/models/collection_item.rs @@ -157,6 +157,8 @@ impl Collection { { use futures::TryStreamExt; + let mut redis = redis.connect().await?; + if collection_ids.is_empty() { return Ok(Vec::new()); } @@ -166,7 +168,10 @@ impl Collection { if !collection_ids.is_empty() { let collections = redis - .multi_get::(COLLECTIONS_NAMESPACE, collection_ids.iter().map(|x| x.0)) + .multi_get::( + COLLECTIONS_NAMESPACE, + collection_ids.iter().map(|x| x.0.to_string()), + ) .await?; for collection in collections { @@ -240,6 +245,8 @@ impl Collection { } pub async fn clear_cache(id: CollectionId, redis: &RedisPool) -> Result<(), DatabaseError> { + let mut redis = redis.connect().await?; + redis.delete(COLLECTIONS_NAMESPACE, id.0).await?; Ok(()) } diff --git a/src/database/models/flow_item.rs b/src/database/models/flow_item.rs index fe81e4a8..22d30895 100644 --- a/src/database/models/flow_item.rs +++ b/src/database/models/flow_item.rs @@ -58,6 +58,8 @@ impl Flow { expires: Duration, redis: &RedisPool, ) -> Result { + let mut redis = redis.connect().await?; + let flow = ChaCha20Rng::from_entropy() .sample_iter(&Alphanumeric) .take(32) @@ -71,6 +73,8 @@ impl Flow { } pub async fn get(id: &str, redis: &RedisPool) -> Result, DatabaseError> { + let mut redis = redis.connect().await?; + redis.get_deserialized_from_json(FLOWS_NAMESPACE, id).await } @@ -91,6 +95,8 @@ impl Flow { } pub async fn remove(id: &str, redis: &RedisPool) -> Result, DatabaseError> { + let mut redis = redis.connect().await?; + redis.delete(FLOWS_NAMESPACE, id).await?; Ok(Some(())) } diff --git a/src/database/models/image_item.rs b/src/database/models/image_item.rs index 34badd65..68477304 100644 --- a/src/database/models/image_item.rs +++ b/src/database/models/image_item.rs @@ -180,6 +180,7 @@ impl Image { { use futures::TryStreamExt; + let mut redis = redis.connect().await?; if image_ids.is_empty() { return Ok(Vec::new()); } @@ -191,7 +192,7 @@ impl Image { if !image_ids.is_empty() { let images = redis - .multi_get::(IMAGES_NAMESPACE, image_ids) + .multi_get::(IMAGES_NAMESPACE, image_ids.iter().map(|x| x.to_string())) .await?; for image in images { if let Some(image) = image.and_then(|x| serde_json::from_str::(&x).ok()) { @@ -246,6 +247,8 @@ impl Image { } pub async fn clear_cache(id: ImageId, redis: &RedisPool) -> Result<(), DatabaseError> { + let mut redis = redis.connect().await?; + redis.delete(IMAGES_NAMESPACE, id.0).await?; Ok(()) } diff --git a/src/database/models/loader_fields.rs b/src/database/models/loader_fields.rs index 6fc06f72..87edf8cc 100644 --- a/src/database/models/loader_fields.rs +++ b/src/database/models/loader_fields.rs @@ -44,6 +44,7 @@ impl Game { where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { + let mut redis = redis.connect().await?; let cached_games: Option> = redis .get_deserialized_from_json(GAMES_LIST_NAMESPACE, "games") .await?; @@ -96,6 +97,7 @@ impl Loader { where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { + let mut redis = redis.connect().await?; let cached_id: Option = redis.get_deserialized_from_json(LOADER_ID, name).await?; if let Some(cached_id) = cached_id { return Ok(Some(LoaderId(cached_id))); @@ -125,6 +127,7 @@ impl Loader { where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { + let mut redis = redis.connect().await?; let cached_loaders: Option> = redis .get_deserialized_from_json(LOADERS_LIST_NAMESPACE, "all") .await?; @@ -321,9 +324,11 @@ impl LoaderField { { type RedisLoaderFieldTuple = (LoaderId, Vec); + let mut redis = redis.connect().await?; + let mut loader_ids = loader_ids.to_vec(); let cached_fields: Vec = redis - .multi_get::(LOADER_FIELDS_NAMESPACE, loader_ids.iter().map(|x| x.0)) + .multi_get::(LOADER_FIELDS_NAMESPACE, loader_ids.iter().map(|x| x.0)) .await? .into_iter() .flatten() @@ -402,6 +407,8 @@ impl LoaderFieldEnum { where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { + let mut redis = redis.connect().await?; + let cached_enum = redis .get_deserialized_from_json(LOADER_FIELD_ENUMS_ID_NAMESPACE, enum_name) .await?; @@ -491,12 +498,13 @@ impl LoaderFieldEnumValue { where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { + let mut redis = redis.connect().await?; let mut found_enums = Vec::new(); let mut remaining_enums: Vec = loader_field_enum_ids.to_vec(); if !remaining_enums.is_empty() { let enums = redis - .multi_get::( + .multi_get::( LOADER_FIELD_ENUM_VALUES_NAMESPACE, loader_field_enum_ids.iter().map(|x| x.0), ) diff --git a/src/database/models/notification_item.rs b/src/database/models/notification_item.rs index 2b15a4bd..2bc89fec 100644 --- a/src/database/models/notification_item.rs +++ b/src/database/models/notification_item.rs @@ -174,8 +174,10 @@ impl Notification { where E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy, { + let mut redis = redis.connect().await?; + let cached_notifications: Option> = redis - .get_deserialized_from_json(USER_NOTIFICATIONS_NAMESPACE, user_id.0) + .get_deserialized_from_json(USER_NOTIFICATIONS_NAMESPACE, &user_id.0.to_string()) .await?; if let Some(notifications) = cached_notifications { @@ -319,6 +321,8 @@ impl Notification { user_ids: impl IntoIterator, redis: &RedisPool, ) -> Result<(), DatabaseError> { + let mut redis = redis.connect().await?; + redis .delete_many( user_ids diff --git a/src/database/models/organization_item.rs b/src/database/models/organization_item.rs index f92622df..137d7ae0 100644 --- a/src/database/models/organization_item.rs +++ b/src/database/models/organization_item.rs @@ -103,6 +103,8 @@ impl Organization { { use futures::stream::TryStreamExt; + let mut redis = redis.connect().await?; + if organization_strings.is_empty() { return Ok(Vec::new()); } @@ -120,11 +122,12 @@ impl Organization { organization_ids.append( &mut redis - .multi_get::( + .multi_get::( ORGANIZATIONS_TITLES_NAMESPACE, organization_strings .iter() - .map(|x| x.to_string().to_lowercase()), + .map(|x| x.to_string().to_lowercase()) + .collect::>(), ) .await? .into_iter() @@ -134,7 +137,10 @@ impl Organization { if !organization_ids.is_empty() { let organizations = redis - .multi_get::(ORGANIZATIONS_NAMESPACE, organization_ids) + .multi_get::( + ORGANIZATIONS_NAMESPACE, + organization_ids.iter().map(|x| x.to_string()), + ) .await?; for organization in organizations { @@ -197,8 +203,8 @@ impl Organization { redis .set( ORGANIZATIONS_TITLES_NAMESPACE, - organization.title.to_lowercase(), - organization.id.0, + &organization.title.to_lowercase(), + &organization.id.0.to_string(), None, ) .await?; @@ -318,6 +324,8 @@ impl Organization { title: Option, redis: &RedisPool, ) -> Result<(), super::DatabaseError> { + let mut redis = redis.connect().await?; + redis .delete_many([ (ORGANIZATIONS_NAMESPACE, Some(id.0.to_string())), diff --git a/src/database/models/pat_item.rs b/src/database/models/pat_item.rs index fc2432ae..9352d637 100644 --- a/src/database/models/pat_item.rs +++ b/src/database/models/pat_item.rs @@ -89,6 +89,8 @@ impl PersonalAccessToken { { use futures::TryStreamExt; + let mut redis = redis.connect().await?; + if pat_strings.is_empty() { return Ok(Vec::new()); } @@ -106,7 +108,7 @@ impl PersonalAccessToken { pat_ids.append( &mut redis - .multi_get::( + .multi_get::( PATS_TOKENS_NAMESPACE, pat_strings.iter().map(|x| x.to_string()), ) @@ -118,7 +120,7 @@ impl PersonalAccessToken { if !pat_ids.is_empty() { let pats = redis - .multi_get::(PATS_NAMESPACE, pat_ids) + .multi_get::(PATS_NAMESPACE, pat_ids.iter().map(|x| x.to_string())) .await?; for pat in pats { if let Some(pat) = @@ -174,8 +176,8 @@ impl PersonalAccessToken { redis .set( PATS_TOKENS_NAMESPACE, - pat.access_token.clone(), - pat.id.0, + &pat.access_token, + &pat.id.0.to_string(), None, ) .await?; @@ -194,8 +196,10 @@ impl PersonalAccessToken { where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { + let mut redis = redis.connect().await?; + let res = redis - .get_deserialized_from_json::, _>(PATS_USERS_NAMESPACE, user_id.0) + .get_deserialized_from_json::>(PATS_USERS_NAMESPACE, &user_id.0.to_string()) .await?; if let Some(res) = res { @@ -220,8 +224,8 @@ impl PersonalAccessToken { redis .set( PATS_USERS_NAMESPACE, - user_id.0, - serde_json::to_string(&db_pats)?, + &user_id.0.to_string(), + &serde_json::to_string(&db_pats)?, None, ) .await?; @@ -232,6 +236,8 @@ impl PersonalAccessToken { clear_pats: Vec<(Option, Option, Option)>, redis: &RedisPool, ) -> Result<(), DatabaseError> { + let mut redis = redis.connect().await?; + if clear_pats.is_empty() { return Ok(()); } diff --git a/src/database/models/project_item.rs b/src/database/models/project_item.rs index 61dd2464..6be0f01b 100644 --- a/src/database/models/project_item.rs +++ b/src/database/models/project_item.rs @@ -513,6 +513,8 @@ impl Project { return Ok(Vec::new()); } + let mut redis = redis.connect().await?; + let mut found_projects = Vec::new(); let mut remaining_strings = project_strings .iter() @@ -526,7 +528,7 @@ impl Project { project_ids.append( &mut redis - .multi_get::( + .multi_get::( PROJECTS_SLUGS_NAMESPACE, project_strings.iter().map(|x| x.to_string().to_lowercase()), ) @@ -537,7 +539,10 @@ impl Project { ); if !project_ids.is_empty() { let projects = redis - .multi_get::(PROJECTS_NAMESPACE, project_ids) + .multi_get::( + PROJECTS_NAMESPACE, + project_ids.iter().map(|x| x.to_string()), + ) .await?; for project in projects { if let Some(project) = @@ -686,8 +691,8 @@ impl Project { redis .set( PROJECTS_SLUGS_NAMESPACE, - slug.to_lowercase(), - project.inner.id.0, + &slug.to_lowercase(), + &project.inner.id.0.to_string(), None, ) .await?; @@ -709,8 +714,13 @@ impl Project { { type Dependencies = Vec<(Option, Option, Option)>; + let mut redis = redis.connect().await?; + let dependencies = redis - .get_deserialized_from_json::(PROJECTS_DEPENDENCIES_NAMESPACE, id.0) + .get_deserialized_from_json::( + PROJECTS_DEPENDENCIES_NAMESPACE, + &id.0.to_string(), + ) .await?; if let Some(dependencies) = dependencies { return Ok(dependencies); @@ -755,6 +765,8 @@ impl Project { clear_dependencies: Option, redis: &RedisPool, ) -> Result<(), DatabaseError> { + let mut redis = redis.connect().await?; + redis .delete_many([ (PROJECTS_NAMESPACE, Some(id.0.to_string())), diff --git a/src/database/models/session_item.rs b/src/database/models/session_item.rs index ff9a874e..f27af5bb 100644 --- a/src/database/models/session_item.rs +++ b/src/database/models/session_item.rs @@ -130,6 +130,8 @@ impl Session { { use futures::TryStreamExt; + let mut redis = redis.connect().await?; + if session_strings.is_empty() { return Ok(Vec::new()); } @@ -147,7 +149,7 @@ impl Session { session_ids.append( &mut redis - .multi_get::( + .multi_get::( SESSIONS_IDS_NAMESPACE, session_strings.iter().map(|x| x.to_string()), ) @@ -159,7 +161,10 @@ impl Session { if !session_ids.is_empty() { let sessions = redis - .multi_get::(SESSIONS_NAMESPACE, session_ids) + .multi_get::( + SESSIONS_NAMESPACE, + session_ids.iter().map(|x| x.to_string()), + ) .await?; for session in sessions { if let Some(session) = @@ -218,8 +223,8 @@ impl Session { redis .set( SESSIONS_IDS_NAMESPACE, - session.session.clone(), - session.id.0, + &session.session, + &session.id.0.to_string(), None, ) .await?; @@ -238,8 +243,13 @@ impl Session { where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { + let mut redis = redis.connect().await?; + let res = redis - .get_deserialized_from_json::, _>(SESSIONS_USERS_NAMESPACE, user_id.0) + .get_deserialized_from_json::>( + SESSIONS_USERS_NAMESPACE, + &user_id.0.to_string(), + ) .await?; if let Some(res) = res { @@ -272,6 +282,8 @@ impl Session { clear_sessions: Vec<(Option, Option, Option)>, redis: &RedisPool, ) -> Result<(), DatabaseError> { + let mut redis = redis.connect().await?; + if clear_sessions.is_empty() { return Ok(()); } diff --git a/src/database/models/team_item.rs b/src/database/models/team_item.rs index a513aefe..a0a92f70 100644 --- a/src/database/models/team_item.rs +++ b/src/database/models/team_item.rs @@ -197,18 +197,23 @@ impl TeamMember { where E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy, { + use futures::stream::TryStreamExt; + if team_ids.is_empty() { return Ok(Vec::new()); } - use futures::stream::TryStreamExt; + let mut redis = redis.connect().await?; let mut team_ids_parsed: Vec = team_ids.iter().map(|x| x.0).collect(); let mut found_teams = Vec::new(); let teams = redis - .multi_get::(TEAMS_NAMESPACE, team_ids_parsed.clone()) + .multi_get::( + TEAMS_NAMESPACE, + team_ids_parsed.iter().map(|x| x.to_string()), + ) .await?; for team_raw in teams { @@ -271,6 +276,7 @@ impl TeamMember { } pub async fn clear_cache(id: TeamId, redis: &RedisPool) -> Result<(), super::DatabaseError> { + let mut redis = redis.connect().await?; redis.delete(TEAMS_NAMESPACE, id.0).await?; Ok(()) } diff --git a/src/database/models/user_item.rs b/src/database/models/user_item.rs index 5ab27abe..8230ff58 100644 --- a/src/database/models/user_item.rs +++ b/src/database/models/user_item.rs @@ -134,6 +134,8 @@ impl User { { use futures::TryStreamExt; + let mut redis = redis.connect().await?; + if users_strings.is_empty() { return Ok(Vec::new()); } @@ -151,7 +153,7 @@ impl User { user_ids.append( &mut redis - .multi_get::( + .multi_get::( USER_USERNAMES_NAMESPACE, users_strings.iter().map(|x| x.to_string().to_lowercase()), ) @@ -163,7 +165,7 @@ impl User { if !user_ids.is_empty() { let users = redis - .multi_get::(USERS_NAMESPACE, user_ids) + .multi_get::(USERS_NAMESPACE, user_ids.iter().map(|x| x.to_string())) .await?; for user in users { if let Some(user) = user.and_then(|x| serde_json::from_str::(&x).ok()) { @@ -239,8 +241,8 @@ impl User { redis .set( USER_USERNAMES_NAMESPACE, - user.username.to_lowercase(), - user.id.0, + &user.username.to_lowercase(), + &user.id.0.to_string(), None, ) .await?; @@ -278,8 +280,13 @@ impl User { { use futures::stream::TryStreamExt; + let mut redis = redis.connect().await?; + let cached_projects = redis - .get_deserialized_from_json::, _>(USERS_PROJECTS_NAMESPACE, user_id.0) + .get_deserialized_from_json::>( + USERS_PROJECTS_NAMESPACE, + &user_id.0.to_string(), + ) .await?; if let Some(projects) = cached_projects { @@ -384,6 +391,8 @@ impl User { user_ids: &[(UserId, Option)], redis: &RedisPool, ) -> Result<(), DatabaseError> { + let mut redis = redis.connect().await?; + redis .delete_many(user_ids.iter().flat_map(|(id, username)| { [ @@ -402,6 +411,8 @@ impl User { user_ids: &[UserId], redis: &RedisPool, ) -> Result<(), DatabaseError> { + let mut redis = redis.connect().await?; + redis .delete_many( user_ids diff --git a/src/database/models/version_item.rs b/src/database/models/version_item.rs index b01106e0..3e3fc3ec 100644 --- a/src/database/models/version_item.rs +++ b/src/database/models/version_item.rs @@ -492,18 +492,27 @@ impl Version { where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { + use futures::stream::TryStreamExt; + if version_ids.is_empty() { return Ok(Vec::new()); } - use futures::stream::TryStreamExt; + let mut redis = redis.connect().await?; let mut version_ids_parsed: Vec = version_ids.iter().map(|x| x.0).collect(); let mut found_versions = Vec::new(); let versions = redis - .multi_get::(VERSIONS_NAMESPACE, version_ids_parsed.clone()) + .multi_get::( + VERSIONS_NAMESPACE, + version_ids_parsed + .clone() + .iter() + .map(|x| x.to_string()) + .collect::>(), + ) .await?; for version in versions { @@ -721,18 +730,20 @@ impl Version { where E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy, { + use futures::stream::TryStreamExt; + + let mut redis = redis.connect().await?; + if hashes.is_empty() { return Ok(Vec::new()); } - use futures::stream::TryStreamExt; - let mut file_ids_parsed = hashes.to_vec(); let mut found_files = Vec::new(); let files = redis - .multi_get::( + .multi_get::( VERSION_FILES_NAMESPACE, file_ids_parsed .iter() @@ -829,6 +840,8 @@ impl Version { version: &QueryVersion, redis: &RedisPool, ) -> Result<(), DatabaseError> { + let mut redis = redis.connect().await?; + redis .delete_many( iter::once((VERSIONS_NAMESPACE, Some(version.inner.id.0.to_string()))).chain( diff --git a/src/database/redis.rs b/src/database/redis.rs index 2a517264..f121e3e9 100644 --- a/src/database/redis.rs +++ b/src/database/redis.rs @@ -1,6 +1,7 @@ use super::models::DatabaseError; use deadpool_redis::{Config, Runtime}; -use redis::{cmd, FromRedisValue, ToRedisArgs}; +use itertools::Itertools; +use redis::{cmd, Cmd}; use std::fmt::Display; const DEFAULT_EXPIRY: i64 = 1800; // 30 minutes @@ -11,6 +12,11 @@ pub struct RedisPool { meta_namespace: String, } +pub struct RedisConnection { + pub connection: deadpool_redis::Connection, + meta_namespace: String, +} + impl RedisPool { // initiate a new redis pool // testing pool uses a hashmap to mimic redis behaviour for very small data sizes (ie: tests) @@ -35,32 +41,39 @@ impl RedisPool { } } - pub async fn set( - &self, + pub async fn connect(&self) -> Result { + Ok(RedisConnection { + connection: self.pool.get().await?, + meta_namespace: self.meta_namespace.clone(), + }) + } +} + +impl RedisConnection { + pub async fn set( + &mut self, namespace: &str, - id: T1, - data: T2, + id: &str, + data: &str, expiry: Option, - ) -> Result<(), DatabaseError> - where - T1: Display, - T2: ToRedisArgs, - { - let mut redis_connection = self.pool.get().await?; - - cmd("SET") - .arg(format!("{}_{}:{}", self.meta_namespace, namespace, id)) - .arg(data) - .arg("EX") - .arg(expiry.unwrap_or(DEFAULT_EXPIRY)) - .query_async::<_, ()>(&mut redis_connection) - .await?; - + ) -> Result<(), DatabaseError> { + let mut cmd = cmd("SET"); + redis_args( + &mut cmd, + vec![ + format!("{}_{}:{}", self.meta_namespace, namespace, id), + data.to_string(), + "EX".to_string(), + expiry.unwrap_or(DEFAULT_EXPIRY).to_string(), + ] + .as_slice(), + ); + redis_execute(&mut cmd, &mut self.connection).await?; Ok(()) } pub async fn set_serialized_to_json( - &self, + &mut self, namespace: &str, id: Id, data: D, @@ -70,92 +83,116 @@ impl RedisPool { Id: Display, D: serde::Serialize, { - self.set(namespace, id, serde_json::to_string(&data)?, expiry) - .await + self.set( + namespace, + &id.to_string(), + &serde_json::to_string(&data)?, + expiry, + ) + .await } - pub async fn get(&self, namespace: &str, id: Id) -> Result, DatabaseError> - where - Id: Display, - R: FromRedisValue, - { - let mut redis_connection = self.pool.get().await?; - - let res = cmd("GET") - .arg(format!("{}_{}:{}", self.meta_namespace, namespace, id)) - .query_async::<_, Option>(&mut redis_connection) - .await?; + pub async fn get( + &mut self, + namespace: &str, + id: &str, + ) -> Result, DatabaseError> { + let mut cmd = cmd("GET"); + redis_args( + &mut cmd, + vec![format!("{}_{}:{}", self.meta_namespace, namespace, id)].as_slice(), + ); + let res = redis_execute(&mut cmd, &mut self.connection).await?; Ok(res) } - pub async fn get_deserialized_from_json( - &self, + pub async fn get_deserialized_from_json( + &mut self, namespace: &str, - id: Id, + id: &str, ) -> Result, DatabaseError> where - Id: Display, R: for<'a> serde::Deserialize<'a>, { Ok(self - .get::(namespace, id) + .get(namespace, id) .await? .and_then(|x| serde_json::from_str(&x).ok())) } - pub async fn multi_get( - &self, + pub async fn multi_get( + &mut self, namespace: &str, - ids: impl IntoIterator, + ids: impl IntoIterator, ) -> Result>, DatabaseError> where - T1: Display, - R: FromRedisValue, + R: for<'a> serde::Deserialize<'a>, { - let mut redis_connection = self.pool.get().await?; - let res = cmd("MGET") - .arg( - ids.into_iter() - .map(|x| format!("{}_{}:{}", self.meta_namespace, namespace, x)) - .collect::>(), - ) - .query_async::<_, Vec>>(&mut redis_connection) - .await?; - Ok(res) + let mut cmd = cmd("MGET"); + + redis_args( + &mut cmd, + &ids.into_iter() + .map(|x| format!("{}_{}:{}", self.meta_namespace, namespace, x)) + .collect_vec(), + ); + let res: Vec> = redis_execute(&mut cmd, &mut self.connection).await?; + Ok(res + .into_iter() + .map(|x| x.and_then(|x| serde_json::from_str(&x).ok())) + .collect()) } - pub async fn delete(&self, namespace: &str, id: T1) -> Result<(), DatabaseError> + pub async fn delete(&mut self, namespace: &str, id: T1) -> Result<(), DatabaseError> where T1: Display, { - let mut redis_connection = self.pool.get().await?; - - cmd("DEL") - .arg(format!("{}_{}:{}", self.meta_namespace, namespace, id)) - .query_async::<_, ()>(&mut redis_connection) - .await?; - + let mut cmd = cmd("DEL"); + redis_args( + &mut cmd, + vec![format!("{}_{}:{}", self.meta_namespace, namespace, id)].as_slice(), + ); + redis_execute(&mut cmd, &mut self.connection).await?; Ok(()) } pub async fn delete_many( - &self, + &mut self, iter: impl IntoIterator)>, ) -> Result<(), DatabaseError> { let mut cmd = cmd("DEL"); let mut any = false; for (namespace, id) in iter { if let Some(id) = id { - cmd.arg(format!("{}_{}:{}", self.meta_namespace, namespace, id)); + redis_args( + &mut cmd, + [format!("{}_{}:{}", self.meta_namespace, namespace, id)].as_slice(), + ); any = true; } } if any { - let mut redis_connection = self.pool.get().await?; - cmd.query_async::<_, ()>(&mut redis_connection).await?; + redis_execute(&mut cmd, &mut self.connection).await?; } Ok(()) } } + +pub fn redis_args(cmd: &mut Cmd, args: &[String]) { + for arg in args { + cmd.arg(arg); + } +} + +pub async fn redis_execute( + cmd: &mut Cmd, + redis: &mut deadpool_redis::Connection, +) -> Result +where + T: redis::FromRedisValue, +{ + let res = cmd.query_async::<_, T>(redis).await?; + Ok(res) +} diff --git a/src/main.rs b/src/main.rs index 4b25580e..7aff4d60 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,6 +17,7 @@ pub struct Pepper { pub pepper: String, } +#[cfg(not(tarpaulin_include))] #[actix_rt::main] async fn main() -> std::io::Result<()> { dotenvy::dotenv().ok(); diff --git a/src/routes/v2/projects.rs b/src/routes/v2/projects.rs index 8755b95c..c1a75db7 100644 --- a/src/routes/v2/projects.rs +++ b/src/routes/v2/projects.rs @@ -117,14 +117,10 @@ pub async fn random_projects_get( let response = v3::projects::random_projects_get(web::Query(count), pool.clone(), redis.clone()).await?; // Convert response to V2 format - match v2_reroute::extract_ok_json::(response).await { + match v2_reroute::extract_ok_json::>(response).await { Ok(project) => { - let version_item = match project.versions.first() { - Some(vid) => version_item::Version::get((*vid).into(), &**pool, &redis).await?, - None => None, - }; - let project = LegacyProject::from(project, version_item); - Ok(HttpResponse::Ok().json(project)) + let legacy_projects = LegacyProject::from_many(project, &**pool, &redis).await?; + Ok(HttpResponse::Ok().json(legacy_projects)) } Err(response) => Ok(response), } diff --git a/src/routes/v2/version_creation.rs b/src/routes/v2/version_creation.rs index 3e1de740..4bfa9613 100644 --- a/src/routes/v2/version_creation.rs +++ b/src/routes/v2/version_creation.rs @@ -100,14 +100,33 @@ pub async fn version_create( fields.insert("client_side".to_string(), json!("required")); fields.insert("server_side".to_string(), json!("optional")); - // TODO: Some kind of handling here to ensure project type is fine. - // We expect the version uploaded to be of loader type modpack, but there might not be a way to check here for that. - // After all, theoretically, they could be creating a genuine 'fabric' mod, and modpack no longer carries information on whether its a mod or modpack, - // as those are out to the versions. + // Handle project type via file extension prediction + let mut project_type = None; + for file_part in &legacy_create.file_parts { + if let Some(ext) = file_part.split('.').last() { + match ext { + "mrpack" => { + project_type = Some("modpack"); + break; + } + // No other type matters + _ => {} + } + break; + } + } - // Ideally this would, if the project 'should' be a modpack: - // - change the loaders to mrpack only - // - add loader fields to the project for the corresponding loaders + // Modpacks now use the "mrpack" loader, and loaders are converted to loader fields. + // Setting of 'project_type' directly is removed, it's loader-based now. + if project_type == Some("modpack") { + fields.insert("mrpack_loaders".to_string(), json!(legacy_create.loaders)); + } + + let loaders = if project_type == Some("modpack") { + vec![Loader("mrpack".to_string())] + } else { + legacy_create.loaders + }; Ok(v3::version_creation::InitialVersionData { project_id: legacy_create.project_id, @@ -117,7 +136,7 @@ pub async fn version_create( version_body: legacy_create.version_body, dependencies: legacy_create.dependencies, release_channel: legacy_create.release_channel, - loaders: legacy_create.loaders, + loaders, featured: legacy_create.featured, primary_file: legacy_create.primary_file, status: legacy_create.status, diff --git a/src/routes/v3/analytics_get.rs b/src/routes/v3/analytics_get.rs index dc31c69c..ee12e02e 100644 --- a/src/routes/v3/analytics_get.rs +++ b/src/routes/v3/analytics_get.rs @@ -1,8 +1,10 @@ use super::ApiError; +use crate::database; use crate::database::redis::RedisPool; +use crate::models::teams::ProjectPermissions; use crate::{ - auth::{filter_authorized_projects, filter_authorized_versions, get_user_from_headers}, - database::models::{project_item, user_item, version_item}, + auth::get_user_from_headers, + database::models::user_item, models::{ ids::{ base62_impl::{parse_base62, to_base62}, @@ -351,6 +353,7 @@ pub async fn revenue_get( .try_into() .map_err(|_| ApiError::InvalidInput("Invalid resolution_minutes".to_string()))?; // Get the revenue data + let project_ids = project_ids.unwrap_or_default(); let payouts_values = sqlx::query!( " SELECT mod_id, SUM(amount) amount_sum, DATE_BIN($4::interval, created, TIMESTAMP '2001-01-01') AS interval_start @@ -358,7 +361,7 @@ pub async fn revenue_get( WHERE mod_id = ANY($1) AND created BETWEEN $2 AND $3 GROUP by mod_id, interval_start ORDER BY interval_start ", - &project_ids.unwrap_or_default().into_iter().map(|x| x.0 as i64).collect::>(), + &project_ids.iter().map(|x| x.0 as i64).collect::>(), start_date, end_date, duration, @@ -366,7 +369,10 @@ pub async fn revenue_get( .fetch_all(&**pool) .await?; - let mut hm = HashMap::new(); + let mut hm: HashMap<_, _> = project_ids + .into_iter() + .map(|x| (x.to_string(), HashMap::new())) + .collect::>(); for value in payouts_values { if let Some(mod_id) = value.mod_id { if let Some(amount) = value.amount_sum { @@ -559,7 +565,7 @@ async fn filter_allowed_ids( )); } - // If no project_ids or version_ids are provided, we default to all projects the user has access to + // If no project_ids or version_ids are provided, we default to all projects the user has *public* access to if project_ids.is_none() && version_ids.is_none() { project_ids = Some( user_item::User::get_projects(user.id.into(), &***pool, redis) @@ -572,35 +578,154 @@ async fn filter_allowed_ids( // Convert String list to list of ProjectIds or VersionIds // - Filter out unauthorized projects/versions + let project_ids = if let Some(project_strings) = project_ids { + let projects_data = + database::models::Project::get_many(&project_strings, &***pool, redis).await?; - let project_ids = if let Some(project_ids) = project_ids { - // Submitted project_ids are filtered by the user's permissions - let ids = project_ids + let team_ids = projects_data .iter() - .map(|id| Ok(ProjectId(parse_base62(id)?).into())) - .collect::, ApiError>>()?; - let projects = project_item::Project::get_many_ids(&ids, &***pool, redis).await?; - let ids: Vec = filter_authorized_projects(projects, &Some(user.clone()), pool) - .await? + .map(|x| x.inner.team_id) + .collect::>(); + let team_members = + database::models::TeamMember::get_from_team_full_many(&team_ids, &***pool, redis) + .await?; + + let organization_ids = projects_data + .iter() + .filter_map(|x| x.inner.organization_id) + .collect::>(); + let organizations = + database::models::Organization::get_many_ids(&organization_ids, &***pool, redis) + .await?; + + let organization_team_ids = organizations + .iter() + .map(|x| x.team_id) + .collect::>(); + let organization_team_members = database::models::TeamMember::get_from_team_full_many( + &organization_team_ids, + &***pool, + redis, + ) + .await?; + + let ids = projects_data .into_iter() - .map(|x| x.id) + .filter(|project| { + let team_member = team_members + .iter() + .find(|x| x.team_id == project.inner.team_id && x.user_id == user.id.into()); + + let organization = project + .inner + .organization_id + .and_then(|oid| organizations.iter().find(|x| x.id == oid)); + + let organization_team_member = if let Some(organization) = organization { + organization_team_members + .iter() + .find(|x| x.team_id == organization.team_id && x.user_id == user.id.into()) + } else { + None + }; + + let permissions = ProjectPermissions::get_permissions_by_role( + &user.role, + &team_member.cloned(), + &organization_team_member.cloned(), + ) + .unwrap_or_default(); + + permissions.contains(ProjectPermissions::VIEW_ANALYTICS) + }) + .map(|x| x.inner.id.into()) .collect::>(); + Some(ids) } else { None }; + let version_ids = if let Some(version_ids) = version_ids { // Submitted version_ids are filtered by the user's permissions let ids = version_ids .iter() .map(|id| Ok(VersionId(parse_base62(id)?).into())) .collect::, ApiError>>()?; - let versions = version_item::Version::get_many(&ids, &***pool, redis).await?; - let ids: Vec = filter_authorized_versions(versions, &Some(user), pool) - .await? + let versions_data = database::models::Version::get_many(&ids, &***pool, redis).await?; + let project_ids = versions_data + .iter() + .map(|x| x.inner.project_id) + .collect::>(); + + let projects_data = + database::models::Project::get_many_ids(&project_ids, &***pool, redis).await?; + + let team_ids = projects_data + .iter() + .map(|x| x.inner.team_id) + .collect::>(); + let team_members = + database::models::TeamMember::get_from_team_full_many(&team_ids, &***pool, redis) + .await?; + + let organization_ids = projects_data + .iter() + .filter_map(|x| x.inner.organization_id) + .collect::>(); + let organizations = + database::models::Organization::get_many_ids(&organization_ids, &***pool, redis) + .await?; + + let organization_team_ids = organizations + .iter() + .map(|x| x.team_id) + .collect::>(); + let organization_team_members = database::models::TeamMember::get_from_team_full_many( + &organization_team_ids, + &***pool, + redis, + ) + .await?; + + let ids = projects_data + .into_iter() + .filter(|project| { + let team_member = team_members + .iter() + .find(|x| x.team_id == project.inner.team_id && x.user_id == user.id.into()); + + let organization = project + .inner + .organization_id + .and_then(|oid| organizations.iter().find(|x| x.id == oid)); + + let organization_team_member = if let Some(organization) = organization { + organization_team_members + .iter() + .find(|x| x.team_id == organization.team_id && x.user_id == user.id.into()) + } else { + None + }; + + let permissions = ProjectPermissions::get_permissions_by_role( + &user.role, + &team_member.cloned(), + &organization_team_member.cloned(), + ) + .unwrap_or_default(); + + permissions.contains(ProjectPermissions::VIEW_ANALYTICS) + }) + .map(|x| x.inner.id) + .collect::>(); + + let ids = versions_data .into_iter() - .map(|x| x.id) + .filter(|version| ids.contains(&version.inner.project_id)) + .map(|x| x.inner.id.into()) .collect::>(); + Some(ids) } else { None diff --git a/src/util/mod.rs b/src/util/mod.rs index 5729d570..03512d3e 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -7,6 +7,7 @@ pub mod env; pub mod ext; pub mod guards; pub mod img; +pub mod redis; pub mod routes; pub mod validate; pub mod webhook; diff --git a/src/util/redis.rs b/src/util/redis.rs new file mode 100644 index 00000000..b5d33219 --- /dev/null +++ b/src/util/redis.rs @@ -0,0 +1,18 @@ +use redis::Cmd; + +pub fn redis_args(cmd: &mut Cmd, args: &[String]) { + for arg in args { + cmd.arg(arg); + } +} + +pub async fn redis_execute( + cmd: &mut Cmd, + redis: &mut deadpool_redis::Connection, +) -> Result +where + T: redis::FromRedisValue, +{ + let res = cmd.query_async::<_, T>(redis).await?; + Ok(res) +} diff --git a/tests/analytics.rs b/tests/analytics.rs index bc3d80d4..c1f7806d 100644 --- a/tests/analytics.rs +++ b/tests/analytics.rs @@ -1,8 +1,11 @@ +use actix_web::test; use chrono::{DateTime, Duration, Utc}; -use common::database::*; use common::environment::TestEnvironment; +use common::permissions::PermissionsTest; +use common::{database::*, permissions::PermissionsTestContext}; use itertools::Itertools; use labrinth::models::ids::base62_impl::parse_base62; +use labrinth::models::teams::ProjectPermissions; use rust_decimal::{prelude::ToPrimitive, Decimal}; mod common; @@ -70,6 +73,7 @@ pub async fn analytics_revenue() { let analytics = api .get_analytics_revenue_deserialized( vec![&alpha_project_id], + false, None, None, None, @@ -99,6 +103,7 @@ pub async fn analytics_revenue() { let analytics = api .get_analytics_revenue_deserialized( vec![&alpha_project_id], + false, Some(Utc::now() - Duration::days(801)), None, None, @@ -133,3 +138,92 @@ fn to_f64_rounded_up(d: Decimal) -> f64 { fn to_f64_vec_rounded_up(d: Vec) -> Vec { d.into_iter().map(to_f64_rounded_up).collect_vec() } + +#[actix_rt::test] +pub async fn permissions_analytics_revenue() { + let test_env = TestEnvironment::build(None).await; + + let alpha_project_id = test_env + .dummy + .as_ref() + .unwrap() + .project_alpha + .project_id + .clone(); + let alpha_version_id = test_env + .dummy + .as_ref() + .unwrap() + .project_alpha + .version_id + .clone(); + let alpha_team_id = test_env + .dummy + .as_ref() + .unwrap() + .project_alpha + .team_id + .clone(); + + let view_analytics = ProjectPermissions::VIEW_ANALYTICS; + + // first, do check with a project + let req_gen = |ctx: &PermissionsTestContext| { + let projects_string = serde_json::to_string(&vec![ctx.project_id]).unwrap(); + let projects_string = urlencoding::encode(&projects_string); + test::TestRequest::get().uri(&format!( + "/v3/analytics/revenue?project_ids={projects_string}&resolution_minutes=5", + )) + }; + + PermissionsTest::new(&test_env) + .with_failure_codes(vec![200, 401]) + .with_200_json_checks( + // On failure, should have 0 projects returned + |value: &serde_json::Value| { + let value = value.as_object().unwrap(); + assert_eq!(value.len(), 0); + }, + // On success, should have 1 project returned + |value: &serde_json::Value| { + let value = value.as_object().unwrap(); + assert_eq!(value.len(), 1); + }, + ) + .simple_project_permissions_test(view_analytics, req_gen) + .await + .unwrap(); + + // Now with a version + // Need to use alpha + let req_gen = |_: &PermissionsTestContext| { + let versions_string = serde_json::to_string(&vec![alpha_version_id.clone()]).unwrap(); + let versions_string = urlencoding::encode(&versions_string); + test::TestRequest::get().uri(&format!( + "/v3/analytics/revenue?version_ids={versions_string}&resolution_minutes=5", + )) + }; + + PermissionsTest::new(&test_env) + .with_failure_codes(vec![200, 401]) + .with_existing_project(&alpha_project_id, &alpha_team_id) + .with_user(FRIEND_USER_ID, FRIEND_USER_PAT, true) + .with_200_json_checks( + // On failure, should have 0 versions returned + |value: &serde_json::Value| { + let value = value.as_object().unwrap(); + assert_eq!(value.len(), 0); + }, + // On success, should have 1 versions returned + |value: &serde_json::Value| { + let value = value.as_object().unwrap(); + assert_eq!(value.len(), 1); + }, + ) + .simple_project_permissions_test(view_analytics, req_gen) + .await + .unwrap(); + + // Cleanup test db + test_env.cleanup().await; +} diff --git a/tests/common/api_v3/project.rs b/tests/common/api_v3/project.rs index b4365d9c..2ffc1d7a 100644 --- a/tests/common/api_v3/project.rs +++ b/tests/common/api_v3/project.rs @@ -204,13 +204,21 @@ impl ApiV3 { pub async fn get_analytics_revenue( &self, id_or_slugs: Vec<&str>, + ids_are_version_ids: bool, start_date: Option>, end_date: Option>, resolution_minutes: Option, pat: &str, ) -> ServiceResponse { - let projects_string = serde_json::to_string(&id_or_slugs).unwrap(); - let projects_string = urlencoding::encode(&projects_string); + let pv_string = if ids_are_version_ids { + let version_string: String = serde_json::to_string(&id_or_slugs).unwrap(); + let version_string = urlencoding::encode(&version_string); + format!("version_ids={}", version_string) + } else { + let projects_string: String = serde_json::to_string(&id_or_slugs).unwrap(); + let projects_string = urlencoding::encode(&projects_string); + format!("project_ids={}", projects_string) + }; let mut extra_args = String::new(); if let Some(start_date) = start_date { @@ -230,9 +238,7 @@ impl ApiV3 { } let req = test::TestRequest::get() - .uri(&format!( - "/v3/analytics/revenue?{projects_string}{extra_args}", - )) + .uri(&format!("/v3/analytics/revenue?{pv_string}{extra_args}",)) .append_header(("Authorization", pat)) .to_request(); @@ -242,13 +248,21 @@ impl ApiV3 { pub async fn get_analytics_revenue_deserialized( &self, id_or_slugs: Vec<&str>, + ids_are_version_ids: bool, start_date: Option>, end_date: Option>, resolution_minutes: Option, pat: &str, ) -> HashMap> { let resp = self - .get_analytics_revenue(id_or_slugs, start_date, end_date, resolution_minutes, pat) + .get_analytics_revenue( + id_or_slugs, + ids_are_version_ids, + start_date, + end_date, + resolution_minutes, + pat, + ) .await; assert_eq!(resp.status(), 200); test::read_body_json(resp).await diff --git a/tests/common/permissions.rs b/tests/common/permissions.rs index 1bb2e20a..c960b72f 100644 --- a/tests/common/permissions.rs +++ b/tests/common/permissions.rs @@ -1,4 +1,5 @@ #![allow(dead_code)] +use actix_http::StatusCode; use actix_web::test::{self, TestRequest}; use itertools::Itertools; use labrinth::models::teams::{OrganizationPermissions, ProjectPermissions}; @@ -45,6 +46,12 @@ pub struct PermissionsTest<'a> { // The codes that is allow to be returned if the scope is not present. // (for instance, we might expect a 401, but not a 400) allowed_failure_codes: Vec, + + // Closures that check the JSON body of the response for failure and success cases. + // These are used to perform more complex tests than just checking the status code. + // (eg: checking that the response contains the correct data) + failure_json_check: Option>, + success_json_check: Option>, } pub struct PermissionsTestContext<'a> { @@ -71,6 +78,8 @@ impl<'a> PermissionsTest<'a> { project_team_id: None, organization_team_id: None, allowed_failure_codes: vec![401, 404], + failure_json_check: None, + success_json_check: None, } } @@ -87,6 +96,20 @@ impl<'a> PermissionsTest<'a> { self } + // Set check closures for the JSON body of the response + // These are used to perform more complex tests than just checking the status code. + // If not set, no checks will be performed (and the status code is the only check). + // This is useful if, say, both expected status codes are 200. + pub fn with_200_json_checks( + mut self, + failure_json_check: impl Fn(&serde_json::Value) + Send + 'static, + success_json_check: impl Fn(&serde_json::Value) + Send + 'static, + ) -> Self { + self.failure_json_check = Some(Box::new(failure_json_check)); + self.success_json_check = Some(Box::new(success_json_check)); + self + } + // Set the user ID to use // (eg: a moderator, or friend) // remove_user: Whether or not the user ID should be removed from the project/organization team after the test @@ -181,6 +204,11 @@ impl<'a> PermissionsTest<'a> { resp.status().as_u16() )); } + if resp.status() == StatusCode::OK { + if let Some(failure_json_check) = &self.failure_json_check { + failure_json_check(&test::read_body_json(resp).await); + } + } // Failure test- logged in on a non-team user let request = req_gen(&PermissionsTestContext { @@ -202,6 +230,11 @@ impl<'a> PermissionsTest<'a> { resp.status().as_u16() )); } + if resp.status() == StatusCode::OK { + if let Some(failure_json_check) = &self.failure_json_check { + failure_json_check(&test::read_body_json(resp).await); + } + } // Failure test- logged in with EVERY non-relevant permission let request = req_gen(&PermissionsTestContext { @@ -223,6 +256,11 @@ impl<'a> PermissionsTest<'a> { resp.status().as_u16() )); } + if resp.status() == StatusCode::OK { + if let Some(failure_json_check) = &self.failure_json_check { + failure_json_check(&test::read_body_json(resp).await); + } + } // Patch user's permissions to success permissions modify_user_team_permissions( @@ -250,6 +288,11 @@ impl<'a> PermissionsTest<'a> { resp.status().as_u16() )); } + if resp.status() == StatusCode::OK { + if let Some(success_json_check) = &self.success_json_check { + success_json_check(&test::read_body_json(resp).await); + } + } // If the remove_user flag is set, remove the user from the project // Relevant for existing projects/users diff --git a/tests/project.rs b/tests/project.rs index 138acac5..40c9cd30 100644 --- a/tests/project.rs +++ b/tests/project.rs @@ -40,20 +40,21 @@ async fn test_get_project() { assert_eq!(versions[0], json!(alpha_version_id)); // Confirm that the request was cached + let mut redis_pool = test_env.db.redis_pool.connect().await.unwrap(); assert_eq!( - test_env - .db - .redis_pool - .get::(PROJECTS_SLUGS_NAMESPACE, alpha_project_slug) + redis_pool + .get(PROJECTS_SLUGS_NAMESPACE, alpha_project_slug) .await - .unwrap(), + .unwrap() + .and_then(|x| x.parse::().ok()), Some(parse_base62(alpha_project_id).unwrap() as i64) ); - let cached_project = test_env - .db - .redis_pool - .get::(PROJECTS_NAMESPACE, parse_base62(alpha_project_id).unwrap()) + let cached_project = redis_pool + .get( + PROJECTS_NAMESPACE, + &parse_base62(alpha_project_id).unwrap().to_string(), + ) .await .unwrap() .unwrap(); @@ -249,22 +250,21 @@ async fn test_add_remove_project() { assert_eq!(resp.status(), 204); // Confirm that the project is gone from the cache + let mut redis_pool = test_env.db.redis_pool.connect().await.unwrap(); assert_eq!( - test_env - .db - .redis_pool - .get::(PROJECTS_SLUGS_NAMESPACE, "demo") + redis_pool + .get(PROJECTS_SLUGS_NAMESPACE, "demo") .await - .unwrap(), + .unwrap() + .and_then(|x| x.parse::().ok()), None ); assert_eq!( - test_env - .db - .redis_pool - .get::(PROJECTS_SLUGS_NAMESPACE, id) + redis_pool + .get(PROJECTS_SLUGS_NAMESPACE, &id) .await - .unwrap(), + .unwrap() + .and_then(|x| x.parse::().ok()), None ); diff --git a/tests/search.rs b/tests/search.rs index 36483547..120aedd6 100644 --- a/tests/search.rs +++ b/tests/search.rs @@ -20,7 +20,7 @@ mod common; #[actix_rt::test] async fn search_projects() { // Test setup and dummy data - let test_env = TestEnvironment::build(Some(8)).await; + let test_env = TestEnvironment::build(Some(10)).await; let api = &test_env.v3; let test_name = test_env.db.database_name.clone(); diff --git a/tests/v2/project.rs b/tests/v2/project.rs index 609b8481..7e56d3a6 100644 --- a/tests/v2/project.rs +++ b/tests/v2/project.rs @@ -221,22 +221,21 @@ async fn test_add_remove_project() { assert_eq!(resp.status(), 204); // Confirm that the project is gone from the cache + let mut redis_conn = test_env.db.redis_pool.connect().await.unwrap(); assert_eq!( - test_env - .db - .redis_pool - .get::(PROJECTS_SLUGS_NAMESPACE, "demo") + redis_conn + .get(PROJECTS_SLUGS_NAMESPACE, "demo") .await - .unwrap(), + .unwrap() + .map(|x| x.parse::().unwrap()), None ); assert_eq!( - test_env - .db - .redis_pool - .get::(PROJECTS_SLUGS_NAMESPACE, id) + redis_conn + .get(PROJECTS_SLUGS_NAMESPACE, &id) .await - .unwrap(), + .unwrap() + .map(|x| x.parse::().unwrap()), None ); diff --git a/tests/v2/search.rs b/tests/v2/search.rs index fbe39ca6..1e3ccbdf 100644 --- a/tests/v2/search.rs +++ b/tests/v2/search.rs @@ -17,7 +17,7 @@ async fn search_projects() { // It should drastically simplify this function // Test setup and dummy data - let test_env = TestEnvironment::build(Some(8)).await; + let test_env = TestEnvironment::build(Some(10)).await; let api = &test_env.v2; let test_name = test_env.db.database_name.clone(); diff --git a/tests/version.rs b/tests/version.rs index 665cbf58..57c5c710 100644 --- a/tests/version.rs +++ b/tests/version.rs @@ -33,10 +33,12 @@ async fn test_get_version() { assert_eq!(&version.project_id.to_string(), alpha_project_id); assert_eq!(&version.id.to_string(), alpha_version_id); - let cached_project = test_env - .db - .redis_pool - .get::(VERSIONS_NAMESPACE, parse_base62(alpha_version_id).unwrap()) + let mut redis_conn = test_env.db.redis_pool.connect().await.unwrap(); + let cached_project = redis_conn + .get( + VERSIONS_NAMESPACE, + &parse_base62(alpha_version_id).unwrap().to_string(), + ) .await .unwrap() .unwrap();