Skip to content

Commit efc9e32

Browse files
authored
Merge pull request #615 from mathLab/dev
Dev Updates
2 parents 3778ef7 + 419ac7f commit efc9e32

24 files changed

+1145
-352
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,20 +90,21 @@ Do you want to learn more about it? Look at our [Tutorials](https://github.com/m
9090
### Solve Data Driven Problems
9191
Data driven modelling aims to learn a function that given some input data gives an output (e.g. regression, classification, ...). In PINA you can easily do this by:
9292
```python
93+
import torch
9394
from pina import Trainer
9495
from pina.model import FeedForward
9596
from pina.solver import SupervisedSolver
9697
from pina.problem.zoo import SupervisedProblem
9798

9899
input_tensor = torch.rand((10, 1))
99-
output_tensor = input_tensor.pow(3)
100+
target_tensor = input_tensor.pow(3)
100101

101102
# Step 1. Define problem
102103
problem = SupervisedProblem(input_tensor, target_tensor)
103104
# Step 2. Design model (you can use your favourite torch.nn.Module in here)
104105
model = FeedForward(input_dimensions=1, output_dimensions=1, layers=[64, 64])
105106
# Step 3. Define Solver
106-
solver = SupervisedSolver(problem, model)
107+
solver = SupervisedSolver(problem, model, use_lt=False)
107108
# Step 4. Train
108109
trainer = Trainer(solver, max_epochs=1000, accelerator='gpu')
109110
trainer.train()
@@ -149,6 +150,7 @@ class SimpleODE(SpatialProblem):
149150

150151
# Step 1. Define problem
151152
problem = SimpleODE()
153+
problem.discretise_domain(n=100, mode="grid", domains=["D", "x0"])
152154
# Step 2. Design model (you can use your favourite torch.nn.Module in here)
153155
model = FeedForward(input_dimensions=1, output_dimensions=1, layers=[64, 64])
154156
# Step 3. Define Solver

docs/source/_rst/_code.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ Models
104104
LowRankNeuralOperator <model/low_rank_neural_operator.rst>
105105
GraphNeuralOperator <model/graph_neural_operator.rst>
106106
GraphNeuralKernel <model/graph_neural_operator_integral_kernel.rst>
107+
PirateNet <model/pirate_network.rst>
107108

108109
Blocks
109110
-------------
@@ -121,6 +122,7 @@ Blocks
121122
Continuous Convolution Interface <model/block/convolution_interface.rst>
122123
Continuous Convolution Block <model/block/convolution.rst>
123124
Orthogonal Block <model/block/orthogonal.rst>
125+
PirateNet Block <model/block/pirate_network_block.rst>
124126

125127
Message Passing
126128
-------------------
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
PirateNet Block
2+
=======================================
3+
.. currentmodule:: pina.model.block.pirate_network_block
4+
5+
.. autoclass:: PirateNetBlock
6+
:members:
7+
:show-inheritance:
8+
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
PirateNet
2+
=======================
3+
.. currentmodule:: pina.model.pirate_network
4+
5+
.. autoclass:: PirateNet
6+
:members:
7+
:show-inheritance:

pina/callback/optimizer_callback.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,45 +21,52 @@ def __init__(self, new_optimizers, epoch_switch):
2121
single :class:`torch.optim.Optimizer` instance or a list of them
2222
for multiple model solver.
2323
:type new_optimizers: pina.optim.TorchOptimizer | list
24-
:param epoch_switch: The epoch at which the optimizer switch occurs.
25-
:type epoch_switch: int
24+
:param int epoch_switch: The epoch at which the optimizer switch occurs.
2625
2726
Example:
28-
>>> switch_callback = SwitchOptimizer(new_optimizers=optimizer,
29-
>>> epoch_switch=10)
27+
>>> optimizer = TorchOptimizer(torch.optim.Adam, lr=0.01)
28+
>>> switch_callback = SwitchOptimizer(
29+
>>> new_optimizers=optimizer, epoch_switch=10
30+
>>> )
3031
"""
3132
super().__init__()
3233

34+
# Check if epoch_switch is greater than 1
3335
if epoch_switch < 1:
3436
raise ValueError("epoch_switch must be greater than one.")
3537

38+
# If new_optimizers is not a list, convert it to a list
3639
if not isinstance(new_optimizers, list):
3740
new_optimizers = [new_optimizers]
3841

39-
# check type consistency
42+
# Check consistency
43+
check_consistency(epoch_switch, int)
4044
for optimizer in new_optimizers:
4145
check_consistency(optimizer, TorchOptimizer)
42-
check_consistency(epoch_switch, int)
43-
# save new optimizers
46+
47+
# Store the new optimizers and epoch switch
4448
self._new_optimizers = new_optimizers
4549
self._epoch_switch = epoch_switch
4650

4751
def on_train_epoch_start(self, trainer, __):
4852
"""
4953
Switch the optimizer at the start of the specified training epoch.
5054
51-
:param trainer: The trainer object managing the training process.
52-
:type trainer: pytorch_lightning.Trainer
55+
:param lightning.pytorch.Trainer trainer: The trainer object managing
56+
the training process.
5357
:param _: Placeholder argument (not used).
54-
55-
:return: None
56-
:rtype: None
5758
"""
59+
# Check if the current epoch matches the switch epoch
5860
if trainer.current_epoch == self._epoch_switch:
5961
optims = []
6062

63+
# Hook the new optimizers to the model parameters
6164
for idx, optim in enumerate(self._new_optimizers):
6265
optim.hook(trainer.solver._pina_models[idx].parameters())
6366
optims.append(optim)
6467

68+
# Update the solver's optimizers
6569
trainer.solver._pina_optimizers = optims
70+
71+
# Update the trainer's strategy optimizers
72+
trainer.strategy.optimizers = [o.instance for o in optims]

pina/equation/system_equation.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,51 @@
88

99
class SystemEquation(EquationInterface):
1010
"""
11-
Implementation of the System of Equations. Every ``equation`` passed to a
12-
:class:`~pina.condition.condition.Condition` object must be either a
13-
:class:`~pina.equation.equation.Equation` or a
14-
:class:`~pina.equation.system_equation.SystemEquation` instance.
11+
Implementation of the System of Equations, to be passed to a
12+
:class:`~pina.condition.condition.Condition` object.
13+
14+
Unlike the :class:`~pina.equation.equation.Equation` class, which represents
15+
a single equation, the :class:`SystemEquation` class allows multiple
16+
equations to be grouped together into a system. This is particularly useful
17+
when dealing with multi-component outputs or coupled physical models, where
18+
the residual must be computed collectively across several constraints.
19+
20+
Each equation in the system must be either:
21+
- An instance of :class:`~pina.equation.equation.Equation`;
22+
- A callable function.
23+
24+
The residuals from each equation are computed independently and then
25+
aggregated using an optional reduction strategy (e.g., ``mean``, ``sum``).
26+
The resulting residual is returned as a single :class:`~pina.LabelTensor`.
27+
28+
:Example:
29+
30+
>>> from pina.equation import SystemEquation, FixedValue, FixedGradient
31+
>>> from pina import LabelTensor
32+
>>> import torch
33+
>>> pts = LabelTensor(torch.rand(10, 2), labels=["x", "y"])
34+
>>> pts.requires_grad = True
35+
>>> output_ = torch.pow(pts, 2)
36+
>>> output_.labels = ["u", "v"]
37+
>>> system_equation = SystemEquation(
38+
... [
39+
... FixedValue(value=1.0, components=["u"]),
40+
... FixedGradient(value=0.0, components=["v"],d=["y"]),
41+
... ],
42+
... reduction="mean",
43+
... )
44+
>>> residual = system_equation.residual(pts, output_)
45+
1546
"""
1647

1748
def __init__(self, list_equation, reduction=None):
1849
"""
1950
Initialization of the :class:`SystemEquation` class.
2051
21-
:param Callable equation: A ``torch`` callable function used to compute
22-
the residual of a mathematical equation.
52+
:param list_equation: A list containing either callable functions or
53+
instances of :class:`~pina.equation.equation.Equation`, used to
54+
compute the residuals of mathematical equations.
55+
:type list_equation: list[Callable] | list[Equation]
2356
:param str reduction: The reduction method to aggregate the residuals of
2457
each equation. Available options are: ``None``, ``mean``, ``sum``,
2558
``callable``.
@@ -32,9 +65,10 @@ def __init__(self, list_equation, reduction=None):
3265
check_consistency([list_equation], list)
3366

3467
# equations definition
35-
self.equations = []
36-
for _, equation in enumerate(list_equation):
37-
self.equations.append(Equation(equation))
68+
self.equations = [
69+
equation if isinstance(equation, Equation) else Equation(equation)
70+
for equation in list_equation
71+
]
3872

3973
# possible reduction
4074
if reduction == "mean":
@@ -45,7 +79,7 @@ def __init__(self, list_equation, reduction=None):
4579
self.reduction = reduction
4680
else:
4781
raise NotImplementedError(
48-
"Only mean and sum reductions implemented."
82+
"Only mean and sum reductions are currenly supported."
4983
)
5084

5185
def residual(self, input_, output_, params_=None):

pina/model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"LowRankNeuralOperator",
1414
"Spline",
1515
"GraphNeuralOperator",
16+
"PirateNet",
1617
]
1718

1819
from .feed_forward import FeedForward, ResidualFeedForward
@@ -24,3 +25,4 @@
2425
from .low_rank_neural_operator import LowRankNeuralOperator
2526
from .spline import Spline
2627
from .graph_neural_operator import GraphNeuralOperator
28+
from .pirate_network import PirateNet

pina/model/block/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"LowRankBlock",
1919
"RBFBlock",
2020
"GNOBlock",
21+
"PirateNetBlock",
2122
]
2223

2324
from .convolution_2d import ContinuousConvBlock
@@ -35,3 +36,4 @@
3536
from .low_rank_block import LowRankBlock
3637
from .rbf_block import RBFBlock
3738
from .gno_block import GNOBlock
39+
from .pirate_network_block import PirateNetBlock
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""Module for the PirateNet block class."""
2+
3+
import torch
4+
from ...utils import check_consistency, check_positive_integer
5+
6+
7+
class PirateNetBlock(torch.nn.Module):
8+
"""
9+
The inner block of Physics-Informed residual adaptive network (PirateNet).
10+
11+
The block consists of three dense layers with dual gating operations and an
12+
adaptive residual connection. The trainable ``alpha`` parameter controls
13+
the contribution of the residual connection.
14+
15+
.. seealso::
16+
17+
**Original reference**:
18+
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
19+
*Simulating Three-dimensional Turbulence with Physics-informed Neural
20+
Networks*.
21+
DOI: `arXiv preprint arXiv:2507.08972.
22+
<https://arxiv.org/abs/2507.08972>`_
23+
"""
24+
25+
def __init__(self, inner_size, activation):
26+
"""
27+
Initialization of the :class:`PirateNetBlock` class.
28+
29+
:param int inner_size: The number of hidden units in the dense layers.
30+
:param torch.nn.Module activation: The activation function.
31+
"""
32+
super().__init__()
33+
34+
# Check consistency
35+
check_consistency(activation, torch.nn.Module, subclass=True)
36+
check_positive_integer(inner_size, strict=True)
37+
38+
# Initialize the linear transformations of the dense layers
39+
self.linear1 = torch.nn.Linear(inner_size, inner_size)
40+
self.linear2 = torch.nn.Linear(inner_size, inner_size)
41+
self.linear3 = torch.nn.Linear(inner_size, inner_size)
42+
43+
# Initialize the scales of the dense layers
44+
self.scale1 = torch.nn.Parameter(torch.zeros(inner_size))
45+
self.scale2 = torch.nn.Parameter(torch.zeros(inner_size))
46+
self.scale3 = torch.nn.Parameter(torch.zeros(inner_size))
47+
48+
# Initialize the adaptive residual connection parameter
49+
self._alpha = torch.nn.Parameter(torch.zeros(1))
50+
51+
# Initialize the activation function
52+
self.activation = activation()
53+
54+
def forward(self, x, U, V):
55+
"""
56+
Forward pass of the PirateNet block. It computes the output of the block
57+
by applying the dense layers with scaling, and combines the results with
58+
the input using the adaptive residual connection.
59+
60+
:param x: The input tensor.
61+
:type x: torch.Tensor | LabelTensor
62+
:param torch.Tensor U: The first shared gating tensor. It must have the
63+
same shape as ``x``.
64+
:param torch.Tensor V: The second shared gating tensor. It must have the
65+
same shape as ``x``.
66+
:return: The output tensor of the block.
67+
:rtype: torch.Tensor | LabelTensor
68+
"""
69+
# Compute the output of the first dense layer with scaling
70+
f = self.activation(self.linear1(x) * torch.exp(self.scale1))
71+
z1 = f * U + (1 - f) * V
72+
73+
# Compute the output of the second dense layer with scaling
74+
g = self.activation(self.linear2(z1) * torch.exp(self.scale2))
75+
z2 = g * U + (1 - g) * V
76+
77+
# Compute the output of the block
78+
h = self.activation(self.linear3(z2) * torch.exp(self.scale3))
79+
return self._alpha * h + (1 - self._alpha) * x
80+
81+
@property
82+
def alpha(self):
83+
"""
84+
Return the alpha parameter.
85+
86+
:return: The alpha parameter controlling the residual connection.
87+
:rtype: torch.nn.Parameter
88+
"""
89+
return self._alpha

0 commit comments

Comments
 (0)