From 9d5912afe8d5aacdcdcacb3d6b25d3233b406dee Mon Sep 17 00:00:00 2001 From: Daniel Berglund Date: Fri, 10 Nov 2023 17:29:02 +0100 Subject: [PATCH] Add functions to check if makeradmin and stripe things are equal. Add functions to update stripe objects --- api/src/shop/stripe_util.py | 116 ++++++++++++++++++++------ api/src/shop/test/stripe_util_test.py | 2 + 2 files changed, 92 insertions(+), 26 deletions(-) diff --git a/api/src/shop/stripe_util.py b/api/src/shop/stripe_util.py index d0459f925..d016b9a74 100644 --- a/api/src/shop/stripe_util.py +++ b/api/src/shop/stripe_util.py @@ -3,7 +3,7 @@ import random import time from logging import getLogger -from typing import Callable, Optional, TypeVar +from typing import Any, Callable, Dict, TypeVar from service.error import InternalServerError from shop.models import Product @@ -24,11 +24,24 @@ } +def makeradmin_to_stripe_recurring(makeradmin_product: Product, price_type: PriceType) -> Dict[str, Any]: + if price_type == PriceType.RECURRING or price_type == PriceType.BINDING_PERIOD: + if makeradmin_product.unit in makeradmin_unit_to_stripe_unit: + interval = makeradmin_unit_to_stripe_unit[makeradmin_product.unit] + else: + raise RuntimeError( + f"Unexpected unit {makeradmin_product.unit} in makeradmin product {makeradmin_product.id}" + ) + interval_count = makeradmin_product.smallest_multiple if price_type == PriceType.BINDING_PERIOD else 1 + return {"interval": interval, "interval_count": interval_count} + else: + return {} + + def get_stripe_product_id(makeradmin_product: Product, livemode: bool = False) -> str: return f"prod_{makeradmin_product.id}" if livemode else f"test_{makeradmin_product.id}" -# TODO cache? def get_stripe_product(makeradmin_product: Product, livemode: bool = False) -> stripe.Product | None: id = get_stripe_product_id(makeradmin_product, livemode) try: @@ -42,13 +55,12 @@ def get_stripe_product(makeradmin_product: Product, livemode: bool = False) -> s return product -# TODO cache? def get_stripe_prices( - stripe_product: stripe.Product, filterInactive: bool = True, livemode: bool = False + stripe_product: stripe.Product, filter_inactive: bool = True, livemode: bool = False ) -> list[stripe.Price] | None: try: prices = list(retry(lambda: stripe.Price.list(product=stripe_product.stripe_id))) - if filterInactive: + if filter_inactive: prices = [p for p in prices if p.active] except stripe.error.InvalidRequestError as e: logger.warning(f"failed to retrive prices from stripe for stripe product with id {stripe_product.id}, {e}") @@ -58,6 +70,46 @@ def get_stripe_prices( return prices +def eq_makeradmin_stripe_product(makeradmin_product: Product, stripe_product: stripe.Product) -> bool: + return stripe_product.name == makeradmin_product.name + + +# TODO need different levels depending on what can be changed and not +def eq_makeradmin_stripe_price(makeradmin_product: Product, stripe_price: stripe.Price, price_type: PriceType) -> bool: + recurring = makeradmin_to_stripe_recurring(makeradmin_product, price_type) + different = [] + + if len(recurring) != 0: + if stripe_price.recurring is None: + return False + different.append(stripe_price.recurring.get("interval") == recurring["interval"]) + different.append(stripe_price.recurring.get("interval_count") == recurring["interval_count"]) + different.append(stripe_price.unit_amount == stripe_amount_from_makeradmin_product(makeradmin_product, recurring)) + different.append(stripe_price.currency == CURRENCY) + different.append(stripe_price.metadata.get("price_type") == price_type.value) + return not any(different) + + +def eq_makeradmin_stripe( + makeradmin_product: Product, stripe_product: stripe.Product, stripe_prices: Dict[PriceType, stripe.Price] +) -> Dict[str, bool]: + differences = {"product": eq_makeradmin_stripe_product(makeradmin_product, stripe_product)} + interval_count = makeradmin_product.smallest_multiple + expected_number_of_prices = 1 if interval_count == 1 else 2 + if len(stripe_prices) != expected_number_of_prices: + raise RuntimeError( + f"Number of stripe prices does not match with makeradmin product multiplier. Got {len(stripe_prices)}, expected {expected_number_of_prices}" + ) + differences[PriceType.RECURRING.value] = eq_makeradmin_stripe_price( + makeradmin_product, stripe_prices[PriceType.RECURRING], PriceType.RECURRING + ) + if expected_number_of_prices == 2: + differences[PriceType.BINDING_PERIOD.value] = eq_makeradmin_stripe_price( + makeradmin_product, stripe_prices[PriceType.BINDING_PERIOD], PriceType.BINDING_PERIOD + ) + return differences + + def _create_stripe_product(makeradmin_product: Product, livemode: bool = False) -> stripe.Product | None: id = get_stripe_product_id(makeradmin_product, livemode) stripe_product = retry( @@ -72,33 +124,22 @@ def _create_stripe_product(makeradmin_product: Product, livemode: bool = False) def _create_stripe_price( - makeradmin_product: Product, stripe_product: stripe.Product, priceType: PriceType, livemode: bool = False + makeradmin_product: Product, stripe_product: stripe.Product, price_type: PriceType, livemode: bool = False ) -> stripe.Price | None: nickname = ( - f"prod_{makeradmin_product.id}_{priceType.value}" + f"prod_{makeradmin_product.id}_{price_type.value}" if livemode - else f"test_{makeradmin_product.id}_{priceType.value}" + else f"test_{makeradmin_product.id}_{price_type.value}" ) - if priceType == PriceType.RECURRING or priceType == PriceType.BINDING_PERIOD: - if makeradmin_product.unit in makeradmin_unit_to_stripe_unit: - interval = makeradmin_unit_to_stripe_unit[makeradmin_product.unit] - else: - raise RuntimeError( - f"Unexpected unit {makeradmin_product.unit} in makeradmin product {makeradmin_product.id}" - ) - interval_count = makeradmin_product.smallest_multiple if priceType == PriceType.BINDING_PERIOD else 1 - recurring = {"interval": interval, "interval_count": interval_count} - else: - interval_count = 1 - recurring = {} + recurring = makeradmin_to_stripe_recurring(makeradmin_product, price_type) stripe_price = retry( lambda: stripe.Price.create( nickname=nickname, product=stripe_product.id, - unit_amount=convert_to_stripe_amount(makeradmin_product.price * interval_count), + unit_amount=stripe_amount_from_makeradmin_product(makeradmin_product, recurring), currency=CURRENCY, recurring=recurring, - metadata={"price_type": priceType.value}, + metadata={"price_type": price_type.value}, ) ) assert stripe_price.livemode == livemode @@ -128,9 +169,9 @@ def _find_price_type(stripe_prices: list[stripe.Price] | None, price_type: Price def find_or_create_stripe_prices_for_product( makeradmin_product: Product, stripe_product: stripe.Product, livemode: bool = False -) -> list[stripe.Price] | None: +) -> Dict[PriceType, stripe.Price] | None: interval_count = makeradmin_product.smallest_multiple - stripe_prices = get_stripe_prices(stripe_product, filterInactive=False, livemode=livemode) + stripe_prices = get_stripe_prices(stripe_product, filter_inactive=False, livemode=livemode) stripe_price_recurring = _find_price_type(stripe_prices, PriceType.RECURRING) @@ -140,7 +181,7 @@ def find_or_create_stripe_prices_for_product( stripe_price_recurring = retry(lambda: stripe.Price.modify(stripe_price_recurring.id, active=True)) if interval_count == 1: - return [stripe_price_recurring] + return {PriceType.RECURRING: stripe_price_recurring} stripe_price_binding = _find_price_type(stripe_prices, PriceType.BINDING_PERIOD) if stripe_price_binding is None: @@ -150,7 +191,30 @@ def find_or_create_stripe_prices_for_product( elif not stripe_price_binding.active: stripe_price_binding = retry(lambda: stripe.Price.modify(stripe_price_binding.id, active=True)) - return [stripe_price_recurring, stripe_price_binding] + return {PriceType.RECURRING: stripe_price_recurring, PriceType.BINDING_PERIOD: stripe_price_binding} + + +def update_stripe_product(makeradmin_product: Product, stripe_product: stripe.Product) -> stripe.Product: + return retry(lambda: stripe.Product.modify(stripe_product.id, name=makeradmin_product.name)) + + +def update_stripe_price(makeradmin_product: Product, stripe_price: stripe.Price, price_type: PriceType) -> stripe.Price: + # TOOD reccuring cant be changed, hmm??? + recurring = makeradmin_to_stripe_recurring(makeradmin_product, price_type) + unit_amount = stripe_amount_from_makeradmin_product(makeradmin_product, recurring) + currency_options_param = {"unit_amount": unit_amount} # TODO fix this + currency_options = {CURRENCY: currency_options_param} + return retry( + lambda: stripe.Price.modify( + stripe_price.id, + currency_options=currency_options, + metadata={"price_type": price_type.value}, + ) + ) + + +def stripe_amount_from_makeradmin_product(makeradmin_product: Product, recurring: Dict[str, Any]) -> int: + return convert_to_stripe_amount(makeradmin_product.price * recurring["interval_count"]) def convert_to_stripe_amount(amount: Decimal) -> int: diff --git a/api/src/shop/test/stripe_util_test.py b/api/src/shop/test/stripe_util_test.py index 178778834..ce72bb04f 100644 --- a/api/src/shop/test/stripe_util_test.py +++ b/api/src/shop/test/stripe_util_test.py @@ -14,6 +14,8 @@ logger = getLogger("makeradmin") +# The stripe ids product have to be unique in each test to prevent race conditions + class Test(ShopTestMixin, FlaskTestBase): models = [membership.models, messages.models, shop.models, core.models]