Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mppi goal to critic #4822

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ namespace mppi

/**
* @struct mppi::CriticData
* @brief Data to pass to critics for scoring, including state, trajectories, path, costs, and
* important parameters to share
* @brief Data to pass to critics for scoring, including state, trajectories,
* pruned path, global goal, costs, and important parameters to share
*/
struct CriticData
{
const models::State & state;
const models::Trajectories & trajectories;
const models::Path & path;
const geometry_msgs::msg::Pose & goal;

xt::xtensor<float, 1> & costs;
float & model_dt;
Expand Down
11 changes: 7 additions & 4 deletions nav2_mppi_controller/include/nav2_mppi_controller/optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class Optimizer
geometry_msgs::msg::TwistStamped evalControl(
const geometry_msgs::msg::PoseStamped & robot_pose,
const geometry_msgs::msg::Twist & robot_speed, const nav_msgs::msg::Path & plan,
nav2_core::GoalChecker * goal_checker);
const geometry_msgs::msg::Pose & goal, nav2_core::GoalChecker * goal_checker);

/**
* @brief Get the trajectories generated in a cycle for visualization
Expand Down Expand Up @@ -138,7 +138,8 @@ class Optimizer
void prepare(
const geometry_msgs::msg::PoseStamped & robot_pose,
const geometry_msgs::msg::Twist & robot_speed,
const nav_msgs::msg::Path & plan, nav2_core::GoalChecker * goal_checker);
const nav_msgs::msg::Path & plan,
const geometry_msgs::msg::Pose & goal, nav2_core::GoalChecker * goal_checker);

/**
* @brief Obtain the main controller's parameters
Expand Down Expand Up @@ -256,10 +257,12 @@ class Optimizer
std::array<mppi::models::Control, 4> control_history_;
models::Trajectories generated_trajectories_;
models::Path path_;
geometry_msgs::msg::Pose goal_;
xt::xtensor<float, 1> costs_;

CriticData critics_data_ =
{state_, generated_trajectories_, path_, costs_, settings_.model_dt, false, nullptr, nullptr,
CriticData critics_data_ = {
state_, generated_trajectories_, path_, goal_,
costs_, settings_.model_dt, false, nullptr, nullptr,
std::nullopt, std::nullopt}; /// Caution, keep references

rclcpp::Logger logger_{rclcpp::get_logger("MPPIController")};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ class PathHandler
*/
nav_msgs::msg::Path transformPath(const geometry_msgs::msg::PoseStamped & robot_pose);

/**
* @brief Get the global goal pose transformed to the local frame
* @return Transformed goal pose
*/
geometry_msgs::msg::PoseStamped getTransformedGoal();

protected:
/**
* @brief Transform a pose to another frame
Expand Down
27 changes: 9 additions & 18 deletions nav2_mppi_controller/include/nav2_mppi_controller/tools/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,27 +204,23 @@ inline models::Path toTensor(const nav_msgs::msg::Path & path)
* @brief Check if the robot pose is within the Goal Checker's tolerances to goal
* @param global_checker Pointer to the goal checker
* @param robot Pose of robot
* @param path Path to retrieve goal pose from
* @param goal Goal pose
* @return bool If robot is within goal checker tolerances to the goal
*/
inline bool withinPositionGoalTolerance(
nav2_core::GoalChecker * goal_checker,
const geometry_msgs::msg::Pose & robot,
const models::Path & path)
const geometry_msgs::msg::Pose & goal)
{
const auto goal_idx = path.x.shape(0) - 1;
const auto goal_x = path.x(goal_idx);
const auto goal_y = path.y(goal_idx);

if (goal_checker) {
geometry_msgs::msg::Pose pose_tolerance;
geometry_msgs::msg::Twist velocity_tolerance;
goal_checker->getTolerances(pose_tolerance, velocity_tolerance);

const auto pose_tolerance_sq = pose_tolerance.position.x * pose_tolerance.position.x;

auto dx = robot.position.x - goal_x;
auto dy = robot.position.y - goal_y;
auto dx = robot.position.x - goal.position.x;
auto dy = robot.position.y - goal.position.y;

auto dist_sq = dx * dx + dy * dy;

Expand All @@ -240,25 +236,20 @@ inline bool withinPositionGoalTolerance(
* @brief Check if the robot pose is within tolerance to the goal
* @param pose_tolerance Pose tolerance to use
* @param robot Pose of robot
* @param path Path to retrieve goal pose from
* @param goal Goal pose
* @return bool If robot is within tolerance to the goal
*/
inline bool withinPositionGoalTolerance(
float pose_tolerance,
const geometry_msgs::msg::Pose & robot,
const models::Path & path)
const geometry_msgs::msg::Pose & goal)
{
const auto goal_idx = path.x.shape(0) - 1;
const float goal_x = path.x(goal_idx);
const float goal_y = path.y(goal_idx);
const double & dist_sq =
std::pow(goal.position.x - robot.position.x, 2) +
std::pow(goal.position.y - robot.position.y, 2);

const float pose_tolerance_sq = pose_tolerance * pose_tolerance;

const float dx = static_cast<float>(robot.position.x) - goal_x;
const float dy = static_cast<float>(robot.position.y) - goal_y;

float dist_sq = dx * dx + dy * dy;

if (dist_sq < pose_tolerance_sq) {
return true;
}
Expand Down
4 changes: 3 additions & 1 deletion nav2_mppi_controller/src/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,15 @@ geometry_msgs::msg::TwistStamped MPPIController::computeVelocityCommands(
#endif

std::lock_guard<std::mutex> param_lock(*parameters_handler_->getLock());
geometry_msgs::msg::Pose goal = path_handler_.getTransformedGoal().pose;

nav_msgs::msg::Path transformed_plan = path_handler_.transformPath(robot_pose);

nav2_costmap_2d::Costmap2D * costmap = costmap_ros_->getCostmap();
std::unique_lock<nav2_costmap_2d::Costmap2D::mutex_t> costmap_lock(*(costmap->getMutex()));

geometry_msgs::msg::TwistStamped cmd =
optimizer_.evalControl(robot_pose, robot_speed, transformed_plan, goal_checker);
optimizer_.evalControl(robot_pose, robot_speed, transformed_plan, goal, goal_checker);

#ifdef BENCHMARK_TESTING
auto end = std::chrono::system_clock::now();
Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/cost_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void CostCritic::score(CriticData & data)

// If near the goal, don't apply the preferential term since the goal is near obstacles
bool near_goal = false;
if (utils::withinPositionGoalTolerance(near_goal_distance_, data.state.pose.pose, data.path)) {
if (utils::withinPositionGoalTolerance(near_goal_distance_, data.state.pose.pose, data.goal)) {
near_goal = true;
}

Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/goal_angle_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void GoalAngleCritic::initialize()
void GoalAngleCritic::score(CriticData & data)
{
if (!enabled_ || !utils::withinPositionGoalTolerance(
threshold_to_consider_, data.state.pose.pose, data.path))
threshold_to_consider_, data.state.pose.pose, data.goal))
{
return;
}
Expand Down
8 changes: 3 additions & 5 deletions nav2_mppi_controller/src/critics/goal_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@ void GoalCritic::initialize()
void GoalCritic::score(CriticData & data)
{
if (!enabled_ || !utils::withinPositionGoalTolerance(
threshold_to_consider_, data.state.pose.pose, data.path))
threshold_to_consider_, data.state.pose.pose, data.goal))
{
return;
}

const auto goal_idx = data.path.x.shape(0) - 1;

const auto goal_x = data.path.x(goal_idx);
const auto goal_y = data.path.y(goal_idx);
const auto & goal_x = data.goal.position.x;
const auto & goal_y = data.goal.position.y;

const auto traj_x = xt::view(data.trajectories.x, xt::all(), xt::all());
const auto traj_y = xt::view(data.trajectories.y, xt::all(), xt::all());
Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/obstacles_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ void ObstaclesCritic::score(CriticData & data)

// If near the goal, don't apply the preferential term since the goal is near obstacles
bool near_goal = false;
if (utils::withinPositionGoalTolerance(near_goal_distance_, data.state.pose.pose, data.path)) {
if (utils::withinPositionGoalTolerance(near_goal_distance_, data.state.pose.pose, data.goal)) {
near_goal = true;
}

Expand Down
4 changes: 2 additions & 2 deletions nav2_mppi_controller/src/critics/path_align_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ void PathAlignCritic::initialize()
void PathAlignCritic::score(CriticData & data)
{
// Don't apply close to goal, let the goal critics take over
if (!enabled_ ||
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.path))
if (!enabled_ || utils::withinPositionGoalTolerance(
threshold_to_consider_, data.state.pose.pose, data.goal))
{
return;
}
Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/path_angle_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void PathAngleCritic::initialize()
void PathAngleCritic::score(CriticData & data)
{
if (!enabled_ ||
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.path))
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.goal))
{
return;
}
Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/path_follow_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void PathFollowCritic::initialize()
void PathFollowCritic::score(CriticData & data)
{
if (!enabled_ || data.path.x.shape(0) < 2 ||
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.path))
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.goal))
{
return;
}
Expand Down
4 changes: 2 additions & 2 deletions nav2_mppi_controller/src/critics/prefer_forward_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ void PreferForwardCritic::initialize()
void PreferForwardCritic::score(CriticData & data)
{
using xt::evaluation_strategy::immediate;
if (!enabled_ ||
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.path))
if (!enabled_ || utils::withinPositionGoalTolerance(
threshold_to_consider_, data.state.pose.pose, data.goal))
{
return;
}
Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/twirling_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void TwirlingCritic::score(CriticData & data)
{
using xt::evaluation_strategy::immediate;
if (!enabled_ ||
utils::withinPositionGoalTolerance(data.goal_checker, data.state.pose.pose, data.path))
utils::withinPositionGoalTolerance(data.goal_checker, data.state.pose.pose, data.goal))
{
return;
}
Expand Down
11 changes: 8 additions & 3 deletions nav2_mppi_controller/src/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,11 @@ bool Optimizer::isHolonomic() const
geometry_msgs::msg::TwistStamped Optimizer::evalControl(
const geometry_msgs::msg::PoseStamped & robot_pose,
const geometry_msgs::msg::Twist & robot_speed,
const nav_msgs::msg::Path & plan, nav2_core::GoalChecker * goal_checker)
const nav_msgs::msg::Path & plan,
const geometry_msgs::msg::Pose & goal,
nav2_core::GoalChecker * goal_checker)
{
prepare(robot_pose, robot_speed, plan, goal_checker);
prepare(robot_pose, robot_speed, plan, goal, goal_checker);

do {
optimize();
Expand Down Expand Up @@ -201,12 +203,15 @@ bool Optimizer::fallback(bool fail)
void Optimizer::prepare(
const geometry_msgs::msg::PoseStamped & robot_pose,
const geometry_msgs::msg::Twist & robot_speed,
const nav_msgs::msg::Path & plan, nav2_core::GoalChecker * goal_checker)
const nav_msgs::msg::Path & plan,
const geometry_msgs::msg::Pose & goal,
nav2_core::GoalChecker * goal_checker)
{
state_.pose = robot_pose;
state_.speed = robot_speed;
path_ = utils::toTensor(plan);
costs_.fill(0.0f);
goal_ = goal;

critics_data_.fail_flag = false;
critics_data_.goal_checker = goal_checker;
Expand Down
14 changes: 14 additions & 0 deletions nav2_mppi_controller/src/path_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,20 @@ void PathHandler::prunePlan(nav_msgs::msg::Path & plan, const PathIterator end)
plan.poses.erase(plan.poses.begin(), end);
}

geometry_msgs::msg::PoseStamped PathHandler::getTransformedGoal()
{
auto goal = global_plan_.poses.back();
goal.header.stamp = global_plan_.header.stamp;
if (goal.header.frame_id.empty()) {
throw nav2_core::ControllerTFError("Goal pose has an empty frame_id");
}
geometry_msgs::msg::PoseStamped transformed_goal;
if (!transformPose(costmap_->getGlobalFrameID(), goal, transformed_goal)) {
throw nav2_core::ControllerTFError("Unable to transform goal pose into costmap frame");
}
return transformed_goal;
}

bool PathHandler::isWithinInversionTolerances(const geometry_msgs::msg::PoseStamped & robot_pose)
{
// Keep full path if we are within tolerance of the inversion pose
Expand Down
3 changes: 2 additions & 1 deletion nav2_mppi_controller/test/critic_manager_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,11 @@ TEST(CriticManagerTests, BasicCriticOperations)
models::ControlSequence control_sequence;
models::Trajectories generated_trajectories;
models::Path path;
geometry_msgs::msg::Pose goal;
xt::xtensor<float, 1> costs;
float model_dt = 0.1;
CriticData data =
{state, generated_trajectories, path, costs, model_dt, false, nullptr, nullptr,
{state, generated_trajectories, path, goal, costs, model_dt, false, nullptr, nullptr,
std::nullopt, std::nullopt};

data.fail_flag = true;
Expand Down
Loading
Loading