diff --git a/server/polar/customer/service.py b/server/polar/customer/service.py index edd22ddc2e..5dc38d7d51 100644 --- a/server/polar/customer/service.py +++ b/server/polar/customer/service.py @@ -4,6 +4,7 @@ from sqlalchemy import Select, UnaryExpression, asc, desc, func, or_, select from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.sql.base import ExecutableOption from stripe import Customer as StripeCustomer from polar.auth.models import AuthSubject, is_organization, is_user @@ -64,10 +65,14 @@ async def get_by_id( session: AsyncSession, auth_subject: AuthSubject[User | Organization], id: uuid.UUID, + *, + options: Sequence[ExecutableOption] | None = None, ) -> Customer | None: statement = self._get_readable_customer_statement(auth_subject).where( Customer.id == id ) + if options is not None: + statement = statement.options(*options) result = await session.execute(statement) return result.unique().scalar_one_or_none() diff --git a/server/polar/customer_session/schemas.py b/server/polar/customer_session/schemas.py index 465c59c5ff..7d329db0a8 100644 --- a/server/polar/customer_session/schemas.py +++ b/server/polar/customer_session/schemas.py @@ -23,5 +23,6 @@ class CustomerSession(IDSchema, TimestampedSchema): token: str = Field(validation_alias="raw_token") expires_at: datetime + customer_portal_url: str customer_id: UUID4 customer: Customer diff --git a/server/polar/customer_session/service.py b/server/polar/customer_session/service.py index 0497332f87..e65dbde163 100644 --- a/server/polar/customer_session/service.py +++ b/server/polar/customer_session/service.py @@ -1,4 +1,5 @@ from sqlalchemy import delete, select +from sqlalchemy.orm import joinedload from polar.auth.models import AuthSubject, Organization, User from polar.config import settings @@ -23,7 +24,10 @@ async def create( customer_create: CustomerSessionCreate, ) -> CustomerSession: customer = await customer_service.get_by_id( - session, auth_subject, customer_create.customer_id + session, + auth_subject, + customer_create.customer_id, + options=(joinedload(Customer.organization),), ) if customer is None: raise PolarRequestValidationError( diff --git a/server/polar/models/customer_session.py b/server/polar/models/customer_session.py index 85a423b5de..d3625afb35 100644 --- a/server/polar/models/customer_session.py +++ b/server/polar/models/customer_session.py @@ -7,7 +7,8 @@ from polar.config import settings from polar.kit.db.models.base import RecordModel from polar.kit.utils import utc_now -from polar.models.customer import Customer + +from .customer import Customer def get_expires_at() -> datetime: @@ -37,3 +38,9 @@ def raw_token(self) -> str | None: @raw_token.setter def raw_token(self, value: str) -> None: self._raw_token = value + + @property + def customer_portal_url(self) -> str: + return settings.generate_frontend_url( + f"/{self.customer.organization.slug}/portal?customer_session_token={self.raw_token}" + ) diff --git a/server/tests/customer_session/test_endpoints.py b/server/tests/customer_session/test_endpoints.py index 8d038a3545..cbe2f0b720 100644 --- a/server/tests/customer_session/test_endpoints.py +++ b/server/tests/customer_session/test_endpoints.py @@ -44,3 +44,4 @@ async def test_valid( assert json["token"].startswith(CUSTOMER_SESSION_TOKEN_PREFIX) assert json["customer_id"] == str(customer.id) + assert json["customer_portal_url"].endswith(json["token"])