Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pyg #1492

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
218 changes: 214 additions & 4 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
import torch as th
from gymnasium import spaces

from torch_geometric.data import Data, Batch

from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
from stable_baselines3.common.type_aliases import (
DictReplayBufferSamples,
DictRolloutBufferSamples,
ReplayBufferSamples,
RolloutBufferSamples,
GraphRolloutBufferSamples,
)
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env.util import dict_to_obs, graph_copy_obs_dict
from stable_baselines3.common.vec_env import VecNormalize

try:
Expand Down Expand Up @@ -439,7 +443,6 @@ def add(

# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

self.observations[self.pos] = np.array(obs).copy()
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
Expand Down Expand Up @@ -493,6 +496,214 @@ def _get_samples(
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))


class GraphRolloutBuffer(BaseBuffer):
"""
Rollout buffer used in on-policy algorithms like A2C/PPO.
It corresponds to ``buffer_size`` transitions collected
using the current policy.
This experience will be discarded after the policy update.
In order to use PPO objective, we also store the current value of each state
and the log probability of each taken action.

The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
Hence, it is only involved in policy and value function training but not action selection.

:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""

observations: np.ndarray
actions: np.ndarray
rewards: np.ndarray
advantages: np.ndarray
returns: np.ndarray
episode_starts: np.ndarray
log_probs: np.ndarray
values: np.ndarray

def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
assert isinstance(self.observation_space, spaces.Graph), "Graph buffer"

self.gae_lambda = gae_lambda
self.gamma = gamma
self.generator_ready = False
self.observations = {"node": {}, "edge_weight": {}, "edge_index": {}}
self.reset()

def reset(self) -> None:
self.observations = {"node": {}, "edge_weight": {}, "edge_index": {}} # variable size
self.actions = {}
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = {}
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
super().reset()

def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
"""
Post-processing step: compute the lambda-return (TD(lambda) estimate)
and GAE(lambda) advantage.

Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
where R is the sum of discounted reward with value bootstrap
(because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.

The TD(lambda) estimator has also two special cases:
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))

For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.

:param last_values: state value estimation for the last step (one for each env)
:param dones: if the last step was a terminal step (one bool for each env).
"""
# Convert to numpy
last_values = last_values.clone().cpu().numpy().flatten()

last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = 1.0 - dones
next_values = last_values
else:
next_non_terminal = 1.0 - self.episode_starts[step + 1]
next_values = self.values[step + 1]
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
self.returns = self.advantages + self.values

def add(
self,
obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action -- assumed to be ndarray by clipping
:param reward:
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
following the current policy.
"""
if isinstance(log_prob, List):
if len(log_prob) == 1:
log_prob = log_prob[0].cpu()
else:
raise NotImplementedError
if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)

# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs, *self.obs_shape))

# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
if isinstance(action, List):
if len(action) == 1:
action = action[0]
else:
# Probably loop trough and add each entry independently into the buffer
raise NotImplementedError
else:
action = action.reshape((self.n_envs, self.action_dim))

assert isinstance(obs, Data)
self.observations["node"][self.pos] = obs.x # should be a pyg Data entry
self.observations["edge_index"][self.pos] = obs.edge_index
self.observations["edge_weight"][self.pos] = obs.w

self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.episode_starts[self.pos] = np.array(episode_start).copy()
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True

def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data

if not self.generator_ready:
self.observations = dict_to_obs(self.observation_space, self.observations)
assert all([isinstance(k, int) for k in self.actions.keys()]), f"Action not indexed correctly"
self.actions_flat = np.stack([self.actions[i] for i in range(len(self.actions.keys()))])
self.actions = self.actions_flat

assert all([isinstance(k, int) for k in self.log_probs.keys()]), f"Action not indexed correctly"
self.log_probs_flat = np.stack([self.log_probs[i] for i in range(len(self.log_probs.keys()))])
self.log_probs = self.log_probs_flat

_tensor_names = [
"values",
"advantages",
"returns",
]

for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True

# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs

start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size

def _get_samples(
self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
data = (
Batch.from_data_list(self.observations[batch_inds]),
self.to_torch(self.actions[batch_inds]),
self.to_torch(self.values[batch_inds].flatten()),
self.to_torch(self.log_probs[batch_inds]),
self.to_torch(self.advantages[batch_inds]),
self.to_torch(self.returns[batch_inds].flatten()),
)

return GraphRolloutBufferSamples(*tuple(data))


class DictReplayBuffer(ReplayBuffer):
"""
Dict Replay buffer used in off-policy algorithms like SAC/TD3.
Expand Down Expand Up @@ -681,8 +892,6 @@ class DictRolloutBuffer(RolloutBuffer):
:param n_envs: Number of parallel environments
"""

observations: Dict[str, np.ndarray]

def __init__(
self,
buffer_size: int,
Expand All @@ -699,7 +908,8 @@ def __init__(

self.gae_lambda = gae_lambda
self.gamma = gamma

self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
self.generator_ready = False
self.reset()

Expand Down
49 changes: 44 additions & 5 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from gymnasium import spaces

from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer, GraphRolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
Expand Down Expand Up @@ -100,15 +100,20 @@ def __init__(
self.ent_coef = ent_coef
self.vf_coef = vf_coef
self.max_grad_norm = max_grad_norm
self.episodes = 0

if _init_setup_model:
self._setup_model()

def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)

buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer
if isinstance(self.observation_space, spaces.Dict):
buffer_cls = DictRolloutBuffer
elif isinstance(self.observation_space, spaces.Graph):
buffer_cls = GraphRolloutBuffer
else:
buffer_cls = RolloutBuffer

self.rollout_buffer = buffer_cls(
self.n_steps,
Expand Down Expand Up @@ -158,6 +163,8 @@ def collect_rollouts(

callback.on_rollout_start()

self.episodes = 0

while n_steps < n_rollout_steps:
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
Expand All @@ -167,13 +174,32 @@ def collect_rollouts(
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
actions, values, log_probs = self.policy(obs_tensor)
actions = actions.cpu().numpy()
actions, log_probs = actions, log_probs

# Rescale and perform action
clipped_actions = actions
# Clip the actions to avoid out of bound error
if isinstance(self.action_space, spaces.Box):
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
if isinstance(actions, List):
clipped_actions = [
th.clamp(
a,
min=th.Tensor(self.action_space.low).to(a.device),
max=th.Tensor(self.action_space.high).to(a.device),
)
for a in actions
]
actions = [a.cpu().numpy() for a in actions]
else:
clipped_actions = th.clamp(
actions,
min=th.Tensor(self.action_space.low).to(actions.device),
max=th.Tensor(self.action_space.high).to(actions.device),
)
actions = actions.cpu().numpy()
else:
clipped_actions = actions.cpu().numpy()
actions = actions.cpu().numpy()

new_obs, rewards, dones, infos = env.step(clipped_actions)

Expand Down Expand Up @@ -203,6 +229,8 @@ def collect_rollouts(
with th.no_grad():
terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type]
rewards[idx] += self.gamma * terminal_value
if done:
self.episodes += 1

rollout_buffer.add(
self._last_obs, # type: ignore[arg-type]
Expand Down Expand Up @@ -273,6 +301,17 @@ def learn(
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))

diff = self.episodes
self.logger.record("result/success", sum([ep_info["success"] for ep_info in self.ep_info_buffer][-diff:]))
self.logger.record("result/failed", sum([ep_info["failed"] for ep_info in self.ep_info_buffer][-diff:]))
self.logger.record(
"result/truncated", sum([ep_info["truncated"] for ep_info in self.ep_info_buffer][-diff:])
)
self.logger.record(
"result/terminated", sum([ep_info["terminated"] for ep_info in self.ep_info_buffer][-diff:])
)

self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
Expand Down
Loading