Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
XianzheMa committed Sep 6, 2024
1 parent c24e24c commit 787940b
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 13 deletions.
14 changes: 9 additions & 5 deletions modyn/trainer_server/internal/trainer/pytorch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@
AbstractPerLabelRemoteDownsamplingStrategy,
)
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import (
AbstractRemoteDownsamplingStrategy,
get_tensors_subset,
AbstractRemoteDownsamplingStrategy
)
from modyn.trainer_server.internal.utils.metric_type import MetricType
from modyn.trainer_server.internal.utils.trainer_messages import TrainerMessages
Expand Down Expand Up @@ -569,11 +568,16 @@ def downsample_batch(

# TODO(#218) Persist information on the sample IDs/weights when downsampling is performed
selected_indexes, weights = self._downsampler.select_points()
selected_data, selected_target = get_tensors_subset(selected_indexes, data, target, sample_ids)
sample_ids, data, target = selected_indexes, selected_data, selected_target
selected_target = target[selected_indexes]

if isinstance(data, torch.Tensor):
selected_data = data[selected_indexes]
else:
selected_data = {key: tensor[selected_indexes] for key, tensor in data.items()}
selected_sample_ids = torch.tensor(sample_ids)[selected_indexes].tolist()
# TODO(#219) Investigate if we can avoid 2 forward passes
self._model.model.train()
return data, sample_ids, target, weights.to(self._device)
return selected_data, selected_sample_ids, selected_target, weights.to(self._device)

def start_embedding_recording_if_needed(self) -> None:
if self._downsampler.requires_coreset_supporting_module:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AbstractPerLabelRemoteDownsamplingStrategy,
)
from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.shuffling import _shuffle_list_and_tensor
from modyn.utils import DownsamplingMode


class MatrixContent(Enum):
Expand Down Expand Up @@ -116,7 +117,10 @@ def _select_from_matrix(self) -> tuple[list[int], torch.Tensor]:
number_of_samples = len(matrix)
target_size = max(int(self.downsampling_ratio * number_of_samples / self.ratio_max), 1)
selected_indices, weights = self._select_indexes_from_matrix(matrix, target_size)
selected_ids = [self.index_sampleid_map[index] for index in selected_indices]
if self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE:
selected_ids = selected_indices
else:
selected_ids = [self.index_sampleid_map[index] for index in selected_indices]
return selected_ids, weights

def init_downsampler(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import torch

from modyn.utils import DownsamplingMode

FULL_GRAD_APPROXIMATION = ["LastLayer", "LastLayerWithEmbedding"]


Expand Down Expand Up @@ -83,6 +85,10 @@ def __init__(
# Some methods might only support StB, not BtS.
self.supports_bts = True

def set_downsampling_mode(self, downsampling_mode: DownsamplingMode) -> None:
# pylint: disable=attribute-defined-outside-init
self.downsampling_mode = downsampling_mode

@abstractmethod
def init_downsampler(self) -> None:
raise NotImplementedError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.submodular_optimizer import (
OPTIMIZER_CHOICES,
)
from modyn.utils import DownsamplingMode


class RemoteCraigDownsamplingStrategy(AbstractPerLabelRemoteDownsamplingStrategy):
Expand Down Expand Up @@ -193,7 +194,10 @@ def _select_points_from_distance_matrix(self) -> tuple[list[int], torch.Tensor]:
batch=self.selection_batch,
)
weights = self.calc_weights(self.distance_matrix, selection_result)
selected_ids = [self.index_sampleid_map[sample] for sample in selection_result]
if self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE:
selected_ids = selection_result
else:
selected_ids = [self.index_sampleid_map[sample] for sample in selection_result]
return selected_ids, weights

def init_downsampler(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import (
AbstractRemoteDownsamplingStrategy,
)
from modyn.utils import DownsamplingMode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,8 +73,10 @@ def select_points(self) -> tuple[list[int], torch.Tensor]:

# lower probability, higher weight to reduce the variance
weights = 1.0 / (self.number_of_points_seen * probabilities[downsampled_idxs])

selected_ids = [self.index_sampleid_map[sample] for sample in downsampled_idxs]
if self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE:
selected_ids = downsampled_idxs
else:
selected_ids = [self.index_sampleid_map[sample] for sample in downsampled_idxs]
return selected_ids, weights

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import (
AbstractRemoteDownsamplingStrategy,
)
from modyn.utils import DownsamplingMode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -71,8 +72,10 @@ def select_points(self) -> tuple[list[int], torch.Tensor]:

# lower probability, higher weight to reduce the variance
weights = 1.0 / (self.number_of_points_seen * probabilities[downsampled_idxs])

selected_ids = [self.index_sampleid_map[sample] for sample in downsampled_idxs]
if self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE:
selected_ids = downsampled_idxs
else:
selected_ids = [self.index_sampleid_map[sample] for sample in downsampled_idxs]
return selected_ids, weights

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from modyn.trainer_server.internal.trainer.remote_downsamplers.rho_loss_utils.irreducible_loss_producer import (
IrreducibleLossProducer,
)
from modyn.utils import DownsamplingMode


class RemoteRHOLossDownsampling(AbstractRemoteDownsamplingStrategy):
Expand Down Expand Up @@ -61,8 +62,9 @@ def select_points(self) -> tuple[list[int], torch.Tensor]:
target_size = max(int(self.downsampling_ratio * self.number_of_points_seen / self.ratio_max), 1)
# find the indices of maximal "target_size" elements in the list of rho_loss
selected_indices = torch.topk(self.rho_loss, target_size).indices
assert self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE
# use sorted() because we keep the relative order of the selected samples
selected_sample_ids = [self.index_sampleid_map[i] for i in sorted(selected_indices)]
selected_sample_ids = sorted(selected_indices)
# all selected samples have weight 1.0
selected_weights = torch.ones(target_size)
return selected_sample_ids, selected_weights
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import (
AbstractRemoteDownsamplingStrategy,
)
from modyn.utils import DownsamplingMode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,6 +86,7 @@ def select_points(self) -> tuple[list[int], torch.Tensor]:
assert self._current_subset < len(
self._subsets
), f"Inconsistent state: {self._current_subset}\n{self._subsets}\n{self._first_epoch}\n{self._all_sample_ids}"
assert self.downsampling_mode == DownsamplingMode.SAMPLE_THEN_BATCH, "Only sample-then-batch is supported"
return self._subsets[self._current_subset], torch.ones(len(self._subsets[self._current_subset]))

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
AbstractPerLabelRemoteDownsamplingStrategy,
)
from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.shuffling import _shuffle_list_and_tensor
from modyn.utils import DownsamplingMode


class RemoteUncertaintyDownsamplingStrategy(AbstractPerLabelRemoteDownsamplingStrategy):
Expand Down Expand Up @@ -123,7 +124,10 @@ def _select_from_scores(self) -> tuple[list[int], torch.Tensor]:
number_of_samples = len(self.scores)
target_size = max(int(self.downsampling_ratio * number_of_samples / self.ratio_max), 1)
selected_indices, weights = self._select_indexes_from_scores(target_size)
selected_ids = [self.index_sampleid_map[index] for index in selected_indices]
if self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE:
selected_ids = selected_indices
else:
selected_ids = [self.index_sampleid_map[index] for index in selected_indices]
return selected_ids, weights

def init_downsampler(self) -> None:
Expand Down

0 comments on commit 787940b

Please sign in to comment.