Skip to content

Commit cbabef6

Browse files
committed
Add en_ewt config example, upgrade AllenNLP to 0.9.0
1 parent deb8e07 commit cbabef6

File tree

6 files changed

+171
-12
lines changed

6 files changed

+171
-12
lines changed

README.md

+27-2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,25 @@ be saved under `logs/multilingual`. Note that this process is highly memory inte
4848
12+ GB of GPU memory (requirements are half if fp16 is enabled in AllenNLP, but this [requires custom changes to the library](https://github.com/allenai/allennlp/issues/2149)).
4949
The training may take 20 or more days to complete all 80 epochs depending on the type of your GPU.
5050

51+
### Training on Other Datasets
52+
53+
An example config is given for fine-tuning on just English EWT. Just run:
54+
55+
```bash
56+
python train.py --config config/ud/en/udify_bert_finetune_en_ewt.json --name en_ewt
57+
```
58+
59+
To run your own dataset, copy `config/ud/multilingual/udify_bert_finetune_multilingual.json` and modify the following
60+
json parameters:
61+
62+
- `train_data_path`, `validation_data_path`, and `test_data_path` to the paths of the dataset conllu files. These can
63+
be optionally `null`.
64+
- `directory_path` to `data/vocab/<dataset_name>/vocabulary`.
65+
- `warmup_steps` and `start_step` to be equal to the number of steps in the first epoch. A good initial value is in the
66+
range `100-1000`. Alternatively, run the training script first to see the number of steps to the right of the progress
67+
bar.
68+
- If using just one treebank, optionally add `xpos` to the `tasks` list.
69+
5170
### Viewing Model Performance
5271

5372
One can view how well the models are performing by running TensorBoard
@@ -110,9 +129,15 @@ python train.py --config config/sigmorphon/multilingual/udify_bert_sigmorphon_mu
110129

111130
1. When fine-tuning, my scores/metrics show poor performance.
112131

113-
It should take about 10 epochs to start seeing good scores coming from all the metrics, and 80 epochs to be competitive with UDPipe Future.
132+
It should take about 10 epochs to start seeing good scores coming from all the metrics, and 80 epochs to be competitive
133+
with UDPipe Future.
114134

115-
One caveat is that if you use a subset of treebanks for fine-tuning instead of all 124 UD v2.3 treebanks, *you must modify the configuration file*. Make sure to tune the learning rate scheduler to the number of training steps. Copy the [`udify_bert_finetune_multilingual.json`](https://github.com/Hyperparticle/udify/blob/master/config/ud/multilingual/udify_bert_finetune_multilingual.json) config and modify the `"warmup_steps"` and `"start_step"` values. A good initial choice would be to set both to be equal to the number of training batches of one epoch ( run the training script first to see the batches remaining).
135+
One caveat is that if you use a subset of treebanks for fine-tuning instead of all 124 UD v2.3 treebanks,
136+
*you must modify the configuration file*. Make sure to tune the learning rate scheduler to the number of
137+
training steps. Copy the [`udify_bert_finetune_multilingual.json`](https://github.com/Hyperparticle/udify/blob/master/config/ud/multilingual/udify_bert_finetune_multilingual.json)
138+
config and modify the `"warmup_steps"` and `"start_step"` values. A good initial choice would be to set both to be
139+
equal to the number of training batches of one epoch (run the training script first to see the batches remaining, to
140+
the right of the progress bar).
116141

117142
## Cite This Paper
118143

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
{
2+
"dataset_reader": {
3+
"lazy": false,
4+
"token_indexers": {
5+
"tokens": {
6+
"type": "single_id",
7+
"lowercase_tokens": true
8+
},
9+
"bert": {
10+
"type": "udify-bert-pretrained",
11+
"pretrained_model": "config/archive/bert-base-multilingual-cased/vocab.txt",
12+
"do_lowercase": false,
13+
"use_starting_offsets": true
14+
}
15+
}
16+
},
17+
"train_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-train.conllu",
18+
"validation_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-dev.conllu",
19+
"test_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-test.conllu",
20+
"vocabulary": {
21+
"directory_path": "data/vocab/en_ewt/vocabulary"
22+
},
23+
"model": {
24+
"word_dropout": 0.2,
25+
"mix_embedding": 12,
26+
"layer_dropout": 0.1,
27+
"tasks": ["upos", "feats", "lemmas", "deps"],
28+
"text_field_embedder": {
29+
"type": "udify_embedder",
30+
"dropout": 0.5,
31+
"allow_unmatched_keys": true,
32+
"embedder_to_indexer_map": {
33+
"bert": ["bert", "bert-offsets"]
34+
},
35+
"token_embedders": {
36+
"bert": {
37+
"type": "udify-bert-pretrained",
38+
"pretrained_model": "bert-base-multilingual-cased",
39+
"requires_grad": true,
40+
"dropout": 0.15,
41+
"layer_dropout": 0.1,
42+
"combine_layers": "all"
43+
}
44+
}
45+
},
46+
"encoder": {
47+
"type": "pass_through",
48+
"input_dim": 768
49+
},
50+
"decoders": {
51+
"upos": {
52+
"encoder": {
53+
"type": "pass_through",
54+
"input_dim": 768
55+
}
56+
},
57+
"feats": {
58+
"encoder": {
59+
"type": "pass_through",
60+
"input_dim": 768
61+
},
62+
"adaptive": true
63+
},
64+
"lemmas": {
65+
"encoder": {
66+
"type": "pass_through",
67+
"input_dim": 768
68+
},
69+
"adaptive": true
70+
},
71+
"deps": {
72+
"tag_representation_dim": 256,
73+
"arc_representation_dim": 768,
74+
"encoder": {
75+
"type": "pass_through",
76+
"input_dim": 768
77+
}
78+
}
79+
}
80+
},
81+
"iterator": {
82+
"batch_size": 32,
83+
"maximum_samples_per_batch": ["num_tokens", 32 * 100]
84+
},
85+
"trainer": {
86+
"num_epochs": 80,
87+
"patience": 80,
88+
"num_serialized_models_to_keep": 1,
89+
"should_log_learning_rate": true,
90+
"summary_interval": 100,
91+
"optimizer": {
92+
"type": "bert_adam",
93+
"b1": 0.9,
94+
"b2": 0.99,
95+
"weight_decay": 0.01,
96+
"lr": 1e-3,
97+
"parameter_groups": [
98+
[["^text_field_embedder.*.bert_model.embeddings",
99+
"^text_field_embedder.*.bert_model.encoder"], {}],
100+
[["^text_field_embedder.*._scalar_mix",
101+
"^text_field_embedder.*.pooler",
102+
"^scalar_mix",
103+
"^decoders",
104+
"^shared_encoder"], {}]
105+
]
106+
},
107+
"learning_rate_scheduler": {
108+
"type": "ulmfit_sqrt",
109+
"model_size": 1,
110+
"warmup_steps": 392,
111+
"start_step": 392,
112+
"factor": 5.0,
113+
"gradual_unfreezing": true,
114+
"discriminative_fine_tuning": true,
115+
"decay_factor": 0.04
116+
}
117+
},
118+
"udify_replace": [
119+
"dataset_reader.token_indexers",
120+
"model.text_field_embedder",
121+
"model.encoder",
122+
"model.decoders.xpos",
123+
"model.decoders.deps.encoder",
124+
"model.decoders.upos.encoder",
125+
"model.decoders.feats.encoder",
126+
"model.decoders.lemmas.encoder",
127+
"trainer.learning_rate_scheduler",
128+
"trainer.optimizer"
129+
]
130+
}

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
allennlp==0.8.5
1+
allennlp==0.9.0
22
tensorflow
33
pandas
44
jupyter

udify/models/dependency_decoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def _greedy_decode(self,
331331
attended_arcs = attended_arcs + torch.diag(attended_arcs.new(mask.size(1)).fill_(-numpy.inf))
332332
# Mask padded tokens, because we only want to consider actual words as heads.
333333
if mask is not None:
334-
minus_mask = (1 - mask).byte().unsqueeze(2)
334+
minus_mask = (1 - mask).bool().unsqueeze(2)
335335
attended_arcs.masked_fill_(minus_mask, -numpy.inf)
336336

337337
# Compute the heads greedily.

udify/models/udify_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@ def token_dropout(tokens: torch.LongTensor,
185185
device = tokens.device
186186

187187
# This creates a mask that only considers unpadded tokens for mapping to oov
188-
padding_mask = torch.ones(tokens.size(), dtype=torch.uint8).to(device)
188+
padding_mask = torch.ones(tokens.size(), dtype=torch.bool).to(device)
189189
for pad in padding_tokens:
190-
padding_mask &= tokens != pad
190+
padding_mask &= (tokens != pad)
191191

192192
# Create a uniformly random mask selecting either the original words or OOV tokens
193193
dropout_mask = (torch.empty(tokens.size()).uniform_() < p).to(device)

udify/modules/bert_pretrained.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,16 @@ def get_padding_lengths(self, token: int) -> Dict[str, int]: # pylint: disable=
253253
return {}
254254

255255
@overrides
256-
def pad_token_sequence(self,
257-
tokens: Dict[str, List[int]],
258-
desired_num_tokens: Dict[str, int],
259-
padding_lengths: Dict[str, int]) -> Dict[str, List[int]]: # pylint: disable=unused-argument
260-
return {key: pad_sequence_to_length(val, desired_num_tokens[key])
261-
for key, val in tokens.items()}
256+
def as_padded_tensor(
257+
self,
258+
tokens: Dict[str, List[int]],
259+
desired_num_tokens: Dict[str, int],
260+
padding_lengths: Dict[str, int],
261+
) -> Dict[str, torch.Tensor]:
262+
return {
263+
key: torch.LongTensor(pad_sequence_to_length(val, desired_num_tokens[key]))
264+
for key, val in tokens.items()
265+
}
262266

263267
@overrides
264268
def get_keys(self, index_name: str) -> List[str]:

0 commit comments

Comments
 (0)