Skip to content

Commit

Permalink
server/checkout: prevent redeeming recurring discount on one-time pri…
Browse files Browse the repository at this point in the history
…cing

Fix #4847
  • Loading branch information
frankie567 committed Jan 17, 2025
1 parent 8594e96 commit f8a0811
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 34 deletions.
1 change: 1 addition & 0 deletions server/.vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"python.analysis.typeCheckingMode": "basic",
"python.analysis.autoImportCompletions": true,
"python.analysis.showOnlyDirectDependenciesInAutoImport": true,
"python.envFile": "${workspaceFolder}/.env.testing",
"python.terminal.activateEnvironment": true,
"python.terminal.activateEnvInCurrentTerminal": true,
Expand Down
92 changes: 58 additions & 34 deletions server/polar/checkout/service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
import uuid
from collections.abc import Sequence
from typing import Any
Expand Down Expand Up @@ -62,7 +63,12 @@
UserOrganization,
)
from polar.models.checkout import CheckoutStatus
from polar.models.product_price import ProductPriceAmountType, ProductPriceFree
from polar.models.discount import DiscountDuration
from polar.models.product_price import (
ProductPriceAmountType,
ProductPriceFree,
ProductPriceType,
)
from polar.models.webhook_endpoint import WebhookEventType
from polar.organization.service import organization as organization_service
from polar.postgres import AsyncSession
Expand Down Expand Up @@ -286,7 +292,7 @@ async def create(
discount: Discount | None = None
if checkout_create.discount_id is not None:
discount = await self._get_validated_discount(
session, checkout_create.discount_id, product, price
session, product, price, discount_id=checkout_create.discount_id
)

customer_tax_id: TaxID | None = None
Expand Down Expand Up @@ -619,7 +625,7 @@ async def checkout_link_create(
if checkout_link.discount_id is not None:
try:
discount = await self._get_validated_discount(
session, checkout_link.discount_id, product, price
session, product, price, discount_id=checkout_link.discount_id
)
# If the discount is not valid, just ignore it
except PolarRequestValidationError:
Expand Down Expand Up @@ -1213,12 +1219,34 @@ async def _get_validated_product(

return product, product.prices[0]

@typing.overload
async def _get_validated_discount(
self,
session: AsyncSession,
product: Product,
price: ProductPrice,
*,
discount_id: uuid.UUID,
) -> Discount: ...

@typing.overload
async def _get_validated_discount(
self,
session: AsyncSession,
product: Product,
price: ProductPrice,
*,
discount_code: str,
) -> Discount: ...

async def _get_validated_discount(
self,
session: AsyncSession,
product: Product,
price: ProductPrice,
*,
discount_id: uuid.UUID | None = None,
discount_code: str | None = None,
) -> Discount:
if price.amount_type not in {
ProductPriceAmountType.fixed,
Expand All @@ -1235,9 +1263,14 @@ async def _get_validated_discount(
]
)

discount = await discount_service.get_by_id_and_product(
session, discount_id, product
)
if discount_id is not None:
discount = await discount_service.get_by_id_and_product(
session, discount_id, product
)
elif discount_code is not None:
discount = await discount_service.get_by_code_and_product(
session, discount_code, product
)

if discount is None:
raise PolarRequestValidationError(
Expand All @@ -1251,6 +1284,21 @@ async def _get_validated_discount(
]
)

if (
price.type == ProductPriceType.one_time
and discount.duration == DiscountDuration.repeating
):
raise PolarRequestValidationError(
[
{
"type": "value_error",
"loc": ("body", "discount_id"),
"msg": "Discount is not applicable to this product.",
"input": discount_id,
}
]
)

return discount

async def _get_validated_subscription(
Expand Down Expand Up @@ -1403,9 +1451,9 @@ async def _update_checkout(
if checkout_update.discount_id is not None:
checkout.discount = await self._get_validated_discount(
session,
checkout_update.discount_id,
checkout.product,
checkout.product_price,
discount_id=checkout_update.discount_id,
)
# User explicitly removed the discount
elif "discount_id" in checkout_update.model_fields_set:
Expand All @@ -1415,36 +1463,12 @@ async def _update_checkout(
and checkout.allow_discount_codes
):
if checkout_update.discount_code is not None:
if not checkout.is_discount_applicable:
raise PolarRequestValidationError(
[
{
"type": "value_error",
"loc": ("body", "discount_code"),
"msg": "Discounts are only applicable to fixed prices.",
"input": checkout_update.discount_code,
}
]
)
discount = await discount_service.get_by_code_and_product(
discount = await self._get_validated_discount(
session,
checkout_update.discount_code,
checkout.product,
checkout.product_price,
discount_code=checkout_update.discount_code,
)
if discount is None:
raise PolarRequestValidationError(
[
{
"type": "value_error",
"loc": ("body", "discount_code"),
"msg": (
"This discount code does not exist "
"or is no longer available."
),
"input": checkout_update.discount_code,
}
]
)
checkout.discount = discount
# User explicitly removed the discount
elif "discount_code" in checkout_update.model_fields_set:
Expand Down
26 changes: 26 additions & 0 deletions server/tests/checkout/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)
from polar.models.checkout import CheckoutStatus
from polar.models.custom_field import CustomFieldType
from polar.models.discount import DiscountDuration, DiscountType
from polar.models.product_price import (
ProductPriceCustom,
ProductPriceFixed,
Expand All @@ -67,6 +68,7 @@
create_checkout_link,
create_custom_field,
create_customer,
create_discount,
create_product,
create_product_price_fixed,
create_subscription,
Expand Down Expand Up @@ -1509,6 +1511,30 @@ async def test_invalid_discount_code_not_applicable(
CheckoutUpdatePublic(discount_code=discount_fixed_once.code),
)

async def test_invalid_recurring_discount_on_one_time_price(
self,
save_fixture: SaveFixture,
session: AsyncSession,
checkout_one_time_fixed: Checkout,
organization: Organization,
) -> None:
recurring_discount = await create_discount(
save_fixture,
type=DiscountType.fixed,
code="RECURRING",
amount=1000,
currency="usd",
duration=DiscountDuration.repeating,
duration_in_months=12,
organization=organization,
)
with pytest.raises(PolarRequestValidationError):
await checkout_service.update(
session,
checkout_one_time_fixed,
CheckoutUpdatePublic(discount_code=recurring_discount.code),
)

async def test_valid_price_fixed_change(
self,
save_fixture: SaveFixture,
Expand Down

0 comments on commit f8a0811

Please sign in to comment.