Skip to content

Commit

Permalink
Added visualization of complementarity.
Browse files Browse the repository at this point in the history
Exploiting external processes to avoid matplotlib blocking the execution
  • Loading branch information
S-Dafarra committed Oct 31, 2023
1 parent 14a9725 commit e2ddbcb
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 33 deletions.
6 changes: 6 additions & 0 deletions src/hippopt/robot_planning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
quaternion_xyzw_normalization,
quaternion_xyzw_velocity_to_right_trivialized_angular_velocity,
)
from .utilities.foot_contact_state_plotter import (
ContactPointStatePlotter,
ContactPointStatePlotterSettings,
FootContactStatePlotter,
FootContactStatePlotterSettings,
)
from .utilities.humanoid_state_visualizer import (
HumanoidStateVisualizer,
HumanoidStateVisualizerSettings,
Expand Down
2 changes: 1 addition & 1 deletion src/hippopt/robot_planning/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import humanoid_state_visualizer, terrain_descriptor
from . import foot_contact_state_plotter, humanoid_state_visualizer, terrain_descriptor
103 changes: 72 additions & 31 deletions src/hippopt/robot_planning/utilities/foot_contact_state_plotter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import copy
import dataclasses
import logging
import math
import multiprocessing
from typing import TypeVar

import matplotlib.axes
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -70,8 +73,8 @@ def plot_complementarity(
height_function = self.settings.terrain.height_function()
normal_direction_fun = self.settings.terrain.normal_direction_function()

positions = [height_function(s.p) for s in states]
forces = [normal_direction_fun(s.p).T @ s.f for s in states]
positions = np.array([height_function(s.p) for s in states]).flatten()
forces = np.array([normal_direction_fun(s.p).T @ s.f for s in states]).flatten()
self._axes.plot(_time_s, positions)
self._axes.set_ylabel("Height [m]", color="C0")
self._axes.tick_params(axis="y", color="C0", labelcolor="C0")
Expand All @@ -86,6 +89,7 @@ def plot_complementarity(
self._fig.suptitle(title)
plt.draw()
plt.pause(0.001)
plt.show()


@dataclasses.dataclass
Expand All @@ -94,64 +98,101 @@ class FootContactStatePlotterSettings:
terrain: TerrainDescriptor = dataclasses.field(default=None)


TFootContactStatePlotter = TypeVar(
"TFootContactStatePlotter", bound="FootContactStatePlotter"
)


class FootContactStatePlotter:
def __init__(
self,
settings: FootContactStatePlotterSettings = FootContactStatePlotterSettings(),
):
self.settings = settings
self.number_of_rows = -1
self.fig = None
self.point_plotters = []
self._settings = settings
self._ext_process = None
self._logger = logging.getLogger("[hippopt::FootContactStatePlotter]")

def plot_complementarity(
self,
states: list[FootContactState],
time_s: float | list[float] | np.ndarray = None,
title: str = "Foot Contact Complementarity",
blocking: bool = False,
):
if self._ext_process is not None:
self._logger.warning(
"A plot is already running. "
"Make sure to close the previous plot first."
)
self._ext_process.join()
self._ext_process = None
_time_s = copy.deepcopy(time_s)
_states = copy.deepcopy(states)
_terrain = copy.deepcopy(self._settings.terrain)
if _time_s is None or isinstance(_time_s, float) or _time_s.size == 1:
single_step = _time_s if _time_s is not None else 0.0
_time_s = np.linspace(0, len(states) * single_step, len(states))

if len(_time_s) != len(states):
if len(_time_s) != len(_states):
raise ValueError(
"timestep_s and foot_contact_states have different lengths."
)

if len(states) == 0:
if len(_states) == 0:
return

self._ext_process = multiprocessing.Process(
target=FootContactStatePlotter._create_complementarity_plot,
args=(
_states,
_time_s,
title,
self._settings.number_of_columns,
_terrain,
),
)
self._ext_process.start()

if blocking:
self._ext_process.join()

@staticmethod
def _create_complementarity_plot(
states: list[FootContactState],
time_s: np.ndarray,
title: str,
number_of_columns: int,
terrain: TerrainDescriptor,
):
number_of_plots = len(states[0])
if self.settings.number_of_columns < 1:
self.settings.number_of_columns = math.ceil(math.sqrt(number_of_plots))
number_of_rows = math.ceil(number_of_plots / self.settings.number_of_columns)

if self.number_of_rows != number_of_rows:
self.fig, axes_list = plt.subplots(
nrows=number_of_rows,
ncols=self.settings.number_of_columns,
squeeze=False,
_number_of_columns = (
math.ceil(math.sqrt(number_of_plots))
if number_of_columns < 1
else number_of_columns
)
number_of_rows = math.ceil(number_of_plots / _number_of_columns)

_fig, axes_list = plt.subplots(
nrows=number_of_rows,
ncols=_number_of_columns,
squeeze=False,
)
_point_plotters = [
ContactPointStatePlotter(
ContactPointStatePlotterSettings(input_axes=el, terrain=terrain)
)
self.number_of_rows = number_of_rows
self.point_plotters = [
ContactPointStatePlotter(
ContactPointStatePlotterSettings(
input_axes=el, terrain=self.settings.terrain
)
)
for row in axes_list
for el in row
]
assert len(self.point_plotters) == number_of_plots
for row in axes_list
for el in row
]
assert len(_point_plotters) == number_of_plots

for p in range(number_of_plots):
contact_states = [state[p] for state in states]
self.point_plotters[p].plot_complementarity(
states=contact_states, time_s=_time_s
_point_plotters[p].plot_complementarity(
states=contact_states, time_s=time_s
)

self.fig.suptitle(title)
_fig.suptitle(title)
plt.draw()
plt.pause(0.001)
plt.show()
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,13 @@ def _visualize_multiple_states(
folder_name = f"{self._settings.working_folder}/{file_name_stem}_frames"
pathlib.Path(folder_name).mkdir(parents=True, exist_ok=True)

if save:
self._logger.info(
f"Saving visualization frames in {folder_name}. "
"Make sure to have the visualizer open, "
"otherwise the process will hang."
)

for i, state in enumerate(states):
self._logger.info(f"Visualizing state {i + 1}/{len(states)}")
start = time.time()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,28 @@ def get_references(
output = planner.solve()

humanoid_states = [s.to_humanoid_state() for s in output.values.system]

left_contact_points = [s.contact_points.left for s in humanoid_states]
right_contact_points = [s.contact_points.right for s in humanoid_states]
print("Press [Enter] to visualize the solution.")
input()

plotter_settings = hp_rp.FootContactStatePlotterSettings()
plotter_settings.terrain = planner_settings.terrain
left_complementarity_plotter = hp_rp.FootContactStatePlotter(plotter_settings)
left_complementarity_plotter.plot_complementarity(
states=left_contact_points,
time_s=output.values.dt,
title="Left Foot Complementarity",
blocking=False,
)
right_complementarity_plotter = hp_rp.FootContactStatePlotter(plotter_settings)
right_complementarity_plotter.plot_complementarity(
states=right_contact_points,
time_s=output.values.dt,
title="Right Foot Complementarity",
blocking=False,
)

visualizer.visualize(
state=humanoid_states, timestep_s=output.values.dt, time_multiplier=2.0
)
Expand Down

0 comments on commit e2ddbcb

Please sign in to comment.