Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions getgather/auth/auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import cast

from fastapi import FastAPI
Expand All @@ -6,15 +7,16 @@
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
from mcp.server.auth.provider import TokenVerifier
from pydantic import BaseModel, field_validator
from nanoid import generate
from pydantic import BaseModel, field_validator, model_validator
from starlette.datastructures import Headers
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.responses import RedirectResponse
from starlette.types import Receive, Scope, Send

from getgather.auth.provider import CustomOAuthProvider
from getgather.config import settings
from getgather.config import FRIENDLY_CHARS, settings


class RequireAuthMiddlewareCustom(RequireAuthMiddleware):
Expand Down Expand Up @@ -99,14 +101,32 @@ class AuthUser(BaseModel):
@field_validator("auth_provider")
@classmethod
def validate_auth_provider(cls, v: str) -> str:
if v not in [settings.FIRST_PARTY_OAUTH_PROVIDER_NAME, "google"]:
valid_providers = (
[settings.FIRST_PARTY_OAUTH_PROVIDER_NAME, "google"]
if settings.auth_enabled
else [NO_AUTH_PROVIDER]
)

if v not in valid_providers:
raise ValueError(f"Invalid auth provider: {v}")
return v

@model_validator(mode="after")
def validate_user_id(self) -> "AuthUser":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you want to reverse parse it to check if the auth_provider is valid?

if len(self.user_id) > 54:
raise ValueError(f"User id is too long: {self.user_id}")
if not re.match(r"^[a-z0-9-]+$", self.user_id):
raise ValueError(f"User id contains invalid characters: {self.user_id}")
return self

@property
def user_id(self) -> str:
"""Unique user name combining login and auth provider"""
return f"{self.sub}.{self.auth_provider}"
"""
Unique user name combining login and auth provider.
Only numbers, lowercase letters and dashes are allowed.
Maximum length is 54 characters.
"""
return f"{self.sub}-{self.auth_provider}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why sub-auth provider rather than auth provider - sub? Just curious

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no strong preference, more like firstname - lastname


@classmethod
def from_user_id(cls, user_id: str) -> "AuthUser":
Expand All @@ -121,8 +141,7 @@ def dump(self):

def get_auth_user() -> AuthUser:
if not settings.auth_enabled:
# for testing only when auth is disabled
return AuthUser(sub="test_user", auth_provider=settings.FIRST_PARTY_OAUTH_PROVIDER_NAME)
return _get_user_for_no_auth()

token = get_access_token()
if not token:
Expand All @@ -137,3 +156,12 @@ def get_auth_user() -> AuthUser:
raise RuntimeError("Missing sub or provider in auth token")

return AuthUser(sub=sub, auth_provider=provider, name=name, email=email, app_name=app_name)


NO_AUTH_PROVIDER = "noauth"


def _get_user_for_no_auth() -> AuthUser:
"""Fake auth user for when auth is disabled to keep the code consistent."""
sub = generate(FRIENDLY_CHARS, 6)
return AuthUser(sub=sub, auth_provider=NO_AUTH_PROVIDER)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@faisalive @broerjuang Please confirm that this new format, e.g. xyz123-noauth will not break signin etc, for demo apps.

Copy link
Contributor Author

@bin-ario bin-ario Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to gateway, demo apps will set the token header to ario_APPKEY_xyz123
then the user_id will be like xyz123-ario

where ario is the first party auth provider name (used to getgather), APPKEY is pre-defined random key, xyz123 is generated by express session id

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool so it sounds like the user_id will always avoid the underscores which I think is needed

2 changes: 2 additions & 0 deletions getgather/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from getgather.auth.settings import AuthSettings
from getgather.browser.proxy_types import ProxyConfig

FRIENDLY_CHARS = "23456789abcdefghijkmnpqrstuvwxyz"

PROJECT_DIR = Path(__file__).resolve().parent.parent


Expand Down
12 changes: 5 additions & 7 deletions getgather/zen_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
from nanoid import generate
from zendriver.core.connection import ProtocolException

from getgather.auth.auth import get_auth_user
from getgather.browser.chromefleet import create_remote_browser, terminate_remote_browser
from getgather.browser.proxy import setup_proxy
from getgather.browser.resource_blocker import blocked_domains, load_blocklists, should_be_blocked
from getgather.config import settings
from getgather.config import FRIENDLY_CHARS, settings
from getgather.container_utils import check_x_server_available
from getgather.logs import logger
from getgather.mcp.browser import browser_manager
Expand Down Expand Up @@ -358,9 +359,6 @@ async def auth_challenge_handler(event: zd.cdp.fetch.AuthRequired):
await page.send(zd.cdp.fetch.enable(handle_auth_requests=True))


FRIENDLY_CHARS = "23456789abcdefghijkmnpqrstuvwxyz"


async def _create_zendriver_browser(id: str | None = None) -> zd.Browser:
if id is None:
id = nanoid.generate(FRIENDLY_CHARS, 6)
Expand Down Expand Up @@ -1146,10 +1144,10 @@ async def short_lived_mcp_tool(
) -> tuple[bool, dict[str, Any]]:
path = os.path.join(os.path.dirname(__file__), "mcp", "patterns", pattern_wildcard)
patterns = load_distillation_patterns(path)
id = generate(FRIENDLY_CHARS, 6)
browser = await create_remote_browser(browser_id=id)
browser_id = get_auth_user().user_id
browser = await create_remote_browser(browser_id=browser_id)
terminated, distilled, converted = await run_distillation_loop(location, patterns, browser)
await terminate_remote_browser(browser_id=id)
await terminate_remote_browser(browser_id=browser_id)

result: dict[str, Any] = {result_key: converted if converted else distilled}
if result_key in result:
Expand Down
2 changes: 1 addition & 1 deletion tests/manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def call_tool(
result = None
error = None
try:
async with Client(url, auth=(token or "oauth"), timeout=60) as client:
async with Client(url, auth=(token or "oauth"), timeout=180) as client:
result = await client.call_tool_mcp(tool, {})
except Exception:
error = traceback.format_exc()
Expand Down