Skip to content

Commit

Permalink
merge with upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
baixiac committed Jul 1, 2024
2 parents ea0da01 + d2052d5 commit 7efe62e
Show file tree
Hide file tree
Showing 75 changed files with 1,804 additions and 579 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/api-docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
run: |
typer app/cli/cli.py utils docs --output cli.md
if ! cmp -s cli.md app/cli/README.md; then
echo "The CMS CLI doc needs updating"
echo "The CMS CLI README needs updating"
exit 1
fi
- name: Generate API docs
Expand Down
101 changes: 57 additions & 44 deletions app/api/api.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,76 @@
import logging
import asyncio
import importlib
import logging
import os.path
import api.globals as cms_globals

from typing import Dict, Callable, Any, Optional
from urllib.parse import urlencode
from typing import Dict, Any, Optional
from concurrent.futures import ThreadPoolExecutor
from anyio.lowlevel import RunVar
from anyio import CapacityLimiter
from fastapi import FastAPI, Request, Response
from fastapi import FastAPI, Request
from fastapi.openapi.utils import get_openapi
from fastapi.responses import RedirectResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
from starlette.datastructures import QueryParams
from prometheus_fastapi_instrumentator import Instrumentator

from api.auth.db import make_sure_db_and_tables
from api.auth.users import Props
from api.dependencies import ModelServiceDep
from api.utils import add_exception_handlers, add_middlewares
from domain import Tags
from domain import Tags, TagsStreamable
from management.tracker_client import TrackerClient
from utils import get_settings


logger = logging.getLogger(__name__)
logging.getLogger("asyncio").setLevel(logging.ERROR)


def get_model_server(msd_overwritten: Optional[ModelServiceDep] = None) -> FastAPI:
tags_metadata = [{"name": tag.name, "description": tag.value} for tag in Tags]
app = _get_app(msd_overwritten)
config = get_settings()
add_middlewares(app, config)

app = _load_health_check_router(app)

if config.AUTH_USER_ENABLED == "true":
app = _load_auth_router(app)

app = _load_invocation_router(app)

if config.ENABLE_TRAINING_APIS == "true":
app = _load_supervised_training_router(app)
if config.DISABLE_UNSUPERVISED_TRAINING != "true":
app = _load_unsupervised_training_router(app)
if config.DISABLE_METACAT_TRAINING != "true":
app = _load_metacat_training_router(app)

if config.ENABLE_EVALUATION_APIS == "true":
app = _load_evaluation_router(app)
if config.ENABLE_PREVIEWS_APIS == "true":
app = _load_preview_router(app)

return app


def get_stream_server(msd_overwritten: Optional[ModelServiceDep] = None) -> FastAPI:
app = _get_app(msd_overwritten, streamable=True)
config = get_settings()
add_middlewares(app, config, streamable=True)

app = _load_health_check_router(app)

if config.AUTH_USER_ENABLED == "true":
app = _load_auth_router(app)

app = _load_stream_router(app)

return app


def _get_app(msd_overwritten: Optional[ModelServiceDep] = None, streamable: bool = False) -> FastAPI:
tags_metadata = [{"name": tag.name, "description": tag.value} for tag in (Tags if not streamable else TagsStreamable)]
config = get_settings()
app = FastAPI(title="CogStack ModelServe",
summary="A model serving and governance system for CogStack NLP solutions",
Expand All @@ -39,9 +79,9 @@ def get_model_server(msd_overwritten: Optional[ModelServiceDep] = None) -> FastA
debug=(config.DEBUG == "true"),
openapi_tags=tags_metadata)
add_exception_handlers(app)
add_middlewares(app, config)

instrumentator = Instrumentator(excluded_handlers=["/docs", "/redoc", "/metrics", "/openapi.json", "/favicon.ico", "none"]).instrument(app)
instrumentator = Instrumentator(
excluded_handlers=["/docs", "/redoc", "/metrics", "/openapi.json", "/favicon.ico", "none"]).instrument(app)

if msd_overwritten is not None:
cms_globals.model_service_dep = msd_overwritten
Expand All @@ -59,19 +99,6 @@ async def on_startup() -> None:
if config.AUTH_USER_ENABLED == "true":
await make_sure_db_and_tables()

@app.middleware("http")
async def verify_blank_query_params(request: Request, call_next: Callable) -> Response:
scope = request.scope
if request.method != "POST":
return await call_next(request)
if not scope or not scope.get("query_string"):
return await call_next(request)

query_params = QueryParams(scope["query_string"])

scope["query_string"] = urlencode([(k, v) for k, v in query_params._list if v and v.strip()]).encode("latin-1")
return await call_next(Request(scope, request.receive, request._send))

@app.get("/docs", include_in_schema=False)
async def swagger_doc(req: Request) -> HTMLResponse:
root_path = req.scope.get("root_path", "").rstrip("/")
Expand Down Expand Up @@ -135,27 +162,6 @@ def custom_openapi() -> Dict[str, Any]:
app.openapi_schema = openapi_schema
return app.openapi_schema

app = _load_health_check_router(app)

if config.AUTH_USER_ENABLED == "true":
app = _load_auth_router(app)

app = _load_invocation_router(app)

if config.ENABLE_TRAINING_APIS == "true":
app = _load_supervised_training_router(app)
if config.DISABLE_UNSUPERVISED_TRAINING != "true":
app = _load_unsupervised_training_router(app)
if config.DISABLE_METACAT_TRAINING != "true":
app = _load_metacat_training_router(app)

if config.ENABLE_EVALUATION_APIS == "true":
app = _load_evaluation_router(app)
if config.ENABLE_PREVIEWS_APIS == "true":
app = _load_preview_router(app)

app.openapi = custom_openapi # type: ignore

return app


Expand Down Expand Up @@ -213,3 +219,10 @@ def _load_health_check_router(app: FastAPI) -> FastAPI:
importlib.reload(health_check)
app.include_router(health_check.router)
return app


def _load_stream_router(app: FastAPI) -> FastAPI:
from api.routers import stream
importlib.reload(stream)
app.include_router(stream.router)
return app
2 changes: 1 addition & 1 deletion app/api/auth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ key = Fernet.generate_key()
print(key.decode("utf-8"))
```

Your CMS users can be stored either in a local file-based database (e.g., `<DATABASE_URL>` set to `sqlite+aiosqlite:///./cms-users.db` when SQLite is used) or in a remote one (e.g., `<DATABASE_URL>` set to `postgresql+asyncpg://<AUTH_DB_USERNAME>:<AUTH_DB_PASSWORD>@auth-db:5432/cms-users` when you have an [auth-db container](./../../docker-compose-auth.yml) running).
Your CMS users can be stored either in a local file-based database (e.g., `<DATABASE_URL>` set to `sqlite+aiosqlite:///./cms-users.db` when SQLite is used) or in a remote one (e.g., `<DATABASE_URL>` set to `postgresql+asyncpg://<AUTH_DB_USERNAME>:<AUTH_DB_PASSWORD>@auth-db:5432/cms-users` when you have an [auth-db container](./../../../docker-compose-auth.yml) running).

Currently, user management tasks such as registration and removal are performed by the admin. As an administrator, in order to create a new user, you need to log into the database and create a new record by running:
```sql
Expand Down
2 changes: 1 addition & 1 deletion app/api/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from api.auth.backends import get_backends
from utils import get_settings

logger = logging.getLogger(__name__)
logger = logging.getLogger("cms")


class _UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
Expand Down
13 changes: 12 additions & 1 deletion app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from config import Settings
from registry import model_service_registry
from model_services.base import AbstractModelService
from management.model_manager import ModelManager

logger = logging.getLogger(__name__)
logger = logging.getLogger("cms")


class ModelServiceDep(object):
Expand Down Expand Up @@ -34,3 +35,13 @@ def __call__(self) -> AbstractModelService:
logger.error(f"Unknown model type: {self._model_type}")
exit(1) # throw an exception?
return self._model_sevice


class ModelManagerDep(object):

def __init__(self, model_service: AbstractModelService) -> None:
self._model_manager = ModelManager(model_service.__class__, model_service.service_config)
self._model_manager.model_service = model_service

def __call__(self) -> ModelManager:
return self._model_manager
1 change: 1 addition & 0 deletions app/api/globals.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
model_service_dep = None
model_manager_dep = None
props = None
2 changes: 1 addition & 1 deletion app/api/routers/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import APIRouter
from domain import Tags
router = APIRouter()
logger = logging.getLogger(__name__)
logger = logging.getLogger("cms")

for auth_backend in cms_globals.props.auth_backends:
router.include_router(
Expand Down
36 changes: 33 additions & 3 deletions app/api/routers/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
get_iaa_scores_per_doc,
get_iaa_scores_per_span,
concat_trainer_exports,
get_stats_from_trainer_export,
)
from exception import AnnotationException
from utils import filter_by_concept_ids

router = APIRouter()
logger = logging.getLogger(__name__)
logger = logging.getLogger("cms")


@router.post("/evaluate",
Expand Down Expand Up @@ -91,7 +92,7 @@ def get_sanity_check_with_trainer_export(request: Request,
stream = io.StringIO()
metrics.to_csv(stream, index=False)
response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv")
response.headers["Content-Disposition"] = f'attachment ; filename="evaluation_{str(uuid.uuid4())}.csv"'
response.headers["Content-Disposition"] = f'attachment ; filename="sanity_check_{str(uuid.uuid4())}.csv"'
return response


Expand Down Expand Up @@ -129,7 +130,7 @@ def get_inter_annotator_agreement_scores(request: Request,
stream = io.StringIO()
iaa_scores.to_csv(stream, index=False)
response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv")
response.headers["Content-Disposition"] = f'attachment ; filename="evaluation_{str(uuid.uuid4())}.csv"'
response.headers["Content-Disposition"] = f'attachment ; filename="iaa_{str(uuid.uuid4())}.csv"'
return response


Expand All @@ -153,3 +154,32 @@ def get_concatenated_trainer_exports(request: Request,
response = JSONResponse(concatenated, media_type="application/json; charset=utf-8")
response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{str(uuid.uuid4())}.json"'
return response


@router.post("/annotation-stats",
tags=[Tags.Evaluating.name],
response_class=StreamingResponse,
dependencies=[Depends(cms_globals.props.current_active_user)],
description="Get annotation stats of trainer export files")
def get_annotation_stats(request: Request,
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")]) -> StreamingResponse:
files = []
file_names = []
for te in trainer_export:
temp_te = tempfile.NamedTemporaryFile()
for line in te.file:
temp_te.write(line)
temp_te.flush()
files.append(temp_te)
file_names.append("" if te.filename is None else te.filename)
try:
concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False)
finally:
for file in files:
file.close()
stats = get_stats_from_trainer_export(concatenated, return_df=True)
stream = io.StringIO()
stats.to_csv(stream, index=False)
response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv")
response.headers["Content-Disposition"] = f'attachment ; filename="stats_{str(uuid.uuid4())}.csv"'
return response
Loading

0 comments on commit 7efe62e

Please sign in to comment.