You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm in the process of moving a model from PyTorch to Flux, and I'm going to catalog the challenges I've found in migrating, perhaps for a doc page or to improve docs generally. If others wish to add anything, please do!
Weights initialization:
In PyTorch, the default weight init method is kaiming_uniform aka He initialization.
In Flux, the default weight init method is glorot_uniform aka Xavier initialization.
PyTorch chooses a gain for the init function based on the type of nonlinearity specified, which defaults to leaky_relu, and uses this:
a = √5 # argument passed into PyTorch's `kaiming_uniform_` function, the "negative slope" of the rectifier
gain = √(2 / (1 + a ^ 2))
Flux defaults the kaiming_uniform gain to √2, which is what PyTorch would use if relu was specified rather than leaky_relu.
To replicate the PyTorch default, kaiming_uniform(gain = √(2 / (1 + a ^ 2))) can be provided for the init keyword argument of the layer constructor.
Bias initialization:
PyTorch initializes bias parameters with uniformly random values between +/- 1 / √(fan_in), where fan_in in Flux is first(nfan(filter..., cin÷groups, cout)) for Conv layers. For Dense layers, last(nfan(out, in)) instead. Flux initializes them all to zero.
Layers:
In PyTorch, there are separate objects for different dimensionality (e.g. conv1d, conv2d, conv3d). In Flux, the dimensionality is specified by the tuple provided for the kernel of Conv.
In PyTorch, activation functions are inserted as separate steps in the chain, as equals to layers. In Flux, they are provided as an argument to the layer constructor.
Sequential => Chain
Linear => Dense
Upsample in Flux (via NNlib) is equivalent to align_corners=True with PyTorch's Upsample, but the default there is False. Note that this makes the gradients depend on image size.
When building a custom layer, the (::MyLayer)(input) = method is the equivalent of def forward(self, input):
Often if PyTorch has a method, Flux (many via MLUtils.jl) has the same method. e.g. unsqueeze to insert a 1-length dimension, or erf to compute the error function. Note that the inverse of unsqueeze is actually Base.dropdims.
Training:
A single step in Flux is simply gradient followed by update!. In PyTorch there are more steps: the optimizer's gradients must be zeroed with .zero_grad(), then the loss is calculated, then the tensor returned from the loss function is backward-propagated with .backward() to compute the gradients, and finally the optimizer is stepped forward and the model parameters are updated with .step(). In Flux, both parts can be combined with the train! function, and can also be used to iterate over a set of paired training inputs and outputs.
In Flux, an optimizer state object is first obtained by setup, and this state is passed to the training loop. In PyTorch, the optimizer object itself is manipulated in the training loop.
Added 10/13/24:
Upsample in PyTorch is actually deprecated in favor of nn.functional.interpolate, but the former just relies on the latter anyway.
clip_grad_norm_ in PyTorch (side note: they've adopted a trailing underscore to indicate a modifying function) can be accomplished by creating an Optimisers.OptimiserChain(ClipNorm(___), optimizer).
Added 10/25/24:
torch.where, which produces an array with elements of each type depending on a mask, can be accomplished with a broadcasted ifelse
AdamW optimizer is implemented differently. In PyTorch, the weight decay is moderated by the learning rate. In Flux, it is not. See FluxML/Optimisers.jl#182 for a workaround until Flux makes something built-in available.
The text was updated successfully, but these errors were encountered:
@BioTurboNick Reported on Slack that with the default Flux initialization, his model would get stuck in an all zeros state, but not with the PyTorch init.
Regarding pytorch's initializations, there is a long discussion pytorch/pytorch#18182
on how to update their default initialization to current best practice.
There seems to be consensus that zero bias is a good init, so we should stick to that.
Pytorch initizializes with kaiming uniform [-1/sqrt(in_features), 1/sqrt(in_features)], while Flux's does glorot uniform [-sqrt(6 / (in_features + out_features)), +sqrt(6/(in_features + out_features))]. I couldn't find any convincing paper supporting the choice of one over the other.
I would therefore keep initialization as it is.
For AdamW instead I think we should align with pytorch in Flux v0.15 and solve #2433.
I'm in the process of moving a model from PyTorch to Flux, and I'm going to catalog the challenges I've found in migrating, perhaps for a doc page or to improve docs generally. If others wish to add anything, please do!
Weights initialization:
In PyTorch, the default weight init method is
kaiming_uniform
aka He initialization.In Flux, the default weight init method is
glorot_uniform
aka Xavier initialization.PyTorch chooses a
gain
for the init function based on the type of nonlinearity specified, which defaults toleaky_relu
, and uses this:Flux defaults the
kaiming_uniform
gain to√2
, which is what PyTorch would use ifrelu
was specified rather thanleaky_relu
.To replicate the PyTorch default,
kaiming_uniform(gain = √(2 / (1 + a ^ 2)))
can be provided for theinit
keyword argument of the layer constructor.Bias initialization:
PyTorch initializes bias parameters with uniformly random values between +/-
1 / √(fan_in)
, wherefan_in
in Flux isfirst(nfan(filter..., cin÷groups, cout))
for Conv layers. For Dense layers,last(nfan(out, in))
instead. Flux initializes them all to zero.Layers:
In PyTorch, there are separate objects for different dimensionality (e.g.
conv1d, conv2d
,conv3d
). In Flux, the dimensionality is specified by the tuple provided for the kernel ofConv
.In PyTorch, activation functions are inserted as separate steps in the chain, as equals to layers. In Flux, they are provided as an argument to the layer constructor.
Sequential
=>Chain
Linear
=>Dense
Upsample
in Flux (via NNlib) is equivalent toalign_corners=True
with PyTorch'sUpsample
, but the default there isFalse
. Note that this makes the gradients depend on image size.When building a custom layer, the
(::MyLayer)(input) =
method is the equivalent ofdef forward(self, input):
Often if PyTorch has a method, Flux (many via MLUtils.jl) has the same method. e.g.
unsqueeze
to insert a 1-length dimension, orerf
to compute the error function. Note that the inverse ofunsqueeze
is actuallyBase.dropdims
.Training:
A single step in Flux is simply
gradient
followed byupdate!
. In PyTorch there are more steps: the optimizer's gradients must be zeroed with.zero_grad()
, then the loss is calculated, then the tensor returned from the loss function is backward-propagated with.backward()
to compute the gradients, and finally the optimizer is stepped forward and the model parameters are updated with.step()
. In Flux, both parts can be combined with thetrain!
function, and can also be used to iterate over a set of paired training inputs and outputs.In Flux, an optimizer state object is first obtained by
setup
, and this state is passed to the training loop. In PyTorch, the optimizer object itself is manipulated in the training loop.Added 10/13/24:
Upsample
in PyTorch is actually deprecated in favor ofnn.functional.interpolate
, but the former just relies on the latter anyway.clip_grad_norm_
in PyTorch (side note: they've adopted a trailing underscore to indicate a modifying function) can be accomplished by creating anOptimisers.OptimiserChain(ClipNorm(___), optimizer)
.Added 10/25/24:
torch.where
, which produces an array with elements of each type depending on a mask, can be accomplished with a broadcastedifelse
AdamW
optimizer is implemented differently. In PyTorch, the weight decay is moderated by the learning rate. In Flux, it is not. See FluxML/Optimisers.jl#182 for a workaround until Flux makes something built-in available.The text was updated successfully, but these errors were encountered: