Skip to content

Commit

Permalink
Added possibility to run conversions when converting from dict
Browse files Browse the repository at this point in the history
  • Loading branch information
S-Dafarra committed May 17, 2024
1 parent 944b987 commit fee7b98
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
25 changes: 22 additions & 3 deletions src/hippopt/base/optimization_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class OptimizationObject(abc.ABC):
or isinstance(value, cs.DM)
or isinstance(value, cs.MX)
)
DMConversion: ClassVar[Callable[[str, Any], Any]] = lambda _, value: (
value.full().flatten() if isinstance(value, cs.DM) else value
)

@staticmethod
def _convert_to_np_array(value: Any) -> Any | np.ndarray:
Expand Down Expand Up @@ -64,6 +67,7 @@ def _scan(
parent_metadata: dict | None = None,
input_dict: dict | None = None,
output_filter: Callable[[str, Any, dict], bool] | None = None,
input_conversion: Callable[[str, Any], Any] | None = None,
) -> (dict, dict):
output_dict = {}
metadata_dict = {}
Expand All @@ -79,6 +83,7 @@ def _scan(
parent_metadata=parent_metadata,
input_dict=input_dict,
output_filter=output_filter,
input_conversion=input_conversion,
)
output_dict.update(inner_dict)
metadata_dict.update(inner_metadata)
Expand Down Expand Up @@ -130,6 +135,7 @@ def _scan(
parent_metadata=new_parent_metadata,
input_dict=input_dict,
output_filter=output_filter,
input_conversion=input_conversion,
)
output_dict.update(inner_dict)
metadata_dict.update(inner_metadata)
Expand Down Expand Up @@ -163,7 +169,12 @@ def _scan(
full_name = name_radix + postfix

if input_dict is not None and full_name in input_dict:
value_from_dict.append(input_dict[full_name])
converted_input = (
input_conversion(full_name, input_dict[full_name])
if input_conversion is not None
else input_dict[full_name]
)
value_from_dict.append(converted_input)

output_value = (
OptimizationObject._convert_to_np_array(composite_value[i])
Expand Down Expand Up @@ -208,9 +219,17 @@ def to_dicts(
)
return output_dict, metadata_dict

def from_dict(self, input_dict: dict, prefix: str = "") -> None:
def from_dict(
self,
input_dict: dict,
prefix: str = "",
input_conversion: Callable[[str, Any], Any] | None = None,
) -> None:
OptimizationObject._scan(
input_object=self, name_prefix=prefix, input_dict=input_dict
input_object=self,
name_prefix=prefix,
input_dict=input_dict,
input_conversion=input_conversion,
)

def to_list(self) -> list:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,12 +590,10 @@ def get_full_output_function(
link_densities=parametric_link_densities,
)

for key in computed_output:
if isinstance(computed_output[key], cs.DM):
computed_output[key] = computed_output[key].full().flatten()

output = planner.get_variables_structure()
output.from_dict(computed_output)
output.from_dict(
computed_output, input_conversion=hippopt.OptimizationObject.DMConversion
)

humanoid_states = [s.to_humanoid_state() for s in output.system]
left_contact_points = [s.contact_points.left for s in humanoid_states]
Expand Down

0 comments on commit fee7b98

Please sign in to comment.