Skip to content

Commit

Permalink
Format code
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Feb 24, 2024
1 parent 33774ca commit 3ae31f0
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/algs/fs/dro.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
reduction = str_to_enum(str_=reduction, enum=ReductionType)
self.reduction = reduction
if loss_fn is None:
cross_entropy: Loss = CrossEntropyLoss(reduction=ReductionType.none) # type: ignore
cross_entropy: Loss = CrossEntropyLoss(reduction=ReductionType.none) # type: ignore
loss_fn = cross_entropy
else:
loss_fn.reduction = ReductionType.none
Expand All @@ -38,7 +38,7 @@ def __init__(
self.eta = eta

@override
def forward(self, input: Tensor, *, target: Tensor) -> Tensor: # type: ignore
def forward(self, input: Tensor, *, target: Tensor) -> Tensor:
sample_losses = (self.loss_fn(input, target=target) - self.eta).relu().pow(2)
return reduce(sample_losses, reduction_type=self.reduction)

Expand Down

0 comments on commit 3ae31f0

Please sign in to comment.