Skip to content

Commit 549bc7e

Browse files
committed
update names
1 parent 7732364 commit 549bc7e

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

pina/loss/ntk_weighting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, update_every_n_epochs=1, alpha=0.5):
3939

4040
# Initialize parameters
4141
self.alpha = alpha
42-
self.history = {}
42+
self.weights = {}
4343

4444
def weights_update(self, losses):
4545
"""
@@ -62,7 +62,7 @@ def weights_update(self, losses):
6262

6363
# Update the weights
6464
return {
65-
condition: self.alpha * self.history.get(condition, 1)
65+
condition: self.alpha * self.weights.get(condition, 1)
6666
+ (1 - self.alpha)
6767
* losses_norm[condition]
6868
/ sum(losses_norm.values())

pina/loss/weighting_interface.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, update_every_n_epochs=1, aggregator="sum"):
4848
self._solver = None
4949
self.update_every_n_epochs = update_every_n_epochs
5050
self.aggregator_fn = aggregator
51-
self.weights = {}
51+
self._saved_weights = {}
5252

5353
@abstractmethod
5454
def weights_update(self, losses):
@@ -82,11 +82,13 @@ def aggregate(self, losses):
8282
"""
8383
# Update weights
8484
if self.solver.trainer.current_epoch % self.update_every_n_epochs == 0:
85-
self.weights = self.weights_update(losses)
85+
self._saved_weights = self.weights_update(losses)
8686

87-
# Aggregate
87+
# Aggregate (without the .get() we force error
88+
# if condition is not present in the returned dict)
8889
return self.aggregator_fn(
89-
self.weights[condition] * loss for condition, loss in losses.items()
90+
self._saved_weights[condition] * loss
91+
for condition, loss in losses.items()
9092
)
9193

9294
@property

0 commit comments

Comments
 (0)