Skip to content

Commit

Permalink
server/event: validate customer_id exists on organization
Browse files Browse the repository at this point in the history
  • Loading branch information
frankie567 committed Jan 29, 2025
1 parent c7753d3 commit 78a66ee
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 5 deletions.
48 changes: 45 additions & 3 deletions server/polar/event/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@

from sqlalchemy import UnaryExpression, asc, desc, or_, select

from polar.auth.models import AuthSubject, is_organization
from polar.auth.models import AuthSubject, is_organization, is_user
from polar.exceptions import PolarError, PolarRequestValidationError, ValidationError
from polar.kit.metadata import MetadataQuery
from polar.kit.pagination import PaginationParams
from polar.kit.sorting import Sorting
from polar.models import Event, Organization, User, UserOrganization
from polar.models import Customer, Event, Organization, User, UserOrganization
from polar.models.event import EventSource
from polar.postgres import AsyncSession

from .repository import EventRepository
from .schemas import EventsIngest, EventsIngestResponse
from .schemas import EventCreateCustomer, EventsIngest, EventsIngestResponse
from .sorting import EventSortProperty


Expand Down Expand Up @@ -108,6 +108,9 @@ async def ingest(
validate_organization_id = await self._get_organization_validation_function(
session, auth_subject
)
validate_customer_id = await self._get_customer_validation_function(
session, auth_subject
)

events: list[dict[str, Any]] = []
errors: list[ValidationError] = []
Expand All @@ -116,6 +119,8 @@ async def ingest(
organization_id = validate_organization_id(
index, event_create.organization_id
)
if isinstance(event_create, EventCreateCustomer):
validate_customer_id(index, event_create.customer_id)
except EventIngestValidationError as e:
errors.extend(e.errors)
continue
Expand Down Expand Up @@ -205,5 +210,42 @@ def _validate_organization_id_by_user(

return _validate_organization_id_by_user

async def _get_customer_validation_function(
self, session: AsyncSession, auth_subject: AuthSubject[User | Organization]
) -> Callable[[int, uuid.UUID], uuid.UUID]:
statement = select(Customer.id).where(Customer.deleted_at.is_(None))
if is_user(auth_subject):
statement = statement.where(
Customer.organization_id.in_(
select(UserOrganization.organization_id).where(
UserOrganization.user_id == auth_subject.subject.id,
UserOrganization.deleted_at.is_(None),
)
)
)
else:
statement = statement.where(
Customer.organization_id == auth_subject.subject.id
)
result = await session.execute(statement)
allowed_customers = set(result.scalars().all())

def _validate_customer_id(index: int, customer_id: uuid.UUID) -> uuid.UUID:
if customer_id not in allowed_customers:
raise EventIngestValidationError(
[
{
"type": "customer_id",
"msg": "Customer not found.",
"loc": ("body", "events", index, "customer_id"),
"input": customer_id,
}
]
)

return customer_id

return _validate_customer_id


event = EventService()
53 changes: 51 additions & 2 deletions server/tests/event/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
import pytest
from pydantic import ValidationError

from polar.auth.models import AuthSubject
from polar.auth.models import AuthSubject, is_user
from polar.event.repository import EventRepository
from polar.event.schemas import EventCreateExternalCustomer, EventsIngest
from polar.event.schemas import (
EventCreateCustomer,
EventCreateExternalCustomer,
EventsIngest,
)
from polar.event.service import event as event_service
from polar.exceptions import PolarRequestValidationError
from polar.kit.pagination import PaginationParams
Expand All @@ -16,6 +20,7 @@
from polar.postgres import AsyncSession
from tests.fixtures.auth import AuthSubjectFixture
from tests.fixtures.database import SaveFixture
from tests.fixtures.random_objects import create_customer


async def create_event(
Expand Down Expand Up @@ -287,6 +292,50 @@ async def test_invalid_organization(
errors = e.value.errors()
assert len(errors) == 1

@pytest.mark.auth(
AuthSubjectFixture(subject="user"),
AuthSubjectFixture(subject="organization"),
)
async def test_invalid_customer_id(
self,
save_fixture: SaveFixture,
session: AsyncSession,
auth_subject: AuthSubject[User | Organization],
organization: Organization,
organization_second: Organization,
user_organization: UserOrganization,
customer: Customer,
) -> None:
customer_organization_second = await create_customer(
save_fixture, organization=organization_second
)

ingest = EventsIngest(
events=[
EventCreateCustomer(
name="test",
customer_id=uuid.uuid4(),
organization_id=organization.id if is_user(auth_subject) else None,
),
EventCreateCustomer(
name="test",
customer_id=customer_organization_second.id,
organization_id=organization.id if is_user(auth_subject) else None,
),
EventCreateCustomer(
name="test",
customer_id=customer.id,
organization_id=organization.id if is_user(auth_subject) else None,
),
]
)

with pytest.raises(PolarRequestValidationError) as e:
await event_service.ingest(session, auth_subject, ingest)

errors = e.value.errors()
assert len(errors) == 2

@pytest.mark.auth
async def test_valid_user(
self,
Expand Down

0 comments on commit 78a66ee

Please sign in to comment.