Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to wait for an environment #1957

Open
wants to merge 80 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
6075901
Send in app threads directly from args
8W9aG Sep 16, 2024
3d348cc
Load config right before necessary
8W9aG Sep 16, 2024
1024cac
Add waiting for a wait file
8W9aG Sep 16, 2024
cc7fccf
Add wait_for_imports ability
8W9aG Sep 16, 2024
19cc38d
Fix lint on src_path
8W9aG Sep 17, 2024
e79e25d
Fix watchdog version
8W9aG Sep 17, 2024
6dfe3ec
Remove load_config in openapi_schema cmd
8W9aG Sep 17, 2024
03f7011
Do not access root files on GHA workers
8W9aG Sep 17, 2024
4f75de4
Set recursive to true
8W9aG Sep 17, 2024
949792c
Watch the directory instead
8W9aG Sep 17, 2024
c83f2e8
Add code_xforms test
8W9aG Sep 17, 2024
7265582
Add http server to test to let it respond
8W9aG Sep 17, 2024
71854b9
Abstract away cog config
8W9aG Sep 18, 2024
5ff8030
Wait for environment before executing setup
8W9aG Sep 18, 2024
8f3f41d
Fix Type on lower python versions
8W9aG Sep 18, 2024
1bebbc3
Skip test_strip_model_source_code if < 3.9
8W9aG Sep 18, 2024
d61d782
Change COG_WAIT_IMPORTS to COG_EAGER_IMPORTS
8W9aG Sep 18, 2024
e9464df
Bump integration test timeout to 20 mins
8W9aG Sep 18, 2024
adb04a4
Add tests for Config class
8W9aG Sep 18, 2024
25fb547
Fix get_args and get_origin in python 3.7
8W9aG Sep 19, 2024
ef2a1a0
Add more tests for config
8W9aG Sep 19, 2024
061ff16
Check wait flag has fallen before eager import
8W9aG Sep 19, 2024
33b4106
Add watch handler tests
8W9aG Sep 19, 2024
2c1f2b9
Remove watchdog and use SIGUSR2 for signalling
8W9aG Sep 19, 2024
f6d0e45
Fix no torch import in tests
8W9aG Sep 19, 2024
b480479
Merge branch 'main' into add-waiting-env
8W9aG Sep 19, 2024
42b7306
Do naive waiting for file
8W9aG Sep 20, 2024
f79c597
Merge branch 'main' into add-waiting-env
8W9aG Sep 24, 2024
f0c5b79
Merge branch 'main' into add-waiting-env
8W9aG Sep 25, 2024
8fc79b0
Consolidate code_xforms tests
8W9aG Sep 25, 2024
6b15565
Add logic to keep referenced globals
8W9aG Sep 25, 2024
e4c2675
Convert set of globals to list
8W9aG Sep 26, 2024
bcc14a4
Add logging for file waiting
8W9aG Sep 26, 2024
c33ead0
Fix lint issue by checking that tree is ast.Module
8W9aG Sep 26, 2024
065cf19
Use typing.List instead of list for older python
8W9aG Sep 26, 2024
75d7430
Fix more List issues
8W9aG Sep 26, 2024
f1921c2
Add further logging if the module fast loader failed
8W9aG Sep 26, 2024
f3494c8
Change tuple to Tuple for older pythons
8W9aG Sep 26, 2024
9fbb7ad
Add setup logs to mirror prediction logs
8W9aG Sep 26, 2024
096c45b
Add 3.12 to torch 2.3.0
8W9aG Sep 30, 2024
a522a0f
Add insertion of python path
8W9aG Oct 2, 2024
92d0b75
Add support for handling subclasses in slim predict
8W9aG Oct 3, 2024
f51c576
Merge branch 'main' into add-waiting-env
8W9aG Oct 3, 2024
1a55b0d
Fix listing issues
8W9aG Oct 3, 2024
3b7756a
Wrap in str
8W9aG Oct 3, 2024
21c1724
Remove extraneous print
8W9aG Oct 3, 2024
4598529
Fix python path race condition
8W9aG Oct 3, 2024
3dc9e9d
Merge branch 'main' into add-waiting-env
8W9aG Oct 9, 2024
0031765
Remove predictor import
8W9aG Oct 9, 2024
272756f
Merge branch 'main' into add-waiting-env
8W9aG Oct 9, 2024
1a8a374
Fix pydantic 2 errors
8W9aG Oct 9, 2024
6e7b4c7
More pydantic 2 fixes
8W9aG Oct 9, 2024
84e947d
Support multiple class names and method names
8W9aG Oct 11, 2024
dc89489
Merge branch 'main' into add-waiting-env
8W9aG Oct 11, 2024
5f86293
Use lists instead single instances for stripping
8W9aG Oct 11, 2024
5e1c842
Merge branch 'main' into add-waiting-env
8W9aG Oct 14, 2024
2f883e4
Fix types imports
8W9aG Oct 14, 2024
697adf9
Merge branch 'main' into add-waiting-env
8W9aG Oct 15, 2024
c2d3269
Merge branch 'main' into add-waiting-env
8W9aG Oct 15, 2024
6b05a79
Revert "Add 3.12 to torch 2.3.0"
8W9aG Oct 15, 2024
926f2f5
Merge branch 'main' into add-waiting-env
8W9aG Oct 16, 2024
bd9c6cf
Merge branch 'main' into add-waiting-env
8W9aG Oct 16, 2024
8d6c4ce
Fix imports after merge
8W9aG Oct 16, 2024
bf222ce
Add COG_TRAIN_PREDICTOR
8W9aG Oct 16, 2024
0c3c6f1
Rename env vars to be sensical
8W9aG Oct 16, 2024
00b98bc
Add PYTHONPATH environment variables
8W9aG Oct 17, 2024
5100333
Merge branch 'main' into add-waiting-env
8W9aG Oct 18, 2024
07de3f1
Merge branch 'main' into add-waiting-env
8W9aG Oct 21, 2024
8ea4663
Merge branch 'main' into add-waiting-env
8W9aG Oct 21, 2024
3a29229
Add consistent debug logging in config
8W9aG Oct 21, 2024
7657c23
Add test_strip_model_source_code_keeps_referenced_class_from_function
8W9aG Oct 21, 2024
e3c39c3
Explicitly check return code in test train
8W9aG Oct 21, 2024
5dc23a4
Handle response types in _predict
8W9aG Oct 22, 2024
a0477d8
Make cog train a first class CLI function
8W9aG Oct 22, 2024
c046ff0
Merge branch 'main' into add-waiting-env
8W9aG Oct 22, 2024
94d1ff4
Add back missing imports from merge
8W9aG Oct 22, 2024
187eeea
Merge branch 'main' into add-waiting-env
8W9aG Oct 23, 2024
8ca482f
Merge branch 'main' into add-waiting-env
8W9aG Oct 28, 2024
2840fc7
Remove connection import
8W9aG Oct 28, 2024
200173c
Merge branch 'main' into add-waiting-env
8W9aG Nov 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ jobs:
name: "Test integration"
needs: build-python
runs-on: ubuntu-latest-16-cores
timeout-minutes: 10
timeout-minutes: 20
steps:
- uses: actions/checkout@v4
with:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ docs/README.md
docs/CONTRIBUTING.md
venv
base-image
flag_file
3 changes: 2 additions & 1 deletion pkg/config/torch_compatibility_matrix.json
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@
"3.8",
"3.9",
"3.10",
"3.11"
"3.11",
"3.12"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion python/cog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel

from .predictor import BasePredictor
from .base_predictor import BasePredictor
from .types import ConcatenateIterator, File, Input, Path, Secret

try:
Expand Down
32 changes: 32 additions & 0 deletions python/cog/base_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pathlib import Path

from pydantic import BaseModel

from .types import (
URLPath,
)


# Base class for inputs, constructed dynamically in get_input_type().
# (This can't be a docstring or it gets passed through to the schema.)
class BaseInput(BaseModel):
class Config:
# When using `choices`, the type is converted into an enum to validate
# But, after validation, we want to pass the actual value to predict(), not the enum object
use_enum_values = True

def cleanup(self) -> None:
"""
Cleanup any temporary files created by the input.
"""
for _, value in self:
# Handle URLPath objects specially for cleanup.
# Also handle pathlib.Path objects, which cog.Path is a subclass of.
# A pathlib.Path object shouldn't make its way here,
# but both have an unlink() method, so we may as well be safe.
if isinstance(value, (URLPath, Path)):
# TODO: use unlink(missing_ok=...) when we drop Python 3.7 support.
try:
value.unlink()
except FileNotFoundError:
pass
26 changes: 26 additions & 0 deletions python/cog/base_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Union

from .types import (
File as CogFile,
)
from .types import (
Path as CogPath,
)


class BasePredictor(ABC):
def setup(
self,
weights: Optional[Union[CogFile, CogPath, str]] = None, # pylint: disable=unused-argument
) -> None:
"""
An optional method to prepare the model so multiple predictions run efficiently.
"""
return

@abstractmethod
def predict(self, **kwargs: Any) -> Any:
"""
Run a single prediction on the model
"""
60 changes: 50 additions & 10 deletions python/cog/code_xforms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
import re
import types
from typing import Optional, Set, Union
from typing import List, Optional, Set, Tuple, Union

COG_IMPORT_MODULES = {"cog", "typing", "sys", "os", "functools", "pydantic", "numpy"}

Expand Down Expand Up @@ -67,7 +67,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # pylint: disable=i
return extractor.function_source if extractor.function_source else ""


def make_class_methods_empty(source_code: Union[str, ast.AST], class_name: str) -> str:
def make_class_methods_empty(
source_code: Union[str, ast.AST], class_name: str, globals: List[ast.Assign]
) -> Tuple[str, List[ast.Assign]]:
"""
Transforms the source code of a specified class to remove the bodies of all its methods
and replace them with 'return None'.
Expand All @@ -79,21 +81,42 @@ def make_class_methods_empty(source_code: Union[str, ast.AST], class_name: str)
"""

class MethodBodyTransformer(ast.NodeTransformer):
def __init__(self, globals: List[ast.Assign]) -> None:
self.used_globals = set()
self._targets = {
target.id: global_name
for global_name in globals
for target in global_name.targets
if isinstance(target, ast.Name)
}

def visit_ClassDef(self, node: ast.ClassDef) -> Optional[ast.AST]: # pylint: disable=invalid-name
if node.name == class_name:
for body_item in node.body:
if isinstance(body_item, ast.FunctionDef):
# Replace the body of the method with `return None`
body_item.body = [ast.Return(value=ast.Constant(value=None))]
# Remove decorators from the function
body_item.decorator_list = []
# Determine if one our globals is referenced by the function.
for default in body_item.args.defaults:
if isinstance(default, ast.Call):
for keyword in default.keywords:
if isinstance(keyword.value, ast.Name):
corresponding_global = self._targets.get(
keyword.value.id
)
if corresponding_global is not None:
self.used_globals.add(corresponding_global)
return node

return None

tree = source_code if isinstance(source_code, ast.AST) else ast.parse(source_code)
transformer = MethodBodyTransformer()
transformer = MethodBodyTransformer(globals)
transformed_tree = transformer.visit(tree)
class_code = ast.unparse(transformed_tree)
return class_code
return class_code, list(transformer.used_globals)


def extract_method_return_type(
Expand Down Expand Up @@ -215,6 +238,17 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # pylint: disable=inv
return "\n".join(extractor.imports)


def _extract_globals(source_code: Union[str, ast.AST]) -> List[ast.Assign]:
tree = source_code if isinstance(source_code, ast.AST) else ast.parse(source_code)
if isinstance(tree, ast.Module):
return [x for x in tree.body if isinstance(x, ast.Assign)]
return []


def _render_globals(globals: List[ast.Assign]) -> str:
return "\n".join([ast.unparse(x) for x in globals])


def strip_model_source_code(
source_code: str, class_name: str, method_name: str
) -> Optional[str]:
Expand All @@ -234,14 +268,22 @@ def strip_model_source_code(
class_source = (
None if not class_name else extract_class_source(source_code, class_name)
)
globals = _extract_globals(source_code)
if class_source:
class_source = make_class_methods_empty(class_source, class_name)
class_source, globals = make_class_methods_empty(
class_source, class_name, globals
)
return_type = extract_method_return_type(class_source, class_name, method_name)
return_class_source = (
extract_class_source(source_code, return_type) if return_type else ""
)
model_source = (
imports + "\n\n" + return_class_source + "\n\n" + class_source + "\n"
rendered_globals = _render_globals(globals)
model_source = "\n".join(
[
x
for x in [imports, rendered_globals, return_class_source, class_source]
if x
]
)
else:
# use class_name specified in cog.yaml as method_name
Expand All @@ -256,7 +298,5 @@ def strip_model_source_code(
return_class_source = (
extract_class_source(source_code, return_type) if return_type else ""
)
model_source = (
imports + "\n\n" + return_class_source + "\n\n" + function_source + "\n"
)
model_source = "\n".join([imports, return_class_source, function_source])
return model_source
5 changes: 2 additions & 3 deletions python/cog/command/openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import json
from typing import Any, Dict, List, Union

from ..config import Config
from ..errors import CogError, ConfigDoesNotExist, PredictorNotSet
from ..predictor import load_config
from ..schema import Status
from ..server.http import create_app
from ..suppress_output import suppress_output
Expand Down Expand Up @@ -37,8 +37,7 @@ def remove_title_next_to_ref(
schema = {}
try:
with suppress_output():
config = load_config()
app = create_app(config, shutdown_event=None, is_build=True)
app = create_app(cog_config=Config(), shutdown_event=None, is_build=True)
if (
app.state.setup_result
and app.state.setup_result.status == Status.FAILED
Expand Down
144 changes: 144 additions & 0 deletions python/cog/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import os
import sys
import uuid
from typing import Optional, Tuple, Type

import structlog
import yaml
from pydantic import BaseModel

from .base_input import BaseInput
from .base_predictor import BasePredictor
from .code_xforms import load_module_from_string, strip_model_source_code
from .env_property import env_property
from .errors import ConfigDoesNotExist
from .mode import Mode
from .predictor import (
get_input_type,
get_output_type,
get_predictor,
get_training_input_type,
get_training_output_type,
load_full_predictor_from_file,
)
from .types import CogConfig
from .wait import wait_for_env

COG_YAML_FILE = "cog.yaml"
COG_PREDICTOR_PREDICT_ENV_VAR = "COG_PREDICTOR_PREDICT"
COG_PREDICTOR_TRAIN_ENV_VAR = "COG_PREDICTOR_TRAIN"
COG_PREDICTOR_ENV_VAR = "COG_PREDICTOR"
COG_GPU_ENV_VAR = "COG_GPU"
PREDICT_METHOD_NAME = "predict"
TRAIN_METHOD_NAME = "train"

log = structlog.get_logger("cog.config")


def _method_name_from_mode(mode: Mode) -> str:
if mode == Mode.PREDICT:
return PREDICT_METHOD_NAME
elif mode == Mode.TRAIN:
return TRAIN_METHOD_NAME
raise ValueError(f"Mode {mode} not recognised for method name mapping")


class Config:
"""A class for reading the cog.yaml properties."""

def __init__(self, config: Optional[CogConfig] = None) -> None:
self._config = config

@property
def _cog_config(self) -> CogConfig:
"""
Warning: Do not access this directly outside this class, instead
write an explicit public property and back it by an @env_property
to allow for the possibility of injecting the property you are
trying to read without relying on the underlying file.
"""
config = self._config
if config is None:
wait_for_env(include_imports=False)
config_path = os.path.abspath(COG_YAML_FILE)
try:
with open(config_path, encoding="utf-8") as handle:
config = yaml.safe_load(handle)
except FileNotFoundError as e:
raise ConfigDoesNotExist(
f"Could not find {config_path}",
) from e
self._config = config
return config

@property
@env_property(COG_PREDICTOR_PREDICT_ENV_VAR)
def predictor_predict_ref(self) -> Optional[str]:
"""Find the predictor ref for the predict mode."""
return self._cog_config.get(str(Mode.PREDICT))

@property
@env_property(COG_PREDICTOR_TRAIN_ENV_VAR)
def predictor_train_ref(self) -> Optional[str]:
"""Find the predictor ref for the train mode."""
return self._cog_config.get(str(Mode.TRAIN))

@property
@env_property(COG_GPU_ENV_VAR)
def requires_gpu(self) -> bool:
"""Whether this cog requires the use of a GPU."""
return bool(self._cog_config.get("build", {}).get("gpu", False))

def _predictor_code(
self, module_path: str, class_name: str, method_name: str
) -> Optional[str]:
source_code = os.environ.get(COG_PREDICTOR_ENV_VAR)
if source_code is not None:
return source_code
if sys.version_info >= (3, 9):
wait_for_env(include_imports=False)
with open(module_path, encoding="utf-8") as file:
return strip_model_source_code(file.read(), class_name, method_name)
return None

def _load_predictor_for_types(self, ref: str, method_name: str) -> BasePredictor:
module_path, class_name = ref.split(":", 1)
code = self._predictor_code(module_path, class_name, method_name)
module_name = os.path.basename(module_path).split(".py", 1)[0]
module = None
if code is not None:
try:
module = load_module_from_string(uuid.uuid4().hex, code)
except Exception as e: # pylint: disable=broad-exception-caught
log.info(f"[{module_name}] fast loader failed: {e}")
if module is None:
wait_for_env(include_imports=False)
module = load_full_predictor_from_file(module_path, module_name)
return get_predictor(module, class_name)

def get_predictor_ref(self, mode: Mode) -> str:
"""Find the predictor reference for a given mode."""
predictor_ref = None
if mode == Mode.PREDICT:
predictor_ref = self.predictor_predict_ref
elif mode == Mode.TRAIN:
predictor_ref = self.predictor_train_ref
if predictor_ref is None:
raise ValueError(f"Could not find predictor ref for mode {mode}")
return predictor_ref

def get_predictor_types(
self, mode: Mode
) -> Tuple[Type[BaseInput], Type[BaseModel]]:
"""Find the input and output types of a predictor."""
predictor_ref = self.get_predictor_ref(mode=mode)
predictor = self._load_predictor_for_types(
predictor_ref, _method_name_from_mode(mode=mode)
)
if mode == Mode.PREDICT:
return get_input_type(predictor), get_output_type(predictor)
elif mode == Mode.TRAIN:
return get_training_input_type(predictor), get_training_output_type(
predictor
)
raise ValueError(f"Mode {mode} not found for generating input/output types.")
Loading