Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated black version #11

Merged
merged 1 commit into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/hippopt/base/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,16 @@ def equal(

def __eq__(
self,
other: cs.Function
| str
| list[str]
| cs.MX
| tuple[cs.Function, dict[str, str]]
| tuple[str, dict[str, str]]
| tuple[list[str], dict[str, str]]
| tuple[cs.MX, dict[str, str]],
other: (
cs.Function
| str
| list[str]
| cs.MX
| tuple[cs.Function, dict[str, str]]
| tuple[str, dict[str, str]]
| tuple[list[str], dict[str, str]]
| tuple[cs.MX, dict[str, str]]
),
) -> TDynamics:
if isinstance(other, tuple):
return self.equal(f=other[0], names_map=other[1])
Expand Down
20 changes: 11 additions & 9 deletions src/hippopt/base/multiple_shooting_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ class MultipleShootingSolver(OptimalControlSolver):

_flattened_variables: FlattenedVariableDict = dataclasses.field(default=None)

_symbolic_structure: TOptimizationObject | list[
TOptimizationObject
] = dataclasses.field(default=None)
_symbolic_structure: TOptimizationObject | list[TOptimizationObject] = (
dataclasses.field(default=None)
)

def __post_init__(
self,
Expand Down Expand Up @@ -363,9 +363,9 @@ def _generate_flattened_and_symbolic_objects( # TODO: remove some indentation
object_in=field_value,
top_level=False,
base_string=base_string + field.name + ".",
base_iterator=(base_iterator[0], generator)
if generator is not None
else None,
base_iterator=(
(base_iterator[0], generator) if generator is not None else None
),
)

output_dict = output_dict | inner_dict
Expand Down Expand Up @@ -408,9 +408,11 @@ def _generate_flattened_and_symbolic_objects( # TODO: remove some indentation
+ "["
+ str(k)
+ "].", # we flatten the list. Note the added [k]
base_iterator=(base_iterator[0], generator)
if generator is not None
else None,
base_iterator=(
(base_iterator[0], generator)
if generator is not None
else None
),
)

output_dict = output_dict | inner_dict
Expand Down
46 changes: 28 additions & 18 deletions src/hippopt/base/opti_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ class OptiSolver(OptimizationSolver):
_cost_expressions: dict[str, cs.MX] = dataclasses.field(default=None)
_constraint_expressions: dict[str, cs.MX] = dataclasses.field(default=None)
_solver: cs.Opti = dataclasses.field(default=None)
_output_solution: TOptimizationObject | list[
TOptimizationObject
] = dataclasses.field(default=None)
_output_solution: TOptimizationObject | list[TOptimizationObject] = (
dataclasses.field(default=None)
)
_output_cost: float = dataclasses.field(default=None)
_cost_values: dict[str, float] = dataclasses.field(default=None)
_constraint_values: dict[str, np.ndarray] = dataclasses.field(default=None)
Expand Down Expand Up @@ -304,9 +304,11 @@ def _get_opti_solution(

def _generate_solution_output(
self,
variables: TOptimizationObject
| list[TOptimizationObject]
| list[list[TOptimizationObject]],
variables: (
TOptimizationObject
| list[TOptimizationObject]
| list[list[TOptimizationObject]]
),
input_solution: cs.OptiSol | dict,
) -> TOptimizationObject | list[TOptimizationObject]:
output = copy.deepcopy(variables)
Expand Down Expand Up @@ -383,12 +385,16 @@ def _set_opti_guess(self, variable: cs.MX, value: np.ndarray) -> None:

def _set_initial_guess_internal(
self,
initial_guess: TOptimizationObject
| list[TOptimizationObject]
| list[list[TOptimizationObject]],
corresponding_variable: TOptimizationObject
| list[TOptimizationObject]
| list[list[TOptimizationObject]],
initial_guess: (
TOptimizationObject
| list[TOptimizationObject]
| list[list[TOptimizationObject]]
),
corresponding_variable: (
TOptimizationObject
| list[TOptimizationObject]
| list[list[TOptimizationObject]]
),
base_name: str = "",
) -> None:
if isinstance(initial_guess, list):
Expand Down Expand Up @@ -679,12 +685,16 @@ def solve(self) -> None:
opti=self._solver,
variables=variables,
parameters=parameters,
costs=list(self._cost_expressions.values())
if self._callback_save_costs
else [],
constraints=list(self._constraint_expressions.values())
if self._callback_save_constraint_multipliers
else [],
costs=(
list(self._cost_expressions.values())
if self._callback_save_costs
else []
),
constraints=(
list(self._constraint_expressions.values())
if self._callback_save_constraint_multipliers
else []
),
)
self._solver.callback(self._callback)
try:
Expand Down
6 changes: 3 additions & 3 deletions src/hippopt/base/optimal_control_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def __iter__(self):

@dataclasses.dataclass
class OptimalControlProblem(Problem[TOptimalControlSolver, TInputObjects]):
optimal_control_solver: dataclasses.InitVar[
OptimalControlSolver
] = dataclasses.field(default=None)
optimal_control_solver: dataclasses.InitVar[OptimalControlSolver] = (
dataclasses.field(default=None)
)

def __post_init__(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/hippopt/base/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ class Output(Generic[TGenericOptimizationObject]):
_cost_values: dataclasses.InitVar[dict[str, np.ndarray]] = dataclasses.field(
default=None
)
_constraint_multipliers: dataclasses.InitVar[
dict[str, np.ndarray]
] = dataclasses.field(default=None)
_constraint_multipliers: dataclasses.InitVar[dict[str, np.ndarray]] = (
dataclasses.field(default=None)
)

def __post_init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ class ContactPointStatePlotterSettings:
force_axes: matplotlib.axes.Axes | None = dataclasses.field(default=None)
terrain: TerrainDescriptor = dataclasses.field(default=None)

input_complementarity_axes: dataclasses.InitVar[
list[matplotlib.axes.Axes]
] = dataclasses.field(default=None)
input_complementarity_axes: dataclasses.InitVar[list[matplotlib.axes.Axes]] = (
dataclasses.field(default=None)
)
input_force_axes: dataclasses.InitVar[matplotlib.axes.Axes] = dataclasses.field(
default=None
)
Expand Down
6 changes: 3 additions & 3 deletions src/hippopt/robot_planning/variables/humanoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class HumanoidState(OptimizationObject):

com: StorageType = default_storage_field(OverridableVariable)

contact_point_descriptors: dataclasses.InitVar[
FeetContactPointDescriptors
] = dataclasses.field(default=None)
contact_point_descriptors: dataclasses.InitVar[FeetContactPointDescriptors] = (
dataclasses.field(default=None)
)
number_of_joints: dataclasses.InitVar[int] = dataclasses.field(default=None)

def __post_init__(
Expand Down
8 changes: 4 additions & 4 deletions test/test_opti_generate_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ class CustomCustomOverridableVariableInner(OptimizationObject):

@dataclasses.dataclass
class CustomCustomOverridableVariableNested(OptimizationObject):
composite: CompositeType[
CustomCustomOverridableVariableInner
] = default_composite_field(
cls=OverridableParameter, factory=CustomCustomOverridableVariableInner
composite: CompositeType[CustomCustomOverridableVariableInner] = (
default_composite_field(
cls=OverridableParameter, factory=CustomCustomOverridableVariableInner
)
)


Expand Down
Loading