1
1
"""Module for the Weighting Interface."""
2
2
3
3
from abc import ABCMeta , abstractmethod
4
+ from typing import final
5
+ from ..utils import check_positive_integer , is_function
6
+
7
+ _AGGREGATE_METHODS = {"sum" : sum , "mean" : lambda x : sum (x ) / len (x )}
4
8
5
9
6
10
class WeightingInterface (metaclass = ABCMeta ):
@@ -9,19 +13,86 @@ class WeightingInterface(metaclass=ABCMeta):
9
13
should inherit from this class.
10
14
"""
11
15
12
- def __init__ (self ):
16
+ def __init__ (
17
+ self ,
18
+ update_every_n_epochs = 1 ,
19
+ aggregator = "sum" ,
20
+ ):
13
21
"""
14
22
Initialization of the :class:`WeightingInterface` class.
23
+
24
+ :param int update_every_n_epochs: The number of training epochs between
25
+ weight updates. If set to 1, the weights are updated at every epoch.
26
+ This parameter is ignored by static weighting schemes. Default is 1.
27
+ :param aggregator: Aggregation method. Either:
28
+ - 'sum' → torch.sum
29
+ - 'mean' → torch.mean
30
+ - callable → custom aggregation function
15
31
"""
32
+ # Check consistency
33
+ check_positive_integer (value = update_every_n_epochs , strict = True )
34
+
35
+ if isinstance (aggregator , str ):
36
+ if aggregator not in _AGGREGATE_METHODS :
37
+ raise ValueError (
38
+ f"Invalid aggregator '{ aggregator } '. Must be one of "
39
+ f"{ list (_AGGREGATE_METHODS .keys ())} ."
40
+ )
41
+ aggregator = _AGGREGATE_METHODS [aggregator ]
42
+
43
+ elif not is_function (aggregator ):
44
+ raise TypeError (
45
+ f"Aggregator must be either a string or a callable, "
46
+ f"got { type (aggregator ).__name__ } ."
47
+ )
48
+
49
+ # Initialization
16
50
self ._solver = None
51
+ self .update_every_n_epochs = update_every_n_epochs
52
+ self .aggregator_fn = aggregator
53
+ self ._weights = {}
17
54
18
55
@abstractmethod
56
+ def weights_update (self , losses ):
57
+ """
58
+ Update the weighting scheme based on the given losses.
59
+
60
+ This method must be implemented by subclasses. Its role is to update
61
+ the values in ``self.weights`` (a mapping from loss names to their
62
+ corresponding weights). The updated weights will then be used by
63
+ :meth:`aggregate` to compute the final aggregated loss.
64
+
65
+ :param losses: Dictionary mapping loss condition names to loss tensors.
66
+ :type losses: dict[str, torch.Tensor]
67
+ :return: Dictionary mapping loss names to their updated weight values.
68
+ :rtype: dict[str, float]
69
+ """
70
+
71
+ @final
19
72
def aggregate (self , losses ):
20
73
"""
21
- Aggregate the losses.
74
+ Update the weights (if required) and aggregate the given losses.
75
+
76
+ This method first checks whether the weights should be updated based on
77
+ the current epoch and the ``update_every_n_epochs`` setting. If so, it
78
+ calls :meth:`weights_update` to refresh the weights. Afterwards, it
79
+ aggregates the (weighted) losses into a single scalar tensor using the
80
+ configured aggregator function. This method must not be override.
22
81
23
- :param dict losses: The dictionary of losses.
82
+ :param losses: Dictionary mapping loss names to loss tensors.
83
+ :type losses: dict[str, torch.Tensor]
84
+ :return: The aggregated loss tensor.
85
+ :rtype: torch.Tensor
24
86
"""
87
+ # update weights
88
+ if self .solver .trainer .current_epoch % self .update_every_n_epochs == 0 :
89
+ self ._weights = self .weights_update (losses )
90
+
91
+ # aggregate
92
+ return self .aggregator_fn (
93
+ self ._weights [condition ] * loss
94
+ for condition , loss in losses .items ()
95
+ )
25
96
26
97
@property
27
98
def solver (self ):
0 commit comments