Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dimensionality issue when binary classification outputs 1D instead of 2D tensor #609

Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ def test_collect_embeddings(dummy_system_config: ModynConfig):

first_embedding = torch.randn((4, 5))
second_embedding = torch.randn((3, 5))
amds.inform_samples([1, 2, 3, 4], None, None, None, first_embedding)
amds.inform_samples([21, 31, 41], None, None, None, second_embedding)
amds.inform_samples([1, 2, 3, 4], None, torch.randn((4, 2)), None, first_embedding)
amds.inform_samples([21, 31, 41], None, torch.randn((3, 2)), None, second_embedding)
MaxiBoether marked this conversation as resolved.
Show resolved Hide resolved

assert np.concatenate(amds.matrix_elements).shape == (7, 5)
assert all(torch.equal(el1, el2) for el1, el2 in zip(amds.matrix_elements, [first_embedding, second_embedding]))
assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41]

third_embedding = torch.randn((23, 5))
amds.inform_samples(list(range(1000, 1023)), None, None, None, third_embedding)
amds.inform_samples(list(range(1000, 1023)), None, torch.randn((23, 2)), None, third_embedding)

assert np.concatenate(amds.matrix_elements).shape == (30, 5)
assert all(
Expand All @@ -88,8 +88,8 @@ def test_collect_embedding_balance(test_amds, dummy_system_config: ModynConfig):

first_embedding = torch.randn((4, 5))
second_embedding = torch.randn((3, 5))
amds.inform_samples([1, 2, 3, 4], None, None, None, first_embedding)
amds.inform_samples([21, 31, 41], None, None, None, second_embedding)
amds.inform_samples([1, 2, 3, 4], None, torch.randn((4, 2)), None, first_embedding)
amds.inform_samples([21, 31, 41], None, torch.randn((3, 2)), None, second_embedding)

assert np.concatenate(amds.matrix_elements).shape == (7, 5)
assert all(torch.equal(el1, el2) for el1, el2 in zip(amds.matrix_elements, [first_embedding, second_embedding]))
Expand All @@ -99,7 +99,7 @@ def test_collect_embedding_balance(test_amds, dummy_system_config: ModynConfig):

third_embedding = torch.randn((23, 5))
assert len(amds.matrix_elements) == 0
amds.inform_samples(list(range(1000, 1023)), None, None, None, third_embedding)
amds.inform_samples(list(range(1000, 1023)), None, torch.randn((23, 2)), None, third_embedding)

assert np.concatenate(amds.matrix_elements).shape == (23, 5)
assert all(torch.equal(el1, el2) for el1, el2 in zip(amds.matrix_elements, [third_embedding]))
Expand Down Expand Up @@ -142,3 +142,42 @@ def test_collect_gradients(matrix_content, dummy_system_config: ModynConfig):
assert np.concatenate(amds.matrix_elements).shape == (7, gradient_shape)

assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41]


@pytest.mark.parametrize(
"matrix_content", [MatrixContent.LAST_LAYER_GRADIENTS, MatrixContent.LAST_TWO_LAYERS_GRADIENTS]
)
@patch.multiple(AbstractMatrixDownsamplingStrategy, __abstractmethods__=set())
def test_collect_gradients_binary(matrix_content, dummy_system_config: ModynConfig):
per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none")
sampler_config = list(get_sampler_config(dummy_system_config, matrix_content=matrix_content))
sampler_config[5] = per_sample_loss_fct
sampler_config = tuple(sampler_config)
amds = AbstractMatrixDownsamplingStrategy(*sampler_config)
with torch.inference_mode(mode=(not amds.requires_grad)):
forward_input = torch.randn((4, 5))
first_output = torch.randn((4,))
first_output.requires_grad = True
first_target = torch.tensor([1, 1, 1, 0], dtype=torch.float32)
first_embedding = torch.randn((4, 5))
amds.inform_samples([1, 2, 3, 4], forward_input, first_output, first_target, first_embedding)

second_output = torch.randn((3,))
second_output.requires_grad = True
second_target = torch.tensor([0, 1, 0], dtype=torch.float32)
second_embedding = torch.randn((3, 5))
amds.inform_samples([21, 31, 41], forward_input, second_output, second_target, second_embedding)

assert len(amds.matrix_elements) == 2

# expected shape = (a, gradient_shape)
# a = 7 (4 samples in the first batch and 3 samples in the second batch)
if matrix_content == MatrixContent.LAST_LAYER_GRADIENTS:
# shape same as the last dimension of output
gradient_shape = 1
else:
# 5 is the input dimension of the last layer and 2 is the output one
gradient_shape = 5 * 1 + 1
assert np.concatenate(amds.matrix_elements).shape == (7, gradient_shape)

assert amds.index_sampleid_map == [1, 2, 3, 4, 21, 31, 41]
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,39 @@ def test_bts(grad_approx: str, dummy_system_config: ModynConfig):
assert all(id in [1, 2, 3, 10, 11, 12, 13] for id in selected_points)


@pytest.mark.parametrize("grad_approx", ["LastLayer", "LastLayerWithEmbedding"])
def test_bts_binary(grad_approx: str, dummy_system_config: ModynConfig):
sampler_config = get_sampler_config(dummy_system_config, grad_approx=grad_approx)
per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none")
sampler_config = (0, 0, 0, sampler_config[3], sampler_config[4], per_sample_loss_fct, "cpu")
sampler = RemoteCraigDownsamplingStrategy(*sampler_config)

with torch.inference_mode(mode=(not sampler.requires_grad)):
sample_ids = [1, 2, 3, 10, 11, 12, 13]
forward_input = torch.randn(7, 5) # 7 samples, 5 input features
forward_output = torch.randn(
7,
)
forward_output.requires_grad = True
target = torch.tensor([1, 1, 1, 0, 0, 0, 1], dtype=torch.float32) # 7 target labels
embedding = torch.randn(7, 10) # 7 samples, embedding dimension 10

assert sampler.distance_matrix.shape == (0, 0)
sampler.inform_samples(sample_ids, forward_input, forward_output, target, embedding)
sampler.inform_end_of_current_label()
assert sampler.distance_matrix.shape == (7, 7)
assert len(sampler.current_class_gradients) == 0

assert sampler.index_sampleid_map == [10, 11, 12, 1, 2, 3, 13]

selected_points, selected_weights = sampler.select_points()

assert len(selected_points) == 3
assert len(selected_weights) == 3
assert all(weight > 0 for weight in selected_weights)
assert all(id in [1, 2, 3, 10, 11, 12, 13] for id in selected_points)


@pytest.mark.parametrize("grad_approx", ["LastLayerWithEmbedding", "LastLayer"])
def test_bts_equals_stb(grad_approx: str, dummy_system_config: ModynConfig):
# data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_sample_shape_ce(dummy_system_config: ModynConfig):
assert set(downsampled_indexes) <= set(range(8))


def test_sample_shape_other_losses(dummy_system_config: ModynConfig):
def test_sample_shape_binary(dummy_system_config: ModynConfig):
model = torch.nn.Linear(10, 1)
downsampling_ratio = 50
per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none")
Expand All @@ -51,11 +51,10 @@ def test_sample_shape_other_losses(dummy_system_config: ModynConfig):
)
with torch.inference_mode(mode=(not sampler.requires_grad)):
data = torch.randn(8, 10)
target = torch.randint(2, size=(8,), dtype=torch.float32).unsqueeze(1)
forward_outputs = model(data).squeeze(1)
target = torch.randint(2, size=(8,), dtype=torch.float32)
ids = list(range(8))

forward_outputs = model(data)

sampler.inform_samples(ids, data, forward_outputs, target)
downsampled_indexes, weights = sampler.select_points()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,36 @@ def test_sample_shape(dummy_system_config: ModynConfig):
assert len(indexes) == 4


def test_sample_shape_binary(dummy_system_config: ModynConfig):
model = torch.nn.Linear(10, 1)
downsampling_ratio = 50
per_sample_loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none")

params_from_selector = {"downsampling_ratio": downsampling_ratio, "sample_then_batch": False, "ratio_max": 100}
sampler = RemoteLossDownsampling(
0, 0, 0, params_from_selector, dummy_system_config.model_dump(by_alias=True), per_sample_loss_fct, "cpu"
)
with torch.inference_mode(mode=(not sampler.requires_grad)):
data = torch.randn(8, 10)
forward_outputs = model(data).squeeze(1)
target = torch.randint(2, size=(8,), dtype=torch.float32)
ids = list(range(8))

sampler.inform_samples(ids, data, forward_outputs, target)
downsampled_indexes, weights = sampler.select_points()

assert len(downsampled_indexes) == 4
assert weights.shape[0] == 4

sampled_data, sampled_target = get_tensors_subset(downsampled_indexes, data, target, ids)

assert weights.shape[0] == sampled_target.shape[0]
assert sampled_data.shape[0] == 4
assert sampled_data.shape[1] == data.shape[1]
assert weights.shape[0] == 4
assert sampled_target.shape[0] == 4


def test_sample_weights(dummy_system_config: ModynConfig):
model = torch.nn.Linear(10, 2)
downsampling_ratio = 50
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,31 @@ def test_compute_score(sampler_config):
assert np.allclose(scores, expected_scores, atol=1e-4)


binary_test_data = {
"LeastConfidence": {
"outputs": torch.tensor([-0.8, 0.5, 0.3]),
"expected_scores": np.array([0.8, 0.5, 0.3]), # confidence just picks the highest probability
},
"Entropy": {
"outputs": torch.tensor([0.8, 0.5, 0.3]),
"expected_scores": np.array([-0.5004, -0.6931, -0.6109]),
},
"Margin": {
"outputs": torch.tensor([0.8, 0.5, 0.3]),
"expected_scores": np.array([0.6, 0.0, 0.4]), # margin between top two classes
},
}


def test_compute_score_binary(sampler_config):
metric = sampler_config[3]["score_metric"]
amds = RemoteUncertaintyDownsamplingStrategy(*sampler_config)
outputs = binary_test_data[metric]["outputs"]
expected_scores = binary_test_data[metric]["expected_scores"]
scores = amds._compute_score(outputs, disable_softmax=True)
assert np.allclose(scores, expected_scores, atol=1e-4)


def test_select_points(balance_config):
amds = RemoteUncertaintyDownsamplingStrategy(*balance_config)
with torch.inference_mode():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def inform_samples(
) -> None:
batch_size = len(sample_ids)
assert self.matrix_content is not None
if forward_output.dim() == 1:
forward_output = forward_output.unsqueeze(1)
target = target.unsqueeze(1)
MaxiBoether marked this conversation as resolved.
Show resolved Hide resolved
if self.matrix_content == MatrixContent.LAST_LAYER_GRADIENTS:
grads_wrt_loss_sum = self._compute_last_layer_gradient_wrt_loss_sum(self.criterion, forward_output, target)
grads_wrt_loss_mean = grads_wrt_loss_sum / batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def _inform_samples_single_class(
target: torch.Tensor,
embedding: torch.Tensor | None,
) -> None:
if forward_output.dim() == 1:
forward_output = forward_output.unsqueeze(1)
target = target.unsqueeze(1)
if self.full_grad_approximation == "LastLayerWithEmbedding":
assert embedding is not None
grads_wrt_loss_sum = self._compute_last_two_layers_gradient_wrt_loss_sum(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,16 @@ def inform_samples(
target: torch.Tensor,
embedding: torch.Tensor | None = None,
) -> None:
if forward_output.dim() == 1:
# BCEWithLogitsLoss requires that forward_output and target have the same shape
forward_output = forward_output.unsqueeze(1)
target = target.unsqueeze(1)

last_layer_gradients = self._compute_last_layer_gradient_wrt_loss_sum(
self.per_sample_loss_fct, forward_output, target
)
scores = torch.norm(last_layer_gradients, dim=-1).cpu()
# pylint: disable=not-callable
scores = torch.linalg.vector_norm(last_layer_gradients, dim=1).cpu()
MaxiBoether marked this conversation as resolved.
Show resolved Hide resolved
self.probabilities.append(scores)
self.number_of_points_seen += forward_output.shape[0]
self.index_sampleid_map += sample_ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,42 @@ def inform_samples(
self.index_sampleid_map += sample_ids

def _compute_score(self, forward_output: torch.Tensor, disable_softmax: bool = False) -> np.ndarray:
if forward_output.dim() == 1:
forward_output = forward_output.unsqueeze(1)
feature_size = forward_output.size(1)

if self.score_metric == "LeastConfidence":
scores = forward_output.max(dim=1).values.cpu().numpy()
elif self.score_metric == "Entropy":
preds = (
torch.nn.functional.softmax(forward_output, dim=1).cpu().numpy()
if not disable_softmax
else forward_output.cpu().numpy()
)
scores = (np.log(preds + 1e-6) * preds).sum(axis=1)
elif self.score_metric == "Margin":
preds = torch.nn.functional.softmax(forward_output, dim=1) if not disable_softmax else forward_output
preds_argmax = torch.argmax(preds, dim=1) # gets top class
max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone() # gets scores of top class

# remove highest class from softmax output
preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0

preds_sub_argmax = torch.argmax(preds, dim=1) # gets new top class (=> 2nd top class)
second_max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax]
scores = (max_preds - second_max_preds).cpu().numpy()
if feature_size == 1:
# for binary classification comparing how far away the element is from 0.5 after sigmoid layer
# is the same as comparing the absolute value of the element before sigmoid layer
MaxiBoether marked this conversation as resolved.
Show resolved Hide resolved
scores = torch.abs(forward_output).squeeze(1).cpu().numpy()
else:
scores = forward_output.max(dim=1).values.cpu().numpy()
else:
raise AssertionError("The required metric does not exist")
if feature_size == 1:
# for binary classification the softmax layer is reduced to sigmoid
preds = torch.sigmoid(forward_output) if not disable_softmax else forward_output
# we need to convert it to a 2D tensor with probabilities for both classes
preds = torch.cat((1 - preds, preds), dim=1)
else:
preds = torch.nn.functional.softmax(forward_output, dim=1) if not disable_softmax else forward_output

if self.score_metric == "Entropy":
scores = (np.log(preds + 1e-6) * preds).sum(axis=1)
elif self.score_metric == "Margin":
preds_argmax = torch.argmax(preds, dim=1) # gets top class
max_preds = preds[
torch.ones(preds.shape[0], dtype=bool), preds_argmax
].clone() # gets scores of top class

# remove highest class from softmax output
preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0

preds_sub_argmax = torch.argmax(preds, dim=1) # gets new top class (=> 2nd top class)
second_max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax]
scores = (max_preds - second_max_preds).cpu().numpy()
else:
raise AssertionError("The required metric does not exist")

return scores

Expand Down Expand Up @@ -139,7 +153,7 @@ def _select_indexes_from_scores(self, target_size: int) -> tuple[list[int], torc
# we select those with minimal negative entropy, i.e., maximum entropy
# Margin: We look for the smallest margin. The larger the margin, the more certain the
# model is.
return np.argsort(self.scores)[:target_size], torch.ones(target_size).float()
MaxiBoether marked this conversation as resolved.
Show resolved Hide resolved
return np.argsort(self.scores)[:target_size].tolist(), torch.ones(target_size).float()

@property
def requires_grad(self) -> bool:
Expand Down
Loading