-
Notifications
You must be signed in to change notification settings - Fork 2
/
pretrained_models.py
executable file
·20 lines (20 loc) · 1.11 KB
/
pretrained_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from transformers import BertConfig, BertModel, BertTokenizer
from transformers import XLNetConfig, XLNetModel, XLNetTokenizer
from transformers import RobertaConfig, RobertaModel, RobertaTokenizer
from transformers import AlbertConfig, AlbertModel, AlbertTokenizer
from transformers import XLMRobertaConfig, XLMRobertaModel, XLMRobertaTokenizer
from transformers import ElectraConfig, ElectraModel, ElectraTokenizer
from transformers import T5Config, T5EncoderModel, T5Tokenizer
from transformers import DebertaConfig, DebertaModel, DebertaTokenizer
from module.san_model import SanModel
MODEL_CLASSES = {
"bert": (BertConfig, BertModel, BertTokenizer),
"xlnet": (XLNetConfig, XLNetModel, XLNetTokenizer),
"roberta": (RobertaConfig, RobertaModel, RobertaTokenizer),
"albert": (AlbertConfig, AlbertModel, AlbertTokenizer),
"xlm": (XLMRobertaConfig, XLMRobertaModel, XLMRobertaTokenizer),
"san": (BertConfig, SanModel, BertTokenizer),
"electra": (ElectraConfig, ElectraModel, ElectraTokenizer),
"t5": (T5Config, T5EncoderModel, T5Tokenizer),
"deberta": (DebertaConfig, DebertaModel, DebertaTokenizer),
}