diff --git a/DESCRIPTION b/DESCRIPTION index d4c1a53..1ce2e2f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -51,7 +51,6 @@ Config/reticulate: list(package = "polars"), list(package = "tqdm"), list(package = "connectorx"), - list(package = "scikit-learn"), list(package = "pyarrow") ) ) diff --git a/inst/python/Estimator.py b/inst/python/Estimator.py index f8a182f..cd534e2 100644 --- a/inst/python/Estimator.py +++ b/inst/python/Estimator.py @@ -5,7 +5,6 @@ from torch.utils.data import DataLoader, BatchSampler, RandomSampler, SequentialSampler import torch.nn.functional as F from tqdm import tqdm -from sklearn.metrics import roc_auc_score class Estimator: @@ -155,8 +154,7 @@ def score(self, dataloader): mean_loss = loss.mean().item() predictions = torch.concat(predictions) targets = torch.concat(targets) - auc = roc_auc_score(targets.cpu(), predictions.cpu()) - # auc = compute_auc(predictions, targets) + auc = compute_auc(targets.cpu(), predictions.cpu()) scores = dict() if self.metric: if self.metric["name"] == "auc": @@ -330,26 +328,30 @@ def batch_to_device(batch, device='cpu'): return batch -def compute_auc(input, target, n_threshold=1000): - threshold = torch.linspace(0, 1.0, n_threshold).to(device=input.device) - pred_label = input >= threshold[:, None, None] - input_target = pred_label * target +def compute_auc(y_true, y_pred): + """ + Computes the AUC score for binary classification predictions with a fast algorithm. + Args: + y_true (torch.Tensor): True binary labels. + y_pred (torch.Tensor): Predicted scores. + Returns: + float: Computed AUC score. + """ + # Ensure inputs are sorted by predicted score + y_pred_sorted, sorted_indices = torch.sort(y_pred, descending=True) + y_true_sorted = y_true[sorted_indices] - cum_tp = F.pad(input_target.sum(dim=-1).rot90(1, [1, 0]), (1, 0), value=0.0) - cum_fp = F.pad( - (pred_label.sum(dim=-1) - input_target.sum(dim=-1)).rot90(1, [1, 0]), - (1, 0), - value=0.0, - ) + # Get the number of positive and negative examples + n_pos = y_true_sorted.sum() + n_neg = (1 - y_true_sorted).sum() - if len(cum_tp.shape) > 1: - factor = cum_tp[:, -1] * cum_fp[:, -1] - else: - factor = cum_tp[-1] * cum_fp[-1] - # Set AUROC to 0.5 when the target contains all ones or all zeros. - auroc = torch.where( - factor == 0, - 0.5, - torch.trapz(cum_tp, cum_fp).double() / factor, - ) - return auroc.item() + # for every negative label, count preceding positive labels in sorted labels + num_crossings = torch.cumsum(y_true_sorted, 0)[y_true_sorted == 0].sum() + + # Compute AUC + auc = num_crossings / (n_pos * n_neg) + return auc + + + +