Skip to content

[BUG] Shape mismatch in Transform with single-agent environments (Knights Archers Zombies) #3515

@CatherineRizk

Description

@CatherineRizk

Describe the bug

I am using TorchRL to train agents in the Knights Archers Zombies (PettingZoo) environment. I've started with a very minimalistic implementation using only archers, it works perfectly for any number of archers (N) > 1, but fails consistently with a single archer.

The issue stems from how data is reshaped within my replay buffer Transform class I’ve traced the error to a specific section of the TorchRL source code responsible for reshaping TensorDicts. The current logic seems to handle multi-agent batches correctly but collapses or misinterprets dimensions when a single agent is present. Could you clarify the intended design for reshaping dimensions in this context? I have detailed the specific line of code and the resulting error below.

To Reproduce

Steps to reproduce the behavior.

from torchrl.modules import ProbabilisticActor, MLP
from core.transform_data import TransformRB
from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage, RandomSampler
from torchrl.collectors import SyncDataCollector
from torchrl.envs import PettingZooWrapper, Transform, TransformedEnv, DTypeCastTransform
from pettingzoo.butterfly import knights_archers_zombies_v10
from tensordict.nn import TensorDictModule
import torch
import math
    
def _create_env(**kwargs):

    env = knights_archers_zombies_v10.parallel_env(**{
        key: kwargs[key] for key in [
            "render_mode", "num_archers", "num_knights", "max_zombies",
            "spawn_rate", "max_arrows", "vector_state", "use_typemasks"
        ]
    })
    group_map = {}
    if kwargs.get("num_knights", 0) > 0:
        group_map["knights"] = []
    if kwargs.get("num_archers", 0) > 0:
        group_map["archers"] = [] 
    
    for agent in env.possible_agents:
        type, id = agent.split("_")
        if type == "archer":
            group_map["archers"].append(agent)
        elif type == "knight":
            group_map["knights"].append(agent)
        else:
            ValueError("Wrong agent type")

    env = PettingZooWrapper(
        env,
        return_state=kwargs.get("return_state"),
        group_map=group_map,
        use_mask=kwargs.get("use_mask"),
        categorical_actions=kwargs.get("categorical_actions"),
        done_on_any=kwargs.get("done_on_any"),
    )
    return TransformedEnv(env,DTypeCastTransform(torch.float64, torch.float32))

class ActorModel(torch.nn.Module):
    def __init__(self, input_shape, output_dim, device, hidden_dim=256, depth=3):
        super().__init__()
        self.flatten = torch.nn.Flatten(-2)
        self.mlp = MLP(
            in_features=input_shape,
            out_features=output_dim,
            num_cells=hidden_dim,
            depth=depth,
            device=device
        )

    def forward(self, x):
        x = self.flatten(x)
        return self.mlp(x)

class TransformRB(Transform):
    def __init__(self, agent_group_name):
        super().__init__()
        self.agent_name = agent_group_name
    
    def _inv_call(self, tensordict):
        tensordict = super()._inv_call(tensordict)

        shaping = tensordict[self.agent_name].shape
        nb_env = 1
        traj_len = shaping[-2]
        nb_agent = shaping[-1]


        td_agents, td_rest = tensordict.split_keys([self.agent_name, ("next", self.agent_name)])
        td_new = td_rest.unsqueeze(-1).expand(nb_env, traj_len, nb_agent)
        td_agents= td_agents.unsqueeze(0)
        td_new.update(td_agents)
        return td_new
    
if __name__ == "__main__":

    ENV_SPAWN_RATE = 10
    ENV_NUM_ARCHERS = 1
    ENV_NUM_KNIGHTS = 0
    ENV_MAX_ZOMBIES = 10
    ENV_MAX_ARROWS = 9

    AC_HIDDEN_DIM = 256
    AC_DEPTH = 2

    RB_CAPACITY = 2e6
    RB_MINIBATCH_SIZE = 64
    C_FRAMES_PER_BATCH = 2048
    
    default_env_args = {
        "render_mode": None,
        "vector_state": True, 
        "use_typemasks": True, 
        "return_state": True, 
        "use_mask": True, 
        "categorical_actions": True, 
        "done_on_any": False,
    }

    env = _create_env(num_archers = ENV_NUM_ARCHERS, num_knights = ENV_NUM_KNIGHTS, max_zombies = ENV_MAX_ZOMBIES, spawn_rate = ENV_SPAWN_RATE, max_arrows = ENV_MAX_ARROWS, **default_env_args)

    group = "archers"

    actor = ProbabilisticActor(
        TensorDictModule(ActorModel(
            input_shape=math.prod(env.observation_spec[group]["observation"].shape[-2:]),
            output_dim=env.action_spec[group]["action"].space.n,
            device="cuda",
            hidden_dim=AC_HIDDEN_DIM,
            depth=AC_DEPTH
            ), 
        in_keys=[(group, "observation")],
        out_keys=[(group, "logits")]
        ),
        in_keys=[(group, "logits")],
        out_keys=[(group, "action")],
        distribution_class=torch.distributions.Categorical
    )

    replay_buffer = TensorDictReplayBuffer(
        storage = LazyTensorStorage(RB_CAPACITY, ndim=3),
        sampler = RandomSampler(),
        batch_size = RB_MINIBATCH_SIZE,
        priority_key = "td_error",
        pin_memory = True,
        prefetch = 2,
        transform = TransformRB(group)
    )
    replay_buffer.append_transform(lambda x: x.cuda(non_blocking=True))

    collector = SyncDataCollector(
        create_env_fn = env,
        policy = actor,
        frames_per_batch = C_FRAMES_PER_BATCH,
        total_frames = -1,
        device = "cpu",
        storing_device = "cpu",
        env_device = "cpu",
        policy_device = "cuda"
    )

    for i, collected in enumerate(collector): 
        replay_buffer.extend(collected)
Traceback (most recent call last):
  File "debug.py", line 148, in <module>
    replay_buffer.extend(collected)
  File ".pixi/envs/petzoo/lib/python3.12/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 87, in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".pixi/envs/petzoo/lib/python3.12/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 1640, in extend
    self._set_index_in_td(tensordicts, index)
  File ".pixi/envs/petzoo/lib/python3.12/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 1669, in _set_index_in_td
    tensordict.set("index", index)
  File ".pixi/envs/petzoo/lib/python3.12/site-packages/tensordict/base.py", line 7464, in set
    return self._set_tuple(
           ^^^^^^^^^^^^^^^^
  File ".pixi/envs/petzoo/lib/python3.12/site-packages/tensordict/_td.py", line 2537, in _set_tuple
    return self._set_str(
           ^^^^^^^^^^^^^^
  File ".pixi/envs/petzoo/lib/python3.12/site-packages/tensordict/_td.py", line 2480, in _set_str
    value = self._validate_value(
            ^^^^^^^^^^^^^^^^^^^^^
  File ".pixi/envs/petzoo/lib/python3.12/site-packages/tensordict/base.py", line 13022, in _validate_value_generic
    raise RuntimeError(
RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([1, 2048, 1]) and value.shape=torch.Size([1, 2048, 3]).

Reason and Possible fixes

The issue occurs in the .extend method of the replay_buffer when _set_index_in_td() is called to determine how to reshape the index tensor to match the tensordict batch size.The logic relies on comparing the .numel() of the index shape with the .numel() of the tensordict batch dimensions. While this works for N > 1 agents, it fails when N = 1 because the check returns True prematurely, leading to a missing dimension in the final index shape.

Traceback Analysis

The TransformRB class was implemented to reshape the tensordict in the shape : (number_envs, number_batches, number_agents).
When N = 1, the loop hits a "false positive" match at dim = 2 because the product of the dimensions aligns too early:
For a batch of 2048 with one env, here is what happen in the if index.shape[:1].numel() == tensordict.shape[:dim].numel(): loop:

Case N=3 (Correct):index.shape[:1].numel() = 6144
At dim=2, tensordict.shape[:2].numel() = 1 * 2048 = 2048 (False)
At dim=3, tensordict.shape[:3].numel() = 2048* 3 = 6144 (True) -> reshapes index in size (1, 2048, 3, 3)

Case N=1 (Bug):index.shape[:1].numel() = 2048
At dim=2, tensordict.shape[:2].numel() = 1*2048 = 2048 : (True : Incorrect early exit) -> The index is reshaped to (1, 2048, 3) instead of (1, 2048, 1, 3), causing a failure in _validate_value_generic() because the shape lacks the agent dimension.

Is the reliance on .numel() intended to support flexible batch nesting? If so, how should we handle cases where the agent dimension is 1, causing the product of leading dimensions to overlap with the index size?

Thank you for your help.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions