Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support for RS2 Downsampler #465

Merged
merged 14 commits into from
Jun 4, 2024
Merged
29 changes: 28 additions & 1 deletion modyn/config/schema/sampling/downsampling_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Annotated, List, Literal, Union

from modyn.config.schema.modyn_base_model import ModynBaseModel
from pydantic import Field, model_validator
from pydantic import Field, field_validator, model_validator
from typing_extensions import Self


Expand Down Expand Up @@ -121,6 +121,33 @@ class RHOLossDownsamplingConfig(BaseDownsamplingConfig):
il_training_config: ILTrainingConfig = Field(description="The configuration for the IL training.")


class RS2DownsamplingConfig(BaseDownsamplingConfig):
"""Config for the RS2 downsampling strategy."""

strategy: Literal["RS2"] = "RS2"
with_replacement: bool = Field(
description=(
"Whether we resample from the full TTS each epoch (= True) or train "
"on all the data with a different subset each epoch (= False)."
)
)

@field_validator("sample_then_batch")
@classmethod
def sample_then_batch_must_be_true(cls, v: bool) -> bool:
if not v:
raise ValueError("sample_then_batch must be set to True for this config.")
return v

@field_validator("period")
@classmethod
def only_support_period_one(cls, v: int) -> int:
if v != 0:
MaxiBoether marked this conversation as resolved.
Show resolved Hide resolved
# RS2 requires us to resample every epoch.
XianzheMa marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("period must be set to 1 for this config.")
return v


SingleDownsamplingConfig = Annotated[
Union[
UncertaintyDownsamplingConfig,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import torch
XianzheMa marked this conversation as resolved.
Show resolved Hide resolved
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import (
get_tensors_subset,
)
from modyn.trainer_server.internal.trainer.remote_downsamplers.remote_rs2_downsampling import RemoteRS2Downsampling


def test_init():
pipeline_id = 0
trigger_id = 0
batch_size = 32
params_from_selector = {"replacement": True, "downsampling_ratio": 50}
per_sample_loss = None
device = "cpu"

downsampler = RemoteRS2Downsampling(
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device
)

assert downsampler.pipeline_id == pipeline_id
assert downsampler.trigger_id == trigger_id
assert downsampler.batch_size == batch_size
assert downsampler.device == device
assert not downsampler.forward_required
assert not downsampler.supports_bts
assert downsampler._all_sample_ids == []
assert downsampler._subsets == []
assert downsampler._current_subset == -1
assert downsampler._with_replacement == params_from_selector["replacement"]
assert downsampler._max_subset == -1
assert downsampler._first_epoch


def test_inform_samples():
pipeline_id = 0
trigger_id = 0
batch_size = 32
params_from_selector = {"replacement": True, "downsampling_ratio": 50}
per_sample_loss = None
device = "cpu"

downsampler = RemoteRS2Downsampling(
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device
)

sample_ids = [1, 2, 3, 4, 5]
forward_output = torch.randn(5, 10)
target = torch.randint(0, 10, (5,))

downsampler.inform_samples(sample_ids, forward_output, target)

assert downsampler._all_sample_ids == sample_ids
downsampler.inform_samples(sample_ids, forward_output, target)
assert downsampler._all_sample_ids == 2 * sample_ids
# Now it should not change anymore
downsampler.select_points()
downsampler.inform_samples(sample_ids, forward_output, target)
assert set(downsampler._all_sample_ids) == set(sample_ids)
assert len(downsampler._all_sample_ids) == 2 * len(sample_ids)
XianzheMa marked this conversation as resolved.
Show resolved Hide resolved


def test_multiple_epochs_with_replacement():
pipeline_id = 0
trigger_id = 0
batch_size = 32
params_from_selector = {"replacement": True, "downsampling_ratio": 50}
per_sample_loss = None
device = "cpu"

downsampler = RemoteRS2Downsampling(
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device
)
with torch.inference_mode(mode=(not downsampler.requires_grad)):
sample_ids = list(range(10))
data = torch.randn(10, 10)
target = torch.randint(0, 10, (10,))

for _ in range(3):
downsampler.inform_samples(sample_ids, data, target)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each call to inform_samples should be provided with a different set of sample_ids

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? That would not be the case in the trainer server / pytorch trainer due to the nature of downsampling and also it will not make a difference

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it does not make a difference here, as we just test the shape. But naturally they should be different because,

In sample_and_batch. In the pytorch_trainer.py, we first iterate over the dataloader and keep informing each batch in _iterate_dataloader_and_compute_scores

self._downsampler.inform_samples(sample_ids, model_output, target, embeddings)

the sample_ids come from the dataloader and should be naturally distinct right? (they are keys of the samples)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they are not. What differs is the model output (on which true downsamplers sample), but the list of samples is always the same, since the trigger training set from the selector does not change between epochs. Since RS2 only relies on the IDs, it should not matter. The IDs will in all cases be identical across epochs.

Copy link
Collaborator

@XianzheMa XianzheMa Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have a misunderstanding. I am saying the consecutive calls to inform_samples within two select_points call boundaries should contain different sample ids.

I copy the code of _iterate_dataloader_and_compute_scores here:

        for batch_number, batch in enumerate(dataloader):
            self.update_queue(AvailableQueues.DOWNSAMPLING, batch_number, number_of_samples, training_active=False)

            sample_ids, target, data = self.preprocess_batch(batch)
            number_of_samples += len(sample_ids)

            with torch.inference_mode(mode=(not self._downsampler.requires_grad)):
                with torch.autocast(self._device_type, enabled=self._amp):
                    # compute the scores and accumulate them
                    model_output = self._model.model(data)
                    embeddings = self.get_embeddings_if_recorded()
                    self._downsampler.inform_samples(sample_ids, model_output, target, embeddings)

You see: We load one batch after another from the dataloader. One inform_samples call does not contain the entire dataset data but just one batch. The first batch must have different sample ids than the second batch's sample ids. That means if we do not call select_points in the middle, then the inform_samples call should contain different sample ids

I am not talking about sample ids across epochs. Those definitely do not change.

Copy link
Collaborator

@XianzheMa XianzheMa Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e. if we have

downsampler.inform_samples(...)
downsampler.select_points(...)
downsampler.inform_samples(...)

Then the first inform_samples call can have the same sample ids as the second inform_samples.

But when we do

downsampler.inform_samples(...)
downsampler.inform_samples(...)
downsampler.select_points(...)
downsampler.inform_samples(...)
downsampler.inform_samples(...)

Suppose the whole dataset contains two batches. Then the first two inform_samples calls should contain different sample_ids.

In this unit test, we only keep calling inform_samples(...) without calling select_points(...), so each call should contain distinct sample_ids.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, I think it does not really make a difference here to use different sample ids. But I still do think consecutive inform_samples calls (without select_points call in the middle) should contain distinct sample ids.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your point and agree with your description, but I still don't understand why you are suggesting it here :D The code is this

    with torch.inference_mode(mode=(not downsampler.requires_grad)):
        sample_ids = list(range(10))
        data = torch.randn(10, 10)
        target = torch.randint(0, 10, (10,))

        for _ in range(3):
            downsampler.inform_samples(sample_ids, data, target)
            selected_ids, weights = downsampler.select_points()

so the loop is the epoch loop (!). Since sample_ids = list(range(10)) we don't have duplicate samples in the same epoch and consistent samples across epochs. This is exactly like you describe. I am not sure if I am missing something or you just confused this loop with something else. I am merging this for now and happy to do a follow up PR in case I am missing something here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

" this unit test, we only keep calling inform_samples(...) without calling select_points(...),"

i dont get it. isn't it directly below :D?

selected_ids, weights = downsampler.select_points()
sampled_data, sampled_target = get_tensors_subset(selected_ids, data, target, sample_ids)

assert len(set(selected_ids)) == 5
assert weights.shape == (5,)
assert all(idx in sample_ids for idx in selected_ids)
assert sampled_data.shape == (5, 10)
assert sampled_target.shape == (5,)


def test_multiple_epochs_without_replacement():
pipeline_id = 0
trigger_id = 0
batch_size = 32
params_from_selector = {"replacement": False, "downsampling_ratio": 50}
per_sample_loss = None
device = "cpu"

downsampler = RemoteRS2Downsampling(
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device
)
with torch.inference_mode(mode=(not downsampler.requires_grad)):

sample_ids = list(range(10))
data = torch.randn(10, 10)
target = torch.randint(0, 10, (10,))

# Epoch 1
downsampler.inform_samples(sample_ids, data, target)
epoch1_ids, weights = downsampler.select_points()
sampled_data, sampled_target = get_tensors_subset(epoch1_ids, data, target, sample_ids)

assert len(set(epoch1_ids)) == 5
assert weights.shape == (5,)
assert all(idx in sample_ids for idx in epoch1_ids)
assert sampled_data.shape == (5, 10)
assert sampled_target.shape == (5,)

# Epoch 2
downsampler.inform_samples(sample_ids, data, target)
epoch2_ids, weights = downsampler.select_points()
sampled_data, sampled_target = get_tensors_subset(epoch2_ids, data, target, sample_ids)

assert len(set(epoch2_ids)) == 5
assert weights.shape == (5,)
assert all(idx in sample_ids for idx in epoch2_ids)
assert not any(idx in epoch1_ids for idx in epoch2_ids) # No overlap across epochs
XianzheMa marked this conversation as resolved.
Show resolved Hide resolved
assert sampled_data.shape == (5, 10)
assert sampled_target.shape == (5,)

# Epoch 3
downsampler.inform_samples(sample_ids, data, target)
epoch3_ids, weights = downsampler.select_points()
sampled_data, sampled_target = get_tensors_subset(epoch3_ids, data, target, sample_ids)

assert len(set(epoch3_ids)) == 5
assert weights.shape == (5,)
assert all(idx in sample_ids for idx in epoch3_ids)
assert all(idx in epoch1_ids or idx in epoch2_ids for idx in epoch3_ids) # There needs to be overlap now
# but (with very high probability, this might be flaky lets see) there is some difference
assert any(idx not in epoch1_ids for idx in epoch3_ids)
assert sampled_data.shape == (5, 10)
assert sampled_target.shape == (5,)


def test_multiple_epochs_without_replacement_leftover_data():
pipeline_id = 0
trigger_id = 0
batch_size = 32
params_from_selector = {"replacement": False, "downsampling_ratio": 40}
per_sample_loss = None
device = "cpu"

downsampler = RemoteRS2Downsampling(
pipeline_id, trigger_id, batch_size, params_from_selector, per_sample_loss, device
)
with torch.inference_mode(mode=(not downsampler.requires_grad)):
sample_ids = list(range(10))
data = torch.randn(10, 10)
target = torch.randint(0, 10, (10,))

for _ in range(3):
downsampler.inform_samples(sample_ids, data, target)

selected_ids, weights = downsampler.select_points()
sampled_data, sampled_target = get_tensors_subset(selected_ids, data, target, sample_ids)
assert len(set(selected_ids)) == 4
assert weights.shape == (4,)
assert sampled_data.shape == (4, 10)
assert sampled_target.shape == (4,)

assert all(idx in sample_ids for idx in selected_ids)
assert len(set(selected_ids)) == len(selected_ids)
6 changes: 4 additions & 2 deletions modyn/trainer_server/internal/trainer/pytorch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches
self._info(f"Training will stop when the number of samples to pass reaches {self.num_samples_to_pass}.")

if self._downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE:
# assertion since model validation by pydantic should catch this.
assert self._downsampler.supports_bts, "The downsampler does not support batch then sample"
XianzheMa marked this conversation as resolved.
Show resolved Hide resolved
# We cannot pass the target size from the trainer server since that depends on StB vs BtS.
post_downsampling_size = max(int(self._downsampler.downsampling_ratio * self._batch_size / 100), 1)
assert post_downsampling_size < self._batch_size
Expand Down Expand Up @@ -692,7 +694,7 @@ def downsample_batch(
self.start_embedding_recording_if_needed()

with torch.inference_mode(mode=(not self._downsampler.requires_grad)):
big_batch_output = self._model.model(data)
big_batch_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor()
embeddings = self.get_embeddings_if_recorded()
self._downsampler.inform_samples(sample_ids, big_batch_output, target, embeddings)

Expand Down Expand Up @@ -831,7 +833,7 @@ def _iterate_dataloader_and_compute_scores(
with torch.inference_mode(mode=(not self._downsampler.requires_grad)):
with torch.autocast(self._device_type, enabled=self._amp):
# compute the scores and accumulate them
model_output = self._model.model(data)
model_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor()
embeddings = self.get_embeddings_if_recorded()
self._downsampler.inform_samples(sample_ids, model_output, target, embeddings)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ def __init__(
# CoresetSupportingModule for model implementations.
self.requires_coreset_supporting_module = False

# Some methods might not need information from forward pass (e.g. completely random)
# Most do (definition), hence we default to True
# We might want to refactor those downsamplers to presamplers and support some
# adaptivity at the selector, but for now we allow random downsamplers mostly
# to support RS2.
self.forward_required = True

# Some methods might only support StB, not BtS.
self.supports_bts = True
XianzheMa marked this conversation as resolved.
Show resolved Hide resolved

@abstractmethod
def init_downsampler(self) -> None:
raise NotImplementedError
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import logging
import random
from typing import Any, Optional

import torch
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import (
AbstractRemoteDownsamplingStrategy,
)

logger = logging.getLogger(__name__)


class RemoteRS2Downsampling(AbstractRemoteDownsamplingStrategy):
"""
XianzheMa marked this conversation as resolved.
Show resolved Hide resolved
Method adapted from REPEATED RANDOM SAMPLING FOR MINIMIZING THE TIME-TO-ACCURACY OF LEARNING (Okanovic+, 2024)
https://openreview.net/pdf?id=JnRStoIuTe
"""

def __init__(
self,
pipeline_id: int,
trigger_id: int,
batch_size: int,
params_from_selector: dict,
per_sample_loss: Any,
device: str,
) -> None:
super().__init__(pipeline_id, trigger_id, batch_size, params_from_selector, device)
self.forward_required = False
self.supports_bts = False
self._all_sample_ids: list[int] = []
self._subsets: list[list[int]] = []
self._current_subset = -1
self._with_replacement: bool = params_from_selector["replacement"]
self._max_subset = -1
self._first_epoch = True

def init_downsampler(self) -> None:
pass # We take care of that in inform_samples

def inform_samples(
self,
sample_ids: list[int],
forward_output: torch.Tensor,
target: torch.Tensor,
embedding: Optional[torch.Tensor] = None,
) -> None:
# We only need to collect the sample information once
if self._first_epoch:
self._all_sample_ids.extend(sample_ids)

def _epoch_step_wr(self, target_size: int) -> None:
self._subsets = [self._all_sample_ids[:target_size]]
self._current_subset = 0

def _epoch_step_r(self, target_size: int) -> None:
self._max_subset = len(self._all_sample_ids) // target_size
XianzheMa marked this conversation as resolved.
Show resolved Hide resolved
self._current_subset += 1
if self._current_subset >= self._max_subset or len(self._subsets) == 0:
XianzheMa marked this conversation as resolved.
Show resolved Hide resolved
self._current_subset = 0
self._subsets = [
self._all_sample_ids[i * target_size : (i + 1) * target_size] for i in range(self._max_subset)
]

def _epoch_step(self) -> None:
target_size = max(int(self.downsampling_ratio * len(self._all_sample_ids) / 100), 1)
random.shuffle(self._all_sample_ids)

if self._with_replacement:
self._epoch_step_wr(target_size)
else:
self._epoch_step_r(target_size)

def select_points(self) -> tuple[list[int], torch.Tensor]:
self._first_epoch = False
self._epoch_step()
return self._subsets[self._current_subset], torch.ones(len(self._subsets[self._current_subset]))

@property
def requires_grad(self) -> bool:
return False
Loading