-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce mBart #29
base: main
Are you sure you want to change the base?
Introduce mBart #29
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me. Just left a few nit comments. Just wondering thoug; have you have any asset cards that we can bundle with this PR? How did you verify parity with the original fairseq implementation?
num_encoder_attn_heads=16, | ||
num_decoder_attn_heads=16, | ||
ffn_inner_dim=4096, | ||
pos_encoder_type="learned", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like pos_encoder_type
and norm_order
are always learned
, and POST
according to this. If that is the case, I would suggest removing these configuration parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm having to do this to successfully load the mBart checkpoint with UnitY: https://github.com/fairinternal/seamless_communication/pull/28/files#diff-189811785a49637a011c2db015430cfd708d92f832f8ef30ed7e10dc7f922635R103
The argument about norm_order
makes sense, I'll remove that.
|
||
def build_frontend(self, embed: Embedding) -> TransformerFrontend: | ||
"""Build a Transformer encoder/decoder front-end.""" | ||
if self.config.pos_encoder_type == "sinusoidal": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above, I don't think that this is necessary. mBART always uses learned positional embeddings.
@cbalioglu I'm yet to verify parity with the fairseq mBart model by running forward passes. The asset has an internal checkpoint, wondering what the best way to open-source that would be. |
You can use one of mBARTs public checkpoints here (e.g. mbart.CC25) to verify parity and include it as an asset card in your PR. |
14a5b9b
to
c52ce3a
Compare
What does this PR do? Please describe:
Implements the mBart model and its text tokenizer. We are able to successfully load the base model.
Testing the text tokenizer:
We see that the encoded_tokens is the same as the sample_tokens and the decoded_str is the same as the round_trip_str.
TODO: Check parity for forward pass through the same checkpoint with fairseq1.
Fixes #{issue number}
Does your PR introduce any breaking changes? If yes, please list them:
List of all backwards-incompatible changes.
Check list: