diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 4de384fa00..2d494ad8d1 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -38,6 +38,11 @@ ) from ..types import PYDANTIC_V2, CogConfig +try: + from .._version import __version__ +except ImportError: + __version__ = "dev" + if PYDANTIC_V2: from .helpers import ( unwrap_pydantic_serialization_iterators, @@ -187,6 +192,17 @@ async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T": # pylint: disa return wrapped + index_document = { + "cog_version": __version__, + "docs_url": "/docs", + "openapi_url": "/openapi.json", + "shutdown_url": "/shutdown", + "healthcheck_url": "/health-check", + "predictions_url": "/predictions", + "predictions_idempotent_url": "/predictions/{prediction_id}", + "predictions_cancel_url": "/predictions/{prediction_id}/cancel", + } + if "train" in config: try: trainer_ref = get_predictor_ref(config, "train") @@ -281,6 +297,14 @@ def cancel_training( ) -> Any: return cancel(training_id) + index_document.update( + { + "trainings_url": "/trainings", + "trainings_idempotent_url": "/trainings/{training_id}", + "trainings_cancel_url": "/trainings/{training_id}/cancel", + } + ) + except Exception as e: # pylint: disable=broad-exception-caught if isinstance(e, (PredictorNotSet, FileNotFoundError)) and not is_build: pass # ignore missing train.py for backward compatibility with existing "bad" models in use @@ -310,11 +334,7 @@ def shutdown() -> None: @app.get("/") async def root() -> Any: - return { - # "cog_version": "", # TODO - "docs_url": "/docs", - "openapi_url": "/openapi.json", - } + return index_document @app.get("/health-check") async def healthcheck() -> Any: @@ -570,6 +590,9 @@ def _cpu_count() -> int: if __name__ == "__main__": parser = argparse.ArgumentParser(description="Cog HTTP server") + parser.add_argument( + "-v", "--version", action="store_true", help="Show version and exit" + ) parser.add_argument( "--host", dest="host", @@ -608,6 +631,10 @@ def _cpu_count() -> int: ) args = parser.parse_args() + if args.version: + print(f"cog.server.http {__version__}") + sys.exit(0) + # log level is configurable so we can make it quiet or verbose for `cog predict` # cog predict --debug # -> debug # cog predict # -> warning diff --git a/python/tests/server/test_http.py b/python/tests/server/test_http.py index 0bae95f6b6..515a0e5cb7 100644 --- a/python/tests/server/test_http.py +++ b/python/tests/server/test_http.py @@ -18,6 +18,24 @@ ) +def test_index_document(): + client = make_client(fixture_name="slow_setup") + resp = client.get("/") + data = resp.json() + for field in ( + "cog_version", + "docs_url", + "openapi_url", + "shutdown_url", + "healthcheck_url", + "predictions_url", + "predictions_idempotent_url", + "predictions_cancel_url", + ): + assert field in data + assert data[field] is not None + + def test_setup_healthcheck(): client = make_client(fixture_name="slow_setup") resp = client.get("/health-check")