diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index fb17ee1661..b59257513f 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -220,7 +220,9 @@ def backward( seq_size = int(vocab_parallel_logits.shape[1]) num_chunks = (seq_size + chunk_size - 1) // chunk_size - all_grad_input = [] + grad_input: torch.Tensor = torch.empty_like( + vocab_parallel_logits, dtype=torch.float32 + ) for chunk_idx in range(num_chunks): chunk_start = chunk_idx * chunk_size @@ -243,13 +245,18 @@ def backward( num_classes=partition_vocab_size, ) - grad_input = is_chosen.float().sub_(softmax_output) - - grad_input.mul_(grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1)) + # Inplace index into the preallocated grad_input tensor + grad_input_chunk = grad_input[:, chunk_start:chunk_end, :] - all_grad_input.append(grad_input) + grad_input_chunk.copy_( + is_chosen.float().sub_(softmax_output) + ) # inplace copy + grad_input_chunk.mul_( + grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1) + ) - grad_input = torch.cat(all_grad_input, dim=1) + # Explicitly free before next iteration allocates + del softmax_output, is_chosen, logits # if you add an argument to the forward method, then you must add a corresponding None here return grad_input, None, None, None, None, None, None