Skip to content

Commit ab22331

Browse files
weighting refactory
Co-authored-by: Dario Coscia <[email protected]>
1 parent ca8f370 commit ab22331

File tree

11 files changed

+219
-138
lines changed

11 files changed

+219
-138
lines changed

pina/loss/ntk_weighting.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from .weighting_interface import WeightingInterface
5-
from ..utils import check_consistency
5+
from ..utils import check_consistency, in_range
66

77

88
class NeuralTangentKernelWeighting(WeightingInterface):
@@ -20,32 +20,31 @@ class NeuralTangentKernelWeighting(WeightingInterface):
2020
2121
"""
2222

23-
def __init__(self, alpha=0.5):
23+
def __init__(self, update_every_n_epochs=1, alpha=0.5):
2424
"""
2525
Initialization of the :class:`NeuralTangentKernelWeighting` class.
2626
27+
:param int update_every_n_epochs: The number of training epochs between
28+
weight updates. If set to 1, the weights are updated at every epoch.
29+
Default is 1.
2730
:param float alpha: The alpha parameter.
2831
:raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive).
2932
"""
30-
super().__init__()
33+
super().__init__(update_every_n_epochs=update_every_n_epochs)
3134

3235
# Check consistency
3336
check_consistency(alpha, float)
34-
if alpha < 0 or alpha > 1:
35-
raise ValueError("alpha should be a value between 0 and 1")
36-
37+
if not in_range(alpha, [0, 1], strict=False):
38+
raise ValueError("alpha must be in range (0, 1).")
3739
# Initialize parameters
3840
self.alpha = alpha
39-
self.weights = {}
40-
self.default_value_weights = 1.0
41+
self.history = {}
4142

42-
def aggregate(self, losses):
43+
def weights_update(self, losses):
4344
"""
44-
Weight the losses according to the Neural Tangent Kernel algorithm.
45+
Update the weights of the losses.
4546
4647
:param dict(torch.Tensor) input: The dictionary of losses.
47-
:return: The aggregation of the losses. It should be a scalar Tensor.
48-
:rtype: torch.Tensor
4948
"""
5049
# Define a dictionary to store the norms of the gradients
5150
losses_norm = {}
@@ -59,15 +58,10 @@ def aggregate(self, losses):
5958
losses_norm[condition] = grads.norm()
6059

6160
# Update the weights
62-
self.weights = {
63-
condition: self.alpha
64-
* self.weights.get(condition, self.default_value_weights)
61+
return {
62+
condition: self.alpha * self.history.get(condition, 1)
6563
+ (1 - self.alpha)
6664
* losses_norm[condition]
6765
/ sum(losses_norm.values())
6866
for condition in losses
6967
}
70-
71-
return sum(
72-
self.weights[condition] * loss for condition, loss in losses.items()
73-
)

pina/loss/scalar_weighting.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,6 @@
44
from ..utils import check_consistency
55

66

7-
class _NoWeighting(WeightingInterface):
8-
"""
9-
Weighting scheme that does not apply any weighting to the losses.
10-
"""
11-
12-
def aggregate(self, losses):
13-
"""
14-
Aggregate the losses.
15-
16-
:param dict losses: The dictionary of losses.
17-
:return: The aggregated losses.
18-
:rtype: torch.Tensor
19-
"""
20-
return sum(losses.values())
21-
22-
237
class ScalarWeighting(WeightingInterface):
248
"""
259
Weighting scheme that assigns a scalar weight to each loss term.
@@ -36,28 +20,44 @@ def __init__(self, weights):
3620
dictionary, the default value is used.
3721
:type weights: float | int | dict
3822
"""
39-
super().__init__()
4023

41-
# Check consistency
24+
super().__init__(update_every_n_epochs=1, aggregator="sum")
25+
4226
check_consistency([weights], (float, dict, int))
4327

44-
# Weights initialization
45-
if isinstance(weights, (float, int)):
28+
if isinstance(weights, dict):
29+
self.values = weights
30+
self.default_value_weights = 1
31+
elif isinstance(weights, (float, int)):
32+
self.values = {}
4633
self.default_value_weights = weights
47-
self.weights = {}
4834
else:
49-
self.default_value_weights = 1.0
50-
self.weights = weights
35+
raise ValueError
5136

52-
def aggregate(self, losses):
37+
def weights_update(self, losses):
5338
"""
54-
Aggregate the losses.
39+
Update the weighting scheme based on the given losses.
40+
41+
This method must be implemented by subclasses. Its role is to update
42+
the values in ``self.weights`` (a mapping from loss names to their
43+
corresponding weights). The updated weights will then be used by
44+
:meth:`aggregate` to compute the final aggregated loss.
5545
56-
:param dict losses: The dictionary of losses.
57-
:return: The aggregated losses.
58-
:rtype: torch.Tensor
46+
:param losses: Dictionary mapping loss condition names to loss tensors.
47+
:type losses: dict[str, torch.Tensor]
48+
:return: Dictionary mapping loss names to their updated weight values.
49+
:rtype: dict[str, float]
5950
"""
60-
return sum(
61-
self.weights.get(condition, self.default_value_weights) * loss
62-
for condition, loss in losses.items()
63-
)
51+
return {
52+
condition: self.values.get(condition, self.default_value_weights)
53+
for condition in losses.keys()
54+
}
55+
56+
57+
class _NoWeighting(ScalarWeighting):
58+
"""
59+
Weighting scheme that does not apply any weighting to the losses.
60+
"""
61+
62+
def __init__(self):
63+
super().__init__(weights=1)

pina/loss/self_adaptive_weighting.py

Lines changed: 24 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import torch
44
from .weighting_interface import WeightingInterface
5-
from ..utils import check_positive_integer
65

76

87
class SelfAdaptiveWeighting(WeightingInterface):
@@ -22,59 +21,35 @@ class SelfAdaptiveWeighting(WeightingInterface):
2221
2322
"""
2423

25-
def __init__(self, k=100):
24+
def __init__(self, update_every_n_epochs=1):
2625
"""
2726
Initialization of the :class:`SelfAdaptiveWeighting` class.
2827
29-
:param int k: The number of epochs after which the weights are updated.
30-
Default is 100.
31-
32-
:raises ValueError: If ``k`` is not a positive integer.
28+
:param int update_every_n_epochs: The number of training epochs between
29+
weight updates. If set to 1, the weights are updated at every epoch.
30+
Default is 1.
3331
"""
34-
super().__init__()
35-
36-
# Check consistency
37-
check_positive_integer(value=k, strict=True)
32+
super().__init__(update_every_n_epochs=update_every_n_epochs)
3833

39-
# Initialize parameters
40-
self.k = k
41-
self.weights = {}
42-
self.default_value_weights = 1.0
43-
44-
def aggregate(self, losses):
34+
def weights_update(self, losses):
4535
"""
46-
Weight the losses according to the self-adaptive algorithm.
36+
Update the weights of the losses.
4737
48-
:param dict(torch.Tensor) losses: The dictionary of losses.
49-
:return: The aggregation of the losses. It should be a scalar Tensor.
50-
:rtype: torch.Tensor
38+
:param dict(torch.Tensor) input: The dictionary of losses.
5139
"""
52-
# If weights have not been initialized, set them to 1
53-
if not self.weights:
54-
self.weights = {
55-
condition: self.default_value_weights for condition in losses
56-
}
57-
58-
# Update every k epochs
59-
if self.solver.trainer.current_epoch % self.k == 0:
60-
61-
# Define a dictionary to store the norms of the gradients
62-
losses_norm = {}
63-
64-
# Compute the gradient norms for each loss component
65-
for condition, loss in losses.items():
66-
loss.backward(retain_graph=True)
67-
grads = torch.cat(
68-
[p.grad.flatten() for p in self.solver.model.parameters()]
69-
)
70-
losses_norm[condition] = grads.norm()
71-
72-
# Update the weights
73-
self.weights = {
74-
condition: sum(losses_norm.values()) / losses_norm[condition]
75-
for condition in losses
76-
}
77-
78-
return sum(
79-
self.weights[condition] * loss for condition, loss in losses.items()
80-
)
40+
# Define a dictionary to store the norms of the gradients
41+
losses_norm = {}
42+
43+
# Compute the gradient norms for each loss component
44+
for condition, loss in losses.items():
45+
loss.backward(retain_graph=True)
46+
grads = torch.cat(
47+
[p.grad.flatten() for p in self.solver.model.parameters()]
48+
)
49+
losses_norm[condition] = grads.norm()
50+
51+
# Update the weights
52+
return {
53+
condition: sum(losses_norm.values()) / losses_norm[condition]
54+
for condition in losses
55+
}

pina/loss/weighting_interface.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Module for the Weighting Interface."""
22

33
from abc import ABCMeta, abstractmethod
4+
from typing import final
5+
from ..utils import check_positive_integer, is_function
6+
7+
_AGGREGATE_METHODS = {"sum": sum, "mean": lambda x: sum(x) / len(x)}
48

59

610
class WeightingInterface(metaclass=ABCMeta):
@@ -9,19 +13,86 @@ class WeightingInterface(metaclass=ABCMeta):
913
should inherit from this class.
1014
"""
1115

12-
def __init__(self):
16+
def __init__(
17+
self,
18+
update_every_n_epochs=1,
19+
aggregator="sum",
20+
):
1321
"""
1422
Initialization of the :class:`WeightingInterface` class.
23+
24+
:param int update_every_n_epochs: The number of training epochs between
25+
weight updates. If set to 1, the weights are updated at every epoch.
26+
This parameter is ignored by static weighting schemes. Default is 1.
27+
:param aggregator: Aggregation method. Either:
28+
- 'sum' → torch.sum
29+
- 'mean' → torch.mean
30+
- callable → custom aggregation function
1531
"""
32+
# Check consistency
33+
check_positive_integer(value=update_every_n_epochs, strict=True)
34+
35+
if isinstance(aggregator, str):
36+
if aggregator not in _AGGREGATE_METHODS:
37+
raise ValueError(
38+
f"Invalid aggregator '{aggregator}'. Must be one of "
39+
f"{list(_AGGREGATE_METHODS.keys())}."
40+
)
41+
aggregator = _AGGREGATE_METHODS[aggregator]
42+
43+
elif not is_function(aggregator):
44+
raise TypeError(
45+
f"Aggregator must be either a string or a callable, "
46+
f"got {type(aggregator).__name__}."
47+
)
48+
49+
# Initialization
1650
self._solver = None
51+
self.update_every_n_epochs = update_every_n_epochs
52+
self.aggregator_fn = aggregator
53+
self._weights = {}
1754

1855
@abstractmethod
56+
def weights_update(self, losses):
57+
"""
58+
Update the weighting scheme based on the given losses.
59+
60+
This method must be implemented by subclasses. Its role is to update
61+
the values in ``self.weights`` (a mapping from loss names to their
62+
corresponding weights). The updated weights will then be used by
63+
:meth:`aggregate` to compute the final aggregated loss.
64+
65+
:param losses: Dictionary mapping loss condition names to loss tensors.
66+
:type losses: dict[str, torch.Tensor]
67+
:return: Dictionary mapping loss names to their updated weight values.
68+
:rtype: dict[str, float]
69+
"""
70+
71+
@final
1972
def aggregate(self, losses):
2073
"""
21-
Aggregate the losses.
74+
Update the weights (if required) and aggregate the given losses.
75+
76+
This method first checks whether the weights should be updated based on
77+
the current epoch and the ``update_every_n_epochs`` setting. If so, it
78+
calls :meth:`weights_update` to refresh the weights. Afterwards, it
79+
aggregates the (weighted) losses into a single scalar tensor using the
80+
configured aggregator function. This method must not be override.
2281
23-
:param dict losses: The dictionary of losses.
82+
:param losses: Dictionary mapping loss names to loss tensors.
83+
:type losses: dict[str, torch.Tensor]
84+
:return: The aggregated loss tensor.
85+
:rtype: torch.Tensor
2486
"""
87+
# update weights
88+
if self.solver.trainer.current_epoch % self.update_every_n_epochs == 0:
89+
self._weights = self.weights_update(losses)
90+
91+
# aggregate
92+
return self.aggregator_fn(
93+
self._weights[condition] * loss
94+
for condition, loss in losses.items()
95+
)
2596

2697
@property
2798
def solver(self):

pina/solver/garom.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def validation_step(self, batch):
274274
condition_loss[condition_name] = self._loss_fn(
275275
snapshots, snapshots_gen
276276
)
277-
loss = self.weighting.aggregate(condition_loss)
277+
loss = self.weighting.update_and_aggregate(condition_loss)
278278
self.store_log("val_loss", loss, self.get_batch_size(batch))
279279
return loss
280280

@@ -297,7 +297,7 @@ def test_step(self, batch):
297297
condition_loss[condition_name] = self._loss_fn(
298298
snapshots, snapshots_gen
299299
)
300-
loss = self.weighting.aggregate(condition_loss)
300+
loss = self.weighting.update_and_aggregate(condition_loss)
301301
self.store_log("test_loss", loss, self.get_batch_size(batch))
302302
return loss
303303

pina/solver/physics_informed_solver/rba_pinn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,9 @@ def _optimization_cycle(self, batch, batch_idx, **kwargs):
266266
self._clamp_params()
267267

268268
# aggregate
269-
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
269+
loss = self.weighting.update_and_aggregate(losses).as_subclass(
270+
torch.Tensor
271+
)
270272

271273
return loss
272274

pina/solver/physics_informed_solver/self_adaptive_pinn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,9 @@ def _optimization_cycle(self, batch, batch_idx, **kwargs):
367367
self._clamp_params()
368368

369369
# Aggregate
370-
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
370+
loss = self.weighting.update_and_aggregate(losses).as_subclass(
371+
torch.Tensor
372+
)
371373

372374
return loss
373375

0 commit comments

Comments
 (0)