From a067304c0506c9828d62ce92a14c140b1ccd56f2 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Sun, 1 Sep 2024 16:25:01 +0200 Subject: [PATCH 01/20] first commit --- .../remote_downsamplers/remote_gradnorm_downsampling.py | 4 +++- .../remote_uncertainty_downsampling_strategy.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py index 6c09bfca1..20257053b 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py @@ -52,7 +52,9 @@ def inform_samples( last_layer_gradients = self._compute_last_layer_gradient_wrt_loss_sum( self.per_sample_loss_fct, forward_output, target ) - scores = torch.norm(last_layer_gradients, dim=-1).cpu() + if last_layer_gradients.dim() == 1: + last_layer_gradients = last_layer_gradients.unsqueeze(1) + scores = torch.linalg.vector_norm(last_layer_gradients, dim=1).cpu() self.probabilities.append(scores) self.number_of_points_seen += forward_output.shape[0] self.index_sampleid_map += sample_ids diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py index 3a8eb24dc..90223a876 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py @@ -69,6 +69,8 @@ def inform_samples( self.index_sampleid_map += sample_ids def _compute_score(self, forward_output: torch.Tensor, disable_softmax: bool = False) -> np.ndarray: + if forward_output.dim() == 1: + forward_output = forward_output.unsqueeze(1) if self.score_metric == "LeastConfidence": scores = forward_output.max(dim=1).values.cpu().numpy() elif self.score_metric == "Entropy": @@ -139,7 +141,7 @@ def _select_indexes_from_scores(self, target_size: int) -> tuple[list[int], torc # we select those with minimal negative entropy, i.e., maximum entropy # Margin: We look for the smallest margin. The larger the margin, the more certain the # model is. - return np.argsort(self.scores)[:target_size], torch.ones(target_size).float() + return np.argsort(self.scores)[:target_size].tolist(), torch.ones(target_size).float() @property def requires_grad(self) -> bool: From ed12a56c46a73273c9541822ad8fead245687604 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Wed, 4 Sep 2024 11:20:44 +0800 Subject: [PATCH 02/20] add test --- .../test_remote_gradnorm_downsample.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py index 600f6a17b..156ae0106 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py @@ -1,3 +1,4 @@ +import pytest import torch from torch import nn @@ -40,7 +41,9 @@ def test_sample_shape_ce(dummy_system_config: ModynConfig): assert set(downsampled_indexes) <= set(range(8)) -def test_sample_shape_other_losses(dummy_system_config: ModynConfig): + +@pytest.mark.parametrize("squeeze_dim", [True, False]) +def test_sample_shape_other_losses(dummy_system_config: ModynConfig, squeeze_dim): model = torch.nn.Linear(10, 1) downsampling_ratio = 50 per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none") @@ -51,11 +54,13 @@ def test_sample_shape_other_losses(dummy_system_config: ModynConfig): ) with torch.inference_mode(mode=(not sampler.requires_grad)): data = torch.randn(8, 10) + forward_outputs = model(data) target = torch.randint(2, size=(8,), dtype=torch.float32).unsqueeze(1) + if squeeze_dim: + target = target.squeeze(1) + forward_outputs = forward_outputs.squeeze(1) ids = list(range(8)) - forward_outputs = model(data) - sampler.inform_samples(ids, data, forward_outputs, target) downsampled_indexes, weights = sampler.select_points() From 16f6ffd8983e1ae9799d06bb1b2996958f7f8193 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Fri, 6 Sep 2024 18:02:18 +0800 Subject: [PATCH 03/20] fix ruff --- .../remote_downsamplers/test_remote_gradnorm_downsample.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py index 156ae0106..6f370cb76 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py @@ -41,7 +41,6 @@ def test_sample_shape_ce(dummy_system_config: ModynConfig): assert set(downsampled_indexes) <= set(range(8)) - @pytest.mark.parametrize("squeeze_dim", [True, False]) def test_sample_shape_other_losses(dummy_system_config: ModynConfig, squeeze_dim): model = torch.nn.Linear(10, 1) From f364642edec1800878db90e1c3f0757852a93c87 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Fri, 6 Sep 2024 23:24:58 +0800 Subject: [PATCH 04/20] add tests --- ...emote_uncertainty_downsampling_strategy.py | 28 +++++++++++++++++++ .../remote_gradnorm_downsampling.py | 1 + ...emote_uncertainty_downsampling_strategy.py | 3 ++ 3 files changed, 32 insertions(+) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py index 89828f03b..befe3e941 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py @@ -74,6 +74,34 @@ def test_compute_score(sampler_config): assert np.allclose(scores, expected_scores, atol=1e-4) +binary_test_data = { + "LeastConfidence": { + "outputs": torch.tensor([[0.8], [0.5], [0.3]]), + "expected_scores": np.array([0.8, 0.5, 0.7]), # confidence just picks the highest probability + }, + "Entropy": { + "outputs": torch.tensor([[0.8], [0.5], [0.3]]), + "expected_scores": np.array([-0.5004 , -0.6931, -0.6109]), + }, + "Margin": { + "outputs": torch.tensor([[0.8], [0.5], [0.3]]), + "expected_scores": np.array([0.6, 0.0, 0.4]), # margin between top two classes + }, +} + + +@pytest.mark.parametrize("squeeze_dim", [True, False]) +def test_compute_score_binary(sampler_config, squeeze_dim): + metric = sampler_config[3]["score_metric"] + amds = RemoteUncertaintyDownsamplingStrategy(*sampler_config) + outputs = binary_test_data[metric]["outputs"] + if squeeze_dim: + outputs = outputs.squeeze() + expected_scores = binary_test_data[metric]["expected_scores"] + scores = amds._compute_score(outputs, disable_softmax=True) + assert np.allclose(scores, expected_scores, atol=1e-4) + + def test_select_points(balance_config): amds = RemoteUncertaintyDownsamplingStrategy(*balance_config) with torch.inference_mode(): diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py index 20257053b..4320a4666 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py @@ -54,6 +54,7 @@ def inform_samples( ) if last_layer_gradients.dim() == 1: last_layer_gradients = last_layer_gradients.unsqueeze(1) + # pylint: disable=not-callable scores = torch.linalg.vector_norm(last_layer_gradients, dim=1).cpu() self.probabilities.append(scores) self.number_of_points_seen += forward_output.shape[0] diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py index 90223a876..7daf340ba 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py @@ -71,6 +71,9 @@ def inform_samples( def _compute_score(self, forward_output: torch.Tensor, disable_softmax: bool = False) -> np.ndarray: if forward_output.dim() == 1: forward_output = forward_output.unsqueeze(1) + feature_size = forward_output.size(1) + if feature_size == 1: + forward_output = torch.cat((1 - forward_output, forward_output), dim=1) if self.score_metric == "LeastConfidence": scores = forward_output.max(dim=1).values.cpu().numpy() elif self.score_metric == "Entropy": From 9e3c4b6fadcf1caadb91e05068f68f71cd300fe6 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Fri, 6 Sep 2024 23:27:18 +0800 Subject: [PATCH 05/20] fix ruff --- .../test_remote_uncertainty_downsampling_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py index befe3e941..4ba3445e2 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py @@ -81,7 +81,7 @@ def test_compute_score(sampler_config): }, "Entropy": { "outputs": torch.tensor([[0.8], [0.5], [0.3]]), - "expected_scores": np.array([-0.5004 , -0.6931, -0.6109]), + "expected_scores": np.array([-0.5004, -0.6931, -0.6109]), }, "Margin": { "outputs": torch.tensor([[0.8], [0.5], [0.3]]), From a0be9a73548d7bbbaf288f386ab8a92a9d8c067d Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Sat, 7 Sep 2024 00:04:48 +0800 Subject: [PATCH 06/20] add test to loss --- .../test_remote_gradnorm_downsample.py | 4 +-- .../test_remote_loss_downsample.py | 35 +++++++++++++++++++ .../remote_loss_downsampling.py | 2 ++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py index 6f370cb76..d82150216 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py @@ -22,7 +22,7 @@ def test_sample_shape_ce(dummy_system_config: ModynConfig): ) with torch.inference_mode(mode=(not sampler.requires_grad)): data = torch.randn(8, 10) - target = torch.randint(2, size=(8,)) + target = torch.randint(3, size=(8,)) ids = list(range(8)) forward_outputs = model(data) sampler.inform_samples(ids, data, forward_outputs, target) @@ -42,7 +42,7 @@ def test_sample_shape_ce(dummy_system_config: ModynConfig): @pytest.mark.parametrize("squeeze_dim", [True, False]) -def test_sample_shape_other_losses(dummy_system_config: ModynConfig, squeeze_dim): +def test_sample_shape_binary(dummy_system_config: ModynConfig, squeeze_dim): model = torch.nn.Linear(10, 1) downsampling_ratio = 50 per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none") diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py index a278d12ca..52bd9e124 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py @@ -1,3 +1,4 @@ +import pytest import torch from torch import nn @@ -35,6 +36,40 @@ def test_sample_shape(dummy_system_config: ModynConfig): assert len(indexes) == 4 +@pytest.mark.parametrize("squeeze_dim", [True, False]) +def test_sample_shape_binary(dummy_system_config: ModynConfig, squeeze_dim): + model = torch.nn.Linear(10, 1) + downsampling_ratio = 50 + per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none") + + params_from_selector = {"downsampling_ratio": downsampling_ratio, "sample_then_batch": False, "ratio_max": 100} + sampler = RemoteLossDownsampling( + 0, 0, 0, params_from_selector, dummy_system_config.model_dump(by_alias=True), per_sample_loss_fct, "cpu" + ) + with torch.inference_mode(mode=(not sampler.requires_grad)): + data = torch.randn(8, 10) + forward_outputs = model(data) + target = torch.randint(2, size=(8,), dtype=torch.float32).unsqueeze(1) + if squeeze_dim: + target = target.squeeze(1) + forward_outputs = forward_outputs.squeeze(1) + ids = list(range(8)) + + sampler.inform_samples(ids, data, forward_outputs, target) + downsampled_indexes, weights = sampler.select_points() + + assert len(downsampled_indexes) == 4 + assert weights.shape[0] == 4 + + sampled_data, sampled_target = get_tensors_subset(downsampled_indexes, data, target, ids) + + assert weights.shape[0] == sampled_target.shape[0] + assert sampled_data.shape[0] == 4 + assert sampled_data.shape[1] == data.shape[1] + assert weights.shape[0] == 4 + assert sampled_target.shape[0] == 4 + + def test_sample_weights(dummy_system_config: ModynConfig): model = torch.nn.Linear(10, 2) downsampling_ratio = 50 diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_loss_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_loss_downsampling.py index 79bcee16c..239d86425 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_loss_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_loss_downsampling.py @@ -52,6 +52,8 @@ def inform_samples( embedding: torch.Tensor | None = None, ) -> None: scores = self.get_scores(forward_output, target) + if scores.dim() == 2: + scores = scores.squeeze(1) self.probabilities.append(scores) self.number_of_points_seen += forward_output.shape[0] self.index_sampleid_map += sample_ids From b01f3bcc4d7d9d0e747eae08501121cbd5e862ed Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Sat, 7 Sep 2024 00:09:22 +0800 Subject: [PATCH 07/20] change num classes --- .../remote_downsamplers/test_remote_gradnorm_downsample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py index d82150216..cfee6d511 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py @@ -22,7 +22,7 @@ def test_sample_shape_ce(dummy_system_config: ModynConfig): ) with torch.inference_mode(mode=(not sampler.requires_grad)): data = torch.randn(8, 10) - target = torch.randint(3, size=(8,)) + target = torch.randint(2, size=(8,)) ids = list(range(8)) forward_outputs = model(data) sampler.inform_samples(ids, data, forward_outputs, target) From d08bedad8e37e48f5c977e924655b1a1f3626911 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Sat, 7 Sep 2024 09:39:28 +0800 Subject: [PATCH 08/20] refix --- .../test_remote_gradnorm_downsample.py | 10 ++-- .../test_remote_loss_downsample.py | 10 ++-- ...emote_uncertainty_downsampling_strategy.py | 13 ++--- .../remote_gradnorm_downsampling.py | 7 ++- .../remote_loss_downsampling.py | 2 - ...emote_uncertainty_downsampling_strategy.py | 51 +++++++++++-------- 6 files changed, 45 insertions(+), 48 deletions(-) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py index cfee6d511..901aa517b 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py @@ -41,8 +41,7 @@ def test_sample_shape_ce(dummy_system_config: ModynConfig): assert set(downsampled_indexes) <= set(range(8)) -@pytest.mark.parametrize("squeeze_dim", [True, False]) -def test_sample_shape_binary(dummy_system_config: ModynConfig, squeeze_dim): +def test_sample_shape_binary(dummy_system_config: ModynConfig): model = torch.nn.Linear(10, 1) downsampling_ratio = 50 per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none") @@ -53,11 +52,8 @@ def test_sample_shape_binary(dummy_system_config: ModynConfig, squeeze_dim): ) with torch.inference_mode(mode=(not sampler.requires_grad)): data = torch.randn(8, 10) - forward_outputs = model(data) - target = torch.randint(2, size=(8,), dtype=torch.float32).unsqueeze(1) - if squeeze_dim: - target = target.squeeze(1) - forward_outputs = forward_outputs.squeeze(1) + forward_outputs = model(data).squeeze(1) + target = torch.randint(2, size=(8,), dtype=torch.float32) ids = list(range(8)) sampler.inform_samples(ids, data, forward_outputs, target) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py index 52bd9e124..8b4174b93 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py @@ -36,8 +36,7 @@ def test_sample_shape(dummy_system_config: ModynConfig): assert len(indexes) == 4 -@pytest.mark.parametrize("squeeze_dim", [True, False]) -def test_sample_shape_binary(dummy_system_config: ModynConfig, squeeze_dim): +def test_sample_shape_binary(dummy_system_config: ModynConfig): model = torch.nn.Linear(10, 1) downsampling_ratio = 50 per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none") @@ -48,11 +47,8 @@ def test_sample_shape_binary(dummy_system_config: ModynConfig, squeeze_dim): ) with torch.inference_mode(mode=(not sampler.requires_grad)): data = torch.randn(8, 10) - forward_outputs = model(data) - target = torch.randint(2, size=(8,), dtype=torch.float32).unsqueeze(1) - if squeeze_dim: - target = target.squeeze(1) - forward_outputs = forward_outputs.squeeze(1) + forward_outputs = model(data).squeeze(1) + target = torch.randint(2, size=(8,), dtype=torch.float32) ids = list(range(8)) sampler.inform_samples(ids, data, forward_outputs, target) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py index 4ba3445e2..d93539715 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py @@ -76,27 +76,24 @@ def test_compute_score(sampler_config): binary_test_data = { "LeastConfidence": { - "outputs": torch.tensor([[0.8], [0.5], [0.3]]), - "expected_scores": np.array([0.8, 0.5, 0.7]), # confidence just picks the highest probability + "outputs": torch.tensor([-0.8, 0.5, 0.3]), + "expected_scores": np.array([0.8, 0.5, 0.3]), # confidence just picks the highest probability }, "Entropy": { - "outputs": torch.tensor([[0.8], [0.5], [0.3]]), + "outputs": torch.tensor([0.8, 0.5, 0.3]), "expected_scores": np.array([-0.5004, -0.6931, -0.6109]), }, "Margin": { - "outputs": torch.tensor([[0.8], [0.5], [0.3]]), + "outputs": torch.tensor([0.8, 0.5, 0.3]), "expected_scores": np.array([0.6, 0.0, 0.4]), # margin between top two classes }, } -@pytest.mark.parametrize("squeeze_dim", [True, False]) -def test_compute_score_binary(sampler_config, squeeze_dim): +def test_compute_score_binary(sampler_config): metric = sampler_config[3]["score_metric"] amds = RemoteUncertaintyDownsamplingStrategy(*sampler_config) outputs = binary_test_data[metric]["outputs"] - if squeeze_dim: - outputs = outputs.squeeze() expected_scores = binary_test_data[metric]["expected_scores"] scores = amds._compute_score(outputs, disable_softmax=True) assert np.allclose(scores, expected_scores, atol=1e-4) diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py index 4320a4666..b6a1c1b63 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py @@ -49,11 +49,14 @@ def inform_samples( target: torch.Tensor, embedding: torch.Tensor | None = None, ) -> None: + if forward_output.dim() == 1: + # BCEWithLogitsLoss requires that forward_output and target have the same shape + forward_output = forward_output.unsqueeze(1) + target = target.unsqueeze(1) + last_layer_gradients = self._compute_last_layer_gradient_wrt_loss_sum( self.per_sample_loss_fct, forward_output, target ) - if last_layer_gradients.dim() == 1: - last_layer_gradients = last_layer_gradients.unsqueeze(1) # pylint: disable=not-callable scores = torch.linalg.vector_norm(last_layer_gradients, dim=1).cpu() self.probabilities.append(scores) diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_loss_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_loss_downsampling.py index 239d86425..79bcee16c 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_loss_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_loss_downsampling.py @@ -52,8 +52,6 @@ def inform_samples( embedding: torch.Tensor | None = None, ) -> None: scores = self.get_scores(forward_output, target) - if scores.dim() == 2: - scores = scores.squeeze(1) self.probabilities.append(scores) self.number_of_points_seen += forward_output.shape[0] self.index_sampleid_map += sample_ids diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py index 7daf340ba..b92c64b97 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py @@ -72,30 +72,37 @@ def _compute_score(self, forward_output: torch.Tensor, disable_softmax: bool = F if forward_output.dim() == 1: forward_output = forward_output.unsqueeze(1) feature_size = forward_output.size(1) - if feature_size == 1: - forward_output = torch.cat((1 - forward_output, forward_output), dim=1) + if self.score_metric == "LeastConfidence": - scores = forward_output.max(dim=1).values.cpu().numpy() - elif self.score_metric == "Entropy": - preds = ( - torch.nn.functional.softmax(forward_output, dim=1).cpu().numpy() - if not disable_softmax - else forward_output.cpu().numpy() - ) - scores = (np.log(preds + 1e-6) * preds).sum(axis=1) - elif self.score_metric == "Margin": - preds = torch.nn.functional.softmax(forward_output, dim=1) if not disable_softmax else forward_output - preds_argmax = torch.argmax(preds, dim=1) # gets top class - max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone() # gets scores of top class - - # remove highest class from softmax output - preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0 - - preds_sub_argmax = torch.argmax(preds, dim=1) # gets new top class (=> 2nd top class) - second_max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax] - scores = (max_preds - second_max_preds).cpu().numpy() + if feature_size == 1: + # for binary classification comparing how far away the element is from 0.5 after sigmoid layer + # is the same as comparing the absolute value of the element before sigmoid layer + scores = torch.abs(forward_output).squeeze(1).cpu().numpy() + else: + scores = forward_output.max(dim=1).values.cpu().numpy() else: - raise AssertionError("The required metric does not exist") + if feature_size == 1: + # for binary classification the softmax layer is reduced to sigmoid + preds = torch.sigmoid(forward_output) if not disable_softmax else forward_output + # we need to convert it to a 2D tensor with probabilities for both classes + preds = torch.cat((1 - preds, preds), dim=1) + else: + preds = torch.nn.functional.softmax(forward_output, dim=1) if not disable_softmax else forward_output + + if self.score_metric == "Entropy": + scores = (np.log(preds + 1e-6) * preds).sum(axis=1) + elif self.score_metric == "Margin": + preds_argmax = torch.argmax(preds, dim=1) # gets top class + max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone() # gets scores of top class + + # remove highest class from softmax output + preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0 + + preds_sub_argmax = torch.argmax(preds, dim=1) # gets new top class (=> 2nd top class) + second_max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax] + scores = (max_preds - second_max_preds).cpu().numpy() + else: + raise AssertionError("The required metric does not exist") return scores From 127bc7a2362866edefdcd4a91b14f9c0df90084a Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Sat, 7 Sep 2024 10:08:07 +0800 Subject: [PATCH 09/20] fix craig --- .../test_craig_remote_downsampling.py | 31 +++++++++++++++++++ .../remote_craig_downsampling.py | 3 ++ 2 files changed, 34 insertions(+) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py index b72b55331..cfed636f9 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py @@ -205,6 +205,37 @@ def test_bts(grad_approx: str, dummy_system_config: ModynConfig): assert all(id in [1, 2, 3, 10, 11, 12, 13] for id in selected_points) +@pytest.mark.parametrize("grad_approx", ["LastLayer", "LastLayerWithEmbedding"]) +def test_bts_binary(grad_approx: str, dummy_system_config: ModynConfig): + sampler_config = get_sampler_config(dummy_system_config, grad_approx=grad_approx) + per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none") + sampler_config = (0, 0, 0, sampler_config[3], sampler_config[4], per_sample_loss_fct, "cpu") + sampler = RemoteCraigDownsamplingStrategy(*sampler_config) + + with torch.inference_mode(mode=(not sampler.requires_grad)): + sample_ids = [1, 2, 3, 10, 11, 12, 13] + forward_input = torch.randn(7, 5) # 7 samples, 5 input features + forward_output = torch.randn(7, ) # 7 samples, 1 output classes + forward_output.requires_grad = True + target = torch.tensor([1, 1, 1, 0, 0, 0, 1], dtype=torch.float32) # 7 target labels + embedding = torch.randn(7, 10) # 7 samples, embedding dimension 10 + + assert sampler.distance_matrix.shape == (0, 0) + sampler.inform_samples(sample_ids, forward_input, forward_output, target, embedding) + sampler.inform_end_of_current_label() + assert sampler.distance_matrix.shape == (7, 7) + assert len(sampler.current_class_gradients) == 0 + + assert sampler.index_sampleid_map == [10, 11, 12, 1, 2, 3, 13] + + selected_points, selected_weights = sampler.select_points() + + assert len(selected_points) == 3 + assert len(selected_weights) == 3 + assert all(weight > 0 for weight in selected_weights) + assert all(id in [1, 2, 3, 10, 11, 12, 13] for id in selected_points) + + @pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"]) def test_bts_equals_stb(grad_approx: str, dummy_system_config: ModynConfig): # data diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py index 7ee7033ba..54668435d 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py @@ -110,6 +110,9 @@ def _inform_samples_single_class( target: torch.Tensor, embedding: torch.Tensor | None, ) -> None: + if forward_output.dim() == 1: + forward_output = forward_output.unsqueeze(1) + target = target.unsqueeze(1) if self.full_grad_approximation == "LastLayerWithEmbedding": assert embedding is not None grads_wrt_loss_sum = self._compute_last_two_layers_gradient_wrt_loss_sum( From abac6ae7b7bce162d6e9a6d501e835420c9e7592 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Sat, 7 Sep 2024 10:24:27 +0800 Subject: [PATCH 10/20] fix matrix downsampling --- ...t_abstract_matrix_downsampling_strategy.py | 39 +++++++++++++++++++ .../abstract_matrix_downsampling_strategy.py | 3 ++ 2 files changed, 42 insertions(+) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py index 9764d8b39..2453fd2bf 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py @@ -142,3 +142,42 @@ def test_collect_gradients(matrix_content, dummy_system_config: ModynConfig): assert np.concatenate(amds.matrix_elements).shape == (7, gradient_shape) assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41] + + +@pytest.mark.parametrize( + "matrix_content", [MatrixContent.LAST_LAYER_GRADIENTS, MatrixContent.LAST_TWO_LAYERS_GRADIENTS] +) +@patch.multiple(AbstractMatrixDownsamplingStrategy, __abstractmethods__=set()) +def test_collect_gradients_binary(matrix_content, dummy_system_config: ModynConfig): + per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none") + sampler_config = list(get_sampler_config(dummy_system_config, matrix_content=matrix_content)) + sampler_config[5] = per_sample_loss_fct + sampler_config = tuple(sampler_config) + amds = AbstractMatrixDownsamplingStrategy(*sampler_config) + with torch.inference_mode(mode=(not amds.requires_grad)): + forward_input = torch.randn((4, 5)) + first_output = torch.randn((4,)) + first_output.requires_grad = True + first_target = torch.tensor([1, 1, 1, 0], dtype=torch.float32) + first_embedding = torch.randn((4, 5)) + amds.inform_samples([1, 2, 3, 4], forward_input, first_output, first_target, first_embedding) + + second_output = torch.randn((3,)) + second_output.requires_grad = True + second_target = torch.tensor([0, 1, 0], dtype=torch.float32) + second_embedding = torch.randn((3, 5)) + amds.inform_samples([21, 31, 41], forward_input, second_output, second_target, second_embedding) + + assert len(amds.matrix_elements) == 2 + + # expected shape = (a, gradient_shape) + # a = 7 (4 samples in the first batch and 3 samples in the second batch) + if matrix_content == MatrixContent.LAST_LAYER_GRADIENTS: + # shape same as the last dimension of output + gradient_shape = 1 + else: + # 5 is the input dimension of the last layer and 2 is the output one + gradient_shape = 5 * 1 + 1 + assert np.concatenate(amds.matrix_elements).shape == (7, gradient_shape) + + assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41] diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py index 1534e740b..167cc80a0 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py @@ -71,6 +71,9 @@ def inform_samples( ) -> None: batch_size = len(sample_ids) assert self.matrix_content is not None + if forward_output.dim() == 1: + forward_output = forward_output.unsqueeze(1) + target = target.unsqueeze(1) if self.matrix_content == MatrixContent.LAST_LAYER_GRADIENTS: grads_wrt_loss_sum = self._compute_last_layer_gradient_wrt_loss_sum(self.criterion, forward_output, target) grads_wrt_loss_mean = grads_wrt_loss_sum / batch_size From 39ec783b01418aff9fd67e87fa7f2845ac349c2f Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Sat, 7 Sep 2024 10:28:59 +0800 Subject: [PATCH 11/20] fix ruff --- .../remote_downsamplers/test_craig_remote_downsampling.py | 2 +- .../remote_downsamplers/test_remote_gradnorm_downsample.py | 1 - .../remote_downsamplers/test_remote_loss_downsample.py | 1 - .../remote_uncertainty_downsampling_strategy.py | 4 +++- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py index cfed636f9..42981b2a6 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py @@ -215,7 +215,7 @@ def test_bts_binary(grad_approx: str, dummy_system_config: ModynConfig): with torch.inference_mode(mode=(not sampler.requires_grad)): sample_ids = [1, 2, 3, 10, 11, 12, 13] forward_input = torch.randn(7, 5) # 7 samples, 5 input features - forward_output = torch.randn(7, ) # 7 samples, 1 output classes + forward_output = torch.randn(7,) forward_output.requires_grad = True target = torch.tensor([1, 1, 1, 0, 0, 0, 1], dtype=torch.float32) # 7 target labels embedding = torch.randn(7, 10) # 7 samples, embedding dimension 10 diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py index 901aa517b..2b3b49fd7 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_gradnorm_downsample.py @@ -1,4 +1,3 @@ -import pytest import torch from torch import nn diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py index 8b4174b93..5f4f0ceb3 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_loss_downsample.py @@ -1,4 +1,3 @@ -import pytest import torch from torch import nn diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py index b92c64b97..22faf2196 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py @@ -93,7 +93,9 @@ def _compute_score(self, forward_output: torch.Tensor, disable_softmax: bool = F scores = (np.log(preds + 1e-6) * preds).sum(axis=1) elif self.score_metric == "Margin": preds_argmax = torch.argmax(preds, dim=1) # gets top class - max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone() # gets scores of top class + max_preds = preds[ + torch.ones(preds.shape[0], dtype=bool), preds_argmax + ].clone() # gets scores of top class # remove highest class from softmax output preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0 From 73c96ff4cf2fcde75f02a7d85b07dc17e997b2b1 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Sat, 7 Sep 2024 11:00:01 +0800 Subject: [PATCH 12/20] fix ci --- .../test_abstract_matrix_downsampling_strategy.py | 12 ++++++------ .../test_craig_remote_downsampling.py | 4 +++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py index 2453fd2bf..4b121c9a0 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py @@ -56,15 +56,15 @@ def test_collect_embeddings(dummy_system_config: ModynConfig): first_embedding = torch.randn((4, 5)) second_embedding = torch.randn((3, 5)) - amds.inform_samples([1, 2, 3, 4], None, None, None, first_embedding) - amds.inform_samples([21, 31, 41], None, None, None, second_embedding) + amds.inform_samples([1, 2, 3, 4], None, torch.randn((4, 2)), None, first_embedding) + amds.inform_samples([21, 31, 41], None, torch.randn((3, 2)), None, second_embedding) assert np.concatenate(amds.matrix_elements).shape == (7, 5) assert all(torch.equal(el1, el2) for el1, el2 in zip(amds.matrix_elements, [first_embedding, second_embedding])) assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41] third_embedding = torch.randn((23, 5)) - amds.inform_samples(list(range(1000, 1023)), None, None, None, third_embedding) + amds.inform_samples(list(range(1000, 1023)), None, torch.randn((23, 2)), None, third_embedding) assert np.concatenate(amds.matrix_elements).shape == (30, 5) assert all( @@ -88,8 +88,8 @@ def test_collect_embedding_balance(test_amds, dummy_system_config: ModynConfig): first_embedding = torch.randn((4, 5)) second_embedding = torch.randn((3, 5)) - amds.inform_samples([1, 2, 3, 4], None, None, None, first_embedding) - amds.inform_samples([21, 31, 41], None, None, None, second_embedding) + amds.inform_samples([1, 2, 3, 4], None, torch.randn((4, 2)), None, first_embedding) + amds.inform_samples([21, 31, 41], None, torch.randn((3, 2)), None, second_embedding) assert np.concatenate(amds.matrix_elements).shape == (7, 5) assert all(torch.equal(el1, el2) for el1, el2 in zip(amds.matrix_elements, [first_embedding, second_embedding])) @@ -99,7 +99,7 @@ def test_collect_embedding_balance(test_amds, dummy_system_config: ModynConfig): third_embedding = torch.randn((23, 5)) assert len(amds.matrix_elements) == 0 - amds.inform_samples(list(range(1000, 1023)), None, None, None, third_embedding) + amds.inform_samples(list(range(1000, 1023)), None, torch.randn((23, 2)), None, third_embedding) assert np.concatenate(amds.matrix_elements).shape == (23, 5) assert all(torch.equal(el1, el2) for el1, el2 in zip(amds.matrix_elements, [third_embedding])) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py index 42981b2a6..fe93c8e5a 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_craig_remote_downsampling.py @@ -215,7 +215,9 @@ def test_bts_binary(grad_approx: str, dummy_system_config: ModynConfig): with torch.inference_mode(mode=(not sampler.requires_grad)): sample_ids = [1, 2, 3, 10, 11, 12, 13] forward_input = torch.randn(7, 5) # 7 samples, 5 input features - forward_output = torch.randn(7,) + forward_output = torch.randn( + 7, + ) forward_output.requires_grad = True target = torch.tensor([1, 1, 1, 0, 0, 0, 1], dtype=torch.float32) # 7 target labels embedding = torch.randn(7, 10) # 7 samples, embedding dimension 10 From e78bf91ba12a25f6ac3dda38247399c3bb17814e Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Sat, 7 Sep 2024 11:09:04 +0800 Subject: [PATCH 13/20] fix comment --- .../test_abstract_matrix_downsampling_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py index 4b121c9a0..dc77e6f2d 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_abstract_matrix_downsampling_strategy.py @@ -176,7 +176,7 @@ def test_collect_gradients_binary(matrix_content, dummy_system_config: ModynConf # shape same as the last dimension of output gradient_shape = 1 else: - # 5 is the input dimension of the last layer and 2 is the output one + # 5 is the input dimension of the last layer and 1 is the output one gradient_shape = 5 * 1 + 1 assert np.concatenate(amds.matrix_elements).shape == (7, gradient_shape) From 65037b0e0d9b23f702f3d7892f665a0c9ce3674b Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Sat, 7 Sep 2024 11:10:03 +0800 Subject: [PATCH 14/20] fix comment --- .../trainer/remote_downsamplers/remote_gradnorm_downsampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py index b6a1c1b63..3154cdaa1 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py @@ -50,7 +50,7 @@ def inform_samples( embedding: torch.Tensor | None = None, ) -> None: if forward_output.dim() == 1: - # BCEWithLogitsLoss requires that forward_output and target have the same shape + # BCEWithLogitsLoss requires that forward_output and target have the same dimension forward_output = forward_output.unsqueeze(1) target = target.unsqueeze(1) From dd78e1d74cae139ef47d0140d50e2151c4b61f84 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Tue, 10 Sep 2024 18:11:35 +0800 Subject: [PATCH 15/20] util func --- ...t_remote_uncertainty_downsampling_strategy.py | 6 +++--- .../abstract_matrix_downsampling_strategy.py | 6 +++--- .../abstract_remote_downsampling_strategy.py | 16 ++++++++++++++++ .../remote_craig_downsampling.py | 6 ++---- .../remote_gradnorm_downsampling.py | 7 ++----- .../remote_uncertainty_downsampling_strategy.py | 6 +++--- 6 files changed, 29 insertions(+), 18 deletions(-) diff --git a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py index d93539715..52c391f50 100644 --- a/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py +++ b/modyn/tests/trainer_server/internal/trainer/remote_downsamplers/test_remote_uncertainty_downsampling_strategy.py @@ -76,15 +76,15 @@ def test_compute_score(sampler_config): binary_test_data = { "LeastConfidence": { - "outputs": torch.tensor([-0.8, 0.5, 0.3]), + "outputs": torch.tensor([[-0.8], [0.5], [0.3]]), "expected_scores": np.array([0.8, 0.5, 0.3]), # confidence just picks the highest probability }, "Entropy": { - "outputs": torch.tensor([0.8, 0.5, 0.3]), + "outputs": torch.tensor([[0.8], [0.5], [0.3]]), "expected_scores": np.array([-0.5004, -0.6931, -0.6109]), }, "Margin": { - "outputs": torch.tensor([0.8, 0.5, 0.3]), + "outputs": torch.tensor([[0.8], [0.5], [0.3]]), "expected_scores": np.array([0.6, 0.0, 0.4]), # margin between top two classes }, } diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py index 167cc80a0..6ed733831 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py @@ -8,6 +8,8 @@ from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_per_label_remote_downsample_strategy import ( AbstractPerLabelRemoteDownsamplingStrategy, ) +from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import \ + unsqueeze_dimensions_if_necessary from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.shuffling import _shuffle_list_and_tensor @@ -71,9 +73,7 @@ def inform_samples( ) -> None: batch_size = len(sample_ids) assert self.matrix_content is not None - if forward_output.dim() == 1: - forward_output = forward_output.unsqueeze(1) - target = target.unsqueeze(1) + forward_output, target = unsqueeze_dimensions_if_necessary(forward_output, target) if self.matrix_content == MatrixContent.LAST_LAYER_GRADIENTS: grads_wrt_loss_sum = self._compute_last_layer_gradient_wrt_loss_sum(self.criterion, forward_output, target) grads_wrt_loss_mean = grads_wrt_loss_sum / batch_size diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_remote_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_remote_downsampling_strategy.py index 7d24bbffd..ecdd9e177 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_remote_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_remote_downsampling_strategy.py @@ -36,6 +36,22 @@ def get_tensors_subset( return sub_data, sub_target +def unsqueeze_dimensions_if_necessary( + forward_output: torch.Tensor, target: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """ + For binary classification, the forward output is a 1D tensor of length batch_size. We need to unsqueeze it to + have a 2D tensor of shape (batch_size, 1). + + For binary classification we use BCEWithLogitsLoss, which requires the same dimensionality between the forward + output and the target, so we also need to unsqueeze the target tensor. + """ + if forward_output.dim() == 1: + forward_output = forward_output.unsqueeze(1) + target = target.unsqueeze(1) + return forward_output, target + + class AbstractRemoteDownsamplingStrategy(ABC): def __init__( self, diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py index 54668435d..cc9d6c473 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py @@ -8,7 +8,7 @@ AbstractPerLabelRemoteDownsamplingStrategy, ) from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import ( - FULL_GRAD_APPROXIMATION, + FULL_GRAD_APPROXIMATION, unsqueeze_dimensions_if_necessary, ) from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils import submodular_optimizer from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.euclidean import euclidean_dist_pair_np @@ -110,9 +110,7 @@ def _inform_samples_single_class( target: torch.Tensor, embedding: torch.Tensor | None, ) -> None: - if forward_output.dim() == 1: - forward_output = forward_output.unsqueeze(1) - target = target.unsqueeze(1) + forward_output, target = unsqueeze_dimensions_if_necessary(forward_output, target) if self.full_grad_approximation == "LastLayerWithEmbedding": assert embedding is not None grads_wrt_loss_sum = self._compute_last_two_layers_gradient_wrt_loss_sum( diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py index 3154cdaa1..f5eb5e22b 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py @@ -4,7 +4,7 @@ import torch from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import ( - AbstractRemoteDownsamplingStrategy, + AbstractRemoteDownsamplingStrategy, unsqueeze_dimensions_if_necessary, ) logger = logging.getLogger(__name__) @@ -49,10 +49,7 @@ def inform_samples( target: torch.Tensor, embedding: torch.Tensor | None = None, ) -> None: - if forward_output.dim() == 1: - # BCEWithLogitsLoss requires that forward_output and target have the same dimension - forward_output = forward_output.unsqueeze(1) - target = target.unsqueeze(1) + forward_output, target = unsqueeze_dimensions_if_necessary(forward_output, target) last_layer_gradients = self._compute_last_layer_gradient_wrt_loss_sum( self.per_sample_loss_fct, forward_output, target diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py index 22faf2196..f6c6a87c3 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py @@ -6,6 +6,8 @@ from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_per_label_remote_downsample_strategy import ( AbstractPerLabelRemoteDownsamplingStrategy, ) +from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import \ + unsqueeze_dimensions_if_necessary from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.shuffling import _shuffle_list_and_tensor @@ -64,15 +66,13 @@ def inform_samples( ) -> None: assert embedding is None + forward_output, _ = unsqueeze_dimensions_if_necessary(forward_output, target) self.scores = np.append(self.scores, self._compute_score(forward_output.detach())) # keep the mapping index<->sample_id self.index_sampleid_map += sample_ids def _compute_score(self, forward_output: torch.Tensor, disable_softmax: bool = False) -> np.ndarray: - if forward_output.dim() == 1: - forward_output = forward_output.unsqueeze(1) feature_size = forward_output.size(1) - if self.score_metric == "LeastConfidence": if feature_size == 1: # for binary classification comparing how far away the element is from 0.5 after sigmoid layer From 8afc4977fe5d6e61314749473a50d283422aca54 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Tue, 10 Sep 2024 18:24:41 +0800 Subject: [PATCH 16/20] fix ruff --- .../abstract_matrix_downsampling_strategy.py | 6 ++++-- .../abstract_remote_downsampling_strategy.py | 13 +++++++------ .../remote_craig_downsampling.py | 3 ++- .../remote_gradnorm_downsampling.py | 3 ++- .../remote_uncertainty_downsampling_strategy.py | 5 +++-- 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py index 6ed733831..2fc79b0ef 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py @@ -8,8 +8,10 @@ from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_per_label_remote_downsample_strategy import ( AbstractPerLabelRemoteDownsamplingStrategy, ) -from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import \ - unsqueeze_dimensions_if_necessary +from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import ( + unsqueeze_dimensions_if_necessary, +) + from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.shuffling import _shuffle_list_and_tensor diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_remote_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_remote_downsampling_strategy.py index ecdd9e177..2f0e0b545 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_remote_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_remote_downsampling_strategy.py @@ -37,14 +37,15 @@ def get_tensors_subset( def unsqueeze_dimensions_if_necessary( - forward_output: torch.Tensor, target: torch.Tensor + forward_output: torch.Tensor, target: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - """ - For binary classification, the forward output is a 1D tensor of length batch_size. We need to unsqueeze it to - have a 2D tensor of shape (batch_size, 1). + """For binary classification, the forward output is a 1D tensor of length + batch_size. We need to unsqueeze it to have a 2D tensor of shape + (batch_size, 1). - For binary classification we use BCEWithLogitsLoss, which requires the same dimensionality between the forward - output and the target, so we also need to unsqueeze the target tensor. + For binary classification we use BCEWithLogitsLoss, which requires + the same dimensionality between the forward output and the target, + so we also need to unsqueeze the target tensor. """ if forward_output.dim() == 1: forward_output = forward_output.unsqueeze(1) diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py index cc9d6c473..7a5c88878 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_craig_downsampling.py @@ -8,7 +8,8 @@ AbstractPerLabelRemoteDownsamplingStrategy, ) from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import ( - FULL_GRAD_APPROXIMATION, unsqueeze_dimensions_if_necessary, + FULL_GRAD_APPROXIMATION, + unsqueeze_dimensions_if_necessary, ) from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils import submodular_optimizer from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.euclidean import euclidean_dist_pair_np diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py index f5eb5e22b..66ae24784 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_gradnorm_downsampling.py @@ -4,7 +4,8 @@ import torch from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import ( - AbstractRemoteDownsamplingStrategy, unsqueeze_dimensions_if_necessary, + AbstractRemoteDownsamplingStrategy, + unsqueeze_dimensions_if_necessary, ) logger = logging.getLogger(__name__) diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py index f6c6a87c3..93dd12635 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py @@ -6,8 +6,9 @@ from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_per_label_remote_downsample_strategy import ( AbstractPerLabelRemoteDownsamplingStrategy, ) -from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import \ - unsqueeze_dimensions_if_necessary +from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import ( + unsqueeze_dimensions_if_necessary, +) from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.shuffling import _shuffle_list_and_tensor From 785dac04cb143577208226a8a42a09e58f919695 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Tue, 10 Sep 2024 18:26:49 +0800 Subject: [PATCH 17/20] fix ruff again --- .../remote_downsamplers/abstract_matrix_downsampling_strategy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py index 2fc79b0ef..9b8a1bee9 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/abstract_matrix_downsampling_strategy.py @@ -11,7 +11,6 @@ from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import ( unsqueeze_dimensions_if_necessary, ) - from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.shuffling import _shuffle_list_and_tensor From aff337cd3fa248aaea3da2d98bf4f9732bcebd93 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Wed, 11 Sep 2024 09:37:13 +0800 Subject: [PATCH 18/20] add comment --- .../remote_uncertainty_downsampling_strategy.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py index 93dd12635..d3e19a880 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py @@ -76,8 +76,15 @@ def _compute_score(self, forward_output: torch.Tensor, disable_softmax: bool = F feature_size = forward_output.size(1) if self.score_metric == "LeastConfidence": if feature_size == 1: - # for binary classification comparing how far away the element is from 0.5 after sigmoid layer - # is the same as comparing the absolute value of the element before sigmoid layer + # For binary classification there is only one pre-sigmoid output value, which, after sigmoid layer, + # is the probability of the positive class. The probability of the negative class is + # 1 - probability_positive_class. + # For each sample we need to compute the pre-sigmoid output for the class with the highest probability. + # If model_output_value > 0, then sigmoid(model_output_value) > 0.5, hence the positive class has the + # highest probability and model_output_value is what we need. + # If model_output_value < 0, this case is symmetric to the case where the model output value is + # - | model_output_value |. Hence, we can just take the absolute value. + # In any case, we just need to compute the absolute value of the model output value. scores = torch.abs(forward_output).squeeze(1).cpu().numpy() else: scores = forward_output.max(dim=1).values.cpu().numpy() From fab6511230f241523cd2504a10c23a0c88081f88 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Wed, 11 Sep 2024 09:40:59 +0800 Subject: [PATCH 19/20] change comment --- .../remote_uncertainty_downsampling_strategy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py index d3e19a880..e3500256c 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py @@ -82,8 +82,9 @@ def _compute_score(self, forward_output: torch.Tensor, disable_softmax: bool = F # For each sample we need to compute the pre-sigmoid output for the class with the highest probability. # If model_output_value > 0, then sigmoid(model_output_value) > 0.5, hence the positive class has the # highest probability and model_output_value is what we need. - # If model_output_value < 0, this case is symmetric to the case where the model output value is - # - | model_output_value |. Hence, we can just take the absolute value. + # If model_output_value < 0, then sigmoid(model_output_value) < 0.5, hence the negative class has the + # highest probability. The corresponding pre-sigmoid output value for the negative class + # is - model_output_value. # In any case, we just need to compute the absolute value of the model output value. scores = torch.abs(forward_output).squeeze(1).cpu().numpy() else: From eabe7de901989b00f54a957c4a764533f014354e Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Wed, 11 Sep 2024 09:41:22 +0800 Subject: [PATCH 20/20] add margin case --- .../remote_uncertainty_downsampling_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py index e3500256c..3cf9341a6 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py @@ -80,7 +80,7 @@ def _compute_score(self, forward_output: torch.Tensor, disable_softmax: bool = F # is the probability of the positive class. The probability of the negative class is # 1 - probability_positive_class. # For each sample we need to compute the pre-sigmoid output for the class with the highest probability. - # If model_output_value > 0, then sigmoid(model_output_value) > 0.5, hence the positive class has the + # If model_output_value >= 0, then sigmoid(model_output_value) >= 0.5, hence the positive class has the # highest probability and model_output_value is what we need. # If model_output_value < 0, then sigmoid(model_output_value) < 0.5, hence the negative class has the # highest probability. The corresponding pre-sigmoid output value for the negative class