diff --git a/server/polar/order/schemas.py b/server/polar/order/schemas.py index 83bb063252..2261caa119 100644 --- a/server/polar/order/schemas.py +++ b/server/polar/order/schemas.py @@ -5,7 +5,10 @@ from pydantic import UUID4, AliasChoices, AliasPath, Field, computed_field from pydantic.json_schema import SkipJsonSchema -from polar.custom_field.data import CustomFieldDataOutputMixin +from polar.custom_field.data import ( + CustomFieldDataInputMixin, + CustomFieldDataOutputMixin, +) from polar.customer.schemas.customer import CustomerBase from polar.discount.schemas import DiscountMinimal from polar.exceptions import ResourceNotFound @@ -177,18 +180,20 @@ class Order(CustomFieldDataOutputMixin, MetadataOutputMixin, OrderBase): items: list[OrderItemSchema] = Field(description="Line items composing the order.") -class OrderUpdateBase(Schema): +class OrderUpdateBase(CustomFieldDataInputMixin, Schema): billing_name: str | None = Field( + default=None, description=( "The name of the customer that should appear on the invoice. " "Can't be updated after the invoice is generated." - ) + ), ) billing_address: Address | None = Field( + default=None, description=( "The address of the customer that should appear on the invoice. " "Can't be updated after the invoice is generated." - ) + ), ) diff --git a/server/polar/order/service.py b/server/polar/order/service.py index dad462f713..045ae1154d 100644 --- a/server/polar/order/service.py +++ b/server/polar/order/service.py @@ -7,7 +7,7 @@ import stripe as stripe_lib import structlog from sqlalchemy import select -from sqlalchemy.orm import contains_eager, joinedload +from sqlalchemy.orm import contains_eager, joinedload, selectinload from polar.account.repository import AccountRepository from polar.auth.models import AuthSubject @@ -15,6 +15,7 @@ from polar.checkout.eventstream import CheckoutEvent, publish_checkout_event from polar.checkout.repository import CheckoutRepository from polar.config import settings +from polar.custom_field.data import validate_custom_field_data from polar.customer.repository import CustomerRepository from polar.customer_portal.schemas.order import ( CustomerOrderPaymentConfirmation, @@ -402,8 +403,9 @@ async def get( .options( *repository.get_eager_options( customer_load=contains_eager(Order.customer), - product_load=joinedload(Order.product).joinedload( - Product.organization + product_load=joinedload(Order.product).options( + joinedload(Product.organization), + selectinload(Product.attached_custom_fields), ), ) ) @@ -436,10 +438,19 @@ async def update( if errors: raise PolarRequestValidationError(errors) + update_dict = order_update.model_dump(exclude_unset=True) + + if "custom_field_data" in update_dict: + # Validate custom field data against the product's attached custom fields + custom_field_data = validate_custom_field_data( + order.product.attached_custom_fields, + order_update.custom_field_data, + validate_required=False, # Allow merchants to update even if required fields are missing + ) + update_dict["custom_field_data"] = custom_field_data + repository = OrderRepository.from_session(session) - order = await repository.update( - order, update_dict=order_update.model_dump(exclude_unset=True) - ) + order = await repository.update(order, update_dict=update_dict) await self.send_webhook(session, order, WebhookEventType.order_updated) diff --git a/server/tests/order/test_endpoints.py b/server/tests/order/test_endpoints.py index 35ce0fd84c..06336c54ae 100644 --- a/server/tests/order/test_endpoints.py +++ b/server/tests/order/test_endpoints.py @@ -5,10 +5,17 @@ from httpx import AsyncClient from polar.auth.scope import Scope -from polar.models import Customer, Order, Product, UserOrganization +from polar.enums import SubscriptionRecurringInterval +from polar.models import Customer, Order, Organization, Product, UserOrganization +from polar.models.custom_field import CustomFieldType from tests.fixtures.auth import AuthSubjectFixture from tests.fixtures.database import SaveFixture -from tests.fixtures.random_objects import create_order +from tests.fixtures.random_objects import ( + create_custom_field, + create_customer, + create_order, + create_product, +) @pytest_asyncio.fixture @@ -166,6 +173,187 @@ async def test_custom_field( assert json["custom_field_data"] == {"test": None} +@pytest.mark.asyncio +class TestUpdateOrder: + async def test_anonymous(self, client: AsyncClient, orders: list[Order]) -> None: + response = await client.patch( + f"/v1/orders/{orders[0].id}", + json={"custom_field_data": {"test": "updated"}}, + ) + + assert response.status_code == 401 + + @pytest.mark.auth + async def test_not_existing(self, client: AsyncClient) -> None: + response = await client.patch( + f"/v1/orders/{uuid.uuid4()}", + json={"custom_field_data": {"test": "updated"}}, + ) + + assert response.status_code == 404 + + @pytest.mark.auth + async def test_user_not_organization_member( + self, client: AsyncClient, orders: list[Order] + ) -> None: + response = await client.patch( + f"/v1/orders/{orders[0].id}", + json={"custom_field_data": {"test": "updated"}}, + ) + + assert response.status_code == 404 + + @pytest.mark.auth( + AuthSubjectFixture(scopes={Scope.web_write}), + AuthSubjectFixture(scopes={Scope.orders_write}), + ) + async def test_user_valid( + self, + save_fixture: SaveFixture, + client: AsyncClient, + user_organization: UserOrganization, + organization: Organization, + ) -> None: + # Create a product with custom fields + text_field = await create_custom_field( + save_fixture, + type=CustomFieldType.text, + slug="text", + organization=organization, + ) + select_field = await create_custom_field( + save_fixture, + type=CustomFieldType.select, + slug="select", + organization=organization, + properties={ + "options": [{"value": "a", "label": "A"}, {"value": "b", "label": "B"}], + }, + ) + product = await create_product( + save_fixture, + organization=organization, + recurring_interval=SubscriptionRecurringInterval.month, + attached_custom_fields=[(text_field, False), (select_field, True)], + ) + + # Create an order with custom field data + order = await create_order( + save_fixture, + product=product, + customer=await create_customer(save_fixture, organization=organization), + custom_field_data={"text": "original", "select": "a"}, + ) + + response = await client.patch( + f"/v1/orders/{order.id}", + json={"custom_field_data": {"text": "updated", "select": "b"}}, + ) + + assert response.status_code == 200 + + json = response.json() + assert json["custom_field_data"] == {"text": "updated", "select": "b"} + + @pytest.mark.auth( + AuthSubjectFixture(subject="organization", scopes={Scope.orders_write}), + ) + async def test_organization( + self, save_fixture: SaveFixture, client: AsyncClient, organization: Organization + ) -> None: + # Create a product with custom fields + text_field = await create_custom_field( + save_fixture, + type=CustomFieldType.text, + slug="text", + organization=organization, + ) + select_field = await create_custom_field( + save_fixture, + type=CustomFieldType.select, + slug="select", + organization=organization, + properties={ + "options": [{"value": "a", "label": "A"}, {"value": "b", "label": "B"}], + }, + ) + product = await create_product( + save_fixture, + organization=organization, + recurring_interval=SubscriptionRecurringInterval.month, + attached_custom_fields=[(text_field, False), (select_field, True)], + ) + + # Create an order with custom field data + order = await create_order( + save_fixture, + product=product, + customer=await create_customer(save_fixture, organization=organization), + custom_field_data={"text": "original", "select": "a"}, + ) + + response = await client.patch( + f"/v1/orders/{order.id}", + json={"custom_field_data": {"text": "updated", "select": "b"}}, + ) + + assert response.status_code == 200 + + json = response.json() + assert json["custom_field_data"] == {"text": "updated", "select": "b"} + + @pytest.mark.auth( + AuthSubjectFixture(scopes={Scope.web_write}), + ) + async def test_update_existing_custom_field_data( + self, + save_fixture: SaveFixture, + client: AsyncClient, + user_organization: UserOrganization, + organization: Organization, + ) -> None: + # Create a product with custom fields + text_field = await create_custom_field( + save_fixture, + type=CustomFieldType.text, + slug="text", + organization=organization, + ) + select_field = await create_custom_field( + save_fixture, + type=CustomFieldType.select, + slug="select", + organization=organization, + properties={ + "options": [{"value": "a", "label": "A"}, {"value": "b", "label": "B"}], + }, + ) + product = await create_product( + save_fixture, + organization=organization, + recurring_interval=SubscriptionRecurringInterval.month, + attached_custom_fields=[(text_field, False), (select_field, True)], + ) + + # Create an order with custom field data + order = await create_order( + save_fixture, + product=product, + customer=await create_customer(save_fixture, organization=organization), + custom_field_data={"text": "original", "select": "a"}, + ) + + response = await client.patch( + f"/v1/orders/{order.id}", + json={"custom_field_data": {"text": "updated", "select": "b"}}, + ) + + assert response.status_code == 200 + + json = response.json() + assert json["custom_field_data"] == {"text": "updated", "select": "b"} + + @pytest.mark.asyncio class TesGetOrdersStatistics: async def test_anonymous(self, client: AsyncClient) -> None: diff --git a/server/tests/order/test_service.py b/server/tests/order/test_service.py index 430f582507..d75611623b 100644 --- a/server/tests/order/test_service.py +++ b/server/tests/order/test_service.py @@ -15,6 +15,7 @@ from polar.auth.models import AuthSubject from polar.checkout.eventstream import CheckoutEvent from polar.enums import PaymentProcessor, SubscriptionRecurringInterval +from polar.exceptions import PolarRequestValidationError from polar.held_balance.service import held_balance as held_balance_service from polar.integrations.stripe.schemas import ProductType from polar.integrations.stripe.service import StripeService @@ -37,6 +38,7 @@ ) from polar.models.billing_entry import BillingEntryDirection, BillingEntryType from polar.models.checkout import CheckoutStatus +from polar.models.custom_field import CustomFieldType from polar.models.discount import DiscountDuration, DiscountFixed, DiscountType from polar.models.order import OrderBillingReason, OrderStatus from polar.models.organization import Organization @@ -44,6 +46,7 @@ from polar.models.product import ProductBillingType from polar.models.subscription import SubscriptionStatus from polar.models.transaction import PlatformFeeType, TransactionType +from polar.order.schemas import OrderUpdate from polar.order.service import ( CardPaymentFailed, MissingCheckoutCustomer, @@ -75,6 +78,7 @@ create_billing_entry, create_canceled_subscription, create_checkout, + create_custom_field, create_customer, create_discount, create_event, @@ -3135,3 +3139,107 @@ async def test_process_retry_payment_already_in_progress( await order_service.process_retry_payment( session, order, "ctoken_test", PaymentProcessor.stripe ) + + +@pytest_asyncio.fixture +async def product_custom_fields( + save_fixture: SaveFixture, organization: Organization +) -> Product: + text_field = await create_custom_field( + save_fixture, type=CustomFieldType.text, slug="text", organization=organization + ) + select_field = await create_custom_field( + save_fixture, + type=CustomFieldType.select, + slug="select", + organization=organization, + properties={ + "options": [{"value": "a", "label": "A"}, {"value": "b", "label": "B"}], + }, + ) + return await create_product( + save_fixture, + organization=organization, + recurring_interval=SubscriptionRecurringInterval.month, + attached_custom_fields=[(text_field, False), (select_field, True)], + ) + + +@pytest.mark.asyncio +class TestUpdateOrder: + async def test_update_custom_field_data( + self, + session: AsyncSession, + save_fixture: SaveFixture, + product_custom_fields: Product, + customer: Customer, + ) -> None: + """Test updating custom field data for an order.""" + order = await create_order( + save_fixture, + product=product_custom_fields, + customer=customer, + custom_field_data={"text": "original", "select": "a"}, + ) + + updated_order = await order_service.update( + session, + order, + OrderUpdate(custom_field_data={"text": "updated", "select": "b"}), + ) + + assert updated_order.custom_field_data == {"text": "updated", "select": "b"} + + async def test_update_billing_name( + self, + session: AsyncSession, + save_fixture: SaveFixture, + product: Product, + customer: Customer, + ) -> None: + """Test updating billing name for an order.""" + order = await create_order( + save_fixture, + product=product, + customer=customer, + billing_name="Original Name", + ) + + updated_order = await order_service.update( + session, + order, + OrderUpdate(billing_name="Updated Name"), + ) + + assert updated_order.billing_name == "Updated Name" + + async def test_update_with_invoice_generated( + self, + session: AsyncSession, + save_fixture: SaveFixture, + product: Product, + customer: Customer, + ) -> None: + """Test that billing fields cannot be updated after invoice is generated.""" + order = await create_order( + save_fixture, + product=product, + customer=customer, + billing_name="Original Name", + ) + + # Set invoice_path after creation + order.invoice_path = "/path/to/invoice.pdf" # Invoice already generated + await save_fixture(order) + + with pytest.raises(PolarRequestValidationError) as e: + await order_service.update( + session, + order, + OrderUpdate(billing_name="Updated Name"), + ) + + errors = e.value.errors() + assert len(errors) == 1 + assert errors[0]["loc"] == ("body", "billing_name") + assert "cannot be updated after the invoice is generated" in errors[0]["msg"]