Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dimensionality issue when binary classification outputs 1D instead of 2D tensor #609

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ def test_collect_embeddings(dummy_system_config: ModynConfig):

first_embedding = torch.randn((4, 5))
second_embedding = torch.randn((3, 5))
amds.inform_samples([1, 2, 3, 4], None, None, None, first_embedding)
amds.inform_samples([21, 31, 41], None, None, None, second_embedding)
amds.inform_samples([1, 2, 3, 4], None, torch.randn((4, 2)), None, first_embedding)
amds.inform_samples([21, 31, 41], None, torch.randn((3, 2)), None, second_embedding)

assert np.concatenate(amds.matrix_elements).shape == (7, 5)
assert all(torch.equal(el1, el2) for el1, el2 in zip(amds.matrix_elements, [first_embedding, second_embedding]))
assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41]

third_embedding = torch.randn((23, 5))
amds.inform_samples(list(range(1000, 1023)), None, None, None, third_embedding)
amds.inform_samples(list(range(1000, 1023)), None, torch.randn((23, 2)), None, third_embedding)

assert np.concatenate(amds.matrix_elements).shape == (30, 5)
assert all(
Expand All @@ -88,8 +88,8 @@ def test_collect_embedding_balance(test_amds, dummy_system_config: ModynConfig):

first_embedding = torch.randn((4, 5))
second_embedding = torch.randn((3, 5))
amds.inform_samples([1, 2, 3, 4], None, None, None, first_embedding)
amds.inform_samples([21, 31, 41], None, None, None, second_embedding)
amds.inform_samples([1, 2, 3, 4], None, torch.randn((4, 2)), None, first_embedding)
amds.inform_samples([21, 31, 41], None, torch.randn((3, 2)), None, second_embedding)

assert np.concatenate(amds.matrix_elements).shape == (7, 5)
assert all(torch.equal(el1, el2) for el1, el2 in zip(amds.matrix_elements, [first_embedding, second_embedding]))
Expand All @@ -99,7 +99,7 @@ def test_collect_embedding_balance(test_amds, dummy_system_config: ModynConfig):

third_embedding = torch.randn((23, 5))
assert len(amds.matrix_elements) == 0
amds.inform_samples(list(range(1000, 1023)), None, None, None, third_embedding)
amds.inform_samples(list(range(1000, 1023)), None, torch.randn((23, 2)), None, third_embedding)

assert np.concatenate(amds.matrix_elements).shape == (23, 5)
assert all(torch.equal(el1, el2) for el1, el2 in zip(amds.matrix_elements, [third_embedding]))
Expand Down Expand Up @@ -142,3 +142,42 @@ def test_collect_gradients(matrix_content, dummy_system_config: ModynConfig):
assert np.concatenate(amds.matrix_elements).shape == (7, gradient_shape)

assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41]


@pytest.mark.parametrize(
"matrix_content", [MatrixContent.LAST_LAYER_GRADIENTS, MatrixContent.LAST_TWO_LAYERS_GRADIENTS]
)
@patch.multiple(AbstractMatrixDownsamplingStrategy, __abstractmethods__=set())
def test_collect_gradients_binary(matrix_content, dummy_system_config: ModynConfig):
per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none")
sampler_config = list(get_sampler_config(dummy_system_config, matrix_content=matrix_content))
sampler_config[5] = per_sample_loss_fct
sampler_config = tuple(sampler_config)
amds = AbstractMatrixDownsamplingStrategy(*sampler_config)
with torch.inference_mode(mode=(not amds.requires_grad)):
forward_input = torch.randn((4, 5))
first_output = torch.randn((4,))
first_output.requires_grad = True
first_target = torch.tensor([1, 1, 1, 0], dtype=torch.float32)
first_embedding = torch.randn((4, 5))
amds.inform_samples([1, 2, 3, 4], forward_input, first_output, first_target, first_embedding)

second_output = torch.randn((3,))
second_output.requires_grad = True
second_target = torch.tensor([0, 1, 0], dtype=torch.float32)
second_embedding = torch.randn((3, 5))
amds.inform_samples([21, 31, 41], forward_input, second_output, second_target, second_embedding)

assert len(amds.matrix_elements) == 2

# expected shape = (a, gradient_shape)
# a = 7 (4 samples in the first batch and 3 samples in the second batch)
if matrix_content == MatrixContent.LAST_LAYER_GRADIENTS:
# shape same as the last dimension of output
gradient_shape = 1
else:
# 5 is the input dimension of the last layer and 1 is the output one
gradient_shape = 5 * 1 + 1
assert np.concatenate(amds.matrix_elements).shape == (7, gradient_shape)

assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41]
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,39 @@ def test_bts(grad_approx: str, dummy_system_config: ModynConfig):
assert all(id in [1, 2, 3, 10, 11, 12, 13] for id in selected_points)


@pytest.mark.parametrize("grad_approx", ["LastLayer", "LastLayerWithEmbedding"])
def test_bts_binary(grad_approx: str, dummy_system_config: ModynConfig):
sampler_config = get_sampler_config(dummy_system_config, grad_approx=grad_approx)
per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none")
sampler_config = (0, 0, 0, sampler_config[3], sampler_config[4], per_sample_loss_fct, "cpu")
sampler = RemoteCraigDownsamplingStrategy(*sampler_config)

with torch.inference_mode(mode=(not sampler.requires_grad)):
sample_ids = [1, 2, 3, 10, 11, 12, 13]
forward_input = torch.randn(7, 5) # 7 samples, 5 input features
forward_output = torch.randn(
7,
)
forward_output.requires_grad = True
target = torch.tensor([1, 1, 1, 0, 0, 0, 1], dtype=torch.float32) # 7 target labels
embedding = torch.randn(7, 10) # 7 samples, embedding dimension 10

assert sampler.distance_matrix.shape == (0, 0)
sampler.inform_samples(sample_ids, forward_input, forward_output, target, embedding)
sampler.inform_end_of_current_label()
assert sampler.distance_matrix.shape == (7, 7)
assert len(sampler.current_class_gradients) == 0

assert sampler.index_sampleid_map == [10, 11, 12, 1, 2, 3, 13]

selected_points, selected_weights = sampler.select_points()

assert len(selected_points) == 3
assert len(selected_weights) == 3
assert all(weight > 0 for weight in selected_weights)
assert all(id in [1, 2, 3, 10, 11, 12, 13] for id in selected_points)


@pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"])
def test_bts_equals_stb(grad_approx: str, dummy_system_config: ModynConfig):
# data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_sample_shape_ce(dummy_system_config: ModynConfig):
assert set(downsampled_indexes) <= set(range(8))


def test_sample_shape_other_losses(dummy_system_config: ModynConfig):
def test_sample_shape_binary(dummy_system_config: ModynConfig):
model = torch.nn.Linear(10, 1)
downsampling_ratio = 50
per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none")
Expand All @@ -51,11 +51,10 @@ def test_sample_shape_other_losses(dummy_system_config: ModynConfig):
)
with torch.inference_mode(mode=(not sampler.requires_grad)):
data = torch.randn(8, 10)
target = torch.randint(2, size=(8,), dtype=torch.float32).unsqueeze(1)
forward_outputs = model(data).squeeze(1)
target = torch.randint(2, size=(8,), dtype=torch.float32)
ids = list(range(8))

forward_outputs = model(data)

sampler.inform_samples(ids, data, forward_outputs, target)
downsampled_indexes, weights = sampler.select_points()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,36 @@ def test_sample_shape(dummy_system_config: ModynConfig):
assert len(indexes) == 4


def test_sample_shape_binary(dummy_system_config: ModynConfig):
model = torch.nn.Linear(10, 1)
downsampling_ratio = 50
per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none")

params_from_selector = {"downsampling_ratio": downsampling_ratio, "sample_then_batch": False, "ratio_max": 100}
sampler = RemoteLossDownsampling(
0, 0, 0, params_from_selector, dummy_system_config.model_dump(by_alias=True), per_sample_loss_fct, "cpu"
)
with torch.inference_mode(mode=(not sampler.requires_grad)):
data = torch.randn(8, 10)
forward_outputs = model(data).squeeze(1)
target = torch.randint(2, size=(8,), dtype=torch.float32)
ids = list(range(8))

sampler.inform_samples(ids, data, forward_outputs, target)
downsampled_indexes, weights = sampler.select_points()

assert len(downsampled_indexes) == 4
assert weights.shape[0] == 4

sampled_data, sampled_target = get_tensors_subset(downsampled_indexes, data, target, ids)

assert weights.shape[0] == sampled_target.shape[0]
assert sampled_data.shape[0] == 4
assert sampled_data.shape[1] == data.shape[1]
assert weights.shape[0] == 4
assert sampled_target.shape[0] == 4


def test_sample_weights(dummy_system_config: ModynConfig):
model = torch.nn.Linear(10, 2)
downsampling_ratio = 50
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,31 @@ def test_compute_score(sampler_config):
assert np.allclose(scores, expected_scores, atol=1e-4)


binary_test_data = {
"LeastConfidence": {
"outputs": torch.tensor([[-0.8], [0.5], [0.3]]),
"expected_scores": np.array([0.8, 0.5, 0.3]), # confidence just picks the highest probability
},
"Entropy": {
"outputs": torch.tensor([[0.8], [0.5], [0.3]]),
"expected_scores": np.array([-0.5004, -0.6931, -0.6109]),
},
"Margin": {
"outputs": torch.tensor([[0.8], [0.5], [0.3]]),
"expected_scores": np.array([0.6, 0.0, 0.4]), # margin between top two classes
},
}


def test_compute_score_binary(sampler_config):
metric = sampler_config[3]["score_metric"]
amds = RemoteUncertaintyDownsamplingStrategy(*sampler_config)
outputs = binary_test_data[metric]["outputs"]
expected_scores = binary_test_data[metric]["expected_scores"]
scores = amds._compute_score(outputs, disable_softmax=True)
assert np.allclose(scores, expected_scores, atol=1e-4)


def test_select_points(balance_config):
amds = RemoteUncertaintyDownsamplingStrategy(*balance_config)
with torch.inference_mode():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_per_label_remote_downsample_strategy import (
AbstractPerLabelRemoteDownsamplingStrategy,
)
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import (
unsqueeze_dimensions_if_necessary,
)
from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.shuffling import _shuffle_list_and_tensor


Expand Down Expand Up @@ -71,6 +74,7 @@ def inform_samples(
) -> None:
batch_size = len(sample_ids)
assert self.matrix_content is not None
forward_output, target = unsqueeze_dimensions_if_necessary(forward_output, target)
if self.matrix_content == MatrixContent.LAST_LAYER_GRADIENTS:
grads_wrt_loss_sum = self._compute_last_layer_gradient_wrt_loss_sum(self.criterion, forward_output, target)
grads_wrt_loss_mean = grads_wrt_loss_sum / batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ def get_tensors_subset(
return sub_data, sub_target


def unsqueeze_dimensions_if_necessary(
forward_output: torch.Tensor, target: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""For binary classification, the forward output is a 1D tensor of length
batch_size. We need to unsqueeze it to have a 2D tensor of shape
(batch_size, 1).

For binary classification we use BCEWithLogitsLoss, which requires
the same dimensionality between the forward output and the target,
so we also need to unsqueeze the target tensor.
"""
if forward_output.dim() == 1:
forward_output = forward_output.unsqueeze(1)
target = target.unsqueeze(1)
return forward_output, target


class AbstractRemoteDownsamplingStrategy(ABC):
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import (
FULL_GRAD_APPROXIMATION,
unsqueeze_dimensions_if_necessary,
)
from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils import submodular_optimizer
from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.euclidean import euclidean_dist_pair_np
Expand Down Expand Up @@ -110,6 +111,7 @@ def _inform_samples_single_class(
target: torch.Tensor,
embedding: torch.Tensor | None,
) -> None:
forward_output, target = unsqueeze_dimensions_if_necessary(forward_output, target)
if self.full_grad_approximation == "LastLayerWithEmbedding":
assert embedding is not None
grads_wrt_loss_sum = self._compute_last_two_layers_gradient_wrt_loss_sum(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import (
AbstractRemoteDownsamplingStrategy,
unsqueeze_dimensions_if_necessary,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -49,10 +50,13 @@ def inform_samples(
target: torch.Tensor,
embedding: torch.Tensor | None = None,
) -> None:
forward_output, target = unsqueeze_dimensions_if_necessary(forward_output, target)

last_layer_gradients = self._compute_last_layer_gradient_wrt_loss_sum(
self.per_sample_loss_fct, forward_output, target
)
scores = torch.norm(last_layer_gradients, dim=-1).cpu()
# pylint: disable=not-callable
scores = torch.linalg.vector_norm(last_layer_gradients, dim=1).cpu()
self.probabilities.append(scores)
self.number_of_points_seen += forward_output.shape[0]
self.index_sampleid_map += sample_ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_per_label_remote_downsample_strategy import (
AbstractPerLabelRemoteDownsamplingStrategy,
)
from modyn.trainer_server.internal.trainer.remote_downsamplers.abstract_remote_downsampling_strategy import (
unsqueeze_dimensions_if_necessary,
)
from modyn.trainer_server.internal.trainer.remote_downsamplers.deepcore_utils.shuffling import _shuffle_list_and_tensor


Expand Down Expand Up @@ -64,33 +67,53 @@
) -> None:
assert embedding is None

forward_output, _ = unsqueeze_dimensions_if_necessary(forward_output, target)
self.scores = np.append(self.scores, self._compute_score(forward_output.detach()))
# keep the mapping index<->sample_id
self.index_sampleid_map += sample_ids

def _compute_score(self, forward_output: torch.Tensor, disable_softmax: bool = False) -> np.ndarray:
feature_size = forward_output.size(1)
if self.score_metric == "LeastConfidence":
scores = forward_output.max(dim=1).values.cpu().numpy()
elif self.score_metric == "Entropy":
preds = (
torch.nn.functional.softmax(forward_output, dim=1).cpu().numpy()
if not disable_softmax
else forward_output.cpu().numpy()
)
scores = (np.log(preds + 1e-6) * preds).sum(axis=1)
elif self.score_metric == "Margin":
preds = torch.nn.functional.softmax(forward_output, dim=1) if not disable_softmax else forward_output
preds_argmax = torch.argmax(preds, dim=1) # gets top class
max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone() # gets scores of top class

# remove highest class from softmax output
preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0

preds_sub_argmax = torch.argmax(preds, dim=1) # gets new top class (=> 2nd top class)
second_max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax]
scores = (max_preds - second_max_preds).cpu().numpy()
if feature_size == 1:
# For binary classification there is only one pre-sigmoid output value, which, after sigmoid layer,
# is the probability of the positive class. The probability of the negative class is
# 1 - probability_positive_class.
# For each sample we need to compute the pre-sigmoid output for the class with the highest probability.
# If model_output_value >= 0, then sigmoid(model_output_value) >= 0.5, hence the positive class has the
# highest probability and model_output_value is what we need.
# If model_output_value < 0, then sigmoid(model_output_value) < 0.5, hence the negative class has the
# highest probability. The corresponding pre-sigmoid output value for the negative class
# is - model_output_value.
# In any case, we just need to compute the absolute value of the model output value.
scores = torch.abs(forward_output).squeeze(1).cpu().numpy()
else:
scores = forward_output.max(dim=1).values.cpu().numpy()
else:
raise AssertionError("The required metric does not exist")
if feature_size == 1:
# for binary classification the softmax layer is reduced to sigmoid
preds = torch.sigmoid(forward_output) if not disable_softmax else forward_output
# we need to convert it to a 2D tensor with probabilities for both classes
preds = torch.cat((1 - preds, preds), dim=1)
else:
preds = torch.nn.functional.softmax(forward_output, dim=1) if not disable_softmax else forward_output

if self.score_metric == "Entropy":
scores = (np.log(preds + 1e-6) * preds).sum(axis=1)
elif self.score_metric == "Margin":
preds_argmax = torch.argmax(preds, dim=1) # gets top class
max_preds = preds[
torch.ones(preds.shape[0], dtype=bool), preds_argmax
].clone() # gets scores of top class

# remove highest class from softmax output
preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0

preds_sub_argmax = torch.argmax(preds, dim=1) # gets new top class (=> 2nd top class)
second_max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax]
scores = (max_preds - second_max_preds).cpu().numpy()
else:
raise AssertionError("The required metric does not exist")

Check warning on line 116 in modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py

View check run for this annotation

Codecov / codecov/patch

modyn/trainer_server/internal/trainer/remote_downsamplers/remote_uncertainty_downsampling_strategy.py#L116

Added line #L116 was not covered by tests

return scores

Expand Down Expand Up @@ -139,7 +162,7 @@
# we select those with minimal negative entropy, i.e., maximum entropy
# Margin: We look for the smallest margin. The larger the margin, the more certain the
# model is.
return np.argsort(self.scores)[:target_size], torch.ones(target_size).float()
return np.argsort(self.scores)[:target_size].tolist(), torch.ones(target_size).float()

@property
def requires_grad(self) -> bool:
Expand Down
Loading