-
Notifications
You must be signed in to change notification settings - Fork 6
Description
Description
The BatchLimiter class in src/saev/utils/scheduling.py incorrectly counts the number of samples seen during iteration, leading to premature termination or overcounting when the actual batch size is smaller than the expected batch size.
Location
src/saev/utils/scheduling.py:114
Root Cause
In the __iter__ method, the code always increments self.n_seen by self.batch_size:
self.n_seen += self.batch_size
if self.n_seen > self.n_samples:
returnHowever, the actual batch yielded might have fewer samples than self.batch_size, particularly:
- For the last batch when
drop_last=False - For dataloaders with uneven dataset sizes
This causes the limiter to overcount samples, terminating the iterator at the wrong time.
Expected Behavior
The BatchLimiter should count the actual number of samples in each batch, not assume all batches have size self.batch_size.
Actual Behavior
The limiter terminates based on incorrect counts, yielding either too many or too few samples.
Example
If we have:
- A dataloader with 105 samples
batch_size = 32drop_last = Falsen_samples = 100(what we want from BatchLimiter)
The batches would be: [32, 32, 32, 9]
But the counter would be: [32, 64, 96, 128]
When the counter hits 128 > 100, it returns after yielding all 105 samples (not 100).
Reproduction
See the unit tests in tests/test_batch_limiter.py which demonstrate this bug:
uv run --no-dev python -m pytest tests/test_batch_limiter.py -vTest results:
test_batch_limiter_with_uneven_batches: Expected ≤100 samples, got 105test_batch_limiter_early_termination: Expected 100 samples, got 160
Proposed Fix
Change line 114 to count the actual batch size instead of always using self.batch_size.
See PR for the implementation.