diff --git a/.gitignore b/.gitignore index bd5163ee7a..46f781ab3a 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ docs/README.md docs/CONTRIBUTING.md venv base-image +flag_file diff --git a/python/cog/command/openapi_schema.py b/python/cog/command/openapi_schema.py index 7a260c7239..dd87293af3 100644 --- a/python/cog/command/openapi_schema.py +++ b/python/cog/command/openapi_schema.py @@ -7,8 +7,8 @@ import json from typing import Any, Dict, List, Union +from ..config import Config from ..errors import CogError, ConfigDoesNotExist, PredictorNotSet -from ..predictor import load_config from ..schema import Status from ..server.http import create_app from ..suppress_output import suppress_output @@ -37,8 +37,7 @@ def remove_title_next_to_ref( schema = {} try: with suppress_output(): - config = load_config() - app = create_app(config, shutdown_event=None, is_build=True) + app = create_app(cog_config=Config(), shutdown_event=None, is_build=True) if ( app.state.setup_result and app.state.setup_result.status == Status.FAILED diff --git a/python/cog/config.py b/python/cog/config.py new file mode 100644 index 0000000000..44675c79dc --- /dev/null +++ b/python/cog/config.py @@ -0,0 +1,167 @@ +import os +import sys +import uuid +from typing import Optional, Tuple, Type + +import structlog +import yaml +from pydantic import BaseModel + +from .base_input import BaseInput +from .base_predictor import BasePredictor +from .code_xforms import load_module_from_string, strip_model_source_code +from .env_property import env_property +from .errors import ConfigDoesNotExist +from .mode import Mode +from .predictor import ( + get_input_type, + get_output_type, + get_predictor, + get_training_input_type, + get_training_output_type, + load_full_predictor_from_file, +) +from .types import CogConfig +from .wait import wait_for_env + +COG_YAML_FILE = "cog.yaml" +COG_PREDICT_TYPE_STUB_ENV_VAR = "COG_PREDICT_TYPE_STUB" +COG_TRAIN_TYPE_STUB_ENV_VAR = "COG_TRAIN_TYPE_STUB" +COG_PREDICT_CODE_STRIP_ENV_VAR = "COG_PREDICT_CODE_STRIP" +COG_TRAIN_CODE_STRIP_ENV_VAR = "COG_TRAIN_CODE_STRIP" +COG_GPU_ENV_VAR = "COG_GPU" +PREDICT_METHOD_NAME = "predict" +TRAIN_METHOD_NAME = "train" + +log = structlog.get_logger("cog.config") + + +def _method_name_from_mode(mode: Mode) -> str: + if mode == Mode.PREDICT: + return PREDICT_METHOD_NAME + elif mode == Mode.TRAIN: + return TRAIN_METHOD_NAME + raise ValueError(f"Mode {mode} not recognised for method name mapping") + + +def _env_var_from_mode(mode: Mode) -> str: + if mode == Mode.PREDICT: + return COG_PREDICT_CODE_STRIP_ENV_VAR + elif mode == Mode.TRAIN: + return COG_TRAIN_CODE_STRIP_ENV_VAR + raise ValueError(f"Mode {mode} not recognised for env var mapping") + + +class Config: + """A class for reading the cog.yaml properties.""" + + def __init__(self, config: Optional[CogConfig] = None) -> None: + self._config = config + + @property + def _cog_config(self) -> CogConfig: + """ + Warning: Do not access this directly outside this class, instead + write an explicit public property and back it by an @env_property + to allow for the possibility of injecting the property you are + trying to read without relying on the underlying file. + """ + config = self._config + if config is None: + wait_for_env(include_imports=False) + config_path = os.path.abspath(COG_YAML_FILE) + try: + with open(config_path, encoding="utf-8") as handle: + config = yaml.safe_load(handle) + except FileNotFoundError as e: + raise ConfigDoesNotExist( + f"Could not find {config_path}", + ) from e + self._config = config + return config + + @property + @env_property(COG_PREDICT_TYPE_STUB_ENV_VAR) + def predictor_predict_ref(self) -> Optional[str]: + """Find the predictor ref for the predict mode.""" + return self._cog_config.get(str(Mode.PREDICT)) + + @property + @env_property(COG_TRAIN_TYPE_STUB_ENV_VAR) + def predictor_train_ref(self) -> Optional[str]: + """Find the predictor ref for the train mode.""" + return self._cog_config.get(str(Mode.TRAIN)) + + @property + @env_property(COG_GPU_ENV_VAR) + def requires_gpu(self) -> bool: + """Whether this cog requires the use of a GPU.""" + return bool(self._cog_config.get("build", {}).get("gpu", False)) + + def _predictor_code( + self, + module_path: str, + class_name: str, + method_name: str, + mode: Mode, + module_name: str, + ) -> Optional[str]: + source_code = os.environ.get(_env_var_from_mode(mode)) + if source_code is not None: + return source_code + if sys.version_info >= (3, 9): + wait_for_env(include_imports=False) + with open(module_path, encoding="utf-8") as file: + return strip_model_source_code(file.read(), [class_name], [method_name]) + else: + log.debug(f"[{module_name}] cannot use fast loader as current Python <3.9") + return None + + def _load_predictor_for_types( + self, ref: str, method_name: str, mode: Mode + ) -> BasePredictor: + module_path, class_name = ref.split(":", 1) + module_name = os.path.basename(module_path).split(".py", 1)[0] + code = self._predictor_code( + module_path, class_name, method_name, mode, module_name + ) + module = None + if code is not None: + try: + module = load_module_from_string(uuid.uuid4().hex, code) + except Exception as e: # pylint: disable=broad-exception-caught + log.info(f"[{module_name}] fast loader failed: {e}") + if module is None: + log.debug(f"[{module_name}] falling back to slow loader") + wait_for_env(include_imports=False) + module = load_full_predictor_from_file(module_path, module_name) + return get_predictor(module, class_name) + + def get_predictor_ref(self, mode: Mode) -> str: + """Find the predictor reference for a given mode.""" + predictor_ref = None + if mode == Mode.PREDICT: + predictor_ref = self.predictor_predict_ref + elif mode == Mode.TRAIN: + predictor_ref = self.predictor_train_ref + if predictor_ref is None: + raise ValueError( + f"Can't run predictions: '{mode}' option not found in cog.yaml" + ) + return predictor_ref + + def get_predictor_types( + self, mode: Mode + ) -> Tuple[Type[BaseInput], Type[BaseModel]]: + """Find the input and output types of a predictor.""" + predictor_ref = self.get_predictor_ref(mode=mode) + predictor = self._load_predictor_for_types( + predictor_ref, _method_name_from_mode(mode=mode), mode + ) + if mode == Mode.PREDICT: + return get_input_type(predictor), get_output_type(predictor) + elif mode == Mode.TRAIN: + return get_training_input_type(predictor), get_training_output_type( + predictor + ) + raise ValueError(f"Mode {mode} not found for generating input/output types.") diff --git a/python/cog/env_property.py b/python/cog/env_property.py new file mode 100644 index 0000000000..0853f5255b --- /dev/null +++ b/python/cog/env_property.py @@ -0,0 +1,42 @@ +import os +from functools import wraps +from typing import Any, Callable, Optional, TypeVar, Union + +R = TypeVar("R") + + +def _get_origin(typ: Any) -> Any: + if hasattr(typ, "__origin__"): + return typ.__origin__ + return None + + +def _get_args(typ: Any) -> Any: + if hasattr(typ, "__args__"): + return typ.__args__ + return () + + +def env_property( + env_var: str, +) -> Callable[[Callable[[Any], R]], Callable[[Any], R]]: + """Wraps a class property in an environment variable check.""" + + def decorator(func: Callable[[Any], R]) -> Callable[[Any], R]: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> R: + result = os.environ.get(env_var) + if result is not None: + expected_type = func.__annotations__.get("return", str) + if ( + _get_origin(expected_type) is Optional + or _get_origin(expected_type) is Union + ): + expected_type = _get_args(expected_type)[0] + return expected_type(result) + result = func(*args, **kwargs) + return result + + return wrapper + + return decorator diff --git a/python/cog/mode.py b/python/cog/mode.py new file mode 100644 index 0000000000..db45a0a8a5 --- /dev/null +++ b/python/cog/mode.py @@ -0,0 +1,11 @@ +from enum import Enum + + +class Mode(Enum): + """Enumeration over the different prediction modes.""" + + PREDICT = "predict" + TRAIN = "train" + + def __str__(self) -> str: + return str(self.value) diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 0e25d90544..d0b04fb235 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -30,7 +30,6 @@ import pydantic import structlog -import yaml from pydantic import BaseModel, Field, create_model from pydantic.fields import FieldInfo @@ -40,10 +39,8 @@ from .base_input import BaseInput from .base_predictor import BasePredictor from .code_xforms import load_module_from_string, strip_model_source_code -from .errors import ConfigDoesNotExist, PredictorNotSet from .types import ( PYDANTIC_V2, - CogConfig, Input, ) from .types import ( @@ -142,43 +139,6 @@ def get_weights_type(setup_function: Callable[[Any], None]) -> Optional[Any]: return Type -def load_config() -> CogConfig: - """ - Reads cog.yaml and returns it as a typed dict. - """ - # Assumes the working directory is /src - config_path = os.path.abspath("cog.yaml") - try: - with open(config_path, encoding="utf-8") as fh: - config = yaml.safe_load(fh) - except FileNotFoundError as e: - raise ConfigDoesNotExist( - f"Could not find {config_path}", - ) from e - return config - - -def load_predictor(config: CogConfig) -> BasePredictor: - """ - Constructs an instance of the user-defined Predictor class from a config. - """ - - ref = get_predictor_ref(config) - return load_predictor_from_ref(ref) - - -def get_predictor_ref(config: CogConfig, mode: str = "predict") -> str: - if mode not in ["predict", "train"]: - raise ValueError(f"Invalid mode: {mode}") - - if mode not in config: - raise PredictorNotSet( - f"Can't run predictions: '{mode}' option not found in cog.yaml" - ) - - return config[mode] - - def load_full_predictor_from_file( module_path: str, module_name: str ) -> types.ModuleType: @@ -211,27 +171,6 @@ def get_predictor(module: types.ModuleType, class_name: str) -> Any: return predictor -def load_slim_predictor_from_ref(ref: str, method_name: str) -> BasePredictor: - module_path, class_name = ref.split(":", 1) - module_name = os.path.basename(module_path).split(".py", 1)[0] - module = None - try: - if sys.version_info >= (3, 9): - module = load_slim_predictor_from_file(module_path, class_name, method_name) - if not module: - log.debug(f"[{module_name}] fast loader returned None") - else: - log.debug(f"[{module_name}] cannot use fast loader as current Python <3.9") - except Exception as e: # pylint: disable=broad-exception-caught - log.debug(f"[{module_name}] fast loader failed: {e}") - finally: - if not module: - log.debug(f"[{module_name}] falling back to slow loader") - module = load_full_predictor_from_file(module_path, module_name) - predictor = get_predictor(module, class_name) - return predictor - - def load_predictor_from_ref(ref: str) -> BasePredictor: module_path, class_name = ref.split(":", 1) module_name = os.path.basename(module_path).split(".py", 1)[0] diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 2d494ad8d1..75c04d1c9f 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -23,20 +23,13 @@ from pydantic import ValidationError from .. import schema +from ..config import Config from ..errors import PredictorNotSet from ..files import upload_file from ..json import upload_files from ..logging import setup_logging -from ..predictor import ( - get_input_type, - get_output_type, - get_predictor_ref, - get_training_input_type, - get_training_output_type, - load_config, - load_slim_predictor_from_ref, -) -from ..types import PYDANTIC_V2, CogConfig +from ..mode import Mode +from ..types import PYDANTIC_V2 try: from .._version import __version__ @@ -117,11 +110,11 @@ async def healthcheck_startup_failed() -> Any: def create_app( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements - config: CogConfig, # pylint: disable=redefined-outer-name + cog_config: Config, shutdown_event: Optional[threading.Event], # pylint: disable=redefined-outer-name - threads: int = 1, # pylint: disable=redefined-outer-name + app_threads: Optional[int] = None, upload_url: Optional[str] = None, - mode: str = "predict", + mode: Mode = Mode.PREDICT, is_build: bool = False, await_explicit_shutdown: bool = False, # pylint: disable=redefined-outer-name ) -> MyFastAPI: @@ -163,16 +156,13 @@ async def start_shutdown() -> Any: return JSONResponse({}, status_code=200) try: - predictor_ref = get_predictor_ref(config, mode) - predictor = load_slim_predictor_from_ref(predictor_ref, "predict") - InputType = get_input_type(predictor) # pylint: disable=invalid-name - OutputType = get_output_type(predictor) # pylint: disable=invalid-name + InputType, OutputType = cog_config.get_predictor_types(mode=Mode.PREDICT) except Exception: # pylint: disable=broad-exception-caught msg = "Error while loading predictor:\n\n" + traceback.format_exc() add_setup_failed_routes(app, started_at, msg) return app - worker = make_worker(predictor_ref=predictor_ref) + worker = make_worker(predictor_ref=cog_config.get_predictor_ref(mode=mode)) runner = PredictionRunner(worker=worker) class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)): @@ -182,7 +172,9 @@ class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType input_type=InputType, output_type=OutputType ) - http_semaphore = asyncio.Semaphore(threads) + if app_threads is None: + app_threads = 1 if cog_config.requires_gpu else _cpu_count() + http_semaphore = asyncio.Semaphore(app_threads) def limited(f: "Callable[P, Awaitable[T]]") -> "Callable[P, Awaitable[T]]": @functools.wraps(f) @@ -203,12 +195,11 @@ async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T": # pylint: disa "predictions_cancel_url": "/predictions/{prediction_id}/cancel", } - if "train" in config: + if cog_config.predictor_train_ref: try: - trainer_ref = get_predictor_ref(config, "train") - trainer = load_slim_predictor_from_ref(trainer_ref, "train") - TrainingInputType = get_training_input_type(trainer) # pylint: disable=invalid-name - TrainingOutputType = get_training_output_type(trainer) # pylint: disable=invalid-name + TrainingInputType, TrainingOutputType = cog_config.get_predictor_types( + Mode.TRAIN + ) class TrainingRequest( schema.TrainingRequest.with_types(input_type=TrainingInputType) @@ -624,9 +615,9 @@ def _cpu_count() -> int: parser.add_argument( "--x-mode", dest="mode", - type=str, - default="predict", - choices=["predict", "train"], + type=Mode, + default=Mode.PREDICT, + choices=list(Mode), help="Experimental: Run in 'predict' or 'train' mode", ) args = parser.parse_args() @@ -642,13 +633,6 @@ def _cpu_count() -> int: log_level = logging.getLevelName(os.environ.get("COG_LOG_LEVEL", "INFO").upper()) setup_logging(log_level=log_level) - config = load_config() - - threads = args.threads - if threads is None: - gpu_enabled = config.get("build", {}).get("gpu", False) - threads = 1 if gpu_enabled else _cpu_count() - shutdown_event = threading.Event() await_explicit_shutdown = args.await_explicit_shutdown @@ -658,9 +642,9 @@ def _cpu_count() -> int: signal.signal(signal.SIGTERM, signal_set_event(shutdown_event)) app = create_app( - config=config, + cog_config=Config(), shutdown_event=shutdown_event, - threads=threads, + app_threads=args.threads, upload_url=args.upload_url, mode=args.mode, await_explicit_shutdown=await_explicit_shutdown, diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 7633a586cc..55708b6341 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -16,9 +16,11 @@ import structlog +from ..base_predictor import BasePredictor from ..json import make_encodeable -from ..predictor import BasePredictor, get_predict, load_predictor_from_ref, run_setup +from ..predictor import get_predict, load_predictor_from_ref, run_setup from ..types import PYDANTIC_V2, URLPath +from ..wait import wait_for_env from .connection import AsyncConnection, LockedConnection from .eventtypes import ( Cancel, @@ -325,6 +327,7 @@ def send_cancel(self) -> None: def _setup(self, redirector: AsyncStreamRedirector) -> None: done = Done() + wait_for_env() try: self._predictor = load_predictor_from_ref(self._predictor_ref) # Could be a function or a class diff --git a/python/cog/wait.py b/python/cog/wait.py new file mode 100644 index 0000000000..796fca88a5 --- /dev/null +++ b/python/cog/wait.py @@ -0,0 +1,81 @@ +import importlib +import os +import sys +import time + +import structlog + +COG_WAIT_FILE_ENV_VAR = "COG_WAIT_FILE" +COG_EAGER_IMPORTS_ENV_VAR = "COG_EAGER_IMPORTS" +COG_PYENV_PATH_ENV_VAR = "COG_PYENV_PATH" +PYTHONPATH_ENV_VAR = "PYTHONPATH" +PYTHON_VERSION_ENV_VAR = "PYTHON_VERSION" + +log = structlog.get_logger("cog.wait") + + +def _wait_flag_fallen() -> bool: + wait_file = os.environ.get(COG_WAIT_FILE_ENV_VAR) + if wait_file is None: + return True + return os.path.exists(wait_file) + + +def _insert_pythonpath() -> None: + pyenv_path = os.environ.get(COG_PYENV_PATH_ENV_VAR) + if pyenv_path is None: + return + full_module_path = os.path.join( + pyenv_path, + "lib", + "python" + os.environ[PYTHON_VERSION_ENV_VAR], + "site-packages", + ) + if full_module_path not in sys.path: + sys.path.append(full_module_path) + os.environ[PYTHONPATH_ENV_VAR] = ":".join(sys.path) + + +def wait_for_file(timeout: float = 60.0) -> bool: + """Wait for a file in the environment variables.""" + wait_file = os.environ.get(COG_WAIT_FILE_ENV_VAR) + if wait_file is None: + return True + if os.path.exists(wait_file): + log.info(f"Wait file found {wait_file}...") + return True + log.info(f"Waiting for file {wait_file}...") + time_taken = 0.0 + while time_taken < timeout: + sleep_time = 0.01 + time.sleep(sleep_time) + time_taken += sleep_time + if os.path.exists(wait_file): + return True + log.info(f"Waiting for file {wait_file} timed out.") + return False + + +def eagerly_import_modules() -> int: + """Wait for python to import big modules.""" + wait_imports = os.environ.get(COG_EAGER_IMPORTS_ENV_VAR) + import_count = 0 + if wait_imports is None: + return import_count + log.info(f"Eagerly importing {wait_imports}.") + for import_statement in wait_imports.split(","): + importlib.import_module(import_statement) + import_count += 1 + return import_count + + +def wait_for_env(file_timeout: float = 60.0, include_imports: bool = True) -> bool: + """Wait for the environment to load.""" + if _wait_flag_fallen(): + _insert_pythonpath() + return True + if include_imports: + eagerly_import_modules() + waited = wait_for_file(timeout=file_timeout) + _insert_pythonpath() + return waited diff --git a/python/tests/server/conftest.py b/python/tests/server/conftest.py index 0c2e92a78b..e05a9a9cc0 100644 --- a/python/tests/server/conftest.py +++ b/python/tests/server/conftest.py @@ -10,6 +10,7 @@ from fastapi.testclient import TestClient from cog.command import ast_openapi_schema +from cog.config import Config from cog.server.http import create_app from cog.server.worker import make_worker @@ -98,7 +99,7 @@ def make_client( config.update(additional_config) app = create_app( - config=config, + cog_config=Config(config=config), shutdown_event=threading.Event(), upload_url=upload_url, ) diff --git a/python/tests/server/test_http_input.py b/python/tests/server/test_http_input.py index 9aa9203b89..9bcd02b733 100644 --- a/python/tests/server/test_http_input.py +++ b/python/tests/server/test_http_input.py @@ -8,6 +8,7 @@ from werkzeug.wrappers import Response from cog import schema +from cog.config import Config from cog.server.http import Health, create_app from cog.types import PYDANTIC_V2 @@ -300,7 +301,7 @@ def test_secret_str(client, match): def test_untyped_inputs(): config = {"predict": _fixture_path("input_untyped")} app = create_app( - config=config, + cog_config=Config(config), shutdown_event=threading.Event(), upload_url="input_untyped", ) @@ -314,7 +315,7 @@ def test_untyped_inputs(): def test_input_with_unsupported_type(): config = {"predict": _fixture_path("input_unsupported_type")} app = create_app( - config=config, + cog_config=Config(config), shutdown_event=threading.Event(), upload_url="input_untyped", ) diff --git a/python/tests/server/test_predictor.py b/python/tests/server/test_predictor.py index 1b86b7da5c..e7fa767d74 100644 --- a/python/tests/server/test_predictor.py +++ b/python/tests/server/test_predictor.py @@ -1,14 +1,15 @@ import inspect import os import sys +import uuid import pytest +from cog.code_xforms import load_module_from_string, strip_model_source_code from cog.predictor import ( get_predict, get_predictor, load_full_predictor_from_file, - load_slim_predictor_from_file, ) PREDICTOR_FIXTURES = [ @@ -49,10 +50,13 @@ def _fixture_path(name): @pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires Python 3.9 or newer") @pytest.mark.parametrize("fixture_name, class_name, method_name", PREDICTOR_FIXTURES) -def test_fast_slow_signatures(fixture_name, class_name, method_name): +def test_fast_slow_signatures(fixture_name: str, class_name: str, method_name: str): module_path = _fixture_path(fixture_name) # get signature from FAST loader - module_fast = load_slim_predictor_from_file(module_path, class_name, method_name) + code = None + with open(module_path, encoding="utf-8") as file: + code = strip_model_source_code(file.read(), [class_name], [method_name]) + module_fast = load_module_from_string(uuid.uuid4().hex, code) assert hasattr(module_fast, class_name) predictor_fast = get_predictor(module_fast, class_name) predict_fast = get_predict(predictor_fast) diff --git a/python/tests/test_config.py b/python/tests/test_config.py new file mode 100644 index 0000000000..2233965484 --- /dev/null +++ b/python/tests/test_config.py @@ -0,0 +1,220 @@ +import os +import tempfile + +import pytest + +from cog.config import ( + COG_GPU_ENV_VAR, + COG_PREDICT_CODE_STRIP_ENV_VAR, + COG_PREDICT_TYPE_STUB_ENV_VAR, + COG_TRAIN_TYPE_STUB_ENV_VAR, + COG_YAML_FILE, + Config, +) +from cog.errors import ConfigDoesNotExist +from cog.mode import Mode + + +def test_predictor_predict_ref_env_var(): + predict_ref = "predict.py:Predictor" + os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR] = predict_ref + config = Config() + config_predict_ref = config.predictor_predict_ref + del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR] + assert ( + config_predict_ref == predict_ref + ), "Predict Reference should come from the environment variable." + + +def test_predictor_predict_ref_no_env_var(): + if COG_PREDICT_TYPE_STUB_ENV_VAR in os.environ: + del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR] + pwd = os.getcwd() + with tempfile.TemporaryDirectory() as tmpdir: + os.chdir(tmpdir) + with open(COG_YAML_FILE, "w", encoding="utf-8") as handle: + handle.write(""" +build: + python_version: "3.11" +predict: "predict.py:Predictor" +""") + config = Config() + config_predict_ref = config.predictor_predict_ref + assert ( + config_predict_ref == "predict.py:Predictor" + ), "Predict Reference should come from the cog config file." + os.chdir(pwd) + + +def test_config_no_config_file(): + if COG_PREDICT_TYPE_STUB_ENV_VAR in os.environ: + del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR] + config = Config() + with pytest.raises(ConfigDoesNotExist): + _ = config.predictor_predict_ref + + +def test_config_initial_values(): + if COG_PREDICT_TYPE_STUB_ENV_VAR in os.environ: + del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR] + config = Config(config={"predict": "predict.py:Predictor"}) + config_predict_ref = config.predictor_predict_ref + assert ( + config_predict_ref == "predict.py:Predictor" + ), "Predict Reference should come from the initial config dictionary." + + +def test_predictor_train_ref_env_var(): + train_ref = "predict.py:Predictor" + os.environ[COG_TRAIN_TYPE_STUB_ENV_VAR] = train_ref + config = Config() + config_train_ref = config.predictor_train_ref + del os.environ[COG_TRAIN_TYPE_STUB_ENV_VAR] + assert ( + config_train_ref == train_ref + ), "Train Reference should come from the environment variable." + + +def test_predictor_train_ref_no_env_var(): + train_ref = "predict.py:Predictor" + if COG_TRAIN_TYPE_STUB_ENV_VAR in os.environ: + del os.environ[COG_TRAIN_TYPE_STUB_ENV_VAR] + config = Config(config={"train": train_ref}) + config_train_ref = config.predictor_train_ref + assert ( + config_train_ref == train_ref + ), "Train Reference should come from the initial config dictionary." + + +def test_requires_gpu_env_var(): + gpu = True + os.environ[COG_GPU_ENV_VAR] = str(gpu) + config = Config() + config_gpu = config.requires_gpu + del os.environ[COG_GPU_ENV_VAR] + assert config_gpu, "Requires GPU should come from the environment variable." + + +def test_requires_gpu_no_env_var(): + if COG_GPU_ENV_VAR in os.environ: + del os.environ[COG_GPU_ENV_VAR] + config = Config(config={"build": {"gpu": False}}) + config_gpu = config.requires_gpu + assert ( + not config_gpu + ), "Requires GPU should come from the initial config dictionary." + + +def test_get_predictor_ref_predict(): + train_ref = "predict.py:Predictor" + config = Config(config={"train": train_ref}) + config_train_ref = config.get_predictor_ref(Mode.TRAIN) + assert ( + train_ref == config_train_ref + ), "The train ref should equal the config train ref." + + +def test_get_predictor_ref_train(): + predict_ref = "predict.py:Predictor" + config = Config(config={"predict": predict_ref}) + config_predict_ref = config.get_predictor_ref(Mode.PREDICT) + assert ( + predict_ref == config_predict_ref + ), "The predict ref should equal the config predict ref." + + +def test_get_predictor_types_with_env_var(): + predict_ref = "predict.py:Predictor" + os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR] = predict_ref + os.environ[COG_PREDICT_CODE_STRIP_ENV_VAR] = """ +from cog import BasePredictor, Path +from typing import Optional +from pydantic import BaseModel + +class ModelOutput(BaseModel): + success: bool + error: Optional[str] + segmentedImage: Optional[Path] + +class Predictor(BasePredictor): + + def predict(self, msg: str) -> ModelOutput: + return None +""" + config = Config() + input_type, output_type = config.get_predictor_types(Mode.PREDICT) + del os.environ[COG_PREDICT_CODE_STRIP_ENV_VAR] + del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR] + assert ( + str(input_type) == "" + ), "Predict input type should be the predictor Input." + assert ( + str(output_type) == ".Output'>" + ), "Predict output type should be the predictor Output." + + +def test_get_predictor_types(): + with tempfile.TemporaryDirectory() as tmpdir: + predict_python_file = os.path.join(tmpdir, "predict.py") + with open(predict_python_file, "w", encoding="utf-8") as handle: + handle.write(""" +import io + +from cog import BasePredictor, Path +from typing import Optional +from pydantic import BaseModel + + +class ModelOutput(BaseModel): + success: bool + error: Optional[str] + segmentedImage: Optional[Path] + + +class Predictor(BasePredictor): + # setup code + def predict(self, msg: str) -> ModelOutput: + return ModelOutput(success=False, error=msg, segmentedImage=None) +""") + predict_ref = f"{predict_python_file}:Predictor" + config = Config(config={"predict": predict_ref}) + input_type, output_type = config.get_predictor_types(Mode.PREDICT) + assert ( + str(input_type) == "" + ), "Predict input type should be the predictor Input." + assert ( + str(output_type) + == ".Output'>" + ), "Predict output type should be the predictor Output." + + +def test_get_predictor_types_for_train(): + with tempfile.TemporaryDirectory() as tmpdir: + predict_python_file = os.path.join(tmpdir, "train.py") + with open(predict_python_file, "w", encoding="utf-8") as handle: + handle.write(""" +from cog import BaseModel, Input, Path + +class TrainingOutput(BaseModel): + weights: Path + +def train( + n: int, +) -> TrainingOutput: + with open("weights.bin", "w") as fh: + for _ in range(n): + fh.write("a") + + return TrainingOutput( + weights=Path("weights.bin"), + ) +""") + train_ref = f"{predict_python_file}:train" + config = Config(config={"train": train_ref}) + input_type, output_type = config.get_predictor_types(Mode.TRAIN) + assert ( + str(input_type) == "" + ), "Predict input type should be the training Input." + assert str(output_type).endswith( + "TrainingOutput'>" + ), "Predict output type should be the training Output." diff --git a/python/tests/test_predictor.py b/python/tests/test_predictor.py index df6363a7ca..d232d5ac0f 100644 --- a/python/tests/test_predictor.py +++ b/python/tests/test_predictor.py @@ -4,7 +4,10 @@ from unittest.mock import patch from cog import File, Path -from cog.predictor import get_weights_type, load_predictor_from_ref +from cog.predictor import ( + get_weights_type, + load_predictor_from_ref, +) def test_get_weights_type() -> None: diff --git a/python/tests/test_wait.py b/python/tests/test_wait.py new file mode 100644 index 0000000000..07afb9266d --- /dev/null +++ b/python/tests/test_wait.py @@ -0,0 +1,123 @@ +import os +import sys +import tempfile +import threading +import time +from pathlib import Path + +from cog.wait import ( + COG_EAGER_IMPORTS_ENV_VAR, + COG_PYENV_PATH_ENV_VAR, + COG_WAIT_FILE_ENV_VAR, + PYTHON_VERSION_ENV_VAR, + PYTHONPATH_ENV_VAR, + eagerly_import_modules, + wait_for_env, + wait_for_file, +) + + +def test_wait_for_file_no_env_var(): + if COG_WAIT_FILE_ENV_VAR in os.environ: + del os.environ[COG_WAIT_FILE_ENV_VAR] + result = wait_for_file() + assert result, "We should immediately return when no wait file is specified." + + +def test_wait_for_file_exists(): + with tempfile.NamedTemporaryFile() as tmpfile: + os.environ[COG_WAIT_FILE_ENV_VAR] = tmpfile.name + result = wait_for_file(timeout=5.0) + del os.environ[COG_WAIT_FILE_ENV_VAR] + assert result, "We should immediately return when the file already exists." + + +def test_wait_for_file_waits_for_file(): + wait_file = os.path.join(os.path.dirname(__file__), "flag_file") + if os.path.exists(wait_file): + os.remove(wait_file) + os.environ[COG_WAIT_FILE_ENV_VAR] = wait_file + + def create_file(): + time.sleep(2.0) + Path(wait_file).touch() + + thread = threading.Thread(target=create_file) + thread.start() + result = wait_for_file(timeout=5.0) + del os.environ[COG_WAIT_FILE_ENV_VAR] + os.remove(wait_file) + assert result, "We should return when the file is touched." + + +def test_wait_for_file_timeout(): + os.environ[COG_WAIT_FILE_ENV_VAR] = os.path.join( + os.path.dirname(__file__), "a_file_unknown" + ) + result = wait_for_file(timeout=5.0) + del os.environ[COG_WAIT_FILE_ENV_VAR] + assert not result, "We should return false when the timeout triggers." + + +def test_eagerly_import_modules_no_env_var(): + if COG_EAGER_IMPORTS_ENV_VAR in os.environ: + del os.environ[COG_EAGER_IMPORTS_ENV_VAR] + eagerly_import_modules() + + +def test_eagerly_import_modules(): + os.environ[COG_EAGER_IMPORTS_ENV_VAR] = "pytest,pathlib,time" + import_count = eagerly_import_modules() + del os.environ[COG_EAGER_IMPORTS_ENV_VAR] + assert import_count == 3, "There should be 3 imports performed" + + +def test_wait_for_env_no_env_vars(): + if COG_WAIT_FILE_ENV_VAR in os.environ: + del os.environ[COG_WAIT_FILE_ENV_VAR] + if COG_EAGER_IMPORTS_ENV_VAR in os.environ: + del os.environ[COG_EAGER_IMPORTS_ENV_VAR] + result = wait_for_env() + assert ( + result + ), "We should return true if we have no env vars associated with the wait." + + +def test_wait_for_env(): + with tempfile.NamedTemporaryFile() as tmpfile: + os.environ[COG_WAIT_FILE_ENV_VAR] = tmpfile.name + os.environ[COG_EAGER_IMPORTS_ENV_VAR] = "pytest,pathlib,time" + result = wait_for_env() + assert ( + result + ), "We should return true if we have waited for the right environment." + del os.environ[COG_EAGER_IMPORTS_ENV_VAR] + del os.environ[COG_WAIT_FILE_ENV_VAR] + + +def test_wait_inserts_pythonpath(): + with tempfile.NamedTemporaryFile() as tmpfile: + original_sys_path = sys.path.copy() + original_python_path = os.environ.get(PYTHONPATH_ENV_VAR) + pyenv_path = os.path.dirname(tmpfile.name) + os.environ[COG_WAIT_FILE_ENV_VAR] = tmpfile.name + os.environ[COG_EAGER_IMPORTS_ENV_VAR] = "pytest,pathlib,time" + os.environ[COG_PYENV_PATH_ENV_VAR] = pyenv_path + os.environ[PYTHON_VERSION_ENV_VAR] = "3.11" + wait_for_env() + del os.environ[PYTHON_VERSION_ENV_VAR] + del os.environ[COG_PYENV_PATH_ENV_VAR] + del os.environ[COG_EAGER_IMPORTS_ENV_VAR] + del os.environ[COG_WAIT_FILE_ENV_VAR] + current_python_path = os.environ[PYTHONPATH_ENV_VAR] + if original_python_path is None: + del os.environ[PYTHONPATH_ENV_VAR] + else: + os.environ[PYTHONPATH_ENV_VAR] = original_python_path + sys.path = original_sys_path + expected_path = ":".join( + original_sys_path + [pyenv_path + "/lib/python3.11/site-packages"] + ) + assert ( + expected_path == current_python_path + ), "Our python path should be updated with the pyenv path."