Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Configure ruff linting and formatting #5

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 47 additions & 22 deletions app/api/api.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
import logging
import asyncio
import importlib
import logging
import os.path
import api.globals as cms_globals

from typing import Dict, Any, Optional
from concurrent.futures import ThreadPoolExecutor
from anyio.lowlevel import RunVar
from typing import Any, Dict, Optional

from anyio import CapacityLimiter
from anyio.lowlevel import RunVar
from fastapi import FastAPI, Request
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import RedirectResponse, HTMLResponse
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
from prometheus_fastapi_instrumentator import Instrumentator

from domain import Tags, TagsStreamable
from utils import get_settings

import api.globals as cms_globals
from api.auth.db import make_sure_db_and_tables
from api.auth.users import Props
from api.dependencies import ModelServiceDep
from api.utils import add_exception_handlers, add_rate_limiter
from domain import Tags, TagsStreamable
from management.tracker_client import TrackerClient
from utils import get_settings


logging.getLogger("asyncio").setLevel(logging.ERROR)
logger = logging.getLogger("cms")
Expand Down Expand Up @@ -87,25 +87,37 @@ def get_stream_server(msd_overwritten: Optional[ModelServiceDep] = None) -> Fast
return app


def _get_app(msd_overwritten: Optional[ModelServiceDep] = None, streamable: bool = False) -> FastAPI:
tags_metadata = [{"name": tag.name, "description": tag.value} for tag in (Tags if not streamable else TagsStreamable)]
def _get_app(
msd_overwritten: Optional[ModelServiceDep] = None, streamable: bool = False
) -> FastAPI:
tags_metadata = [
{"name": tag.name, "description": tag.value}
for tag in (Tags if not streamable else TagsStreamable)
]
config = get_settings()
app = FastAPI(title="CogStack ModelServe",
summary="A model serving and governance system for CogStack NLP solutions",
docs_url=None,
redoc_url=None,
debug=(config.DEBUG == "true"),
openapi_tags=tags_metadata)
app = FastAPI(
title="CogStack ModelServe",
summary="A model serving and governance system for CogStack NLP solutions",
docs_url=None,
redoc_url=None,
debug=(config.DEBUG == "true"),
openapi_tags=tags_metadata,
)
add_exception_handlers(app)
instrumentator = Instrumentator(
excluded_handlers=["/docs", "/redoc", "/metrics", "/openapi.json", "/favicon.ico", "none"]).instrument(app)
excluded_handlers=["/docs", "/redoc", "/metrics", "/openapi.json", "/favicon.ico", "none"]
).instrument(app)

if msd_overwritten is not None:
cms_globals.model_service_dep = msd_overwritten

cms_globals.props = Props(config.AUTH_USER_ENABLED == "true")

app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static")
app.mount(
"/static",
StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")),
name="static",
)

@app.on_event("startup")
async def on_startup() -> None:
Expand Down Expand Up @@ -160,8 +172,11 @@ def custom_openapi() -> Dict[str, Any]:
openapi_schema = get_openapi(
title=f"{cms_globals.model_service_dep().model_name} APIs",
version=cms_globals.model_service_dep().api_version,
description="by CogStack ModelServe, a model serving and governance system for CogStack NLP solutions.",
routes=app.routes
description=(
"by CogStack ModelServe, a model serving and governance system for CogStack NLP"
" solutions."
),
routes=app.routes,
)
openapi_schema["info"]["x-logo"] = {
"url": "https://avatars.githubusercontent.com/u/28688163?s=200&v=4"
Expand Down Expand Up @@ -189,69 +204,79 @@ def custom_openapi() -> Dict[str, Any]:

def _load_auth_router(app: FastAPI) -> FastAPI:
from api.routers import authentication

importlib.reload(authentication)
app.include_router(authentication.router)
return app


def _load_model_card(app: FastAPI) -> FastAPI:
from api.routers import model_card

importlib.reload(model_card)
app.include_router(model_card.router)
return app


def _load_invocation_router(app: FastAPI) -> FastAPI:
from api.routers import invocation

importlib.reload(invocation)
app.include_router(invocation.router)
return app


def _load_supervised_training_router(app: FastAPI) -> FastAPI:
from api.routers import supervised_training

importlib.reload(supervised_training)
app.include_router(supervised_training.router)
return app


def _load_evaluation_router(app: FastAPI) -> FastAPI:
from api.routers import evaluation

importlib.reload(evaluation)
app.include_router(evaluation.router)
return app


def _load_preview_router(app: FastAPI) -> FastAPI:
from api.routers import preview

importlib.reload(preview)
app.include_router(preview.router)
return app


def _load_unsupervised_training_router(app: FastAPI) -> FastAPI:
from api.routers import unsupervised_training

importlib.reload(unsupervised_training)
app.include_router(unsupervised_training.router)
return app


def _load_metacat_training_router(app: FastAPI) -> FastAPI:
from api.routers import metacat_training

importlib.reload(metacat_training)
app.include_router(metacat_training.router)
return app


def _load_health_check_router(app: FastAPI) -> FastAPI:
from api.routers import health_check

importlib.reload(health_check)
app.include_router(health_check.router)
return app


def _load_stream_router(app: FastAPI) -> FastAPI:
from api.routers import stream

importlib.reload(stream)
app.include_router(stream.router, prefix="/stream")
return app
25 changes: 19 additions & 6 deletions app/api/auth/backends.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from functools import lru_cache
from typing import List
from fastapi_users.authentication.transport.base import Transport

from fastapi_users.authentication import (
AuthenticationBackend,
BearerTransport,
CookieTransport,
JWTStrategy,
)
from fastapi_users.authentication.strategy.base import Strategy
from fastapi_users.authentication import BearerTransport, JWTStrategy
from fastapi_users.authentication import AuthenticationBackend, CookieTransport
from fastapi_users.authentication.transport.base import Transport

from utils import get_settings


@lru_cache
def get_backends() -> List[AuthenticationBackend]:
return [
AuthenticationBackend(name="jwt", transport=_get_bearer_transport(), get_strategy=_get_strategy),
AuthenticationBackend(name="cookie", transport=_get_cookie_transport(), get_strategy=_get_strategy),
AuthenticationBackend(
name="jwt", transport=_get_bearer_transport(), get_strategy=_get_strategy
),
AuthenticationBackend(
name="cookie", transport=_get_cookie_transport(), get_strategy=_get_strategy
),
]


Expand All @@ -24,4 +34,7 @@ def _get_cookie_transport() -> Transport:


def _get_strategy() -> Strategy:
return JWTStrategy(secret=get_settings().AUTH_JWT_SECRET, lifetime_seconds=get_settings().AUTH_ACCESS_TOKEN_EXPIRE_SECONDS)
return JWTStrategy(
secret=get_settings().AUTH_JWT_SECRET,
lifetime_seconds=get_settings().AUTH_ACCESS_TOKEN_EXPIRE_SECONDS,
)
5 changes: 4 additions & 1 deletion app/api/auth/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase

from utils import get_settings


Expand All @@ -29,5 +30,7 @@ async def make_sure_db_and_tables() -> None:
await conn.run_sync(Base.metadata.create_all)


async def get_user_db(session: AsyncSession = Depends(_get_async_session)) -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
async def get_user_db(
session: AsyncSession = Depends(_get_async_session),
) -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
yield SQLAlchemyUserDatabase(session, User)
30 changes: 20 additions & 10 deletions app/api/auth/users.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import uuid
import logging
from typing import Optional, AsyncGenerator, List, Callable
import uuid
from typing import AsyncGenerator, Callable, List, Optional

from fastapi import Depends, Request
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin
from fastapi_users.db import SQLAlchemyUserDatabase
from fastapi_users.authentication import AuthenticationBackend
from api.auth.db import User, get_user_db
from api.auth.backends import get_backends
from fastapi_users.db import SQLAlchemyUserDatabase

from utils import get_settings

from api.auth.backends import get_backends
from api.auth.db import User, get_user_db

logger = logging.getLogger("cms")


Expand All @@ -19,26 +22,33 @@ class CmsUserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_register(self, user: User, request: Optional[Request] = None) -> None:
logger.info("User %s has registered.", user.id)

async def on_after_forgot_password(self, user: User, token: str, request: Optional[Request] = None) -> None:
async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
logger.info("User %s has forgot their password. Reset token: %s", user.id, token)

async def on_after_request_verify(self, user: User, token: str, request: Optional[Request] = None) -> None:
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
logger.info("Verification requested for user %s. Verification token: %s", user.id, token)


async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)) -> AsyncGenerator:
async def get_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
) -> AsyncGenerator:
yield CmsUserManager(user_db)


class Props(object):

def __init__(self, auth_user_enabled: bool) -> None:
self._auth_backends: List = []
self._fastapi_users = None
self._current_active_user = lambda: None
if auth_user_enabled:
self._auth_backends = get_backends()
self._fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, self.auth_backends)
self._fastapi_users = FastAPIUsers[User, uuid.UUID](
get_user_manager, self.auth_backends
)
self._current_active_user = self._fastapi_users.current_user(active=True)

@property
Expand Down
21 changes: 12 additions & 9 deletions app/api/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import logging
import re
from typing import Union
from typing_extensions import Annotated
from typing import Optional, Union

from fastapi import HTTPException, Query
from starlette.status import HTTP_400_BAD_REQUEST
from typing_extensions import Annotated

from typing import Optional
from config import Settings
from registry import model_service_registry
from model_services.base import AbstractModelService

from management.model_manager import ModelManager
from model_services.base import AbstractModelService

TRACKING_ID_REGEX = re.compile(r"^[a-zA-Z0-9][\w\-]{0,255}$")

logger = logging.getLogger("cms")


class ModelServiceDep(object):

@property
def model_service(self) -> AbstractModelService:
return self._model_sevice
Expand All @@ -41,12 +40,11 @@ def __call__(self) -> AbstractModelService:
self._model_sevice = model_service_registry[self._model_type](self._config)
else:
logger.error("Unknown model type: %s", self._model_type)
exit(1) # throw an exception?
exit(1) # throw an exception?
return self._model_sevice


class ModelManagerDep(object):

def __init__(self, model_service: AbstractModelService) -> None:
self._model_manager = ModelManager(model_service.__class__, model_service.service_config)
self._model_manager.model_service = model_service
Expand All @@ -56,11 +54,16 @@ def __call__(self) -> ModelManager:


def validate_tracking_id(
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the requested task")] = None,
tracking_id: Annotated[
Union[str, None], Query(description="The tracking ID of the requested task")
] = None,
) -> Union[str, None]:
if tracking_id is not None and TRACKING_ID_REGEX.match(tracking_id) is None:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail=f"Invalid tracking ID '{tracking_id}', must be an alphanumeric string of length 1 to 256",
detail=(
f"Invalid tracking ID '{tracking_id}',"
" must be an alphanumeric string of length 1 to 256"
),
)
return tracking_id
6 changes: 5 additions & 1 deletion app/api/routers/authentication.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging
import api.globals as cms_globals

from fastapi import APIRouter

from domain import Tags

import api.globals as cms_globals

router = APIRouter()
logger = logging.getLogger("cms")

Expand Down
Loading
Loading