Skip to content

Commit

Permalink
Added possibility to set OptimizationObject from dict
Browse files Browse the repository at this point in the history
  • Loading branch information
S-Dafarra committed Mar 21, 2024
1 parent 83eaf9c commit dcf6978
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 27 deletions.
70 changes: 43 additions & 27 deletions src/hippopt/base/optimization_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,36 +69,41 @@ def to_mx(self) -> cs.MX:
return cs.vertcat(*self.to_list())

@staticmethod
def _to_dict(
def _scan(
input_object: TOptimizationObject | list[TOptimizationObject],
name_prefix: str = "",
get_metadata: bool = False,
parent_metadata: dict | None = None,
) -> dict:
input_dict: dict | None = None,
) -> (dict, dict):
output_dict = {}
metadata_dict = {}
if isinstance(input_object, list):
assert all(
isinstance(elem, OptimizationObject) or isinstance(elem, list)
for elem in input_object
)
for i, elem in enumerate(input_object):
output_dict.update(
OptimizationObject._to_dict(
input_object=elem,
name_prefix=name_prefix + f"[{str(i)}].",
get_metadata=get_metadata,
parent_metadata=parent_metadata,
)
inner_dict, inner_metadata = OptimizationObject._scan(
input_object=elem,
name_prefix=name_prefix + f"[{str(i)}].",
parent_metadata=parent_metadata,
input_dict=input_dict,
)
return output_dict
output_dict.update(inner_dict)
metadata_dict.update(inner_metadata)
return output_dict, metadata_dict

assert isinstance(input_object, OptimizationObject)
for field in dataclasses.fields(input_object):
composite_value = input_object.__getattribute__(field.name)

list_of_optimization_objects = isinstance(composite_value, list) and all(
isinstance(elem, OptimizationObject) or isinstance(elem, list)
for elem in composite_value
list_of_optimization_objects = (
isinstance(composite_value, list)
and len(composite_value) > 0
and all(
isinstance(elem, OptimizationObject) or isinstance(elem, list)
for elem in composite_value
)
)

if (
Expand Down Expand Up @@ -128,14 +133,14 @@ def _to_dict(
new_parent_metadata = composite_metadata

separator = "" if list_of_optimization_objects else "."
output_dict.update(
OptimizationObject._to_dict(
input_object=composite_value,
name_prefix=name_prefix + field.name + separator,
get_metadata=get_metadata,
parent_metadata=new_parent_metadata,
)
inner_dict, inner_metadata = OptimizationObject._scan(
input_object=composite_value,
name_prefix=name_prefix + field.name + separator,
parent_metadata=new_parent_metadata,
input_dict=input_dict,
)
output_dict.update(inner_dict)
metadata_dict.update(inner_metadata)
continue

if OptimizationObject.StorageTypeField in field.metadata:
Expand All @@ -154,18 +159,29 @@ def _to_dict(
parent_metadata[OptimizationObject.StorageTypeField]
)

output_dict[name_prefix + field.name] = (
composite_value if not get_metadata else value_metadata
)
full_name = name_prefix + field.name

if input_dict is not None and full_name in input_dict:
input_value = input_dict[full_name]
input_object.__setattr__(field.name, input_value)

metadata_dict[full_name] = value_metadata
output_dict[full_name] = composite_value

continue

return output_dict
return output_dict, metadata_dict

def to_dict(self) -> dict:
return OptimizationObject._to_dict(input_object=self)
output_dict, _ = OptimizationObject._scan(input_object=self)
return output_dict

def metadata_to_dict(self) -> dict:
return OptimizationObject._to_dict(input_object=self, get_metadata=True)
_, metadata_dict = OptimizationObject._scan(input_object=self)
return metadata_dict

def from_dict(self, input_dict: dict) -> None:
OptimizationObject._scan(input_object=self, input_dict=input_dict)

@classmethod
def default_storage_metadata(cls, **kwargs) -> dict:
Expand Down
5 changes: 5 additions & 0 deletions test/test_opti_generate_objects.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import dataclasses

import casadi as cs
Expand Down Expand Up @@ -97,6 +98,10 @@ def test_generate_objects():
]
)
assert "other" not in as_dict
dict_copy = copy.deepcopy(as_dict)
dict_copy["aggregated.scalar"] = 7.0
opti_var.from_dict(dict_copy)
assert opti_var.aggregated.scalar == 7.0


def test_generate_objects_list():
Expand Down

0 comments on commit dcf6978

Please sign in to comment.