-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support for RS2 Downsampler #465
Conversation
…nto feature/MaxiBoether/rs2
✅ Result of Pytest Coverage---------- coverage: platform linux, python 3.12.3-final-0 -----------
|
1 similar comment
…nto feature/MaxiBoether/rs2
...trainer_server/internal/trainer/remote_downsamplers/abstract_remote_downsampling_strategy.py
Show resolved
Hide resolved
modyn/trainer_server/internal/trainer/remote_downsamplers/remote_rs2_downsampling.py
Show resolved
Hide resolved
modyn/trainer_server/internal/trainer/remote_downsamplers/remote_rs2_downsampling.py
Outdated
Show resolved
Hide resolved
modyn/trainer_server/internal/trainer/remote_downsamplers/remote_rs2_downsampling.py
Outdated
Show resolved
Hide resolved
modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_rs2_downsampling.py
Outdated
Show resolved
Hide resolved
target = torch.randint(0, 10, (10,)) | ||
|
||
for _ in range(3): | ||
downsampler.inform_samples(sample_ids, data, target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each call to inform_samples
should be provided with a different set of sample_ids
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? That would not be the case in the trainer server / pytorch trainer due to the nature of downsampling and also it will not make a difference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it does not make a difference here, as we just test the shape. But naturally they should be different because,
In sample_and_batch
. In the pytorch_trainer.py
, we first iterate over the dataloader and keep informing each batch in _iterate_dataloader_and_compute_scores
self._downsampler.inform_samples(sample_ids, model_output, target, embeddings) |
the sample_ids
come from the dataloader and should be naturally distinct right? (they are keys of the samples)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, they are not. What differs is the model output (on which true downsamplers sample), but the list of samples is always the same, since the trigger training set from the selector does not change between epochs. Since RS2 only relies on the IDs, it should not matter. The IDs will in all cases be identical across epochs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have a misunderstanding. I am saying the consecutive calls to inform_samples
within two select_points
call boundaries should contain different sample ids.
I copy the code of _iterate_dataloader_and_compute_scores
here:
for batch_number, batch in enumerate(dataloader):
self.update_queue(AvailableQueues.DOWNSAMPLING, batch_number, number_of_samples, training_active=False)
sample_ids, target, data = self.preprocess_batch(batch)
number_of_samples += len(sample_ids)
with torch.inference_mode(mode=(not self._downsampler.requires_grad)):
with torch.autocast(self._device_type, enabled=self._amp):
# compute the scores and accumulate them
model_output = self._model.model(data)
embeddings = self.get_embeddings_if_recorded()
self._downsampler.inform_samples(sample_ids, model_output, target, embeddings)
You see: We load one batch after another from the dataloader. One inform_samples
call does not contain the entire dataset data but just one batch. The first batch must have different sample ids than the second batch's sample ids. That means if we do not call select_points
in the middle, then the inform_samples
call should contain different sample ids
I am not talking about sample ids across epochs. Those definitely do not change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i.e. if we have
downsampler.inform_samples(...)
downsampler.select_points(...)
downsampler.inform_samples(...)
Then the first inform_samples call can have the same sample ids as the second inform_samples.
But when we do
downsampler.inform_samples(...)
downsampler.inform_samples(...)
downsampler.select_points(...)
downsampler.inform_samples(...)
downsampler.inform_samples(...)
Suppose the whole dataset contains two batches. Then the first two inform_samples
calls should contain different sample_ids
.
In this unit test, we only keep calling inform_samples(...)
without calling select_points(...)
, so each call should contain distinct sample_ids.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway, I think it does not really make a difference here to use different sample ids. But I still do think consecutive inform_samples
calls (without select_points
call in the middle) should contain distinct sample ids.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand your point and agree with your description, but I still don't understand why you are suggesting it here :D The code is this
with torch.inference_mode(mode=(not downsampler.requires_grad)):
sample_ids = list(range(10))
data = torch.randn(10, 10)
target = torch.randint(0, 10, (10,))
for _ in range(3):
downsampler.inform_samples(sample_ids, data, target)
selected_ids, weights = downsampler.select_points()
so the loop is the epoch loop (!). Since sample_ids = list(range(10))
we don't have duplicate samples in the same epoch and consistent samples across epochs. This is exactly like you describe. I am not sure if I am missing something or you just confused this loop with something else. I am merging this for now and happy to do a follow up PR in case I am missing something here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
" this unit test, we only keep calling inform_samples(...) without calling select_points(...),"
i dont get it. isn't it directly below :D?
Answered your comments and addressed them where possible for now :) |
...n/selector/internal/selector_strategies/downsampling_strategies/rs2_downsampling_strategy.py
Outdated
Show resolved
Hide resolved
modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_rs2_downsampling.py
Outdated
Show resolved
Hide resolved
modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_rs2_downsampling.py
Outdated
Show resolved
Hide resolved
...trainer_server/internal/trainer/remote_downsamplers/abstract_remote_downsampling_strategy.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are still some points to address, but after that, feel free to merge🚀, thanks a lot!
...n/selector/internal/selector_strategies/downsampling_strategies/rs2_downsampling_strategy.py
Outdated
Show resolved
Hide resolved
Feel free to merge the PR! Thanks for the further explanation!!!! |
This implements the random selection from the RS2 paper minus the learning rate scheduling adjustments.
Note that it is a bit suboptimal to use the downsampling infrastructure here (#466). We might want to think about making the selector a bit more dynamic, but for now, this will suffice to run experiments. #462 should be merged before this is reviewed.