Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recalculate Returns and Advantages After Callback to Ensure Reward Consistency (common/on_policy_algorithm.py) #2000

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion stable_baselines3/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,12 @@ def train(self) -> None:
if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

# Policy gradient loss
# Policy gradient
add_loss = None
policy_loss = -(advantages * log_prob).mean()
if self.has_additional_loss:
add_loss = self._calculate_additional_loss(rollout_data.observations, log_prob).mean()
policy_loss += add_loss

# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(rollout_data.returns, values)
Expand Down Expand Up @@ -188,6 +192,8 @@ def train(self) -> None:
self.logger.record("train/value_loss", value_loss.item())
if hasattr(self.policy, "log_std"):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
if add_loss is not None:
self.logger.record(f"train/{self.additional_loss_name}", add_loss.item())

def learn(
self: SelfA2C,
Expand Down
31 changes: 29 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
from abc import ABC, abstractmethod
from collections import deque
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union

import gymnasium as gym
import numpy as np
Expand All @@ -22,7 +22,14 @@
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict
from stable_baselines3.common.type_aliases import (
GymEnv,
MaybeCallback,
ReplayBufferSamples,
RolloutBufferSamples,
Schedule,
TensorDict,
)
from stable_baselines3.common.utils import (
check_for_correct_spaces,
get_device,
Expand Down Expand Up @@ -199,6 +206,9 @@ def __init__(
np.isfinite(np.array([self.action_space.low, self.action_space.high]))
), "Continuous action space must have a finite lower and upper bound"

# in order to initialize values
self.remove_additional_loss()

@staticmethod
def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv:
""" "
Expand Down Expand Up @@ -864,3 +874,20 @@ def save(
params_to_save = self.get_parameters()

save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)

def set_additional_loss(
self,
loss_fn: Callable[[th.Tensor, th.Tensor], th.Tensor],
name: str,
):
self.has_additional_loss = True
self.additional_loss_func = loss_fn
self.additional_loss_name = name if name.endswith("_loss") else f"{name}_loss"

def remove_additional_loss(self):
self.has_additional_loss = False
self.additional_loss_func = None
self.additional_loss_name = None

def _calculate_additional_loss(self, observations: th.Tensor, logits: th.Tensor) -> th.Tensor:
return self.additional_loss_func(observations, logits) if self.has_additional_loss else th.Tensor(0)
16 changes: 15 additions & 1 deletion stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import gymnasium as gym
import numpy as np

import torch as th
from stable_baselines3.common.logger import Logger
from stable_baselines3.common.type_aliases import ReplayBufferSamples, RolloutBufferSamples

try:
from tqdm import TqdmExperimentalWarning
Expand Down Expand Up @@ -125,6 +126,19 @@ def on_rollout_end(self) -> None:
def _on_rollout_end(self) -> None:
pass

def on_update_loss(
self,
samples: Union[RolloutBufferSamples, ReplayBufferSamples],
) -> th.Tensor:
self.is_rollout_buffer = isinstance(samples, RolloutBufferSamples)
return self._on_update_loss(samples)

def _on_update_loss(
self,
samples: Union[RolloutBufferSamples, ReplayBufferSamples],
) -> th.Tensor:
pass

def update_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables.
Expand Down
1 change: 1 addition & 0 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def collect_rollouts(

callback.on_rollout_end()


return True

def train(self) -> None:
Expand Down
10 changes: 9 additions & 1 deletion stable_baselines3/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def train(self) -> None:
entropy_losses = []
pg_losses, value_losses = [], []
clip_fractions = []
additional_losses = []

continue_training = True
# train for n_epochs epochs
Expand Down Expand Up @@ -252,8 +253,12 @@ def train(self) -> None:
entropy_loss = -th.mean(entropy)

entropy_losses.append(entropy_loss.item())

add_loss = None
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
if self.has_additional_loss:
add_loss = self._calculate_additional_loss(rollout_data.observations, log_prob).mean()
loss += add_loss
additional_losses.append(add_loss.item())

# Calculate approximate form of reverse KL Divergence for early stopping
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
Expand Down Expand Up @@ -299,6 +304,9 @@ def train(self) -> None:
if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf)

if len(additional_losses) > 0:
self.logger.record(f"train/{self.additional_loss_name}", np.mean(additional_losses))

def learn(
self: SelfPPO,
total_timesteps: int,
Expand Down
8 changes: 8 additions & 0 deletions stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:

ent_coef_losses, ent_coefs = [], []
actor_losses, critic_losses = [], []
additional_losses = []

for gradient_step in range(gradient_steps):
# Sample replay buffer
Expand Down Expand Up @@ -270,9 +271,14 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# Compute actor loss
# Alternative: actor_loss = th.mean(log_prob - qf1_pi)
# Min over all critic networks
add_loss = None
q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1)
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
if self.has_additional_loss:
add_loss = self._calculate_additional_loss(replay_data.observations, actions_pi).mean()
actor_loss += add_loss
additional_losses.append(add_loss.item())
actor_losses.append(actor_loss.item())

# Optimize the actor
Expand All @@ -294,6 +300,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
self.logger.record("train/critic_loss", np.mean(critic_losses))
if len(ent_coef_losses) > 0:
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
if len(additional_losses):
self.logger.record(f"train/{self.additional_loss_name}", np.mean(additional_losses))

def learn(
self: SelfSAC,
Expand Down
11 changes: 10 additions & 1 deletion stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
self._update_learning_rate([self.actor.optimizer, self.critic.optimizer])

actor_losses, critic_losses = [], []
additional_losses = []
for _ in range(gradient_steps):
self._n_updates += 1
# Sample replay buffer
Expand Down Expand Up @@ -191,7 +192,13 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Delayed policy updates
if self._n_updates % self.policy_delay == 0:
# Compute actor loss
actor_loss = -self.critic.q1_forward(replay_data.observations, self.actor(replay_data.observations)).mean()
add_loss = None
logits = self.actor(replay_data.observations)
actor_loss = -self.critic.q1_forward(replay_data.observations, logits).mean()
if self.has_additional_loss:
add_loss = self._calculate_additional_loss(replay_data.observations, logits).mean()
actor_loss += add_loss
additional_losses.append(add_loss.item())
actor_losses.append(actor_loss.item())

# Optimize the actor
Expand All @@ -209,6 +216,8 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
if len(actor_losses) > 0:
self.logger.record("train/actor_loss", np.mean(actor_losses))
self.logger.record("train/critic_loss", np.mean(critic_losses))
if len(additional_losses) > 0:
self.logger.record(f"train/{self.additional_loss_name}", np.mean(additional_losses))

def learn(
self: SelfTD3,
Expand Down