Skip to content

Commit

Permalink
Using post_init in test_base.
Browse files Browse the repository at this point in the history
  • Loading branch information
S-Dafarra committed Mar 9, 2023
1 parent 0ddcd4d commit cc78003
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,27 @@
class TestVariable(OptimizationObject):
storage: StorageType = default_storage_type(Variable)

def __post_init__(self):
self.storage = np.ones(shape=3)


@dataclasses.dataclass
class TestParameter(OptimizationObject):
storage: StorageType = default_storage_type(Parameter)

def __post_init__(self):
self.storage = np.ones(shape=3)


def test_zero_variable():
test_var = TestVariable()
test_var.storage = np.ones(shape=3)
test_var_zero = test_var.get_default_initialized_object()
assert test_var_zero.storage.shape == (3,)
assert np.all(test_var_zero.storage == 0)


def test_zero_parameter():
test_par = TestParameter()
test_par.storage = np.ones(shape=3)
test_par_zero = test_par.get_default_initialized_object()
assert test_par_zero.storage.shape == (3,)
assert np.all(test_par_zero.storage == 0)
Expand All @@ -43,6 +47,10 @@ class CustomInitializationVariable(OptimizationObject):
variable: StorageType = default_storage_type(Variable)
parameter: StorageType = default_storage_type(Parameter)

def __post_init__(self):
self.variable = np.ones(shape=3)
self.parameter = np.ones(shape=3)

def get_default_initialization(
self: TOptimizationObject, field_name: str
) -> np.ndarray:
Expand All @@ -54,8 +62,6 @@ def get_default_initialization(

def test_custom_initialization():
test_var = CustomInitializationVariable()
test_var.variable = np.ones(3)
test_var.parameter = np.ones(3)
test_var_init = test_var.get_default_initialized_object()
assert test_var_init.parameter.shape == (3,)
assert np.all(test_var_init.parameter == 0)
Expand All @@ -69,13 +75,14 @@ class AggregateClass(OptimizationObject):
other_parameter: StorageType = default_storage_type(Parameter)
other: str = ""

def __post_init__(self):
self.aggregated = CustomInitializationVariable()
self.other_parameter = np.ones(3)
self.other = "untouched"


def test_aggregated():
test_var = AggregateClass(aggregated=CustomInitializationVariable())
test_var.aggregated.variable = np.ones(3)
test_var.aggregated.parameter = np.ones(3)
test_var.other_parameter = np.ones(3)
test_var.other = "untouched"
test_var_init = test_var.get_default_initialized_object()
assert test_var_init.aggregated.parameter.shape == (3,)
assert np.all(test_var_init.aggregated.parameter == 0)
Expand Down

0 comments on commit cc78003

Please sign in to comment.