Skip to content

Commit

Permalink
Merge pull request #11 from ami-iit/black_update
Browse files Browse the repository at this point in the history
Updated black version
  • Loading branch information
S-Dafarra authored Feb 1, 2024
2 parents e375b05 + 90416c4 commit ec33e27
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 51 deletions.
18 changes: 10 additions & 8 deletions src/hippopt/base/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,16 @@ def equal(

def __eq__(
self,
other: cs.Function
| str
| list[str]
| cs.MX
| tuple[cs.Function, dict[str, str]]
| tuple[str, dict[str, str]]
| tuple[list[str], dict[str, str]]
| tuple[cs.MX, dict[str, str]],
other: (
cs.Function
| str
| list[str]
| cs.MX
| tuple[cs.Function, dict[str, str]]
| tuple[str, dict[str, str]]
| tuple[list[str], dict[str, str]]
| tuple[cs.MX, dict[str, str]]
),
) -> TDynamics:
if isinstance(other, tuple):
return self.equal(f=other[0], names_map=other[1])
Expand Down
20 changes: 11 additions & 9 deletions src/hippopt/base/multiple_shooting_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ class MultipleShootingSolver(OptimalControlSolver):

_flattened_variables: FlattenedVariableDict = dataclasses.field(default=None)

_symbolic_structure: TOptimizationObject | list[
TOptimizationObject
] = dataclasses.field(default=None)
_symbolic_structure: TOptimizationObject | list[TOptimizationObject] = (
dataclasses.field(default=None)
)

def __post_init__(
self,
Expand Down Expand Up @@ -363,9 +363,9 @@ def _generate_flattened_and_symbolic_objects( # TODO: remove some indentation
object_in=field_value,
top_level=False,
base_string=base_string + field.name + ".",
base_iterator=(base_iterator[0], generator)
if generator is not None
else None,
base_iterator=(
(base_iterator[0], generator) if generator is not None else None
),
)

output_dict = output_dict | inner_dict
Expand Down Expand Up @@ -408,9 +408,11 @@ def _generate_flattened_and_symbolic_objects( # TODO: remove some indentation
+ "["
+ str(k)
+ "].", # we flatten the list. Note the added [k]
base_iterator=(base_iterator[0], generator)
if generator is not None
else None,
base_iterator=(
(base_iterator[0], generator)
if generator is not None
else None
),
)

output_dict = output_dict | inner_dict
Expand Down
46 changes: 28 additions & 18 deletions src/hippopt/base/opti_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ class OptiSolver(OptimizationSolver):
_cost_expressions: dict[str, cs.MX] = dataclasses.field(default=None)
_constraint_expressions: dict[str, cs.MX] = dataclasses.field(default=None)
_solver: cs.Opti = dataclasses.field(default=None)
_output_solution: TOptimizationObject | list[
TOptimizationObject
] = dataclasses.field(default=None)
_output_solution: TOptimizationObject | list[TOptimizationObject] = (
dataclasses.field(default=None)
)
_output_cost: float = dataclasses.field(default=None)
_cost_values: dict[str, float] = dataclasses.field(default=None)
_constraint_values: dict[str, np.ndarray] = dataclasses.field(default=None)
Expand Down Expand Up @@ -304,9 +304,11 @@ def _get_opti_solution(

def _generate_solution_output(
self,
variables: TOptimizationObject
| list[TOptimizationObject]
| list[list[TOptimizationObject]],
variables: (
TOptimizationObject
| list[TOptimizationObject]
| list[list[TOptimizationObject]]
),
input_solution: cs.OptiSol | dict,
) -> TOptimizationObject | list[TOptimizationObject]:
output = copy.deepcopy(variables)
Expand Down Expand Up @@ -383,12 +385,16 @@ def _set_opti_guess(self, variable: cs.MX, value: np.ndarray) -> None:

def _set_initial_guess_internal(
self,
initial_guess: TOptimizationObject
| list[TOptimizationObject]
| list[list[TOptimizationObject]],
corresponding_variable: TOptimizationObject
| list[TOptimizationObject]
| list[list[TOptimizationObject]],
initial_guess: (
TOptimizationObject
| list[TOptimizationObject]
| list[list[TOptimizationObject]]
),
corresponding_variable: (
TOptimizationObject
| list[TOptimizationObject]
| list[list[TOptimizationObject]]
),
base_name: str = "",
) -> None:
if isinstance(initial_guess, list):
Expand Down Expand Up @@ -679,12 +685,16 @@ def solve(self) -> None:
opti=self._solver,
variables=variables,
parameters=parameters,
costs=list(self._cost_expressions.values())
if self._callback_save_costs
else [],
constraints=list(self._constraint_expressions.values())
if self._callback_save_constraint_multipliers
else [],
costs=(
list(self._cost_expressions.values())
if self._callback_save_costs
else []
),
constraints=(
list(self._constraint_expressions.values())
if self._callback_save_constraint_multipliers
else []
),
)
self._solver.callback(self._callback)
try:
Expand Down
6 changes: 3 additions & 3 deletions src/hippopt/base/optimal_control_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def __iter__(self):

@dataclasses.dataclass
class OptimalControlProblem(Problem[TOptimalControlSolver, TInputObjects]):
optimal_control_solver: dataclasses.InitVar[
OptimalControlSolver
] = dataclasses.field(default=None)
optimal_control_solver: dataclasses.InitVar[OptimalControlSolver] = (
dataclasses.field(default=None)
)

def __post_init__(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/hippopt/base/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ class Output(Generic[TGenericOptimizationObject]):
_cost_values: dataclasses.InitVar[dict[str, np.ndarray]] = dataclasses.field(
default=None
)
_constraint_multipliers: dataclasses.InitVar[
dict[str, np.ndarray]
] = dataclasses.field(default=None)
_constraint_multipliers: dataclasses.InitVar[dict[str, np.ndarray]] = (
dataclasses.field(default=None)
)

def __post_init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ class ContactPointStatePlotterSettings:
force_axes: matplotlib.axes.Axes | None = dataclasses.field(default=None)
terrain: TerrainDescriptor = dataclasses.field(default=None)

input_complementarity_axes: dataclasses.InitVar[
list[matplotlib.axes.Axes]
] = dataclasses.field(default=None)
input_complementarity_axes: dataclasses.InitVar[list[matplotlib.axes.Axes]] = (
dataclasses.field(default=None)
)
input_force_axes: dataclasses.InitVar[matplotlib.axes.Axes] = dataclasses.field(
default=None
)
Expand Down
6 changes: 3 additions & 3 deletions src/hippopt/robot_planning/variables/humanoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class HumanoidState(OptimizationObject):

com: StorageType = default_storage_field(OverridableVariable)

contact_point_descriptors: dataclasses.InitVar[
FeetContactPointDescriptors
] = dataclasses.field(default=None)
contact_point_descriptors: dataclasses.InitVar[FeetContactPointDescriptors] = (
dataclasses.field(default=None)
)
number_of_joints: dataclasses.InitVar[int] = dataclasses.field(default=None)

def __post_init__(
Expand Down
8 changes: 4 additions & 4 deletions test/test_opti_generate_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ class CustomCustomOverridableVariableInner(OptimizationObject):

@dataclasses.dataclass
class CustomCustomOverridableVariableNested(OptimizationObject):
composite: CompositeType[
CustomCustomOverridableVariableInner
] = default_composite_field(
cls=OverridableParameter, factory=CustomCustomOverridableVariableInner
composite: CompositeType[CustomCustomOverridableVariableInner] = (
default_composite_field(
cls=OverridableParameter, factory=CustomCustomOverridableVariableInner
)
)


Expand Down

0 comments on commit ec33e27

Please sign in to comment.