Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
OneHot,
ReplayBuffer,
ReplayBufferEnsemble,
TensorDictReplayBuffer,
Unbounded,
UnboundedDiscrete,
)
Expand Down Expand Up @@ -5128,6 +5129,34 @@ def test_collector(self, task, parallel):
for _ in collector:
break

def test_single_agent_group_replay_buffer(self):
"""Regression test for gh#3515 - shape mismatch with single-agent group."""
env = PettingZooEnv(
task="simple_v3",
parallel=True,
seed=0,
use_mask=False,
)
group = list(env.group_map.keys())[0]
assert len(env.group_map[group]) == 1

rollout = env.rollout(10)
T = rollout.shape[0]
n_agents = 1

# Reshape to (1, T, n_agents) to reproduce the scenario from gh#3515
# where a replay buffer Transform reshapes collector output to
# (n_envs, traj_len, n_agents). When n_agents=1 the trailing dim of 1
# caused _set_index_in_td to match the wrong number of batch dims.
td = rollout.unsqueeze(0).unsqueeze(-1)
assert td.shape == torch.Size([1, T, n_agents])

rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(10_000, ndim=3),
batch_size=4,
)
rb.extend(td)


@pytest.mark.skipif(not _has_robohive, reason="RoboHive not found")
class TestRoboHive:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1657,7 +1657,7 @@ def _set_index_in_td(self, tensordict, index):
if _is_int(index):
index = torch.as_tensor(index, device=tensordict.device)
elif index.ndim == 2 and index.shape[:1] != tensordict.shape[:1]:
for dim in range(2, tensordict.ndim + 1):
for dim in range(tensordict.ndim, 1, -1):
if index.shape[:1].numel() == tensordict.shape[:dim].numel():
# if index has 2 dims and is in a non-zero format
index = index.unflatten(0, tensordict.shape[:dim])
Expand Down
14 changes: 10 additions & 4 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7920,8 +7920,12 @@ def _propagate_to_nested_keys(self, next_tensordict: TensorDictBase) -> None:
for parent_key in self.truncated_keys:
parent_truncated = next_tensordict.get(parent_key, None)
if parent_truncated is not None:
# Expand parent truncated to match nested shape and apply OR
expanded = parent_truncated.expand_as(nested_truncated)
# Insert extra dims (e.g. agent dims) so the parent is
# broadcastable to the nested agent-level shape.
parent_val = parent_truncated
while parent_val.ndim < nested_truncated.ndim:
parent_val = parent_val.unsqueeze(-2)
expanded = parent_val.expand_as(nested_truncated)
next_tensordict.set(nested_key, nested_truncated | expanded)
break

Expand All @@ -7937,8 +7941,10 @@ def _propagate_to_nested_keys(self, next_tensordict: TensorDictBase) -> None:
for parent_key in self.done_keys:
parent_done = next_tensordict.get(parent_key, None)
if parent_done is not None:
# Expand parent done to match nested shape and apply OR
expanded = parent_done.expand_as(nested_done)
parent_val = parent_done
while parent_val.ndim < nested_done.ndim:
parent_val = parent_val.unsqueeze(-2)
expanded = parent_val.expand_as(nested_done)
next_tensordict.set(nested_key, nested_done | expanded)
break

Expand Down
Loading