diff --git a/getgather/auth/auth.py b/getgather/auth/auth.py index 9b67458e..79ef2c0d 100644 --- a/getgather/auth/auth.py +++ b/getgather/auth/auth.py @@ -1,3 +1,4 @@ +import re from typing import cast from fastapi import FastAPI @@ -6,7 +7,8 @@ from mcp.server.auth.middleware.auth_context import AuthContextMiddleware from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import TokenVerifier -from pydantic import BaseModel, field_validator +from nanoid import generate +from pydantic import BaseModel, field_validator, model_validator from starlette.datastructures import Headers from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -14,7 +16,7 @@ from starlette.types import Receive, Scope, Send from getgather.auth.provider import CustomOAuthProvider -from getgather.config import settings +from getgather.config import FRIENDLY_CHARS, settings class RequireAuthMiddlewareCustom(RequireAuthMiddleware): @@ -99,14 +101,32 @@ class AuthUser(BaseModel): @field_validator("auth_provider") @classmethod def validate_auth_provider(cls, v: str) -> str: - if v not in [settings.FIRST_PARTY_OAUTH_PROVIDER_NAME, "google"]: + valid_providers = ( + [settings.FIRST_PARTY_OAUTH_PROVIDER_NAME, "google"] + if settings.auth_enabled + else [NO_AUTH_PROVIDER] + ) + + if v not in valid_providers: raise ValueError(f"Invalid auth provider: {v}") return v + @model_validator(mode="after") + def validate_user_id(self) -> "AuthUser": + if len(self.user_id) > 54: + raise ValueError(f"User id is too long: {self.user_id}") + if not re.match(r"^[a-z0-9-]+$", self.user_id): + raise ValueError(f"User id contains invalid characters: {self.user_id}") + return self + @property def user_id(self) -> str: - """Unique user name combining login and auth provider""" - return f"{self.sub}.{self.auth_provider}" + """ + Unique user name combining login and auth provider. + Only numbers, lowercase letters and dashes are allowed. + Maximum length is 54 characters. + """ + return f"{self.sub}-{self.auth_provider}" @classmethod def from_user_id(cls, user_id: str) -> "AuthUser": @@ -121,8 +141,7 @@ def dump(self): def get_auth_user() -> AuthUser: if not settings.auth_enabled: - # for testing only when auth is disabled - return AuthUser(sub="test_user", auth_provider=settings.FIRST_PARTY_OAUTH_PROVIDER_NAME) + return _get_user_for_no_auth() token = get_access_token() if not token: @@ -137,3 +156,12 @@ def get_auth_user() -> AuthUser: raise RuntimeError("Missing sub or provider in auth token") return AuthUser(sub=sub, auth_provider=provider, name=name, email=email, app_name=app_name) + + +NO_AUTH_PROVIDER = "noauth" + + +def _get_user_for_no_auth() -> AuthUser: + """Fake auth user for when auth is disabled to keep the code consistent.""" + sub = generate(FRIENDLY_CHARS, 6) + return AuthUser(sub=sub, auth_provider=NO_AUTH_PROVIDER) diff --git a/getgather/config.py b/getgather/config.py index 37fea803..166fa1f2 100644 --- a/getgather/config.py +++ b/getgather/config.py @@ -6,6 +6,8 @@ from getgather.auth.settings import AuthSettings from getgather.browser.proxy_types import ProxyConfig +FRIENDLY_CHARS = "23456789abcdefghijkmnpqrstuvwxyz" + PROJECT_DIR = Path(__file__).resolve().parent.parent diff --git a/getgather/zen_distill.py b/getgather/zen_distill.py index a9c1403d..b38524e4 100644 --- a/getgather/zen_distill.py +++ b/getgather/zen_distill.py @@ -22,10 +22,11 @@ from nanoid import generate from zendriver.core.connection import ProtocolException +from getgather.auth.auth import get_auth_user from getgather.browser.chromefleet import create_remote_browser, terminate_remote_browser from getgather.browser.proxy import setup_proxy from getgather.browser.resource_blocker import blocked_domains, load_blocklists, should_be_blocked -from getgather.config import settings +from getgather.config import FRIENDLY_CHARS, settings from getgather.container_utils import check_x_server_available from getgather.logs import logger from getgather.mcp.browser import browser_manager @@ -358,9 +359,6 @@ async def auth_challenge_handler(event: zd.cdp.fetch.AuthRequired): await page.send(zd.cdp.fetch.enable(handle_auth_requests=True)) -FRIENDLY_CHARS = "23456789abcdefghijkmnpqrstuvwxyz" - - async def _create_zendriver_browser(id: str | None = None) -> zd.Browser: if id is None: id = nanoid.generate(FRIENDLY_CHARS, 6) @@ -1146,10 +1144,10 @@ async def short_lived_mcp_tool( ) -> tuple[bool, dict[str, Any]]: path = os.path.join(os.path.dirname(__file__), "mcp", "patterns", pattern_wildcard) patterns = load_distillation_patterns(path) - id = generate(FRIENDLY_CHARS, 6) - browser = await create_remote_browser(browser_id=id) + browser_id = get_auth_user().user_id + browser = await create_remote_browser(browser_id=browser_id) terminated, distilled, converted = await run_distillation_loop(location, patterns, browser) - await terminate_remote_browser(browser_id=id) + await terminate_remote_browser(browser_id=browser_id) result: dict[str, Any] = {result_key: converted if converted else distilled} if result_key in result: diff --git a/tests/manual.py b/tests/manual.py index 8415fa7f..bf0a345d 100644 --- a/tests/manual.py +++ b/tests/manual.py @@ -26,7 +26,7 @@ async def call_tool( result = None error = None try: - async with Client(url, auth=(token or "oauth"), timeout=60) as client: + async with Client(url, auth=(token or "oauth"), timeout=180) as client: result = await client.call_tool_mcp(tool, {}) except Exception: error = traceback.format_exc()