diff --git a/src/database/models/product_item.rs b/src/database/models/product_item.rs index 2ff0b0c7..3ffc9f15 100644 --- a/src/database/models/product_item.rs +++ b/src/database/models/product_item.rs @@ -1,10 +1,14 @@ -use crate::database::models::{DatabaseError, ProductId, ProductPriceId}; +use crate::database::models::{product_item, DatabaseError, ProductId, ProductPriceId}; +use crate::database::redis::RedisPool; use crate::models::billing::{PriceInterval, ProductMetadata}; use dashmap::DashMap; use itertools::Itertools; +use serde::{Deserialize, Serialize}; use std::convert::TryFrom; use std::convert::TryInto; +const PRODUCTS_NAMESPACE: &str = "products"; + pub struct ProductItem { pub id: ProductId, pub metadata: ProductMetadata, @@ -82,6 +86,67 @@ impl ProductItem { } } +#[derive(Deserialize, Serialize)] +pub struct QueryProduct { + pub id: ProductId, + pub metadata: ProductMetadata, + pub unitary: bool, + pub prices: Vec, +} + +impl QueryProduct { + pub async fn list<'a, E>(exec: E, redis: &RedisPool) -> Result, DatabaseError> + where + E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy, + { + let mut redis = redis.connect().await?; + + let res: Option> = redis + .get_deserialized_from_json(PRODUCTS_NAMESPACE, "all") + .await?; + + if let Some(res) = res { + return Ok(res); + } + + let all_products = product_item::ProductItem::get_all(exec).await?; + let prices = product_item::ProductPriceItem::get_all_products_prices( + &all_products.iter().map(|x| x.id).collect::>(), + exec, + ) + .await?; + + let products = all_products + .into_iter() + .map(|x| QueryProduct { + id: x.id.into(), + metadata: x.metadata, + prices: prices + .remove(&x.id) + .map(|x| x.1) + .unwrap_or_default() + .into_iter() + .map(|x| ProductPriceItem { + id: x.id, + product_id: x.product_id, + interval: x.interval, + price: x.price, + currency_code: x.currency_code, + }) + .collect(), + unitary: x.unitary, + }) + .collect::>(); + + redis + .set_serialized_to_json(PRODUCTS_NAMESPACE, "all", &products, None) + .await?; + + Ok(products) + } +} + +#[derive(Deserialize, Serialize)] pub struct ProductPriceItem { pub id: ProductPriceId, pub product_id: ProductId, diff --git a/src/database/models/user_subscription_item.rs b/src/database/models/user_subscription_item.rs index 1c5844c0..f5bf6090 100644 --- a/src/database/models/user_subscription_item.rs +++ b/src/database/models/user_subscription_item.rs @@ -86,6 +86,17 @@ impl UserSubscriptionItem { Ok(results.into_iter().map(|r| r.into()).collect()) } + pub async fn get_all_expired( + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let now = Utc::now(); + let results = select_user_subscriptions_with_predicate!("WHERE expires < $1", now) + .fetch_all(exec) + .await?; + + Ok(results.into_iter().map(|r| r.into()).collect()) + } + pub async fn upsert( &self, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, diff --git a/src/lib.rs b/src/lib.rs index ca6151c2..6de882ad 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,14 +76,16 @@ pub fn app_setup( let automated_moderation_queue = web::Data::new(AutomatedModerationQueue::default()); - let automated_moderation_queue_ref = automated_moderation_queue.clone(); - let pool_ref = pool.clone(); - let redis_pool_ref = redis_pool.clone(); - actix_rt::spawn(async move { - automated_moderation_queue_ref - .task(pool_ref, redis_pool_ref) - .await; - }); + { + let automated_moderation_queue_ref = automated_moderation_queue.clone(); + let pool_ref = pool.clone(); + let redis_pool_ref = redis_pool.clone(); + actix_rt::spawn(async move { + automated_moderation_queue_ref + .task(pool_ref, redis_pool_ref) + .await; + }); + } let mut scheduler = scheduler::Scheduler::new(); @@ -258,21 +260,14 @@ pub fn app_setup( }); } + let stripe_client = stripe::Client::new(dotenvy::var("STRIPE_API_KEY").unwrap()); { let pool_ref = pool.clone(); let redis_ref = redis_pool.clone(); - scheduler.run(std::time::Duration::from_secs(60 * 30), move || { - let pool_ref = pool_ref.clone(); - let redis_ref = redis_ref.clone(); + let stripe_client_ref = stripe_client.clone(); - async move { - info!("Indexing billing queue"); - let result = crate::routes::internal::billing::task(&pool_ref, &redis_ref).await; - if let Err(e) = result { - warn!("Indexing billing queue failed: {:?}", e); - } - info!("Done indexing billing queue"); - } + actix_rt::spawn(async move { + routes::internal::billing::task(stripe_client_ref, pool_ref, redis_ref).await; }); } @@ -283,8 +278,6 @@ pub fn app_setup( let payouts_queue = web::Data::new(PayoutsQueue::new()); let active_sockets = web::Data::new(RwLock::new(ActiveSockets::default())); - let stripe_client = stripe::Client::new(dotenvy::var("STRIPE_API_KEY").unwrap()); - LabrinthConfig { pool, redis_pool, diff --git a/src/routes/internal/billing.rs b/src/routes/internal/billing.rs index 02aa3160..c98589ce 100644 --- a/src/routes/internal/billing.rs +++ b/src/routes/internal/billing.rs @@ -14,9 +14,10 @@ use crate::queue::session::AuthQueue; use crate::routes::ApiError; use actix_web::{delete, get, patch, post, web, HttpRequest, HttpResponse}; use chrono::{Duration, Utc}; +use log::{info, warn}; use serde_with::serde_derive::Deserialize; use sqlx::PgPool; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::str::FromStr; use stripe::{ CreateCustomer, CreatePaymentIntent, CreatePaymentIntentAutomaticPaymentMethods, @@ -45,25 +46,20 @@ pub fn config(cfg: &mut web::ServiceConfig) { ); } -// TODO: cache this #[get("products")] -pub async fn products(pool: web::Data) -> Result { - let products = product_item::ProductItem::get_all(&**pool).await?; - let prices = product_item::ProductPriceItem::get_all_products_prices( - &products.iter().map(|x| x.id).collect::>(), - &**pool, - ) - .await?; +pub async fn products( + pool: web::Data, + redis: web::Data, +) -> Result { + let products = product_item::QueryProduct::list(&**pool, &redis).await?; let products = products .into_iter() .map(|x| Product { id: x.id.into(), metadata: x.metadata, - prices: prices - .remove(&x.id) - .map(|x| x.1) - .unwrap_or_default() + prices: x + .prices .into_iter() .map(|x| ProductPrice { id: x.id.into(), @@ -173,7 +169,15 @@ pub async fn user_customer( .await? .1; - let customer_id = get_or_create_customer(&user, &*stripe_client, &pool, &redis).await?; + let customer_id = get_or_create_customer( + user.id, + user.stripe_customer_id.as_deref(), + user.email.as_deref(), + &*stripe_client, + &pool, + &redis, + ) + .await?; let customer = stripe::Customer::retrieve(&stripe_client, &customer_id, &[]).await?; Ok(HttpResponse::Ok().json(customer)) @@ -236,7 +240,15 @@ pub async fn add_payment_method_flow( .await? .1; - let customer = get_or_create_customer(&user, &*stripe_client, &pool, &redis).await?; + let customer = get_or_create_customer( + user.id, + user.stripe_customer_id.as_deref(), + user.email.as_deref(), + &*stripe_client, + &pool, + &redis, + ) + .await?; let intent = SetupIntent::create( &stripe_client, @@ -290,7 +302,15 @@ pub async fn edit_payment_method( return Err(ApiError::NotFound); }; - let customer = get_or_create_customer(&user, &*stripe_client, &pool, &redis).await?; + let customer = get_or_create_customer( + user.id, + user.stripe_customer_id.as_deref(), + user.email.as_deref(), + &*stripe_client, + &pool, + &redis, + ) + .await?; let payment_method = stripe::PaymentMethod::retrieve(&stripe_client, &payment_method_id, &[]).await?; @@ -347,7 +367,15 @@ pub async fn remove_payment_method( return Err(ApiError::NotFound); }; - let customer = get_or_create_customer(&user, &*stripe_client, &pool, &redis).await?; + let customer = get_or_create_customer( + user.id, + user.stripe_customer_id.as_deref(), + user.email.as_deref(), + &*stripe_client, + &pool, + &redis, + ) + .await?; let payment_method = stripe::PaymentMethod::retrieve(&stripe_client, &payment_method_id, &[]).await?; @@ -467,7 +495,15 @@ pub async fn initiate_payment( ApiError::InvalidInput("Specified product could not be found!".to_string()) })?; - let customer_id = get_or_create_customer(&user, &*stripe_client, &pool, &redis).await?; + let customer_id = get_or_create_customer( + user.id, + user.stripe_customer_id.as_deref(), + user.email.as_deref(), + &*stripe_client, + &pool, + &redis, + ) + .await?; let mut intent = CreatePaymentIntent::new( price.price as i64, @@ -698,6 +734,12 @@ pub async fn stripe_webhook( } } + crate::database::models::user_item::User::clear_caches( + &[(metadata.user.id.into(), None)], + &redis, + ) + .await?; + transaction.commit().await?; } } @@ -825,25 +867,24 @@ pub async fn stripe_webhook( } async fn get_or_create_customer( - user: &crate::models::users::User, + user_id: crate::models::ids::UserId, + stripe_customer_id: Option<&str>, + user_email: Option<&str>, client: &stripe::Client, pool: &PgPool, redis: &RedisPool, ) -> Result { - if let Some(customer_id) = user - .stripe_customer_id - .as_ref() - .and_then(|x| stripe::CustomerId::from_str(x).ok()) + if let Some(customer_id) = stripe_customer_id.and_then(|x| stripe::CustomerId::from_str(x).ok()) { Ok(customer_id) } else { let mut metadata = HashMap::new(); - metadata.insert("modrinth_user_id".to_string(), to_base62(user.id.0)); + metadata.insert("modrinth_user_id".to_string(), to_base62(user_id.0)); let customer = stripe::Customer::create( client, CreateCustomer { - email: user.email.as_deref(), + email: user_email, metadata: Some(metadata), ..Default::default() }, @@ -857,47 +898,190 @@ async fn get_or_create_customer( WHERE id = $2 ", customer.id.as_str(), - user.id.0 as i64 + user_id.0 as i64 ) .execute(&*pool) .await?; - crate::database::models::user_item::User::clear_caches(&[(user.id.into(), None)], &redis) + crate::database::models::user_item::User::clear_caches(&[(user_id.into(), None)], &redis) .await?; Ok(customer.id) } } -pub async fn task(pool: &PgPool, redis: &RedisPool) -> Result<(), ApiError> { - Ok(()) - - // TODO: scheduler for charging recurring payments - +pub async fn task(stripe_client: stripe::Client, pool: PgPool, redis: RedisPool) { // if subscription is cancelled and expired, unprovision - // if subscription is payment failed and expired, unprovision - // if subscription is payment failed and last attempt is > 4 days ago, try again to charge + // if subscription is payment failed and last attempt is > 2 days ago, try again to charge and unprovision // if subscription is active and expired, attempt to charge and set as processing + loop { + info!("Indexing billing queue"); + let res = async { + let expired = + user_subscription_item::UserSubscriptionItem::get_all_expired(&pool).await?; + let subscription_prices = product_item::ProductPriceItem::get_many( + &expired + .iter() + .map(|x| x.price_id) + .collect::>() + .into_iter() + .collect::>(), + &pool, + ) + .await?; + let subscription_products = product_item::ProductItem::get_many( + &subscription_prices + .iter() + .map(|x| x.product_id) + .collect::>() + .into_iter() + .collect::>(), + &pool, + ) + .await?; + let users = crate::database::models::User::get_many_ids( + &expired + .iter() + .map(|x| x.user_id) + .collect::>() + .into_iter() + .collect::>(), + &pool, + &redis, + ) + .await?; + + let mut transaction = pool.begin().await?; + let mut clear_cache_users = Vec::new(); + + for mut subscription in expired { + let user = users.iter().find(|x| x.id == subscription.user_id); + + if let Some(user) = user { + let product_price = subscription_prices + .iter() + .find(|x| x.id == subscription.price_id); + + if let Some(product_price) = product_price { + let product = subscription_products + .iter() + .find(|x| x.id == product_price.product_id); + + if let Some(product) = product { + let cancelled = subscription.status == SubscriptionStatus::Cancelled; + let payment_failed = subscription + .last_charge + .map(|y| { + (subscription.status == SubscriptionStatus::PaymentFailed + && Utc::now() - y > Duration::days(2)) + }) + .unwrap_or(false); + let active = subscription.status == SubscriptionStatus::Active; + + // Unprovision subscription + if cancelled || payment_failed { + match product.metadata { + ProductMetadata::Midas => { + let badges = user.badges - Badges::MIDAS; + + sqlx::query!( + " + UPDATE users + SET badges = $1 + WHERE (id = $2) + ", + badges.bits() as i64, + user.id as crate::database::models::ids::UserId, + ) + .execute(&mut *transaction) + .await?; + } + } + + clear_cache_users.push(user.id); + } + + if payment_failed || active { + let customer_id = get_or_create_customer( + user.id.into(), + user.stripe_customer_id.as_deref(), + user.email.as_deref(), + &stripe_client, + &pool, + &redis, + ) + .await?; + + let customer = + stripe::Customer::retrieve(&stripe_client, &customer_id, &[]) + .await?; + + let mut intent = CreatePaymentIntent::new( + product_price.price as i64, + Currency::from_str(&product_price.currency_code) + .unwrap_or(Currency::USD), + ); + + let mut metadata = HashMap::new(); + metadata.insert( + "modrinth_user_id".to_string(), + to_base62(user.id.0 as u64), + ); + metadata.insert( + "modrinth_price_id".to_string(), + to_base62(product_price.id.0 as u64), + ); + metadata.insert( + "modrinth_subscription_id".to_string(), + to_base62(subscription.id.0 as u64), + ); + + intent.metadata = Some(metadata); + intent.customer = Some(customer_id); + + if let Some(payment_method) = customer + .invoice_settings + .and_then(|x| x.default_payment_method.map(|x| x.id())) + { + intent.payment_method = Some(payment_method); + intent.confirm = Some(false); + intent.off_session = + Some(PaymentIntentOffSession::Exists(true)); + + stripe::PaymentIntent::create(&stripe_client, intent).await?; + + subscription.status = SubscriptionStatus::PaymentProcessing; + } else { + subscription.status = SubscriptionStatus::PaymentFailed; + } - // get all users - // get all user customers - - // Un provision subscription - // match metadata.product.metadata { - // ProductMetadata::Midas => { - // let badges = metadata.user.badges - Badges::MIDAS; - // - // sqlx::query!( - // " - // UPDATE users - // SET badges = $1 - // WHERE (id = $2) - // ", - // badges.bits() as i64, - // metadata.user.id as crate::database::models::ids::UserId, - // ) - // .execute(&mut *transaction) - // .await?; - // } - // } + subscription.upsert(&mut transaction).await?; + } + } + } + } + } + + crate::database::models::User::clear_caches( + &clear_cache_users + .into_iter() + .map(|x| (x, None)) + .collect::>(), + &redis, + ) + .await?; + transaction.commit().await?; + + Ok::<(), ApiError>(()) + } + .await; + + if let Err(e) = res { + warn!("Error indexing billing queue: {:?}", e); + } + + info!("Done indexing billing queue"); + + tokio::time::sleep(std::time::Duration::from_secs(60 * 5)).await; + } }