diff --git a/keras/src/losses/loss.py b/keras/src/losses/loss.py index 227c43b2128..e0197a6bb8a 100644 --- a/keras/src/losses/loss.py +++ b/keras/src/losses/loss.py @@ -123,7 +123,7 @@ def squeeze_or_expand_to_same_rank(x1, x2, expand_rank_1=True): return x1, x2 -def reduce_values(values, reduction="sum_over_batch_size"): +def reduce_values(values, reduction="sum_over_batch_size", sample_weight=None): if ( reduction is None or reduction == "none" @@ -132,7 +132,12 @@ def reduce_values(values, reduction="sum_over_batch_size"): ): return values loss = ops.sum(values) - if reduction == "sum_over_batch_size": + if sample_weight is not None and reduction == "sum_over_batch_size": + loss /= ops.cast( + ops.sum(sample_weight), + loss.dtype, + ) + elif reduction == "sum_over_batch_size": loss /= ops.cast( ops.prod(ops.convert_to_tensor(ops.shape(values), dtype="int32")), loss.dtype, @@ -169,7 +174,7 @@ def reduce_weighted_values( values = values * sample_weight # Apply reduction function to the individual weighted losses. - loss = reduce_values(values, reduction) + loss = reduce_values(values, reduction, sample_weight) return loss