-
Notifications
You must be signed in to change notification settings - Fork 438
Description
Describe the bug
I am using TorchRL to train agents in the Knights Archers Zombies (PettingZoo) environment. I've started with a very minimalistic implementation using only archers, it works perfectly for any number of archers (N) > 1, but fails consistently with a single archer.
The issue stems from how data is reshaped within my replay buffer Transform class I’ve traced the error to a specific section of the TorchRL source code responsible for reshaping TensorDicts. The current logic seems to handle multi-agent batches correctly but collapses or misinterprets dimensions when a single agent is present. Could you clarify the intended design for reshaping dimensions in this context? I have detailed the specific line of code and the resulting error below.
To Reproduce
Steps to reproduce the behavior.
from torchrl.modules import ProbabilisticActor, MLP
from core.transform_data import TransformRB
from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage, RandomSampler
from torchrl.collectors import SyncDataCollector
from torchrl.envs import PettingZooWrapper, Transform, TransformedEnv, DTypeCastTransform
from pettingzoo.butterfly import knights_archers_zombies_v10
from tensordict.nn import TensorDictModule
import torch
import math
def _create_env(**kwargs):
env = knights_archers_zombies_v10.parallel_env(**{
key: kwargs[key] for key in [
"render_mode", "num_archers", "num_knights", "max_zombies",
"spawn_rate", "max_arrows", "vector_state", "use_typemasks"
]
})
group_map = {}
if kwargs.get("num_knights", 0) > 0:
group_map["knights"] = []
if kwargs.get("num_archers", 0) > 0:
group_map["archers"] = []
for agent in env.possible_agents:
type, id = agent.split("_")
if type == "archer":
group_map["archers"].append(agent)
elif type == "knight":
group_map["knights"].append(agent)
else:
ValueError("Wrong agent type")
env = PettingZooWrapper(
env,
return_state=kwargs.get("return_state"),
group_map=group_map,
use_mask=kwargs.get("use_mask"),
categorical_actions=kwargs.get("categorical_actions"),
done_on_any=kwargs.get("done_on_any"),
)
return TransformedEnv(env,DTypeCastTransform(torch.float64, torch.float32))
class ActorModel(torch.nn.Module):
def __init__(self, input_shape, output_dim, device, hidden_dim=256, depth=3):
super().__init__()
self.flatten = torch.nn.Flatten(-2)
self.mlp = MLP(
in_features=input_shape,
out_features=output_dim,
num_cells=hidden_dim,
depth=depth,
device=device
)
def forward(self, x):
x = self.flatten(x)
return self.mlp(x)
class TransformRB(Transform):
def __init__(self, agent_group_name):
super().__init__()
self.agent_name = agent_group_name
def _inv_call(self, tensordict):
tensordict = super()._inv_call(tensordict)
shaping = tensordict[self.agent_name].shape
nb_env = 1
traj_len = shaping[-2]
nb_agent = shaping[-1]
td_agents, td_rest = tensordict.split_keys([self.agent_name, ("next", self.agent_name)])
td_new = td_rest.unsqueeze(-1).expand(nb_env, traj_len, nb_agent)
td_agents= td_agents.unsqueeze(0)
td_new.update(td_agents)
return td_new
if __name__ == "__main__":
ENV_SPAWN_RATE = 10
ENV_NUM_ARCHERS = 1
ENV_NUM_KNIGHTS = 0
ENV_MAX_ZOMBIES = 10
ENV_MAX_ARROWS = 9
AC_HIDDEN_DIM = 256
AC_DEPTH = 2
RB_CAPACITY = 2e6
RB_MINIBATCH_SIZE = 64
C_FRAMES_PER_BATCH = 2048
default_env_args = {
"render_mode": None,
"vector_state": True,
"use_typemasks": True,
"return_state": True,
"use_mask": True,
"categorical_actions": True,
"done_on_any": False,
}
env = _create_env(num_archers = ENV_NUM_ARCHERS, num_knights = ENV_NUM_KNIGHTS, max_zombies = ENV_MAX_ZOMBIES, spawn_rate = ENV_SPAWN_RATE, max_arrows = ENV_MAX_ARROWS, **default_env_args)
group = "archers"
actor = ProbabilisticActor(
TensorDictModule(ActorModel(
input_shape=math.prod(env.observation_spec[group]["observation"].shape[-2:]),
output_dim=env.action_spec[group]["action"].space.n,
device="cuda",
hidden_dim=AC_HIDDEN_DIM,
depth=AC_DEPTH
),
in_keys=[(group, "observation")],
out_keys=[(group, "logits")]
),
in_keys=[(group, "logits")],
out_keys=[(group, "action")],
distribution_class=torch.distributions.Categorical
)
replay_buffer = TensorDictReplayBuffer(
storage = LazyTensorStorage(RB_CAPACITY, ndim=3),
sampler = RandomSampler(),
batch_size = RB_MINIBATCH_SIZE,
priority_key = "td_error",
pin_memory = True,
prefetch = 2,
transform = TransformRB(group)
)
replay_buffer.append_transform(lambda x: x.cuda(non_blocking=True))
collector = SyncDataCollector(
create_env_fn = env,
policy = actor,
frames_per_batch = C_FRAMES_PER_BATCH,
total_frames = -1,
device = "cpu",
storing_device = "cpu",
env_device = "cpu",
policy_device = "cuda"
)
for i, collected in enumerate(collector):
replay_buffer.extend(collected)Traceback (most recent call last):
File "debug.py", line 148, in <module>
replay_buffer.extend(collected)
File ".pixi/envs/petzoo/lib/python3.12/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 87, in wrapper
return func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".pixi/envs/petzoo/lib/python3.12/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 1640, in extend
self._set_index_in_td(tensordicts, index)
File ".pixi/envs/petzoo/lib/python3.12/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 1669, in _set_index_in_td
tensordict.set("index", index)
File ".pixi/envs/petzoo/lib/python3.12/site-packages/tensordict/base.py", line 7464, in set
return self._set_tuple(
^^^^^^^^^^^^^^^^
File ".pixi/envs/petzoo/lib/python3.12/site-packages/tensordict/_td.py", line 2537, in _set_tuple
return self._set_str(
^^^^^^^^^^^^^^
File ".pixi/envs/petzoo/lib/python3.12/site-packages/tensordict/_td.py", line 2480, in _set_str
value = self._validate_value(
^^^^^^^^^^^^^^^^^^^^^
File ".pixi/envs/petzoo/lib/python3.12/site-packages/tensordict/base.py", line 13022, in _validate_value_generic
raise RuntimeError(
RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([1, 2048, 1]) and value.shape=torch.Size([1, 2048, 3]).Reason and Possible fixes
The issue occurs in the .extend method of the replay_buffer when _set_index_in_td() is called to determine how to reshape the index tensor to match the tensordict batch size.The logic relies on comparing the .numel() of the index shape with the .numel() of the tensordict batch dimensions. While this works for N > 1 agents, it fails when N = 1 because the check returns True prematurely, leading to a missing dimension in the final index shape.
Traceback Analysis
The TransformRB class was implemented to reshape the tensordict in the shape : (number_envs, number_batches, number_agents).
When N = 1, the loop hits a "false positive" match at dim = 2 because the product of the dimensions aligns too early:
For a batch of 2048 with one env, here is what happen in the if index.shape[:1].numel() == tensordict.shape[:dim].numel(): loop:
Case N=3 (Correct):index.shape[:1].numel() = 6144
At dim=2, tensordict.shape[:2].numel() = 1 * 2048 = 2048 (False)
At dim=3, tensordict.shape[:3].numel() = 2048* 3 = 6144 (True) -> reshapes index in size (1, 2048, 3, 3)
Case N=1 (Bug):index.shape[:1].numel() = 2048
At dim=2, tensordict.shape[:2].numel() = 1*2048 = 2048 : (True : Incorrect early exit) -> The index is reshaped to (1, 2048, 3) instead of (1, 2048, 1, 3), causing a failure in _validate_value_generic() because the shape lacks the agent dimension.
Is the reliance on .numel() intended to support flexible batch nesting? If so, how should we handle cases where the agent dimension is 1, causing the product of leading dimensions to overlap with the index size?
Thank you for your help.