Skip to content

Commit

Permalink
Added visualization of complementarity error
Browse files Browse the repository at this point in the history
  • Loading branch information
S-Dafarra committed Oct 31, 2023
1 parent e2ddbcb commit 796435f
Showing 1 changed file with 33 additions and 14 deletions.
47 changes: 33 additions & 14 deletions src/hippopt/robot_planning/utilities/foot_contact_state_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,26 @@

@dataclasses.dataclass
class ContactPointStatePlotterSettings:
axes: matplotlib.axes.Axes = dataclasses.field(default=None)
axes: list[matplotlib.axes.Axes] | None = dataclasses.field(default=None)
terrain: TerrainDescriptor = dataclasses.field(default=None)

input_axes: dataclasses.InitVar[matplotlib.axes.Axes] = dataclasses.field(
input_axes: dataclasses.InitVar[list[matplotlib.axes.Axes]] = dataclasses.field(
default=None
)
input_terrain: dataclasses.InitVar[TerrainDescriptor] = dataclasses.field(
default=None
)

def __post_init__(
self, input_axes: matplotlib.axes.Axes, input_terrain: TerrainDescriptor
self, input_axes: list[matplotlib.axes.Axes], input_terrain: TerrainDescriptor
):
self.axes = input_axes
self.axes = None
if isinstance(input_axes, list):
if len(input_axes) != 2:
raise ValueError("input_axes must be a list of length 2.")

self.axes = input_axes

self.terrain = (
input_terrain
if isinstance(input_terrain, TerrainDescriptor)
Expand Down Expand Up @@ -68,17 +74,22 @@ def plot_complementarity(
)

if self._axes is None:
self._fig, self._axes = plt.subplots()
self._fig, self._axes = plt.subplots(nrows=1, ncols=2)
plt.tight_layout()

height_function = self.settings.terrain.height_function()
normal_direction_fun = self.settings.terrain.normal_direction_function()

positions = np.array([height_function(s.p) for s in states]).flatten()
forces = np.array([normal_direction_fun(s.p).T @ s.f for s in states]).flatten()
self._axes.plot(_time_s, positions)
self._axes.set_ylabel("Height [m]", color="C0")
self._axes.tick_params(axis="y", color="C0", labelcolor="C0")
axes_force = self._axes.twinx()
complementarity_error = np.multiply(positions, forces)
self._axes[1].plot(_time_s, complementarity_error)
self._axes[1].set_ylabel("Complementarity Error [Nm]")
self._axes[1].set_xlabel("Time [s]")
self._axes[0].plot(_time_s, positions)
self._axes[0].set_ylabel("Height [m]", color="C0")
self._axes[0].tick_params(axis="y", color="C0", labelcolor="C0")
axes_force = self._axes[0].twinx()
axes_force.plot(_time_s, forces, "C1")
axes_force.set_ylabel("Normal Force [N]", color="C1")
axes_force.tick_params(axis="y", color="C1", labelcolor="C1")
Expand Down Expand Up @@ -164,9 +175,10 @@ def _create_complementarity_plot(
number_of_columns: int,
terrain: TerrainDescriptor,
):
number_of_plots = len(states[0])
number_of_points = len(states[0])
number_of_plots = number_of_points + 1
_number_of_columns = (
math.ceil(math.sqrt(number_of_plots))
math.floor(math.sqrt(number_of_plots))
if number_of_columns < 1
else number_of_columns
)
Expand All @@ -177,16 +189,23 @@ def _create_complementarity_plot(
ncols=_number_of_columns,
squeeze=False,
)
plt.tight_layout()
last_plot_column = number_of_points - _number_of_columns * (number_of_rows - 1)
last_plot = axes_list[number_of_rows - 1][last_plot_column]
_point_plotters = [
ContactPointStatePlotter(
ContactPointStatePlotterSettings(input_axes=el, terrain=terrain)
ContactPointStatePlotterSettings(
input_axes=[el, last_plot],
terrain=terrain,
)
)
for row in axes_list
for el in row
]
assert len(_point_plotters) == number_of_plots
for i in range(last_plot_column + 1, _number_of_columns):
axes_list[number_of_rows - 1][i].remove()

for p in range(number_of_plots):
for p in range(number_of_points):
contact_states = [state[p] for state in states]
_point_plotters[p].plot_complementarity(
states=contact_states, time_s=time_s
Expand Down

0 comments on commit 796435f

Please sign in to comment.