-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from ami-iit/initialStructure
Initial structure and first definition of the variables and parameters classes
- Loading branch information
Showing
20 changed files
with
229 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from . import base | ||
from .base.optimization_object import ( | ||
OptimizationObject, | ||
StorageType, | ||
TOptimizationObject, | ||
default_storage_field, | ||
) | ||
from .base.parameter import Parameter, TParameter | ||
from .base.variable import TVariable, Variable |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from . import optimization_object, parameter, variable |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import abc | ||
import copy | ||
import dataclasses | ||
from typing import Any, ClassVar, Type, TypeVar | ||
|
||
import casadi as cs | ||
import numpy as np | ||
|
||
TOptimizationObject = TypeVar("TOptimizationObject", bound="OptimizationObject") | ||
StorageType = cs.MX | np.ndarray | ||
|
||
|
||
@dataclasses.dataclass | ||
class OptimizationObject(abc.ABC): | ||
StorageType: ClassVar[str] = "generic" | ||
StorageTypeMetadata: ClassVar[dict[str, Any]] = dict(StorageType=StorageType) | ||
|
||
def get_default_initialization( | ||
self: TOptimizationObject, field_name: str | ||
) -> np.ndarray: | ||
""" | ||
Get the default initialization of a given field | ||
It is supposed to be called only for the fields having the StorageType metadata | ||
""" | ||
return np.zeros(dataclasses.asdict(self)[field_name].shape) | ||
|
||
def get_default_initialized_object( | ||
self: TOptimizationObject, | ||
) -> TOptimizationObject: | ||
""" | ||
:return: A copy of the object with its initial values | ||
""" | ||
|
||
output = copy.deepcopy(self) | ||
output_dict = dataclasses.asdict(output) | ||
|
||
for field in dataclasses.fields(output): | ||
if "StorageType" in field.metadata: | ||
output.__setattr__( | ||
field.name, output.get_default_initialization(field.name) | ||
) | ||
continue | ||
|
||
if isinstance(output.__getattribute__(field.name), OptimizationObject): | ||
output.__setattr__( | ||
field.name, | ||
output.__getattribute__( | ||
field.name | ||
).get_default_initialized_object(), | ||
) | ||
|
||
return output | ||
|
||
|
||
def default_storage_field(cls: Type[OptimizationObject]): | ||
return dataclasses.field( | ||
default=None, | ||
metadata=cls.StorageTypeMetadata, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import dataclasses | ||
from typing import Any, ClassVar, TypeVar | ||
|
||
from hippopt.base.optimization_object import OptimizationObject | ||
|
||
TParameter = TypeVar("TParameter", bound="Parameter") | ||
|
||
|
||
@dataclasses.dataclass | ||
class Parameter(OptimizationObject): | ||
"""""" | ||
|
||
StorageType: ClassVar[str] = "parameter" | ||
StorageTypeMetadata: ClassVar[dict[str, Any]] = dict(StorageType=StorageType) |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import dataclasses | ||
from typing import Any, ClassVar, TypeVar | ||
|
||
from hippopt.base.optimization_object import OptimizationObject | ||
|
||
TVariable = TypeVar("TVariable", bound="Variable") | ||
|
||
|
||
@dataclasses.dataclass | ||
class Variable(OptimizationObject): | ||
"""""" | ||
|
||
StorageType: ClassVar[str] = "variable" | ||
StorageTypeMetadata: ClassVar[dict[str, Any]] = dict(StorageType=StorageType) |
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import dataclasses | ||
|
||
import numpy as np | ||
|
||
from hippopt import ( | ||
OptimizationObject, | ||
Parameter, | ||
StorageType, | ||
TOptimizationObject, | ||
Variable, | ||
default_storage_field, | ||
) | ||
|
||
|
||
@dataclasses.dataclass | ||
class TestVariable(OptimizationObject): | ||
storage: StorageType = default_storage_field(cls=Variable) | ||
|
||
def __post_init__(self): | ||
self.storage = np.ones(shape=3) | ||
|
||
|
||
@dataclasses.dataclass | ||
class TestParameter(OptimizationObject): | ||
storage: StorageType = default_storage_field(cls=Parameter) | ||
|
||
def __post_init__(self): | ||
self.storage = np.ones(shape=3) | ||
|
||
|
||
def test_zero_variable(): | ||
test_var = TestVariable() | ||
test_var_zero = test_var.get_default_initialized_object() | ||
assert test_var_zero.storage.shape == (3,) | ||
assert np.all(test_var_zero.storage == 0) | ||
|
||
|
||
def test_zero_parameter(): | ||
test_par = TestParameter() | ||
test_par_zero = test_par.get_default_initialized_object() | ||
assert test_par_zero.storage.shape == (3,) | ||
assert np.all(test_par_zero.storage == 0) | ||
|
||
|
||
@dataclasses.dataclass | ||
class CustomInitializationVariable(OptimizationObject): | ||
variable: StorageType = default_storage_field(cls=Variable) | ||
parameter: StorageType = default_storage_field(cls=Parameter) | ||
|
||
def __post_init__(self): | ||
self.variable = np.ones(shape=3) | ||
self.parameter = np.ones(shape=3) | ||
|
||
def get_default_initialization( | ||
self: TOptimizationObject, field_name: str | ||
) -> np.ndarray: | ||
if field_name == "variable": | ||
return 2 * np.ones(2) | ||
|
||
return OptimizationObject.get_default_initialization(self, field_name) | ||
|
||
|
||
def test_custom_initialization(): | ||
test_var = CustomInitializationVariable() | ||
test_var_init = test_var.get_default_initialized_object() | ||
assert test_var_init.parameter.shape == (3,) | ||
assert np.all(test_var_init.parameter == 0) | ||
assert test_var_init.variable.shape == (2,) | ||
assert np.all(test_var_init.variable == 2) | ||
|
||
|
||
@dataclasses.dataclass | ||
class AggregateClass(OptimizationObject): | ||
aggregated: CustomInitializationVariable | ||
other_parameter: StorageType = default_storage_field(cls=Parameter) | ||
other: str = "" | ||
|
||
def __post_init__(self): | ||
self.aggregated = CustomInitializationVariable() | ||
self.other_parameter = np.ones(3) | ||
self.other = "untouched" | ||
|
||
|
||
def test_aggregated(): | ||
test_var = AggregateClass(aggregated=CustomInitializationVariable()) | ||
test_var_init = test_var.get_default_initialized_object() | ||
assert test_var_init.aggregated.parameter.shape == (3,) | ||
assert np.all(test_var_init.aggregated.parameter == 0) | ||
assert test_var_init.aggregated.variable.shape == (2,) | ||
assert np.all(test_var_init.aggregated.variable == 2) | ||
assert test_var_init.other_parameter.shape == (3,) | ||
assert np.all(test_var_init.other_parameter == 0) | ||
assert test_var_init.other == "untouched" |