This repository contains the code necessary to train MEXMA, as presented in the MEXMA paper.
Python version 3.11.9
- conda create --name mexma python=3.9.11
- conda activate mexma
- git clone [email protected]:facebookresearch/mexma.git
- cd mexma
- pip install -r requirements.txt
More details about the training data are present in data/train_data.
You need to add the xsim file to evaluation/xsim, in order to be able to evaluate on it during training.
Additionally, you also need to add the FLORES200 dataset inside data/flores200, which you can get here.
In order to train the model, simply launch:
torchrun main.py \
--encoder xlm-roberta-large \
--max_model_context_length 200 \
--checkpoint None \
--mlm_loss_weight 1 \
--cls_loss_weight 1 \
--koleo_loss_weight 0.01 \
--number_of_linear_layers 0 \
--linear_layers_inputs_dims None \
--linear_layers_outputs_dims None \
--number_of_transformer_layers_in_head 6 \
--number_of_transformer_attention_heads_in_head 8 \
--initialization_method torch_default \
--train_data_file None \
--test_data_file None \
--hf_dataset_directory [YOUR_DIRECTORY_HERE] \
--batch_size 150 \
--workers 12 \
--device cuda \
--lr 0.0001 \
--epochs 3 \
--start_epoch 0 \
--src_mlm_probability 0.4 \
--trg_mlm_probability 0.4 \
--number_of_iterations_to_accumulated_gradients 2 \
--testing_frequency 5000000 \
--saving_frequency 2000 \
--mixed_precision_training \
--clip_grad_norm 1.2 \
--wd None \
--lr_scheduler_type cosineannealinglr \
--lr_warmup_percentage 0.3 \
--lr_warmup_method linear \
--lr_warmup_decay 0.1 \
--print_freq 10 \
--save_model_checkpoint 50000 \
--no_wandb \
--flores_200_src_languages acm_Arab aeb_Arab afr_Latn amh_Ethi ary_Arab arz_Arab asm_Beng azb_Arab azj_Latn bel_Cyrl ben_Beng bos_Latn bul_Cyrl cat_Latn ces_Latn ckb_Arab cym_Latn dan_Latn deu_Latn ell_Grek epo_Latn est_Latn eus_Latn fin_Latn fra_Latn gla_Latn gle_Latn glg_Latn guj_Gujr hau_Latn heb_Hebr hin_Deva hrv_Latn hun_Latn hye_Armn ind_Latn isl_Latn ita_Latn jav_Latn jpn_Jpan kan_Knda kat_Geor kaz_Cyrl khm_Khmr kir_Cyrl kor_Hang lao_Laoo mal_Mlym mar_Deva mkd_Cyrl mya_Mymr nld_Latn nno_Latn nob_Latn npi_Deva pol_Latn por_Latn ron_Latn rus_Cyrl san_Deva sin_Sinh slk_Latn slv_Latn snd_Arab som_Latn spa_Latn srp_Cyrl sun_Latn swe_Latn swh_Latn tam_Taml tel_Telu tha_Thai tur_Latn uig_Arab ukr_Cyrl urd_Arab vie_Latn xho_Latn zho_Hant
In order to use the MEXMA model, you can just load it from HuggingFace:
from transformers import AutoTokenizer, XLMRobertaModel
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large")
model = XLMRobertaModel.from_pretrained("facebook/MEXMA", add_pooling_layer=False)
example_sentences = ['Sentence1', 'Sentence2']
example_inputs = tokenizer(example_sentences, return_tensors='pt')
outputs = model(**example_inputs)
sentence_representation = outputs.last_hidden_state[:, 0]
print(sentence_representation.shape) # torch.Size([2, 1024])
MEXMA is MIT licensed. See the LICENSE file for details. However portions of the project are available under separate license terms: backbone/block_diagonal_roberta.py and losses/koleo.py are licensed under the Apache-2.0 license.
@misc{janeiro2024mexma,
title={MEXMA: Token-level objectives improve sentence representations},
author={João Maria Janeiro and Benjamin Piwowarski and Patrick Gallinari and Loïc Barrault},
year={2024},
eprint={2409.12737},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2409.12737},
}