Skip to content

Commit b0dd669

Browse files
committed
batch multinomial
1 parent 40827e0 commit b0dd669

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

pufferlib/cleanrl.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,15 @@ def sample_logits(logits: Union[torch.Tensor, List[torch.Tensor]],
4141

4242

4343
if action is None:
44-
action = torch.stack([torch.multinomial(logits_to_probs(l), 1).squeeze() for l in logits])
44+
probs = logits_to_probs(
45+
torch.nn.utils.rnn.pad_sequence(
46+
[l.transpose(0,1) for l in logits],
47+
batch_first=False,
48+
padding_value=-torch.inf
49+
).permute(1,2,0)
50+
)
51+
action = torch.multinomial(probs.reshape(-1, probs.shape[-1]), 1)
52+
action = action.reshape(probs.shape[:-1])
4553
else:
4654
batch = logits[0].shape[0]
4755
action = action.view(batch, -1).T

0 commit comments

Comments
 (0)