diff --git a/src/nemo_run/__init__.py b/src/nemo_run/__init__.py index e106299..d8b0d5e 100644 --- a/src/nemo_run/__init__.py +++ b/src/nemo_run/__init__.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo_run import cli +from nemo_run import cli, io from nemo_run.api import autoconvert, dryrun_fn -from nemo_run.config import Config, Partial, Script +from nemo_run.config import Config, Partial, Script, build from nemo_run.core.execution.base import Executor, ExecutorMacros, FaultTolerance, Torchrun from nemo_run.core.execution.local import LocalExecutor from nemo_run.core.execution.skypilot import SkypilotExecutor @@ -31,8 +31,10 @@ __all__ = [ "autoconvert", + "build", "cli", "dryrun_fn", + "io", "Config", "DevSpace", "Executor", diff --git a/src/nemo_run/config.py b/src/nemo_run/config.py index 3ce7ad7..25f4a30 100644 --- a/src/nemo_run/config.py +++ b/src/nemo_run/config.py @@ -36,15 +36,18 @@ from typing_extensions import Annotated, ParamSpec, Self import nemo_run.exceptions as run_exceptions +from nemo_run.io.api import _IO_REGISTRY, get, register Params = ParamSpec("Params") ReturnType = TypeVar("ReturnType") _T = TypeVar("_T") _BuildableT = TypeVar("_BuildableT", bound=fdl.Buildable) +build = fdl.build RECURSIVE_TYPES = (typing.Union, typing.Optional) NEMORUN_HOME = os.environ.get("NEMORUN_HOME", os.path.expanduser("~/.nemo_run")) +USE_IO_REGISTRY: bool = True def get_type_namespace(typ: Type | Callable) -> str: @@ -235,10 +238,29 @@ def _repr_svg_(self): return self.__repr__() +# List of classes that require direct initialization +DIRECT_INIT_CLASSES = [Path] + + class Config(Generic[_T], fdl.Config[_T], _CloneAndFNMixin, _VisualizeMixin): """ Wrapper around fdl.Config with nemo_run specific functionality. See `fdl.Config `_ for more. + + This class extends fdl.Config to provide special handling for certain types of objects, + particularly those that require direct initialization (e.g., pathlib.Path). + + The DIRECT_INIT_CLASSES list contains classes that should be instantiated directly, + bypassing Fiddle's normal build process. By default, this includes pathlib.Path. + + To add more classes for direct initialization, simply append them to DIRECT_INIT_CLASSES. + + Example: + >>> from pathlib import Path + >>> path_config = Config(Path, "/tmp/test") + >>> path_instance = build(path_config) + >>> isinstance(path_instance, Path) + True """ def __init__( @@ -257,6 +279,34 @@ def __init__( super().__init__(fn_or_cls, *args, **new_kwargs) + def __build__(self, *args, **kwargs): + """ + Build the instance, with special handling for classes in DIRECT_INIT_CLASSES. + + This method checks if the class to be instantiated is in DIRECT_INIT_CLASSES. + If so, it directly instantiates the class instead of using Fiddle's build process. + This is particularly useful for classes like pathlib.Path that rely on __new__ + for instantiation. + + Args: + *args: Positional arguments for instantiation. + **kwargs: Keyword arguments for instantiation. + + Returns: + The instantiated object. + """ + cls = self.__fn_or_cls__ + if cls in DIRECT_INIT_CLASSES: + # Direct initialization for classes in the list + instance = cls(*args, **kwargs) + else: + instance = super().__build__(*args, **kwargs) + + if USE_IO_REGISTRY: + register(instance, copy.deepcopy(self)) + + return instance + class Partial(Generic[_T], fdl.Partial[_T], _CloneAndFNMixin, _VisualizeMixin): """ @@ -280,6 +330,13 @@ def __init__( super().__init__(fn_or_cls, *args, **new_kwargs) + def __build__(self, *args, **kwargs): + instance = super().__build__(*args, **kwargs) + if USE_IO_REGISTRY: + register(instance, copy.deepcopy(self)) + + return instance + register_supported_cast(fdl.Config, Config) register_supported_cast(fdl.Partial, Partial) @@ -415,6 +472,8 @@ def _construct_args( Config, fdl_dc.convert_dataclasses_to_configs(arg, allow_post_init=True), ) + elif arg in _IO_REGISTRY: + final_args[name] = get(arg) else: final_args[name] = arg elif str(parameter.annotation).startswith("typing.Annotated"): diff --git a/src/nemo_run/io/__init__.py b/src/nemo_run/io/__init__.py new file mode 100644 index 0000000..c6752a1 --- /dev/null +++ b/src/nemo_run/io/__init__.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_run.io.api import capture, get, register, reinit + +__all__ = ["capture", "get", "register", "reinit"] diff --git a/src/nemo_run/io/api.py b/src/nemo_run/io/api.py new file mode 100644 index 0000000..e5d1ba2 --- /dev/null +++ b/src/nemo_run/io/api.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses as dc +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Optional, Set, Type, TypeVar, Union, overload + +import fiddle as fdl +import fiddle._src.experimental.dataclasses as fdl_dc + +from nemo_run.io.capture import _CaptureContext +from nemo_run.io.registry import _ConfigRegistry + +if TYPE_CHECKING: + from nemo_run.config import Config + +_T = TypeVar("_T") +_IO_REGISTRY = _ConfigRegistry() + + +class capture: + """ + A decorator and context manager for capturing object configurations. + + This class provides functionality to automatically capture and register configurations + of objects created within its scope. It can be used as a decorator on functions or as + a context manager. + + Args: + cls_to_ignore (Optional[Set[Type]]): A set of classes to ignore during capture. + + Examples: + As a decorator: + >>> @capture() + ... def create_object(): + ... return SomeClass(42) + >>> obj = create_object() + >>> cfg: run.Config[SomeClass] = get(obj) # Configuration is automatically captured + + As a context manager: + >>> with capture(): + ... obj = SomeClass(42) + >>> cfg: run.Config[SomeClass] = get(obj) # Configuration is automatically captured + + With classes to ignore: + >>> @capture(cls_to_ignore={IgnoredClass}) + ... def create_objects(): + ... obj1 = SomeClass(1) + ... obj2 = IgnoredClass(2) + ... return obj1, obj2 + >>> obj1, obj2 = create_objects() + >>> cfg1: run.Config[SomeClass] = get(obj1) # Works + >>> cfg2: run.Config[IgnoredClass] = get(obj2) # Raises ObjectNotFoundError + + Notes: + - Nested captures are supported. + - Exceptions within the capture scope do not prevent object registration. + - Dataclasses are automatically converted to configs without registration. + - Complex arguments (lists, dicts, callables) are supported in captured configs. + - Unsupported types may raise ValueError during capture. + """ + + def __init__(self, cls_to_ignore: Optional[Set[Type]] = None): + self.cls_to_ignore = cls_to_ignore + self._context: Optional[_CaptureContext] = None + + @overload + def __call__(self, func: Callable[..., _T]) -> Callable[..., _T]: ... + + @overload + def __call__(self) -> "capture": ... + + def __call__( + self, func: Optional[Callable[..., _T]] = None + ) -> Union[Callable[..., _T], "capture"]: + """ + Allows the capture class to be used as a decorator. + + If called without arguments, returns the capture instance for use as a context manager. + If called with a function argument, returns a wrapped version of the function that + executes within a capture context. + + Args: + func (Optional[Callable[..., _T]]): The function to be wrapped. + + Returns: + Union[Callable[..., _T], "capture"]: Either the wrapped function or the capture instance. + """ + if func is None: + return self + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> _T: + with self: + return func(*args, **kwargs) + + return wrapper + + def __enter__(self) -> None: + """ + Enters the capture context. + + This method is called when entering a `with` block or at the start of a decorated function. + It sets up the capture context for automatic configuration registration. + + Returns: + None + """ + self._context = _CaptureContext(get, register, self.cls_to_ignore) + return self._context.__enter__() + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[Any], + ) -> Optional[bool]: + """ + Exits the capture context. + + This method is called when exiting a `with` block or at the end of a decorated function. + It ensures that the capture context is properly closed, even if an exception occurred. + + Args: + exc_type (Optional[Type[BaseException]]): The type of the exception that occurred, if any. + exc_value (Optional[BaseException]): The exception instance that occurred, if any. + traceback (Optional[Any]): The traceback object for the exception, if any. + + Returns: + Optional[bool]: Returns the result of the context's __exit__ method, if applicable. + """ + if self._context: + return self._context.__exit__(exc_type, exc_value, traceback) + return None + + +def register(instance: _T, cfg: "Config[_T]") -> None: + """ + Registers a configuration for a given instance in the global registry. + + Args: + instance (_T): The instance to associate with the configuration. + cfg (Config[_T]): The configuration object to register. + + Returns: + None + + Example: + >>> cfg = SomeConfig() + >>> instance = SomeClass() + >>> register(instance, cfg) + """ + if dc.is_dataclass(instance): + return + + _IO_REGISTRY.register(instance, cfg) + + +def get(obj: _T) -> "Config[_T]": + """ + Retrieves the configuration for a given object from the global registry. + + Args: + obj (_T): The object to retrieve the configuration for. + + Returns: + Config[_T]: The configuration associated with the object. + + Raises: + ObjectNotFoundError: If no configuration is found for the given object. + + Example: + >>> instance = SomeClass() + >>> cfg = get(instance) + """ + if dc.is_dataclass(obj): + return fdl_dc.convert_dataclasses_to_configs(obj, allow_post_init=True) + return _IO_REGISTRY.get(obj) + + +def reinit(obj: _T) -> _T: + """ + Reinitializes an object using its stored configuration. + + Args: + obj (_T): The object to reinitialize. + + Returns: + _T: A new instance of the object created from its configuration. + + Example: + >>> import nemo_sdk as sdk + >>> instance = sdk.build(sdk.Config(SomeClass, a=1, b=2)) + >>> new_instance = reinit(instance) + """ + return fdl.build(get(obj)) diff --git a/src/nemo_run/io/capture.py b/src/nemo_run/io/capture.py new file mode 100644 index 0000000..5833615 --- /dev/null +++ b/src/nemo_run/io/capture.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path +from types import FrameType +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type + +from fiddle._src.config import ordered_arguments + + +def process_args( + args: tuple[Any, ...], + kwargs: dict[str, Any], + func: Callable, + get_fn: Callable[[Any], Any], +) -> Dict[str, Any]: + """ + Process both positional and keyword arguments for a given function. + + This function handles the processing of all arguments passed to a function, + ensuring that each argument is properly processed using the provided get_fn. + + Args: + args (tuple[Any, ...]): Positional arguments. + kwargs (dict[str, Any]): Keyword arguments. + func (Callable): The function for which arguments are being processed. + get_fn (Callable[[Any], Any]): Function to process individual arguments. + + Returns: + Dict[str, Any]: A dictionary containing all processed arguments. + """ + # Process positional arguments + processed_args = [process_single_arg(arg, get_fn) for arg in args] + + # Process keyword arguments + processed_kwargs = {k: process_single_arg(v, get_fn) for k, v in kwargs.items()} + + # Combine processed positional and keyword arguments + result = dict(enumerate(processed_args)) + result.update(processed_kwargs) + return result + + +def process_single_arg(v: Any, get_fn: Callable[[Any], Any]) -> Any: + """ + Process a single argument, handling various data types. + + This function recursively processes complex data structures and applies + special handling for certain types like Path objects and callables. + + Args: + v (Any): The argument to process. + get_fn (Callable[[Any], Any]): Function to process non-primitive types. + + Returns: + Any: The processed argument. + """ + from nemo_run.config import Config + + if isinstance(v, (str, int, float, bool, type(None))): + return v + elif isinstance(v, Path): + return Config(Path, str(v)) + elif isinstance(v, (list, tuple)): + return [process_single_arg(item, get_fn) for item in v] + elif isinstance(v, dict): + return {key: process_single_arg(value, get_fn) for key, value in v.items()} + elif ( + callable(v) + or isinstance(v, type) + or (isinstance(v, set) and all(isinstance(item, type) for item in v)) + ): + return v + else: + try: + return get_fn(v) + except Exception: + return v # If we can't process it, return the original value + + +def wrap_init(frame: FrameType, capture_context: "_CaptureContext"): + """ + Wrap the __init__ method of a class to capture its arguments. + + This function is called when an object is instantiated within a capture context. + It processes the arguments passed to the __init__ method and creates a Config + object representing the instantiated class. + + Args: + frame (FrameType): The current stack frame. + capture_context (_CaptureContext): The current capture context. + """ + cls = frame.f_locals.get("self").__class__ + if cls not in capture_context.cls_to_ignore: + # Capture arguments for the current class + args = frame.f_locals.copy() + del args["self"] + if "__class__" in args: + del args["__class__"] # Remove __class__ attribute + capture_context.arg_stack.append((cls, args)) + + # If we've reached the top of the inheritance chain, create the Config + if len(capture_context.arg_stack) == len(cls.__mro__) - 1: # -1 to exclude 'object' + from nemo_run.config import Config + + combined_args = {} + for captured_cls, captured_args in reversed(capture_context.arg_stack): + combined_args.update(captured_args) + + # Use ordered_arguments to get all arguments, including defaults + cfg = Config(cls) + all_args = ordered_arguments(cfg, include_defaults=True) + + # Update all_args with the actually provided arguments + all_args.update(combined_args) + + # Process all arguments before creating the final Config + processed_args = { + name: process_single_arg(value, capture_context.get) + for name, value in all_args.items() + } + + # Create the Config with all processed arguments + cfg = Config(cls, **processed_args) + + if capture_context.register: + capture_context.register(frame.f_locals.get("self"), cfg) + + capture_context.arg_stack.clear() + + +class _CaptureContext: + """ + A context manager for capturing object configurations during instantiation. + + This class sets up a profiling function to intercept object instantiations + and capture their configurations. It's used internally by the `capture` decorator. + + Attributes: + get (Callable): Function to retrieve configurations. + register (Callable): Function to register captured configurations. + cls_to_ignore (Set[Type]): Set of classes to ignore during capture. + old_profile (Optional[Callable]): The previous profiling function. + arg_stack (List[Tuple[Type, Dict[str, Any]]]): Stack to store captured arguments. + """ + + def __init__( + self, get_fn: Callable, register_fn: Callable, cls_to_ignore: Optional[Set[Type]] = None + ): + """ + Initialize the _CaptureContext. + + Args: + get_fn (Callable): Function to retrieve configurations. + register_fn (Callable): Function to register captured configurations. + cls_to_ignore (Optional[Set[Type]]): Set of classes to ignore during capture. + """ + self.get = get_fn + self.register = register_fn + self.cls_to_ignore = cls_to_ignore or set() + self.old_profile = None + self.arg_stack: List[Tuple[Type, Dict[str, Any]]] = [] + + def __enter__(self): + """ + Enter the capture context, setting up the profiling function. + """ + self.old_profile = sys.getprofile() + sys.setprofile(self._profile_func) + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Exit the capture context, restoring the previous profiling function. + """ + sys.setprofile(self.old_profile) + + def _profile_func(self, frame: FrameType, event: str, arg: Any): + """ + Profiling function that intercepts object instantiations. + + This function is called for every function call while the context is active. + It specifically looks for __init__ calls to capture object configurations. + + Args: + frame (FrameType): The current stack frame. + event (str): The type of event (e.g., 'call', 'return'). + arg (Any): Event-specific argument. + + Returns: + Optional[Callable]: The previous profiling function, if any. + """ + if event == "call" and frame.f_code.co_name == "__init__": + wrap_init(frame, self) + return self.old_profile diff --git a/src/nemo_run/io/registry.py b/src/nemo_run/io/registry.py new file mode 100644 index 0000000..aa0577d --- /dev/null +++ b/src/nemo_run/io/registry.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import weakref +from typing import TYPE_CHECKING, Any, Dict, TypeVar + +if TYPE_CHECKING: + from nemo_run.config import Config + +_T = TypeVar("_T") + + +class _ConfigRegistry: + """ + A registry for storing and retrieving configuration objects. + + This class uses weak references to track object instances and a regular dictionary to store + configurations, automatically removing entries when instances are garbage collected. + + Attributes: + _objects (Dict[int, Config]): A dictionary to store configurations. + _ref_map (weakref.WeakKeyDictionary): A weak key dictionary to map instances to their IDs. + """ + + def __init__(self): + """Initializes the ConfigRegistry with empty dictionaries.""" + self._objects: Dict[int, "Config[_T]"] = {} + self._ref_map = weakref.WeakKeyDictionary() + self._strong_ref_map: Dict[Any, int] = {} # New dictionary for non-weakref objects + + def register(self, instance: _T, cfg: "Config[_T]") -> None: + """ + Registers a configuration for a given instance. + + Args: + instance (_T): The instance to associate with the configuration. + cfg (Config[_T]): The configuration object to register. + + Returns: + None + + Example: + >>> registry = ConfigRegistry() + >>> cfg = SomeConfig() + >>> instance = SomeClass() + >>> registry.register(instance, cfg) + """ + obj_id = id(instance) + self._objects[obj_id] = cfg + if self._is_weakref_able(instance): + self._ref_map[instance] = obj_id + else: + self._strong_ref_map[instance] = obj_id + + def _is_weakref_able(self, obj: Any) -> bool: + try: + weakref.ref(obj) + return True + except TypeError: + return False + + def get(self, obj: _T) -> "Config[_T]": + """ + Retrieves the configuration for a given object. + + Args: + obj (_T): The object to retrieve the configuration for. + + Returns: + Config[_T]: The configuration associated with the object. + + Raises: + ObjectNotFoundError: If no configuration is found for the given object. + + Example: + >>> registry = ConfigRegistry() + >>> instance = SomeClass() + >>> cfg = registry.get(instance) + """ + obj_id = id(obj) + cfg = self._objects.get(obj_id) + if cfg is None: + raise ObjectNotFoundError( + f"No configuration found for {obj} " + f"with id {obj_id}. Total configs in registry: {len(self._objects)}." + ) + return cfg + + def get_by_id(self, obj_id: int) -> "Config[_T]": + """ + Retrieves the configuration for a given object id. + + Args: + obj_id (int): The id of the object to retrieve the configuration for. + + Returns: + Config[_T]: The configuration associated with the object id. + + Raises: + ObjectNotFoundError: If no configuration is found for the given object id. + + Example: + >>> registry = ConfigRegistry() + >>> obj = SomeClass() + >>> obj_id = id(obj) + >>> cfg = registry.get_by_id(obj_id) + """ + cfg = self._objects.get(obj_id) + if cfg is None: + raise ObjectNotFoundError(f"No config found for id {obj_id}") + return cfg + + def __len__(self): + """ + Returns the number of configurations stored in the registry. + + Returns: + int: The number of configurations in the registry. + + Example: + >>> registry = ConfigRegistry() + >>> len(registry) + 0 + """ + return len(self._objects) + + def __contains__(self, obj) -> bool: + """ + Checks if a configuration for the given object exists in the registry. + + Args: + obj (_T): The object to check for. + + Returns: + bool: True if a configuration for the object exists, False otherwise. + + Example: + >>> registry = _ConfigRegistry() + >>> instance = SomeClass() + >>> registry.register(instance, SomeConfig()) + >>> instance in registry + True + >>> other_instance = SomeClass() + >>> other_instance in registry + False + """ + return id(obj) in self._objects + + def cleanup(self): + """ + Removes configurations for instances that have been garbage collected. + + This method should be called periodically to clean up the registry. + + Example: + >>> registry = ConfigRegistry() + >>> registry.cleanup() + """ + active_ids = set(self._ref_map.values()) | set(self._strong_ref_map.values()) + to_remove = set(self._objects.keys()) - active_ids + for obj_id in to_remove: + del self._objects[obj_id] + + +class ObjectNotFoundError(Exception): + """Custom exception for when an object is not found in the registry.""" + + pass diff --git a/test/core/packaging/test_base.py b/test/core/packaging/test_base.py index 6f167d0..366a704 100644 --- a/test/core/packaging/test_base.py +++ b/test/core/packaging/test_base.py @@ -14,9 +14,9 @@ # limitations under the License. import pytest -from nemo_run.config import Config -from src.nemo_run.core.packaging.base import Packager +from nemo_run.config import Config +from nemo_run.core.packaging.base import Packager @pytest.fixture diff --git a/test/io/__init__.py b/test/io/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/io/test_api.py b/test/io/test_api.py new file mode 100644 index 0000000..9da0fd2 --- /dev/null +++ b/test/io/test_api.py @@ -0,0 +1,383 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from pathlib import Path + +import fiddle as fdl +import pytest + +import nemo_run as run +from nemo_run.io.registry import ObjectNotFoundError, _ConfigRegistry + + +class TestCapture: + class DummyClass: + def __init__(self, value): + self.value = value + + def test_capture_as_decorator(self): + @run.io.capture() + def create_object(): + return self.DummyClass(42) + + obj = create_object() + assert isinstance(obj, self.DummyClass) + assert obj.value == 42 + + cfg = run.io.get(obj) + assert isinstance(cfg, run.Config) + assert cfg.value == 42 + + def test_capture_as_context_manager(self): + with run.io.capture(): + obj = self.DummyClass(42) + + assert isinstance(obj, self.DummyClass) + assert obj.value == 42 + + cfg = run.io.get(obj) + assert isinstance(cfg, run.Config) + assert cfg.value == 42 + + def test_capture_with_cls_to_ignore(self): + class IgnoredClass: + def __init__(self, value): + self.value = value + + with run.io.capture(cls_to_ignore={IgnoredClass}): + obj1 = self.DummyClass(1) + obj2 = IgnoredClass(2) + + assert isinstance(run.io.get(obj1), run.Config) + with pytest.raises(ObjectNotFoundError): + run.io.get(obj2) + + def test_capture_as_decorator_with_cls_to_ignore(self): + class IgnoredClass: + def __init__(self, value): + self.value = value + + @run.io.capture(cls_to_ignore={IgnoredClass}) + def create_objects(): + obj1 = self.DummyClass(1) + obj2 = IgnoredClass(2) + return obj1, obj2 + + obj1, obj2 = create_objects() + + assert isinstance(run.io.get(obj1), run.Config) + with pytest.raises(ObjectNotFoundError): + run.io.get(obj2) + + def test_nested_capture(self): + with run.io.capture(): + obj1 = self.DummyClass(1) + with run.io.capture(): + obj2 = self.DummyClass(2) + + assert isinstance(run.io.get(obj1), run.Config) + assert isinstance(run.io.get(obj2), run.Config) + assert run.io.get(obj1).value == 1 + assert run.io.get(obj2).value == 2 + + def test_capture_exception_handling(self): + class TestException(Exception): + pass + + with pytest.raises(TestException): + with run.io.capture(): + obj = self.DummyClass(42) + raise TestException("Test exception") + + # The object should still be captured despite the exception + assert isinstance(run.io.get(obj), run.Config) + assert run.io.get(obj).value == 42 + + def test_capture_nested_objects(self): + class NestedClass: + def __init__(self, value): + self.value = value + + class OuterClass: + def __init__(self, nested): + self.nested = nested + + with run.io.capture(): + nested = NestedClass(42) + outer = OuterClass(nested) + + assert isinstance(run.io.get(outer), run.Config) + assert isinstance(run.io.get(outer).nested, run.Config) + assert run.io.get(outer).nested.value == 42 + + def test_capture_complex_arguments(self): + class ComplexClass: + def __init__(self, list_arg, dict_arg): + self.list_arg = list_arg + self.dict_arg = dict_arg + + with run.io.capture(): + obj = ComplexClass([1, 2, 3], {"a": 1, "b": 2}) + + cfg = run.io.get(obj) + assert isinstance(cfg, run.Config) + assert cfg.list_arg == [1, 2, 3] + assert cfg.dict_arg == {"a": 1, "b": 2} + + def test_capture_callable_arguments(self): + def dummy_func(): + pass + + class CallableClass: + def __init__(self, func): + self.func = func + + with run.io.capture(): + obj = CallableClass(dummy_func) + + cfg = run.io.get(obj) + assert isinstance(cfg, run.Config) + assert cfg.func == dummy_func + + def test_capture_path_arguments(self): + class PathClass: + def __init__(self, path): + self.path = path + + with run.io.capture(): + obj = PathClass(Path("/tmp/test")) + + cfg = run.io.get(obj) + assert isinstance(cfg, run.Config) + assert isinstance(cfg.path, run.Config) + assert str(fdl.build(cfg).path) == "/tmp/test" + + def test_capture_multiple_objects(self): + class ClassA: + def __init__(self, value): + self.value = value + + class ClassB: + def __init__(self, value): + self.value = value + + with run.io.capture(): + obj_a = ClassA(1) + obj_b = ClassB("test") + + assert isinstance(run.io.get(obj_a), run.Config) + assert isinstance(run.io.get(obj_b), run.Config) + assert run.io.get(obj_a).value == 1 + assert run.io.get(obj_b).value == "test" + + def test_capture_with_inheritance(self): + class BaseClass: + def __init__(self, base_value): + self.base_value = base_value + + class DerivedClass(BaseClass): + def __init__(self, base_value, derived_value): + super().__init__(base_value) + self.derived_value = derived_value + + with run.io.capture(): + obj = DerivedClass(1, "test") + + cfg = run.io.get(obj) + assert isinstance(cfg, run.Config) + assert cfg.base_value == 1 + assert cfg.derived_value == "test" + + def test_capture_with_default_arguments(self): + class DefaultArgClass: + def __init__(self, arg1, arg2="default"): + self.arg1 = arg1 + self.arg2 = arg2 + + with run.io.capture(): + obj1 = DefaultArgClass(1) + obj2 = DefaultArgClass(2, "custom") + + cfg1 = run.io.get(obj1) + cfg2 = run.io.get(obj2) + + assert cfg1.arg1 == 1 + assert cfg1.arg2 == "default" + assert cfg2.arg1 == 2 + assert cfg2.arg2 == "custom" + + def test_capture_exception_handling_with_object_persistence(self): + class TestException(Exception): + pass + + with pytest.raises(TestException): + with run.io.capture(): + obj = self.DummyClass(42) + raise TestException("Test exception") + + # The object should still be captured despite the exception + assert isinstance(run.io.get(obj), run.Config) + assert run.io.get(obj).value == 42 + + +class TestReinit: + def test_simple(self): + class DummyClass: + def __init__(self, value): + self.value = value + + cfg = run.Config(DummyClass, value=42) + instance = run.build(cfg) + + new_instance = run.io.reinit(instance) + assert isinstance(new_instance, DummyClass) + assert new_instance.value == 42 + assert new_instance is not instance + + def test_reinit_not_registered(self): + class DummyClass: + pass + + instance = DummyClass() + + with pytest.raises(ObjectNotFoundError): + run.io.reinit(instance) + + def test_reinit_dataclass(self): + """Test reinitializing a dataclass instance.""" + + @dataclasses.dataclass + class DummyDataClass: + value: int + name: str + + instance = DummyDataClass(value=42, name="test") + new_instance = run.io.reinit(instance) + + assert isinstance(new_instance, DummyDataClass) + assert new_instance.value == 42 + assert new_instance.name == "test" + assert new_instance is not instance + + +class TestIOCleanup: + @pytest.fixture + def registry(self): + return _ConfigRegistry() + + def test_cleanup_removes_garbage_collected_objects(self, registry): + class DummyObject: + pass + + obj1 = DummyObject() + obj2 = DummyObject() + cfg1 = run.Config(DummyObject) + cfg2 = run.Config(DummyObject) + + obj1_id = id(obj1) # Store the id before deleting the object + registry.register(obj1, cfg1) + registry.register(obj2, cfg2) + + assert len(registry) == 2 + + del obj1 # Make obj1 eligible for garbage collection + registry.cleanup() + + assert len(registry) == 1 + with pytest.raises(ObjectNotFoundError): + registry.get_by_id(obj1_id) # Use a new method to get by id + assert registry.get(obj2) == cfg2 + + def test_cleanup_keeps_live_objects(self, registry): + class DummyObject: + pass + + obj = DummyObject() + cfg = run.Config(DummyObject) + + registry.register(obj, cfg) + registry.cleanup() + + assert len(registry) == 1 + assert registry.get(obj) == cfg + + def test_cleanup_with_empty_registry(self, registry): + registry.cleanup() + assert len(registry) == 0 + + def test_cleanup_multiple_times(self, registry): + class DummyObject: + pass + + obj1 = DummyObject() + obj2 = DummyObject() + cfg1 = run.Config(DummyObject) + cfg2 = run.Config(DummyObject) + + registry.register(obj1, cfg1) + registry.register(obj2, cfg2) + + assert len(registry) == 2 + + del obj1 # Make obj1 eligible for garbage collection + registry.cleanup() + assert len(registry) == 1 + + registry.cleanup() # Second cleanup should not change anything + assert len(registry) == 1 + + del obj2 # Make obj2 eligible for garbage collection + registry.cleanup() + assert len(registry) == 0 + + def test_cleanup_after_reregistration(self, registry): + class DummyObject: + pass + + obj = DummyObject() + cfg1 = run.Config(DummyObject) + cfg2 = run.Config(DummyObject) + + registry.register(obj, cfg1) + registry.register(obj, cfg2) # Re-register with a new config + + registry.cleanup() + + assert len(registry) == 1 + assert registry.get(obj) == cfg2 + + def test_cleanup_stress_test(self, registry): + class DummyObject: + pass + + objects = [] + for _ in range(10000): + obj = DummyObject() + objects.append(obj) + registry.register(obj, run.Config(DummyObject)) + + assert len(registry) == 10000 + + # Delete all objects + del objects + + # Force garbage collection + import gc + + gc.collect() + + registry.cleanup() + assert len(registry) == 1