From 1f6772fe277ad2e675e75beb5814077a29926d98 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 18 Feb 2026 13:01:26 +0000 Subject: [PATCH 1/2] [BugFix] Fix shape mismatch in _set_index_in_td with trailing dims of 1 Reverse the loop direction in _set_index_in_td to iterate from the highest dim downward, so that dimensions of size 1 don't cause premature numel() matches when reshaping the index tensor. Fixes #3515 Co-authored-by: Cursor --- test/test_libs.py | 29 +++++++++++++++++++ torchrl/data/replay_buffers/replay_buffers.py | 2 +- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/test/test_libs.py b/test/test_libs.py index 26c653dbb2b..d4a66357147 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -64,6 +64,7 @@ OneHot, ReplayBuffer, ReplayBufferEnsemble, + TensorDictReplayBuffer, Unbounded, UnboundedDiscrete, ) @@ -5128,6 +5129,34 @@ def test_collector(self, task, parallel): for _ in collector: break + def test_single_agent_group_replay_buffer(self): + """Regression test for gh#3515 - shape mismatch with single-agent group.""" + env = PettingZooEnv( + task="simple_v3", + parallel=True, + seed=0, + use_mask=False, + ) + group = list(env.group_map.keys())[0] + assert len(env.group_map[group]) == 1 + + rollout = env.rollout(10) + T = rollout.shape[0] + n_agents = 1 + + # Reshape to (1, T, n_agents) to reproduce the scenario from gh#3515 + # where a replay buffer Transform reshapes collector output to + # (n_envs, traj_len, n_agents). When n_agents=1 the trailing dim of 1 + # caused _set_index_in_td to match the wrong number of batch dims. + td = rollout.unsqueeze(0).unsqueeze(-1) + assert td.shape == torch.Size([1, T, n_agents]) + + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(10_000, ndim=3), + batch_size=4, + ) + rb.extend(td) + @pytest.mark.skipif(not _has_robohive, reason="RoboHive not found") class TestRoboHive: diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 8d2e15db305..d6b143f3df9 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -1657,7 +1657,7 @@ def _set_index_in_td(self, tensordict, index): if _is_int(index): index = torch.as_tensor(index, device=tensordict.device) elif index.ndim == 2 and index.shape[:1] != tensordict.shape[:1]: - for dim in range(2, tensordict.ndim + 1): + for dim in range(tensordict.ndim, 1, -1): if index.shape[:1].numel() == tensordict.shape[:dim].numel(): # if index has 2 dims and is in a non-zero format index = index.unflatten(0, tensordict.shape[:dim]) From 4234329c0d279ddf66bfcd05f64d3a8ad0e9023e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 18 Feb 2026 14:15:23 +0000 Subject: [PATCH 2/2] [BugFix] Fix expand_as shape mismatch in StepCounter MARL propagation In _propagate_to_nested_keys, the parent (root-level) truncated/done tensor has fewer dimensions than the nested (agent-level) tensor. expand_as fails because it aligns from the right, mismatching batch dims with agent dims. Fix by unsqueezing extra dims before expanding. Co-authored-by: Cursor --- torchrl/envs/transforms/transforms.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 8a2c191a598..20bc2f5a471 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -7920,8 +7920,12 @@ def _propagate_to_nested_keys(self, next_tensordict: TensorDictBase) -> None: for parent_key in self.truncated_keys: parent_truncated = next_tensordict.get(parent_key, None) if parent_truncated is not None: - # Expand parent truncated to match nested shape and apply OR - expanded = parent_truncated.expand_as(nested_truncated) + # Insert extra dims (e.g. agent dims) so the parent is + # broadcastable to the nested agent-level shape. + parent_val = parent_truncated + while parent_val.ndim < nested_truncated.ndim: + parent_val = parent_val.unsqueeze(-2) + expanded = parent_val.expand_as(nested_truncated) next_tensordict.set(nested_key, nested_truncated | expanded) break @@ -7937,8 +7941,10 @@ def _propagate_to_nested_keys(self, next_tensordict: TensorDictBase) -> None: for parent_key in self.done_keys: parent_done = next_tensordict.get(parent_key, None) if parent_done is not None: - # Expand parent done to match nested shape and apply OR - expanded = parent_done.expand_as(nested_done) + parent_val = parent_done + while parent_val.ndim < nested_done.ndim: + parent_val = parent_val.unsqueeze(-2) + expanded = parent_val.expand_as(nested_done) next_tensordict.set(nested_key, nested_done | expanded) break