MulticlassAccuracy Calculation Results in Dimension Mismatch Error #2946
-
When I try to compute the multiclass accuracy for a preds/target input with a top_k value greater than 1, I get an error about mismatched dimensions. To ReproduceSteps to reproduce the behavior...
I get the error message: Code sample# Ideally attach a minimal code sample to reproduce the decried issue.
# Minimal means having the shortest code but still preserving the bug.
!pip install torch torchmetrics --quiet --upgrade
import torch
import torchmetrics
batch_size = 1
num_classes = 4
sequence_length = 7
top_k = 3
preds = torch.randn(batch_size, num_classes, sequence_length)
targets = torch.randint(num_classes, (batch_size, sequence_length))
accuracy = MulticlassAccuracy(num_classes = num_classes, top_k = top_k, multidim_average='global')
acc = accuracy(preds, targets) Expected BehaviorI expect the output to be a scalar tensor as mentioned in the documentation. Environment
Additional context |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The dimension mismatch error occurs because the import torch
from torchmetrics.classification import MulticlassAccuracy
batch_size = 1
num_classes = 4
sequence_length = 7
top_k = 3
# Original tensors
preds = torch.randn(batch_size, num_classes, sequence_length)
targets = torch.randint(num_classes, (batch_size, sequence_length))
# Flatten batch and sequence dims, reorder preds to (N*L, C)
preds_flat = preds.permute(0, 2, 1).reshape(-1, num_classes) # shape: (batch_size * sequence_length, num_classes)
targets_flat = targets.view(-1) # shape: (batch_size * sequence_length)
# Compute top-k accuracy
accuracy = MulticlassAccuracy(num_classes=num_classes, top_k=top_k)
acc = accuracy(preds_flat, targets_flat)
print("Top-k Accuracy:", acc) This reshaping aligns preds and targets so the metric calculates accuracy correctly over all samples, avoiding dimension mismatch errors. |
Beta Was this translation helpful? Give feedback.
The dimension mismatch error occurs because the
MulticlassAccuracy
metric expects predictions (preds
) and targets to have matching sample dimensions, but your input has a sequence length dimension that causes shape misalignment internally.Your
preds
tensor is shaped(batch_size, num_classes, sequence_length)
and targets(batch_size, sequence_length)
. TorchMetrics expects the samples dimension (batch + any extra dims) to be flattened while keeping classes as a separate dimension.The fix is to flatten the batch and sequence length dimensions together for both preds and targets before passing them to
MulticlassAccuracy
. Here is the corrected code: