-
Notifications
You must be signed in to change notification settings - Fork 177
/
allennlp_jsonnet.py
74 lines (55 loc) · 2.42 KB
/
allennlp_jsonnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
Optuna example that optimizes a classifier configuration for IMDB movie review dataset.
This script is based on the example of allentune (https://github.com/allenai/allentune).
In this example, we optimize the validation accuracy of
sentiment classification using an AllenNLP jsonnet config file.
Since it is too time-consuming to use the training dataset,
we here use the validation dataset instead.
"""
import os.path
import shutil
import optuna
from optuna.integration import AllenNLPExecutor
from optuna.integration.allennlp import dump_best_config
from packaging import version
import allennlp
# This path trick is used since this example is also
# run from the root of this repository by CI.
EXAMPLE_DIR = os.path.dirname(os.path.abspath(__file__))
CONFIG_PATH = os.path.join(EXAMPLE_DIR, "classifier.jsonnet")
MODEL_DIR = "result"
BEST_CONFIG_PATH = "best_classifier.json"
def objective(trial):
trial.suggest_float("DROPOUT", 0.0, 0.5)
trial.suggest_int("EMBEDDING_DIM", 20, 50)
trial.suggest_int("MAX_FILTER_SIZE", 3, 6)
trial.suggest_int("NUM_FILTERS", 16, 32)
trial.suggest_int("HIDDEN_SIZE", 16, 32)
serialization_dir = os.path.join(MODEL_DIR, "test_{}".format(trial.number))
executor = AllenNLPExecutor(trial, CONFIG_PATH, serialization_dir, force=True)
return executor.run()
if __name__ == "__main__":
if version.parse(allennlp.__version__) < version.parse("2.0.0"):
raise RuntimeError(
"`allennlp>=2.0.0` is required for this example."
" If you want to use `allennlp<2.0.0`, please install `optuna==2.5.0`"
" and refer to the following example:"
" https://github.com/optuna/optuna/blob/v2.5.0/examples/allennlp/allennlp_jsonnet.py"
)
study = optuna.create_study(
direction="maximize",
storage="sqlite:///allennlp.db",
pruner=optuna.pruners.HyperbandPruner(),
sampler=optuna.samplers.TPESampler(seed=10),
)
study.optimize(objective, n_trials=50, timeout=600)
print("Number of finished trials: ", len(study.trials))
print("Best trial:")
trial = study.best_trial
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
dump_best_config(CONFIG_PATH, BEST_CONFIG_PATH, study)
print("\nCreated optimized AllenNLP config to `{}`.".format(BEST_CONFIG_PATH))
shutil.rmtree(MODEL_DIR)