Skip to content

Commit

Permalink
Renamed default_storage_type to default_storage_field.
Browse files Browse the repository at this point in the history
  • Loading branch information
S-Dafarra committed Mar 13, 2023
1 parent 09ecc8e commit 67c68b9
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/hippopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
OptimizationObject,
StorageType,
TOptimizationObject,
default_storage_type,
default_storage_field,
)
from .base.parameter import Parameter, TParameter
from .base.variable import TVariable, Variable
4 changes: 2 additions & 2 deletions src/hippopt/base/optimization_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def get_default_initialized_object(
return output


def default_storage_type(input_type: Type[OptimizationObject]):
def default_storage_field(cls: Type[OptimizationObject]):
return dataclasses.field(
default=None,
metadata=input_type.StorageTypeMetadata,
metadata=cls.StorageTypeMetadata,
)
12 changes: 6 additions & 6 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@
StorageType,
TOptimizationObject,
Variable,
default_storage_type,
default_storage_field,
)


@dataclasses.dataclass
class TestVariable(OptimizationObject):
storage: StorageType = default_storage_type(Variable)
storage: StorageType = default_storage_field(cls=Variable)

def __post_init__(self):
self.storage = np.ones(shape=3)


@dataclasses.dataclass
class TestParameter(OptimizationObject):
storage: StorageType = default_storage_type(Parameter)
storage: StorageType = default_storage_field(cls=Parameter)

def __post_init__(self):
self.storage = np.ones(shape=3)
Expand All @@ -44,8 +44,8 @@ def test_zero_parameter():

@dataclasses.dataclass
class CustomInitializationVariable(OptimizationObject):
variable: StorageType = default_storage_type(Variable)
parameter: StorageType = default_storage_type(Parameter)
variable: StorageType = default_storage_field(cls=Variable)
parameter: StorageType = default_storage_field(cls=Parameter)

def __post_init__(self):
self.variable = np.ones(shape=3)
Expand All @@ -72,7 +72,7 @@ def test_custom_initialization():
@dataclasses.dataclass
class AggregateClass(OptimizationObject):
aggregated: CustomInitializationVariable
other_parameter: StorageType = default_storage_type(Parameter)
other_parameter: StorageType = default_storage_field(cls=Parameter)
other: str = ""

def __post_init__(self):
Expand Down

0 comments on commit 67c68b9

Please sign in to comment.