diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index fe633e1af..75d506183 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -6,14 +6,18 @@ import torch as th from gymnasium import spaces +from torch_geometric.data import Data, Batch + from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape from stable_baselines3.common.type_aliases import ( DictReplayBufferSamples, DictRolloutBufferSamples, ReplayBufferSamples, RolloutBufferSamples, + GraphRolloutBufferSamples, ) from stable_baselines3.common.utils import get_device +from stable_baselines3.common.vec_env.util import dict_to_obs, graph_copy_obs_dict from stable_baselines3.common.vec_env import VecNormalize try: @@ -439,7 +443,6 @@ def add( # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 action = action.reshape((self.n_envs, self.action_dim)) - self.observations[self.pos] = np.array(obs).copy() self.actions[self.pos] = np.array(action).copy() self.rewards[self.pos] = np.array(reward).copy() @@ -493,6 +496,214 @@ def _get_samples( return RolloutBufferSamples(*tuple(map(self.to_torch, data))) +class GraphRolloutBuffer(BaseBuffer): + """ + Rollout buffer used in on-policy algorithms like A2C/PPO. + It corresponds to ``buffer_size`` transitions collected + using the current policy. + This experience will be discarded after the policy update. + In order to use PPO objective, we also store the current value of each state + and the log probability of each taken action. + + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + Hence, it is only involved in policy and value function training but not action selection. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + observations: np.ndarray + actions: np.ndarray + rewards: np.ndarray + advantages: np.ndarray + returns: np.ndarray + episode_starts: np.ndarray + log_probs: np.ndarray + values: np.ndarray + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + assert isinstance(self.observation_space, spaces.Graph), "Graph buffer" + + self.gae_lambda = gae_lambda + self.gamma = gamma + self.generator_ready = False + self.observations = {"node": {}, "edge_weight": {}, "edge_index": {}} + self.reset() + + def reset(self) -> None: + self.observations = {"node": {}, "edge_weight": {}, "edge_index": {}} # variable size + self.actions = {} + self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.log_probs = {} + self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.generator_ready = False + super().reset() + + def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None: + """ + Post-processing step: compute the lambda-return (TD(lambda) estimate) + and GAE(lambda) advantage. + + Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) + to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S)) + where R is the sum of discounted reward with value bootstrap + (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization. + + The TD(lambda) estimator has also two special cases: + - TD(1) is Monte-Carlo estimate (sum of discounted rewards) + - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1})) + + For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375. + + :param last_values: state value estimation for the last step (one for each env) + :param dones: if the last step was a terminal step (one bool for each env). + """ + # Convert to numpy + last_values = last_values.clone().cpu().numpy().flatten() + + last_gae_lam = 0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = 1.0 - dones + next_values = last_values + else: + next_non_terminal = 1.0 - self.episode_starts[step + 1] + next_values = self.values[step + 1] + delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step] + last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam + self.advantages[step] = last_gae_lam + # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)" + # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA + self.returns = self.advantages + self.values + + def add( + self, + obs: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: th.Tensor, + log_prob: th.Tensor, + ) -> None: + """ + :param obs: Observation + :param action: Action -- assumed to be ndarray by clipping + :param reward: + :param episode_start: Start of episode signal. + :param value: estimated value of the current state + following the current policy. + :param log_prob: log probability of the action + following the current policy. + """ + if isinstance(log_prob, List): + if len(log_prob) == 1: + log_prob = log_prob[0].cpu() + else: + raise NotImplementedError + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space, spaces.Discrete): + obs = obs.reshape((self.n_envs, *self.obs_shape)) + + # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 + if isinstance(action, List): + if len(action) == 1: + action = action[0] + else: + # Probably loop trough and add each entry independently into the buffer + raise NotImplementedError + else: + action = action.reshape((self.n_envs, self.action_dim)) + + assert isinstance(obs, Data) + self.observations["node"][self.pos] = obs.x # should be a pyg Data entry + self.observations["edge_index"][self.pos] = obs.edge_index + self.observations["edge_weight"][self.pos] = obs.w + + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.episode_starts[self.pos] = np.array(episode_start).copy() + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs[self.pos] = log_prob.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + + if not self.generator_ready: + self.observations = dict_to_obs(self.observation_space, self.observations) + assert all([isinstance(k, int) for k in self.actions.keys()]), f"Action not indexed correctly" + self.actions_flat = np.stack([self.actions[i] for i in range(len(self.actions.keys()))]) + self.actions = self.actions_flat + + assert all([isinstance(k, int) for k in self.log_probs.keys()]), f"Action not indexed correctly" + self.log_probs_flat = np.stack([self.log_probs[i] for i in range(len(self.log_probs.keys()))]) + self.log_probs = self.log_probs_flat + + _tensor_names = [ + "values", + "advantages", + "returns", + ] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples( + self, + batch_inds: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME + data = ( + Batch.from_data_list(self.observations[batch_inds]), + self.to_torch(self.actions[batch_inds]), + self.to_torch(self.values[batch_inds].flatten()), + self.to_torch(self.log_probs[batch_inds]), + self.to_torch(self.advantages[batch_inds]), + self.to_torch(self.returns[batch_inds].flatten()), + ) + + return GraphRolloutBufferSamples(*tuple(data)) + + class DictReplayBuffer(ReplayBuffer): """ Dict Replay buffer used in off-policy algorithms like SAC/TD3. @@ -681,8 +892,6 @@ class DictRolloutBuffer(RolloutBuffer): :param n_envs: Number of parallel environments """ - observations: Dict[str, np.ndarray] - def __init__( self, buffer_size: int, @@ -699,7 +908,8 @@ def __init__( self.gae_lambda = gae_lambda self.gamma = gamma - + self.observations, self.actions, self.rewards, self.advantages = None, None, None, None + self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None self.generator_ready = False self.reset() diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 87e192990..76e2e4bba 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -7,7 +7,7 @@ from gymnasium import spaces from stable_baselines3.common.base_class import BaseAlgorithm -from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer +from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer, GraphRolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -100,6 +100,7 @@ def __init__( self.ent_coef = ent_coef self.vf_coef = vf_coef self.max_grad_norm = max_grad_norm + self.episodes = 0 if _init_setup_model: self._setup_model() @@ -107,8 +108,12 @@ def __init__( def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) - - buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer + if isinstance(self.observation_space, spaces.Dict): + buffer_cls = DictRolloutBuffer + elif isinstance(self.observation_space, spaces.Graph): + buffer_cls = GraphRolloutBuffer + else: + buffer_cls = RolloutBuffer self.rollout_buffer = buffer_cls( self.n_steps, @@ -158,6 +163,8 @@ def collect_rollouts( callback.on_rollout_start() + self.episodes = 0 + while n_steps < n_rollout_steps: if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: # Sample a new noise matrix @@ -167,13 +174,32 @@ def collect_rollouts( # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) actions, values, log_probs = self.policy(obs_tensor) - actions = actions.cpu().numpy() + actions, log_probs = actions, log_probs # Rescale and perform action clipped_actions = actions # Clip the actions to avoid out of bound error if isinstance(self.action_space, spaces.Box): - clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + if isinstance(actions, List): + clipped_actions = [ + th.clamp( + a, + min=th.Tensor(self.action_space.low).to(a.device), + max=th.Tensor(self.action_space.high).to(a.device), + ) + for a in actions + ] + actions = [a.cpu().numpy() for a in actions] + else: + clipped_actions = th.clamp( + actions, + min=th.Tensor(self.action_space.low).to(actions.device), + max=th.Tensor(self.action_space.high).to(actions.device), + ) + actions = actions.cpu().numpy() + else: + clipped_actions = actions.cpu().numpy() + actions = actions.cpu().numpy() new_obs, rewards, dones, infos = env.step(clipped_actions) @@ -203,6 +229,8 @@ def collect_rollouts( with th.no_grad(): terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type] rewards[idx] += self.gamma * terminal_value + if done: + self.episodes += 1 rollout_buffer.add( self._last_obs, # type: ignore[arg-type] @@ -273,6 +301,17 @@ def learn( if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + + diff = self.episodes + self.logger.record("result/success", sum([ep_info["success"] for ep_info in self.ep_info_buffer][-diff:])) + self.logger.record("result/failed", sum([ep_info["failed"] for ep_info in self.ep_info_buffer][-diff:])) + self.logger.record( + "result/truncated", sum([ep_info["truncated"] for ep_info in self.ep_info_buffer][-diff:]) + ) + self.logger.record( + "result/terminated", sum([ep_info["terminated"] for ep_info in self.ep_info_buffer][-diff:]) + ) + self.logger.record("time/fps", fps) self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 21d2034d6..531c718aa 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -624,9 +624,13 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso # Evaluate the values for the given observations values = self.value_net(latent_vf) distribution = self._get_action_dist_from_latent(latent_pi) - actions = distribution.get_actions(deterministic=deterministic) - log_prob = distribution.log_prob(actions) - actions = actions.reshape((-1, *self.action_space.shape)) + if isinstance(distribution, List): + actions = [d.get_actions(deterministic=deterministic) for d in distribution] + log_prob = [d.log_prob(a) for d, a in zip(distribution, actions)] + else: + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + actions = actions.reshape((-1, *self.action_space.shape)) return actions, values, log_prob def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: @@ -650,23 +654,36 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution: :param latent_pi: Latent code for the actor :return: Action distribution """ - mean_actions = self.action_net(latent_pi) - - if isinstance(self.action_dist, DiagGaussianDistribution): - return self.action_dist.proba_distribution(mean_actions, self.log_std) - elif isinstance(self.action_dist, CategoricalDistribution): - # Here mean_actions are the logits before the softmax - return self.action_dist.proba_distribution(action_logits=mean_actions) - elif isinstance(self.action_dist, MultiCategoricalDistribution): - # Here mean_actions are the flattened logits - return self.action_dist.proba_distribution(action_logits=mean_actions) - elif isinstance(self.action_dist, BernoulliDistribution): - # Here mean_actions are the logits (before rounding to get the binary actions) - return self.action_dist.proba_distribution(action_logits=mean_actions) - elif isinstance(self.action_dist, StateDependentNoiseDistribution): - return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi) + if isinstance(latent_pi, List): # indicates we're doing one policy to rule them all + mean_actions_graphs = [self.action_net(a_latent_pi) for a_latent_pi in latent_pi] + if isinstance(self.action_dist, DiagGaussianDistribution): + dist = [ + self.action_dist.proba_distribution(mean_actions, self.log_std) for mean_actions in mean_actions_graphs + ] + else: + raise NotImplementedError + return dist else: - raise ValueError("Invalid action distribution") + mean_actions = self.action_net(latent_pi) + + dist = None + if isinstance(self.action_dist, DiagGaussianDistribution): + dist = self.action_dist.proba_distribution(mean_actions, self.log_std) + elif isinstance(self.action_dist, CategoricalDistribution): + # Here mean_actions are the logits before the softmax + dist = self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, MultiCategoricalDistribution): + # Here mean_actions are the flattened logits + dist = self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, BernoulliDistribution): + # Here mean_actions are the logits (before rounding to get the binary actions) + dist = self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, StateDependentNoiseDistribution): + dist = self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi) + else: + raise ValueError("Invalid action distribution") + + return dist def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: """ @@ -697,9 +714,17 @@ def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tenso latent_pi = self.mlp_extractor.forward_actor(pi_features) latent_vf = self.mlp_extractor.forward_critic(vf_features) distribution = self._get_action_dist_from_latent(latent_pi) - log_prob = distribution.log_prob(actions) + if isinstance(distribution, List): + log_prob = [d.log_prob(a) for d, a in zip(distribution, actions)] + log_prob = th.stack(log_prob) + entropy = [d.entropy() for d in distribution] + entropy = th.stack(entropy) + else: + log_prob = distribution.log_prob(actions) + + entropy = distribution.entropy() values = self.value_net(latent_vf) - entropy = distribution.entropy() + return values, log_prob, entropy def get_distribution(self, obs: th.Tensor) -> Distribution: diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index bc0959480..6031b6e07 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -134,7 +134,8 @@ def preprocess_obs( for key, _obs in obs.items(): preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images) return preprocessed_obs - + elif isinstance(observation_space, spaces.Graph): + return obs else: raise NotImplementedError(f"Preprocessing not implemented for {observation_space}") @@ -161,7 +162,8 @@ def get_obs_shape( return observation_space.shape elif isinstance(observation_space, spaces.Dict): return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc] - + elif isinstance(observation_space, spaces.Graph): + return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} else: raise NotImplementedError(f"{observation_space} observation space is not supported") diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index d38d7cf73..2c08a4062 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -7,6 +7,7 @@ import gymnasium as gym import numpy as np import torch as th +from torch_geometric.data import Batch if sys.version_info >= (3, 8): from typing import Protocol @@ -39,6 +40,15 @@ class RolloutBufferSamples(NamedTuple): returns: th.Tensor +class GraphRolloutBufferSamples(NamedTuple): + observations: Batch + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + + class DictRolloutBufferSamples(NamedTuple): observations: TensorDict actions: th.Tensor diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 08366bda1..7e1c2c5f4 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -13,6 +13,7 @@ import numpy as np import torch as th from gymnasium import spaces +from torch_geometric.data import Data import stable_baselines3 as sb3 @@ -482,6 +483,8 @@ def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.devi """ if isinstance(obs, np.ndarray): return th.as_tensor(obs, device=device) + elif isinstance(obs, Data): + return obs.to(device) elif isinstance(obs, dict): return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()} else: diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 2c036373a..3f35fecae 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -3,7 +3,7 @@ from typing import Optional, Type, Union from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper -from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv +from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv, DummyGraphVecEnv from stable_baselines3.common.vec_env.stacked_observations import StackedObservations from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 822025f53..0242629b9 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -8,7 +8,7 @@ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn from stable_baselines3.common.vec_env.patch_gym import _patch_env -from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info +from stable_baselines3.common.vec_env.util import graph_copy_obs_dict, copy_obs_dict, dict_to_obs, obs_space_info class DummyVecEnv(VecEnv): @@ -24,8 +24,6 @@ class DummyVecEnv(VecEnv): :raises ValueError: If the same environment instance is passed as the output of two or more different env_fn. """ - actions: np.ndarray - def __init__(self, env_fns: List[Callable[[], gym.Env]]): self.envs = [_patch_env(fn()) for fn in env_fns] if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs): @@ -40,13 +38,14 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]]): ) env = self.envs[0] VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space, env.render_mode) - obs_space = env.observation_space - self.keys, shapes, dtypes = obs_space_info(obs_space) + self.obs_space = env.observation_space + self.keys, shapes, dtypes = obs_space_info(self.obs_space) self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys]) self.buf_dones = np.zeros((self.num_envs,), dtype=bool) self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) self.buf_infos: List[Dict[str, Any]] = [{} for _ in range(self.num_envs)] + self.actions = None self.metadata = env.metadata def step_async(self, actions: np.ndarray) -> None: @@ -146,3 +145,20 @@ def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndice def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]: indices = self._get_indices(indices) return [self.envs[i] for i in indices] + + +class DummyGraphVecEnv(DummyVecEnv): + def __init__(self, env_fns: List[Callable[[], gym.Env]]): + super().__init__(env_fns) + assert isinstance(self.obs_space, gym.spaces.Graph) + self.buf_obs = OrderedDict([(k, {}) for k in self.keys]) + + def _obs_from_buf(self) -> VecEnvObs: + assert isinstance(self.observation_space, gym.spaces.Graph) + return dict_to_obs(self.observation_space, graph_copy_obs_dict(self.buf_obs)) + + def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: + assert isinstance(self.envs[env_idx].observation_space, gym.spaces.Graph) + self.buf_obs["node"][env_idx] = obs.x + self.buf_obs["edge_weight"][env_idx] = obs.w + self.buf_obs["edge_index"][env_idx] = obs.edge_index diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 2a03d8e70..241e67a33 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -3,6 +3,7 @@ """ from collections import OrderedDict from typing import Any, Dict, List, Tuple +from torch_geometric.data import Data, Batch import numpy as np from gymnasium import spaces @@ -22,6 +23,17 @@ def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) +def graph_copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Deep-copy a dict of numpy arrays. + + :param obs: a dict of numpy arrays. + :return: a dict of copied numpy arrays. + """ + assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'" + return OrderedDict([(k, {i: env.clone() for i, env in v.items()}) for k, v in obs.items()]) + + def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: """ Convert an internal representation raw_obs into the appropriate type @@ -38,6 +50,12 @@ def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> Vec elif isinstance(obs_space, spaces.Tuple): assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space" return tuple(obs_dict[i] for i in range(len(obs_space.spaces))) + elif isinstance(obs_space, spaces.Graph): + indexes = obs_dict["node"].keys() + list_of_graphs = [ + Data(x=obs_dict["node"][i], edge_index=obs_dict["edge_index"][i], w=obs_dict["edge_weight"][i]) for i in indexes + ] + return Batch.from_data_list(list_of_graphs) else: assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" return obs_dict[None] @@ -63,6 +81,8 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[ subspaces = obs_space.spaces elif isinstance(obs_space, spaces.Tuple): subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment] + elif isinstance(obs_space, spaces.Graph): + subspaces = obs_space.spaces else: assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" subspaces = {None: obs_space} # type: ignore[assignment]