Skip to content

Commit

Permalink
Add new auc function - remove sklearn dependancy (#102)
Browse files Browse the repository at this point in the history
* new_auc function

* remove sklearn from dependancies
  • Loading branch information
egillax authored Nov 8, 2023
1 parent 61bd834 commit 5a67020
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 25 deletions.
1 change: 0 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ Config/reticulate:
list(package = "polars"),
list(package = "tqdm"),
list(package = "connectorx"),
list(package = "scikit-learn"),
list(package = "pyarrow")
)
)
50 changes: 26 additions & 24 deletions inst/python/Estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch.utils.data import DataLoader, BatchSampler, RandomSampler, SequentialSampler
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import roc_auc_score


class Estimator:
Expand Down Expand Up @@ -155,8 +154,7 @@ def score(self, dataloader):
mean_loss = loss.mean().item()
predictions = torch.concat(predictions)
targets = torch.concat(targets)
auc = roc_auc_score(targets.cpu(), predictions.cpu())
# auc = compute_auc(predictions, targets)
auc = compute_auc(targets.cpu(), predictions.cpu())
scores = dict()
if self.metric:
if self.metric["name"] == "auc":
Expand Down Expand Up @@ -330,26 +328,30 @@ def batch_to_device(batch, device='cpu'):
return batch


def compute_auc(input, target, n_threshold=1000):
threshold = torch.linspace(0, 1.0, n_threshold).to(device=input.device)
pred_label = input >= threshold[:, None, None]
input_target = pred_label * target
def compute_auc(y_true, y_pred):
"""
Computes the AUC score for binary classification predictions with a fast algorithm.
Args:
y_true (torch.Tensor): True binary labels.
y_pred (torch.Tensor): Predicted scores.
Returns:
float: Computed AUC score.
"""
# Ensure inputs are sorted by predicted score
y_pred_sorted, sorted_indices = torch.sort(y_pred, descending=True)
y_true_sorted = y_true[sorted_indices]

cum_tp = F.pad(input_target.sum(dim=-1).rot90(1, [1, 0]), (1, 0), value=0.0)
cum_fp = F.pad(
(pred_label.sum(dim=-1) - input_target.sum(dim=-1)).rot90(1, [1, 0]),
(1, 0),
value=0.0,
)
# Get the number of positive and negative examples
n_pos = y_true_sorted.sum()
n_neg = (1 - y_true_sorted).sum()

if len(cum_tp.shape) > 1:
factor = cum_tp[:, -1] * cum_fp[:, -1]
else:
factor = cum_tp[-1] * cum_fp[-1]
# Set AUROC to 0.5 when the target contains all ones or all zeros.
auroc = torch.where(
factor == 0,
0.5,
torch.trapz(cum_tp, cum_fp).double() / factor,
)
return auroc.item()
# for every negative label, count preceding positive labels in sorted labels
num_crossings = torch.cumsum(y_true_sorted, 0)[y_true_sorted == 0].sum()

# Compute AUC
auc = num_crossings / (n_pos * n_neg)
return auc




0 comments on commit 5a67020

Please sign in to comment.