Skip to content

Commit

Permalink
Moved conversion to np.array in conversion to dict
Browse files Browse the repository at this point in the history
  • Loading branch information
S-Dafarra committed Mar 26, 2024
1 parent 9e432a2 commit d331560
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 24 deletions.
25 changes: 10 additions & 15 deletions src/hippopt/base/opti_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,26 +138,21 @@ def _generate_opti_object(
if value is None:
raise ValueError("Field " + name + " is tagged as storage, but it is None.")

if isinstance(value, np.ndarray):
if value.ndim > 2:
raise ValueError(
"Field " + name + " has number of dimensions greater than 2."
)
if value.ndim == 0:
raise ValueError("Field " + name + " is a zero-dimensional vector.")

if value.ndim < 2:
value = np.expand_dims(value, axis=1)
if not isinstance(value, np.ndarray):
raise ValueError(
f"Field {name} is tagged as storage, but it is not an array "
f"(it is a {str(type(value))})."
)

if isinstance(value, float):
value = value * np.ones((1, 1))
if value.ndim != 2:
raise ValueError(
f"Field {name} has number of dimensions different from 2 "
f"(input: {value.ndim})."
)

if value.shape[0] * value.shape[1] == 0:
raise ValueError("Field " + name + " has a zero dimension.")

assert isinstance(value, np.ndarray)
assert value.ndim == 2

if storage_type is Variable.StorageTypeValue:
self._logger.debug("Creating variable " + name)
opti_object = self._solver.variable(*value.shape)
Expand Down
35 changes: 26 additions & 9 deletions src/hippopt/base/optimization_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,25 @@ def to_list(self) -> list:
def to_mx(self) -> cs.MX:
return cs.vertcat(*self.to_list())

@staticmethod
def _convert_to_np_array(value: Any) -> Any | np.ndarray:
output_value = value
list_of_float = isinstance(output_value, list) and (
len(output_value) == 0
or all(isinstance(elem, float) for elem in output_value)
)
if list_of_float:
output_value = np.array(output_value)

if isinstance(output_value, np.ndarray):
if output_value.ndim < 2:
output_value = np.expand_dims(output_value, axis=1)

if isinstance(output_value, float):
output_value = output_value * np.ones((1, 1))

return output_value

@staticmethod
def _scan(
input_object: TOptimizationObject | list[TOptimizationObject],
Expand Down Expand Up @@ -159,15 +178,10 @@ def _scan(
parent_metadata[OptimizationObject.StorageTypeField]
)

value_is_list = isinstance(composite_value, list)
list_of_float = value_is_list and (
len(composite_value) == 0
or all(isinstance(elem, float) for elem in composite_value)
composite_value = OptimizationObject._convert_to_np_array(
composite_value
)
if list_of_float:
composite_value = np.array(composite_value)
value_is_list = False

value_is_list = isinstance(composite_value, list)
value_list = composite_value if value_is_list else [composite_value]
name_radix = name_prefix + field.name
value_from_dict = []
Expand All @@ -179,8 +193,11 @@ def _scan(
value_from_dict.append(input_dict[full_name])

metadata_dict[full_name] = value_metadata

output_dict[full_name] = (
composite_value[i] if value_is_list else composite_value
OptimizationObject._convert_to_np_array(composite_value[i])
if value_is_list
else composite_value
)

if len(value_from_dict) > 0:
Expand Down

0 comments on commit d331560

Please sign in to comment.