Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Diego Ferigo <[email protected]>
  • Loading branch information
S-Dafarra and diegoferigo authored Mar 31, 2023
1 parent 109859c commit aee7685
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
16 changes: 9 additions & 7 deletions src/hippopt/base/opti_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ def _generate_objects_from_list(
or list_of_optimization_objects
)

output = copy.deepcopy(input_structure)
for i in range(len(output)):
output[i] = self.generate_optimization_objects(output[i])
output = [
self.generate_optimization_objects(copy.deepcopy(input_structure[i]))
for i in range(len(output))
]

self._variables = output
return output
Expand Down Expand Up @@ -173,7 +174,7 @@ def _set_initial_guess_internal(
self,
initial_guess: TOptimizationObject,
corresponding_variable: TOptimizationObject,
):
) -> None:
for field in dataclasses.fields(initial_guess):
has_storage_field = OptimizationObject.StorageTypeField in field.metadata

Expand Down Expand Up @@ -321,9 +322,10 @@ def _set_initial_guess_internal(
def generate_optimization_objects(
self, input_structure: TOptimizationObject | List[TOptimizationObject], **kwargs
) -> TOptimizationObject | List[TOptimizationObject]:
if isinstance(input_structure, OptimizationObject):
return self._generate_objects_from_instance(input_structure=input_structure)
return self._generate_objects_from_list(input_structure=input_structure)
return (
self._generate_objects_from_instance if isinstance(input_structure, OptimizationObject)
else self._generate_objects_from_instance
)(input_structure=input_structure)

def get_optimization_objects(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/hippopt/base/optimization_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def generate_optimization_objects(
@abc.abstractmethod
def set_initial_guess(
self, initial_guess: TOptimizationObject | List[TOptimizationObject]
):
) -> None:
pass

@abc.abstractmethod
Expand All @@ -36,11 +36,11 @@ def get_cost_value(self) -> float | None:
pass

@abc.abstractmethod
def add_cost(self, input_cost: cs.MX):
def add_cost(self, input_cost: cs.MX) -> None:
pass

@abc.abstractmethod
def add_constraint(self, input_constraint: cs.MX):
def add_constraint(self, input_constraint: cs.MX) -> None:
pass

@abc.abstractmethod
Expand Down

0 comments on commit aee7685

Please sign in to comment.