diff --git a/src/hippopt/__init__.py b/src/hippopt/__init__.py index db08a86a..927b4b4f 100644 --- a/src/hippopt/__init__.py +++ b/src/hippopt/__init__.py @@ -3,7 +3,7 @@ OptimizationObject, StorageType, TOptimizationObject, - default_storage_type, + default_storage_field, ) from .base.parameter import Parameter, TParameter from .base.variable import TVariable, Variable diff --git a/src/hippopt/base/optimization_object.py b/src/hippopt/base/optimization_object.py index ea410ecd..4d75da95 100644 --- a/src/hippopt/base/optimization_object.py +++ b/src/hippopt/base/optimization_object.py @@ -52,8 +52,8 @@ def get_default_initialized_object( return output -def default_storage_type(input_type: Type[OptimizationObject]): +def default_storage_field(cls: Type[OptimizationObject]): return dataclasses.field( default=None, - metadata=input_type.StorageTypeMetadata, + metadata=cls.StorageTypeMetadata, ) diff --git a/test/test_base.py b/test/test_base.py index 9b390c06..8c7dc394 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -8,13 +8,13 @@ StorageType, TOptimizationObject, Variable, - default_storage_type, + default_storage_field, ) @dataclasses.dataclass class TestVariable(OptimizationObject): - storage: StorageType = default_storage_type(Variable) + storage: StorageType = default_storage_field(cls=Variable) def __post_init__(self): self.storage = np.ones(shape=3) @@ -22,7 +22,7 @@ def __post_init__(self): @dataclasses.dataclass class TestParameter(OptimizationObject): - storage: StorageType = default_storage_type(Parameter) + storage: StorageType = default_storage_field(cls=Parameter) def __post_init__(self): self.storage = np.ones(shape=3) @@ -44,8 +44,8 @@ def test_zero_parameter(): @dataclasses.dataclass class CustomInitializationVariable(OptimizationObject): - variable: StorageType = default_storage_type(Variable) - parameter: StorageType = default_storage_type(Parameter) + variable: StorageType = default_storage_field(cls=Variable) + parameter: StorageType = default_storage_field(cls=Parameter) def __post_init__(self): self.variable = np.ones(shape=3) @@ -72,7 +72,7 @@ def test_custom_initialization(): @dataclasses.dataclass class AggregateClass(OptimizationObject): aggregated: CustomInitializationVariable - other_parameter: StorageType = default_storage_type(Parameter) + other_parameter: StorageType = default_storage_field(cls=Parameter) other: str = "" def __post_init__(self):