Skip to content

Commit

Permalink
Add functions to check if makeradmin and stripe things are equal. Add…
Browse files Browse the repository at this point in the history
… functions to update stripe objects
  • Loading branch information
BerglundDaniel committed Nov 10, 2023
1 parent 60fed22 commit 9d5912a
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 26 deletions.
116 changes: 90 additions & 26 deletions api/src/shop/stripe_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
import time
from logging import getLogger
from typing import Callable, Optional, TypeVar
from typing import Any, Callable, Dict, TypeVar

from service.error import InternalServerError
from shop.models import Product
Expand All @@ -24,11 +24,24 @@
}


def makeradmin_to_stripe_recurring(makeradmin_product: Product, price_type: PriceType) -> Dict[str, Any]:
if price_type == PriceType.RECURRING or price_type == PriceType.BINDING_PERIOD:
if makeradmin_product.unit in makeradmin_unit_to_stripe_unit:
interval = makeradmin_unit_to_stripe_unit[makeradmin_product.unit]
else:
raise RuntimeError(
f"Unexpected unit {makeradmin_product.unit} in makeradmin product {makeradmin_product.id}"
)
interval_count = makeradmin_product.smallest_multiple if price_type == PriceType.BINDING_PERIOD else 1
return {"interval": interval, "interval_count": interval_count}
else:
return {}


def get_stripe_product_id(makeradmin_product: Product, livemode: bool = False) -> str:
return f"prod_{makeradmin_product.id}" if livemode else f"test_{makeradmin_product.id}"


# TODO cache?
def get_stripe_product(makeradmin_product: Product, livemode: bool = False) -> stripe.Product | None:
id = get_stripe_product_id(makeradmin_product, livemode)
try:
Expand All @@ -42,13 +55,12 @@ def get_stripe_product(makeradmin_product: Product, livemode: bool = False) -> s
return product


# TODO cache?
def get_stripe_prices(
stripe_product: stripe.Product, filterInactive: bool = True, livemode: bool = False
stripe_product: stripe.Product, filter_inactive: bool = True, livemode: bool = False
) -> list[stripe.Price] | None:
try:
prices = list(retry(lambda: stripe.Price.list(product=stripe_product.stripe_id)))
if filterInactive:
if filter_inactive:
prices = [p for p in prices if p.active]
except stripe.error.InvalidRequestError as e:
logger.warning(f"failed to retrive prices from stripe for stripe product with id {stripe_product.id}, {e}")
Expand All @@ -58,6 +70,46 @@ def get_stripe_prices(
return prices


def eq_makeradmin_stripe_product(makeradmin_product: Product, stripe_product: stripe.Product) -> bool:
return stripe_product.name == makeradmin_product.name


# TODO need different levels depending on what can be changed and not
def eq_makeradmin_stripe_price(makeradmin_product: Product, stripe_price: stripe.Price, price_type: PriceType) -> bool:
recurring = makeradmin_to_stripe_recurring(makeradmin_product, price_type)
different = []

if len(recurring) != 0:
if stripe_price.recurring is None:
return False
different.append(stripe_price.recurring.get("interval") == recurring["interval"])
different.append(stripe_price.recurring.get("interval_count") == recurring["interval_count"])
different.append(stripe_price.unit_amount == stripe_amount_from_makeradmin_product(makeradmin_product, recurring))
different.append(stripe_price.currency == CURRENCY)
different.append(stripe_price.metadata.get("price_type") == price_type.value)
return not any(different)


def eq_makeradmin_stripe(
makeradmin_product: Product, stripe_product: stripe.Product, stripe_prices: Dict[PriceType, stripe.Price]
) -> Dict[str, bool]:
differences = {"product": eq_makeradmin_stripe_product(makeradmin_product, stripe_product)}
interval_count = makeradmin_product.smallest_multiple
expected_number_of_prices = 1 if interval_count == 1 else 2
if len(stripe_prices) != expected_number_of_prices:
raise RuntimeError(
f"Number of stripe prices does not match with makeradmin product multiplier. Got {len(stripe_prices)}, expected {expected_number_of_prices}"
)
differences[PriceType.RECURRING.value] = eq_makeradmin_stripe_price(
makeradmin_product, stripe_prices[PriceType.RECURRING], PriceType.RECURRING
)
if expected_number_of_prices == 2:
differences[PriceType.BINDING_PERIOD.value] = eq_makeradmin_stripe_price(
makeradmin_product, stripe_prices[PriceType.BINDING_PERIOD], PriceType.BINDING_PERIOD
)
return differences


def _create_stripe_product(makeradmin_product: Product, livemode: bool = False) -> stripe.Product | None:
id = get_stripe_product_id(makeradmin_product, livemode)
stripe_product = retry(
Expand All @@ -72,33 +124,22 @@ def _create_stripe_product(makeradmin_product: Product, livemode: bool = False)


def _create_stripe_price(
makeradmin_product: Product, stripe_product: stripe.Product, priceType: PriceType, livemode: bool = False
makeradmin_product: Product, stripe_product: stripe.Product, price_type: PriceType, livemode: bool = False
) -> stripe.Price | None:
nickname = (
f"prod_{makeradmin_product.id}_{priceType.value}"
f"prod_{makeradmin_product.id}_{price_type.value}"
if livemode
else f"test_{makeradmin_product.id}_{priceType.value}"
else f"test_{makeradmin_product.id}_{price_type.value}"
)
if priceType == PriceType.RECURRING or priceType == PriceType.BINDING_PERIOD:
if makeradmin_product.unit in makeradmin_unit_to_stripe_unit:
interval = makeradmin_unit_to_stripe_unit[makeradmin_product.unit]
else:
raise RuntimeError(
f"Unexpected unit {makeradmin_product.unit} in makeradmin product {makeradmin_product.id}"
)
interval_count = makeradmin_product.smallest_multiple if priceType == PriceType.BINDING_PERIOD else 1
recurring = {"interval": interval, "interval_count": interval_count}
else:
interval_count = 1
recurring = {}
recurring = makeradmin_to_stripe_recurring(makeradmin_product, price_type)
stripe_price = retry(
lambda: stripe.Price.create(
nickname=nickname,
product=stripe_product.id,
unit_amount=convert_to_stripe_amount(makeradmin_product.price * interval_count),
unit_amount=stripe_amount_from_makeradmin_product(makeradmin_product, recurring),
currency=CURRENCY,
recurring=recurring,
metadata={"price_type": priceType.value},
metadata={"price_type": price_type.value},
)
)
assert stripe_price.livemode == livemode
Expand Down Expand Up @@ -128,9 +169,9 @@ def _find_price_type(stripe_prices: list[stripe.Price] | None, price_type: Price

def find_or_create_stripe_prices_for_product(
makeradmin_product: Product, stripe_product: stripe.Product, livemode: bool = False
) -> list[stripe.Price] | None:
) -> Dict[PriceType, stripe.Price] | None:
interval_count = makeradmin_product.smallest_multiple
stripe_prices = get_stripe_prices(stripe_product, filterInactive=False, livemode=livemode)
stripe_prices = get_stripe_prices(stripe_product, filter_inactive=False, livemode=livemode)

stripe_price_recurring = _find_price_type(stripe_prices, PriceType.RECURRING)

Expand All @@ -140,7 +181,7 @@ def find_or_create_stripe_prices_for_product(
stripe_price_recurring = retry(lambda: stripe.Price.modify(stripe_price_recurring.id, active=True))

if interval_count == 1:
return [stripe_price_recurring]
return {PriceType.RECURRING: stripe_price_recurring}

stripe_price_binding = _find_price_type(stripe_prices, PriceType.BINDING_PERIOD)
if stripe_price_binding is None:
Expand All @@ -150,7 +191,30 @@ def find_or_create_stripe_prices_for_product(
elif not stripe_price_binding.active:
stripe_price_binding = retry(lambda: stripe.Price.modify(stripe_price_binding.id, active=True))

return [stripe_price_recurring, stripe_price_binding]
return {PriceType.RECURRING: stripe_price_recurring, PriceType.BINDING_PERIOD: stripe_price_binding}


def update_stripe_product(makeradmin_product: Product, stripe_product: stripe.Product) -> stripe.Product:
return retry(lambda: stripe.Product.modify(stripe_product.id, name=makeradmin_product.name))


def update_stripe_price(makeradmin_product: Product, stripe_price: stripe.Price, price_type: PriceType) -> stripe.Price:
# TOOD reccuring cant be changed, hmm???
recurring = makeradmin_to_stripe_recurring(makeradmin_product, price_type)
unit_amount = stripe_amount_from_makeradmin_product(makeradmin_product, recurring)
currency_options_param = {"unit_amount": unit_amount} # TODO fix this
currency_options = {CURRENCY: currency_options_param}
return retry(
lambda: stripe.Price.modify(
stripe_price.id,
currency_options=currency_options,
metadata={"price_type": price_type.value},
)
)


def stripe_amount_from_makeradmin_product(makeradmin_product: Product, recurring: Dict[str, Any]) -> int:
return convert_to_stripe_amount(makeradmin_product.price * recurring["interval_count"])


def convert_to_stripe_amount(amount: Decimal) -> int:
Expand Down
2 changes: 2 additions & 0 deletions api/src/shop/test/stripe_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

logger = getLogger("makeradmin")

# The stripe ids product have to be unique in each test to prevent race conditions


class Test(ShopTestMixin, FlaskTestBase):
models = [membership.models, messages.models, shop.models, core.models]
Expand Down

0 comments on commit 9d5912a

Please sign in to comment.