Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Config class #2042

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/cog/command/openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import json
from typing import Any, Dict, List, Union

from ..config import Config
from ..errors import CogError, ConfigDoesNotExist, PredictorNotSet
from ..predictor import load_config
from ..schema import Status
from ..server.http import create_app
from ..suppress_output import suppress_output
Expand Down Expand Up @@ -37,8 +37,7 @@ def remove_title_next_to_ref(
schema = {}
try:
with suppress_output():
config = load_config()
app = create_app(config, shutdown_event=None, is_build=True)
app = create_app(cog_config=Config(), shutdown_event=None, is_build=True)
if (
app.state.setup_result
and app.state.setup_result.status == Status.FAILED
Expand Down
143 changes: 143 additions & 0 deletions python/cog/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import os
import sys
import uuid
from typing import Optional, Tuple, Type

import structlog
import yaml
from pydantic import BaseModel

from .base_input import BaseInput
from .base_predictor import BasePredictor
from .code_xforms import load_module_from_string, strip_model_source_code
from .errors import ConfigDoesNotExist
from .mode import Mode
from .predictor import (
get_input_type,
get_output_type,
get_predictor,
get_training_input_type,
get_training_output_type,
load_full_predictor_from_file,
)
from .types import CogConfig

COG_YAML_FILE = "cog.yaml"
PREDICT_METHOD_NAME = "predict"
TRAIN_METHOD_NAME = "train"

log = structlog.get_logger("cog.config")


def _method_name_from_mode(mode: Mode) -> str:
if mode == Mode.PREDICT:
return PREDICT_METHOD_NAME
elif mode == Mode.TRAIN:
return TRAIN_METHOD_NAME
raise ValueError(f"Mode {mode} not recognised for method name mapping")


class Config:
"""A class for reading the cog.yaml properties."""

def __init__(self, config: Optional[CogConfig] = None) -> None:
self._config = config

@property
def _cog_config(self) -> CogConfig:
"""
Warning: Do not access this directly outside this class, instead
write an explicit public property and back it by an @env_property
to allow for the possibility of injecting the property you are
trying to read without relying on the underlying file.
"""
config = self._config
if config is None:
config_path = os.path.abspath(COG_YAML_FILE)
try:
with open(config_path, encoding="utf-8") as handle:
config = yaml.safe_load(handle)
except FileNotFoundError as e:
raise ConfigDoesNotExist(
f"Could not find {config_path}",
) from e
self._config = config
return config

@property
def predictor_predict_ref(self) -> Optional[str]:
"""Find the predictor ref for the predict mode."""
return self._cog_config.get(str(Mode.PREDICT))

@property
def predictor_train_ref(self) -> Optional[str]:
"""Find the predictor ref for the train mode."""
return self._cog_config.get(str(Mode.TRAIN))

@property
def requires_gpu(self) -> bool:
"""Whether this cog requires the use of a GPU."""
return bool(self._cog_config.get("build", {}).get("gpu", False))

def _predictor_code(
self,
module_path: str,
class_name: str,
method_name: str,
mode: Mode,
module_name: str,
) -> Optional[str]:
if sys.version_info >= (3, 9):
with open(module_path, encoding="utf-8") as file:
return strip_model_source_code(file.read(), [class_name], [method_name])
else:
log.debug(f"[{module_name}] cannot use fast loader as current Python <3.9")
return None

def _load_predictor_for_types(
self, ref: str, method_name: str, mode: Mode
) -> BasePredictor:
module_path, class_name = ref.split(":", 1)
module_name = os.path.basename(module_path).split(".py", 1)[0]
code = self._predictor_code(
module_path, class_name, method_name, mode, module_name
)
module = None
if code is not None:
try:
module = load_module_from_string(uuid.uuid4().hex, code)
except Exception as e: # pylint: disable=broad-exception-caught
log.info(f"[{module_name}] fast loader failed: {e}")
if module is None:
log.debug(f"[{module_name}] falling back to slow loader")
module = load_full_predictor_from_file(module_path, module_name)
return get_predictor(module, class_name)

def get_predictor_ref(self, mode: Mode) -> str:
"""Find the predictor reference for a given mode."""
predictor_ref = None
if mode == Mode.PREDICT:
predictor_ref = self.predictor_predict_ref
elif mode == Mode.TRAIN:
predictor_ref = self.predictor_train_ref
if predictor_ref is None:
raise ValueError(
f"Can't run predictions: '{mode}' option not found in cog.yaml"
)
return predictor_ref

def get_predictor_types(
self, mode: Mode
) -> Tuple[Type[BaseInput], Type[BaseModel]]:
"""Find the input and output types of a predictor."""
predictor_ref = self.get_predictor_ref(mode=mode)
predictor = self._load_predictor_for_types(
predictor_ref, _method_name_from_mode(mode=mode), mode
)
if mode == Mode.PREDICT:
return get_input_type(predictor), get_output_type(predictor)
elif mode == Mode.TRAIN:
return get_training_input_type(predictor), get_training_output_type(
predictor
)
raise ValueError(f"Mode {mode} not found for generating input/output types.")
11 changes: 11 additions & 0 deletions python/cog/mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from enum import Enum


class Mode(Enum):
"""Enumeration over the different prediction modes."""

PREDICT = "predict"
TRAIN = "train"

def __str__(self) -> str:
return str(self.value)
61 changes: 0 additions & 61 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

import pydantic
import structlog
import yaml
from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo

Expand All @@ -40,10 +39,8 @@
from .base_input import BaseInput
from .base_predictor import BasePredictor
from .code_xforms import load_module_from_string, strip_model_source_code
from .errors import ConfigDoesNotExist, PredictorNotSet
from .types import (
PYDANTIC_V2,
CogConfig,
Input,
)
from .types import (
Expand Down Expand Up @@ -142,43 +139,6 @@ def get_weights_type(setup_function: Callable[[Any], None]) -> Optional[Any]:
return Type


def load_config() -> CogConfig:
"""
Reads cog.yaml and returns it as a typed dict.
"""
# Assumes the working directory is /src
config_path = os.path.abspath("cog.yaml")
try:
with open(config_path, encoding="utf-8") as fh:
config = yaml.safe_load(fh)
except FileNotFoundError as e:
raise ConfigDoesNotExist(
f"Could not find {config_path}",
) from e
return config


def load_predictor(config: CogConfig) -> BasePredictor:
"""
Constructs an instance of the user-defined Predictor class from a config.
"""

ref = get_predictor_ref(config)
return load_predictor_from_ref(ref)


def get_predictor_ref(config: CogConfig, mode: str = "predict") -> str:
if mode not in ["predict", "train"]:
raise ValueError(f"Invalid mode: {mode}")

if mode not in config:
raise PredictorNotSet(
f"Can't run predictions: '{mode}' option not found in cog.yaml"
)

return config[mode]


def load_full_predictor_from_file(
module_path: str, module_name: str
) -> types.ModuleType:
Expand Down Expand Up @@ -211,27 +171,6 @@ def get_predictor(module: types.ModuleType, class_name: str) -> Any:
return predictor


def load_slim_predictor_from_ref(ref: str, method_name: str) -> BasePredictor:
module_path, class_name = ref.split(":", 1)
module_name = os.path.basename(module_path).split(".py", 1)[0]
module = None
try:
if sys.version_info >= (3, 9):
module = load_slim_predictor_from_file(module_path, class_name, method_name)
if not module:
log.debug(f"[{module_name}] fast loader returned None")
else:
log.debug(f"[{module_name}] cannot use fast loader as current Python <3.9")
except Exception as e: # pylint: disable=broad-exception-caught
log.debug(f"[{module_name}] fast loader failed: {e}")
finally:
if not module:
log.debug(f"[{module_name}] falling back to slow loader")
module = load_full_predictor_from_file(module_path, module_name)
predictor = get_predictor(module, class_name)
return predictor


def load_predictor_from_ref(ref: str) -> BasePredictor:
module_path, class_name = ref.split(":", 1)
module_name = os.path.basename(module_path).split(".py", 1)[0]
Expand Down
56 changes: 20 additions & 36 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,13 @@
from pydantic import ValidationError

from .. import schema
from ..config import Config
from ..errors import PredictorNotSet
from ..files import upload_file
from ..json import upload_files
from ..logging import setup_logging
from ..predictor import (
get_input_type,
get_output_type,
get_predictor_ref,
get_training_input_type,
get_training_output_type,
load_config,
load_slim_predictor_from_ref,
)
from ..types import PYDANTIC_V2, CogConfig
from ..mode import Mode
from ..types import PYDANTIC_V2

try:
from .._version import __version__
Expand Down Expand Up @@ -117,11 +110,11 @@ async def healthcheck_startup_failed() -> Any:


def create_app( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements
config: CogConfig, # pylint: disable=redefined-outer-name
cog_config: Config,
shutdown_event: Optional[threading.Event], # pylint: disable=redefined-outer-name
threads: int = 1, # pylint: disable=redefined-outer-name
app_threads: Optional[int] = None,
upload_url: Optional[str] = None,
mode: str = "predict",
mode: Mode = Mode.PREDICT,
is_build: bool = False,
await_explicit_shutdown: bool = False, # pylint: disable=redefined-outer-name
) -> MyFastAPI:
Expand Down Expand Up @@ -163,16 +156,13 @@ async def start_shutdown() -> Any:
return JSONResponse({}, status_code=200)

try:
predictor_ref = get_predictor_ref(config, mode)
predictor = load_slim_predictor_from_ref(predictor_ref, "predict")
InputType = get_input_type(predictor) # pylint: disable=invalid-name
OutputType = get_output_type(predictor) # pylint: disable=invalid-name
InputType, OutputType = cog_config.get_predictor_types(mode=Mode.PREDICT)
except Exception: # pylint: disable=broad-exception-caught
msg = "Error while loading predictor:\n\n" + traceback.format_exc()
add_setup_failed_routes(app, started_at, msg)
return app

worker = make_worker(predictor_ref=predictor_ref)
worker = make_worker(predictor_ref=cog_config.get_predictor_ref(mode=mode))
runner = PredictionRunner(worker=worker)

class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)):
Expand All @@ -182,7 +172,9 @@ class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType
input_type=InputType, output_type=OutputType
)

http_semaphore = asyncio.Semaphore(threads)
if app_threads is None:
app_threads = 1 if cog_config.requires_gpu else _cpu_count()
http_semaphore = asyncio.Semaphore(app_threads)

def limited(f: "Callable[P, Awaitable[T]]") -> "Callable[P, Awaitable[T]]":
@functools.wraps(f)
Expand All @@ -203,12 +195,11 @@ async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T": # pylint: disa
"predictions_cancel_url": "/predictions/{prediction_id}/cancel",
}

if "train" in config:
if cog_config.predictor_train_ref:
try:
trainer_ref = get_predictor_ref(config, "train")
trainer = load_slim_predictor_from_ref(trainer_ref, "train")
TrainingInputType = get_training_input_type(trainer) # pylint: disable=invalid-name
TrainingOutputType = get_training_output_type(trainer) # pylint: disable=invalid-name
TrainingInputType, TrainingOutputType = cog_config.get_predictor_types(
Mode.TRAIN
)

class TrainingRequest(
schema.TrainingRequest.with_types(input_type=TrainingInputType)
Expand Down Expand Up @@ -624,9 +615,9 @@ def _cpu_count() -> int:
parser.add_argument(
"--x-mode",
dest="mode",
type=str,
default="predict",
choices=["predict", "train"],
type=Mode,
default=Mode.PREDICT,
choices=list(Mode),
help="Experimental: Run in 'predict' or 'train' mode",
)
args = parser.parse_args()
Expand All @@ -642,13 +633,6 @@ def _cpu_count() -> int:
log_level = logging.getLevelName(os.environ.get("COG_LOG_LEVEL", "INFO").upper())
setup_logging(log_level=log_level)

config = load_config()

threads = args.threads
if threads is None:
gpu_enabled = config.get("build", {}).get("gpu", False)
threads = 1 if gpu_enabled else _cpu_count()

shutdown_event = threading.Event()

await_explicit_shutdown = args.await_explicit_shutdown
Expand All @@ -658,9 +642,9 @@ def _cpu_count() -> int:
signal.signal(signal.SIGTERM, signal_set_event(shutdown_event))

app = create_app(
config=config,
cog_config=Config(),
shutdown_event=shutdown_event,
threads=threads,
app_threads=args.threads,
upload_url=args.upload_url,
mode=args.mode,
await_explicit_shutdown=await_explicit_shutdown,
Expand Down
Loading