Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layer Normalization #4

Open
Aktsvigun opened this issue Jul 8, 2021 · 3 comments
Open

Layer Normalization #4

Aktsvigun opened this issue Jul 8, 2021 · 3 comments

Comments

@Aktsvigun
Copy link

Hi,
thanks for a great implementation!

I wanted to clarify one thing that mismatches with the code, proposed in the article itself. In your code, you pre-normalize inputs, so that they are passed through LayerNorm before FFT. In the code, presented in the article, they have:

class FNetEncoderBlock ( nn . Module ) :
30 f o u r i e r _ l a y e r : Fou rie rT ran sfo rmLa ye r
31 f f _ l a y e r : FeedForwardLayer
32
33 @nn. compact
34 def _ _ c a l l _ _ ( s e l f , x , d e t e r m i n i s t i c ) :
35 m i x i n g _ o ut p ut = s e l f . f o u r i e r _ l a y e r ( x )
36 x = nn . LayerNorm (1 e−12 , name=" mixing_laye r_no rm " ) ( x + &
m i x i n g _ o ut p ut )
37 fe ed _fo rw a rd _o utp ut = s e l f . f f _ l a y e r ( x , d e t e r m i n i s t i c )
38 r e t u r n nn . LayerNorm (
39 1e−12 , name=" output_la ye r_no rm " ) ( x + fee d_fo rwa rd _outp ut )

which in my view is done in the opposite order.
Am I mistaken or is it indeed a bug?

@Aktsvigun
Copy link
Author

I see this code is damaged. Here is the image (A.5 in the paper):
Снимок экрана 2021-07-08 в 12 32 13

@Aktsvigun
Copy link
Author

A similar question regards dropout in the FeedForward layer. You have it added twice, while in the paper they add it only in the end:
Снимок экрана 2021-07-08 в 12 36 30

@erksch
Copy link

erksch commented Jul 25, 2021

@Aktsvigun you can checkout our repo https://github.com/erksch/fnet-pytorch. We reimplemented the architecture precisely to such a degree that we can even use the official checkpoints (converted from Jax to PyTorch).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants