From 3ae31f02a1da441a28ee54478a1f12f94dc17509 Mon Sep 17 00:00:00 2001 From: Thomas M Kehrenberg Date: Sat, 24 Feb 2024 16:11:41 +0100 Subject: [PATCH] Format code --- src/algs/fs/dro.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/algs/fs/dro.py b/src/algs/fs/dro.py index 35d2ae2a..05bcd4da 100644 --- a/src/algs/fs/dro.py +++ b/src/algs/fs/dro.py @@ -29,7 +29,7 @@ def __init__( reduction = str_to_enum(str_=reduction, enum=ReductionType) self.reduction = reduction if loss_fn is None: - cross_entropy: Loss = CrossEntropyLoss(reduction=ReductionType.none) # type: ignore + cross_entropy: Loss = CrossEntropyLoss(reduction=ReductionType.none) # type: ignore loss_fn = cross_entropy else: loss_fn.reduction = ReductionType.none @@ -38,7 +38,7 @@ def __init__( self.eta = eta @override - def forward(self, input: Tensor, *, target: Tensor) -> Tensor: # type: ignore + def forward(self, input: Tensor, *, target: Tensor) -> Tensor: sample_losses = (self.loss_fn(input, target=target) - self.eta).relu().pow(2) return reduce(sample_losses, reduction_type=self.reduction)