Skip to content

Commit

Permalink
Bunch of fixes, some review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
BerglundDaniel committed Nov 8, 2023
1 parent 4091235 commit c767119
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 28 deletions.
49 changes: 32 additions & 17 deletions api/src/shop/stripe_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,21 @@

logger = getLogger("makeradmin")

CURRENCY = "sek" # TODO fix this
makeradmin_unit_to_stripe_unit = {
"mån": "month",
"month": "month",
"år": "year",
"year": "year",
}


def get_stripe_product_id(makeradmin_product: Product, livemode: bool = False) -> str:
return f'prod_{makeradmin_product["id"]}' if livemode else f'test_{makeradmin_product["id"]}'


# TODO cache?
def get_stripe_product(makeradmin_product: Product, livemode: bool = False) -> stripe.Product | None:
id = f'prod_{makeradmin_product["id"]}' if livemode else f'test_{makeradmin_product["id"]}'
id = get_stripe_product_id(makeradmin_product, livemode)
try:
product = retry(lambda: stripe.Product.retrieve(id=id))
except stripe.error.InvalidRequestError as e:
Expand Down Expand Up @@ -50,7 +59,7 @@ def get_stripe_prices(


def _create_stripe_product(makeradmin_product: Product, livemode: bool = False) -> stripe.Product | None:
id = f'prod_{makeradmin_product["id"]}' if livemode else f'test_{makeradmin_product["id"]}'
id = get_stripe_product_id(makeradmin_product, livemode)
stripe_product = retry(
lambda: stripe.Product.create(
id=id,
Expand All @@ -65,11 +74,13 @@ def _create_stripe_product(makeradmin_product: Product, livemode: bool = False)
def _create_stripe_price(
makeradmin_product: Product, stripe_product: stripe.Product, priceType: PriceType, livemode: bool = False
) -> stripe.Price | None:
id = f'prod_{makeradmin_product["id"]}_{priceType}' if livemode else f'test_{makeradmin_product["id"]}_{priceType}'
if "mån" in makeradmin_product["unit"] or "month" in makeradmin_product["unit"]:
interval = "month"
elif "år" in makeradmin_product["unit"] or "year" in makeradmin_product["unit"]:
interval = "year"
nickname = (
f'prod_{makeradmin_product["id"]}_{priceType.value}'
if livemode
else f'test_{makeradmin_product["id"]}_{priceType.value}'
)
if makeradmin_product["unit"] in makeradmin_unit_to_stripe_unit:
interval = makeradmin_unit_to_stripe_unit[makeradmin_product["unit"]]
else:
raise RuntimeError(
f'Unexpected unit {makeradmin_product["unit"]} in makeradmin product {makeradmin_product["id"]}'
Expand All @@ -78,12 +89,12 @@ def _create_stripe_price(
recurring = {"interval": interval, "interval_count": interval_count}
stripe_price = retry(
lambda: stripe.Price.create(
nickname=id,
nickname=nickname,
product=stripe_product.id,
unit_amount=convert_to_stripe_amount(makeradmin_product["price"] * interval_count),
currency=CURRENCY,
recurring=recurring,
metadata={"price_type": priceType},
metadata={"price_type": priceType.value},
)
)
assert stripe_price.livemode == livemode
Expand All @@ -93,11 +104,10 @@ def _create_stripe_price(
def find_or_create_stripe_product(makeradmin_product: Product, livemode: bool = False) -> stripe.Product | None:
stripe_product = get_stripe_product(makeradmin_product, livemode)
if stripe_product is None:
return _create_stripe_product(makeradmin_product, livemode)
else:
if not stripe_product.active:
retry(lambda: stripe.Product.modify(stripe_product.id, active=True))
return stripe_product
stripe_product = _create_stripe_product(makeradmin_product, livemode)
elif not stripe_product.active:
stripe_product = retry(lambda: stripe.Product.modify(stripe_product.id, active=True))
return stripe_product


def _find_price_type(stripe_prices: list[stripe.Price] | None, price_type: PriceType):
Expand All @@ -116,19 +126,24 @@ def find_or_create_stripe_prices_for_product(
stripe_prices = get_stripe_prices(stripe_product, filterInactive=False, livemode=livemode)

stripe_price_recurring = _find_price_type(stripe_prices, PriceType.RECURRING)

if stripe_price_recurring is None:
stripe_price_recurring = _create_stripe_price(makeradmin_product, stripe_product, PriceType.RECURRING, livemode)
elif not stripe_price_recurring.active:
stripe_price_recurring = retry(lambda: stripe.Price.modify(stripe_price_recurring.id, active=True))

if interval_count == 1:
return [stripe_price_recurring]

stripe_price_binding = _find_price_type(stripe_prices, PriceType.RECURRING)
stripe_price_binding = _find_price_type(stripe_prices, PriceType.BINDING_PERIOD)
if stripe_price_binding is None:
stripe_price_binding = _create_stripe_price(
makeradmin_product, stripe_product, PriceType.BINDING_PERIOD, livemode
)
elif not stripe_price_binding.active:
stripe_price_binding = retry(lambda: stripe.Price.modify(stripe_price_binding.id, active=True))

return [stripe_price_recurring, stripe_price_recurring]
return [stripe_price_recurring, stripe_price_binding]


def convert_to_stripe_amount(amount: Decimal) -> int:
Expand Down
26 changes: 15 additions & 11 deletions api/src/shop/test/stripe_util_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from datetime import datetime, timezone
from logging import getLogger
from unittest import skipIf

Expand All @@ -8,11 +7,7 @@
import core.models
from shop import stripe_util
from shop.models import Product, ProductAction
from shop.stripe_constants import (
STRIPE_CURRENTY_BASE,
PriceType,
CURRENCY,
)
from shop import stripe_constants
import stripe
from test_aid.test_base import FlaskTestBase, ShopTestMixin

Expand Down Expand Up @@ -51,8 +46,10 @@ def setUp(self) -> None:

@staticmethod
def assertPrice(stripe_price: stripe.Price, makeradmin_product: Product):
assert stripe_price.unit_amount == stripe_util.convert_to_stripe_amount(makeradmin_product["price"])
assert stripe_price.currency == CURRENCY
assert stripe_price.unit_amount == stripe_util.convert_to_stripe_amount(
makeradmin_product["price"] * makeradmin_product["smallest_multiple"]
)
assert stripe_price.currency == stripe_constants.CURRENCY
assert stripe_price.active
assert stripe_price.type == "recurring"
reccuring = stripe_price.recurring
Expand All @@ -62,8 +59,8 @@ def assertPrice(stripe_price: stripe.Price, makeradmin_product: Product):
assert reccuring["interval"] == "year"
assert reccuring["interval_count"] == makeradmin_product["smallest_multiple"]
assert (
stripe_price.metadata["price_type"] == PriceType.BINDING_PERIOD.value
or stripe_price.metadata["price_type"] == PriceType.RECURRING.value
stripe_price.metadata["price_type"] == stripe_constants.PriceType.BINDING_PERIOD.value
or stripe_price.metadata["price_type"] == stripe_constants.PriceType.RECURRING.value
)

def tearDown(self) -> None:
Expand All @@ -76,7 +73,7 @@ def tearDown(self) -> None:
if stripe_prices is None:
continue
for price in stripe_prices:
stripe_util.retry(lambda: stripe.Product.modify(price.id, active=False))
stripe_util.retry(lambda: stripe.Price.modify(price.id, active=False))
stripe_util.retry(lambda: stripe.Product.modify(stripe_product.id, active=False))
return super().tearDown()

Expand Down Expand Up @@ -112,6 +109,7 @@ def test_create_product_with_price_yearly_simple(self) -> None:
)
assert stripe_test_prices
assert len(stripe_test_prices) == 1
self.assertPrice(stripe_test_prices[0], makeradmin_test_product)

def test_create_product_with_price_monthly_with_binding_period(self) -> None:
makeradmin_test_product = self.products[2]
Expand All @@ -124,3 +122,9 @@ def test_create_product_with_price_monthly_with_binding_period(self) -> None:
)
assert stripe_test_prices
assert len(stripe_test_prices) == 2
self.assertPrice(stripe_test_prices[0], makeradmin_test_product)
self.assertPrice(stripe_test_prices[1], makeradmin_test_product)

test_price_types = [p.metadata["price_type"] for p in stripe_test_prices]
assert stripe_constants.PriceType.BINDING_PERIOD.value in test_price_types
assert stripe_constants.PriceType.RECURRING.value in test_price_types

0 comments on commit c767119

Please sign in to comment.