From cc780037ef274e9e8d95f28218eb2d5a1221a697 Mon Sep 17 00:00:00 2001 From: Stefano Date: Thu, 9 Mar 2023 19:35:12 +0100 Subject: [PATCH] Using post_init in test_base. --- test/test_base.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/test/test_base.py b/test/test_base.py index 8e676be8..9b390c06 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -16,15 +16,20 @@ class TestVariable(OptimizationObject): storage: StorageType = default_storage_type(Variable) + def __post_init__(self): + self.storage = np.ones(shape=3) + @dataclasses.dataclass class TestParameter(OptimizationObject): storage: StorageType = default_storage_type(Parameter) + def __post_init__(self): + self.storage = np.ones(shape=3) + def test_zero_variable(): test_var = TestVariable() - test_var.storage = np.ones(shape=3) test_var_zero = test_var.get_default_initialized_object() assert test_var_zero.storage.shape == (3,) assert np.all(test_var_zero.storage == 0) @@ -32,7 +37,6 @@ def test_zero_variable(): def test_zero_parameter(): test_par = TestParameter() - test_par.storage = np.ones(shape=3) test_par_zero = test_par.get_default_initialized_object() assert test_par_zero.storage.shape == (3,) assert np.all(test_par_zero.storage == 0) @@ -43,6 +47,10 @@ class CustomInitializationVariable(OptimizationObject): variable: StorageType = default_storage_type(Variable) parameter: StorageType = default_storage_type(Parameter) + def __post_init__(self): + self.variable = np.ones(shape=3) + self.parameter = np.ones(shape=3) + def get_default_initialization( self: TOptimizationObject, field_name: str ) -> np.ndarray: @@ -54,8 +62,6 @@ def get_default_initialization( def test_custom_initialization(): test_var = CustomInitializationVariable() - test_var.variable = np.ones(3) - test_var.parameter = np.ones(3) test_var_init = test_var.get_default_initialized_object() assert test_var_init.parameter.shape == (3,) assert np.all(test_var_init.parameter == 0) @@ -69,13 +75,14 @@ class AggregateClass(OptimizationObject): other_parameter: StorageType = default_storage_type(Parameter) other: str = "" + def __post_init__(self): + self.aggregated = CustomInitializationVariable() + self.other_parameter = np.ones(3) + self.other = "untouched" + def test_aggregated(): test_var = AggregateClass(aggregated=CustomInitializationVariable()) - test_var.aggregated.variable = np.ones(3) - test_var.aggregated.parameter = np.ones(3) - test_var.other_parameter = np.ones(3) - test_var.other = "untouched" test_var_init = test_var.get_default_initialized_object() assert test_var_init.aggregated.parameter.shape == (3,) assert np.all(test_var_init.aggregated.parameter == 0)