-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
支持TencentPretrain #57
Comments
GeForce RTX 2060上用TencentPretrain的run_patrickstar.sh跑了500步,对比了一下log。 PyTorch 感觉accuracy很相似,速度差点,不过可能是模型太小,这样派大星的overhead引起的。 |
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
一个蛋疼的问题,有人可能这样写代码,但是PatrickStar并无法区分weight tensor被两个param共享的情况。 针对tie weight,即第一层embedding weight和最后一层linear的weight共享参数,目前存在的问题:
|
环境1xV100 运行指令python preprocess.py --corpus_path corpora/book_review.txt --vocab_path models/google_zh_vocab.txt \
--dataset_path dataset.pt --processes_num 8 --target lm
python -m torch.distributed.launch --nproc_per_node=1 pretrain.py \
--dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
--output_model_path models/output_model.bin \
--config_path models/gpt2/config_patrickstar_v2.json --learning_rate 1e-4 \
--world_size 1 --gpu_ranks 0 \
--embedding word_pos --remove_embedding_layernorm \
--encoder transformer --mask causal --layernorm_positioning pre \
--target lm \
--total_steps 500 --batch_size 64 \
--fp16 --report_steps 100 \
--use_patrickstar 配置{
"emb_size": 768,
"feedforward_size": 3072,
"hidden_size": 768,
"hidden_act": "gelu_fast",
"heads_num": 4,
"layers_num": 3,
"max_seq_length": 1024,
"dropout": 0.1,
"embedding": "word_pos",
"remove_embedding_layernorm": true,
"encoder": "transformer",
"mask": "causal",
"layernorm_positioning": "pre",
"target": "lm"
} 运行结果:
|
TencentPretrain是TEG数据安全中心的repo,我们可以利用它们的模型结构和数据
https://git.woa.com/TencentNLP/TencentPretrain/merge_requests/61
TencentPretrain还有一个野生开源项目
https://github.com/dbiir/UER-py
The text was updated successfully, but these errors were encountered: