Skip to content

Commit

Permalink
Add version and more URLs to index document
Browse files Browse the repository at this point in the history
plus supporting `python -m cog.http.server --version` for improved
runtime inspectability.
  • Loading branch information
meatballhat committed Oct 31, 2024
1 parent eb04c7b commit de13e89
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
37 changes: 32 additions & 5 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
)
from ..types import PYDANTIC_V2, CogConfig

try:
from .._version import __version__
except ImportError:
__version__ = "dev"

if PYDANTIC_V2:
from .helpers import (
unwrap_pydantic_serialization_iterators,
Expand Down Expand Up @@ -187,6 +192,17 @@ async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T": # pylint: disa

return wrapped

index_document = {
"cog_version": __version__,
"docs_url": "/docs",
"openapi_url": "/openapi.json",
"shutdown_url": "/shutdown",
"healthcheck_url": "/health-check",
"predictions_url": "/predictions",
"predictions_idempotent_url": "/predictions/{prediction_id}",
"predictions_cancel_url": "/predictions/{prediction_id}/cancel",
}

if "train" in config:
try:
trainer_ref = get_predictor_ref(config, "train")
Expand Down Expand Up @@ -281,6 +297,14 @@ def cancel_training(
) -> Any:
return cancel(training_id)

index_document.update(
{
"trainings_url": "/trainings",
"trainings_idempotent_url": "/trainings/{training_id}",
"trainings_cancel_url": "/trainings/{training_id}/cancel",
}
)

except Exception as e: # pylint: disable=broad-exception-caught
if isinstance(e, (PredictorNotSet, FileNotFoundError)) and not is_build:
pass # ignore missing train.py for backward compatibility with existing "bad" models in use
Expand Down Expand Up @@ -310,11 +334,7 @@ def shutdown() -> None:

@app.get("/")
async def root() -> Any:
return {
# "cog_version": "", # TODO
"docs_url": "/docs",
"openapi_url": "/openapi.json",
}
return index_document

@app.get("/health-check")
async def healthcheck() -> Any:
Expand Down Expand Up @@ -570,6 +590,9 @@ def _cpu_count() -> int:

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Cog HTTP server")
parser.add_argument(
"-v", "--version", action="store_true", help="Show version and exit"
)
parser.add_argument(
"--host",
dest="host",
Expand Down Expand Up @@ -608,6 +631,10 @@ def _cpu_count() -> int:
)
args = parser.parse_args()

if args.version:
print(f"cog.server.http {__version__}")
sys.exit(0)

# log level is configurable so we can make it quiet or verbose for `cog predict`
# cog predict --debug # -> debug
# cog predict # -> warning
Expand Down
18 changes: 18 additions & 0 deletions python/tests/server/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@
)


def test_index_document():
client = make_client(fixture_name="slow_setup")
resp = client.get("/")
data = resp.json()
for field in (
"cog_version",
"docs_url",
"openapi_url",
"shutdown_url",
"healthcheck_url",
"predictions_url",
"predictions_idempotent_url",
"predictions_cancel_url",
):
assert field in data
assert data[field] is not None


def test_setup_healthcheck():
client = make_client(fixture_name="slow_setup")
resp = client.get("/health-check")
Expand Down

0 comments on commit de13e89

Please sign in to comment.