-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
executable file
·60 lines (50 loc) · 2.27 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#!/bin/env python3
import argparse
import rntn
import tree as tr
DATA_DIR = "trees"
def main():
# Parse arguments
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-d", "--dim", type=int, default=25,
help="Vector dimension")
parser.add_argument("-k", "--output-dim", type=int, default=5,
help="Number of output classes")
parser.add_argument("-e", "--epochs", type=int, default=50,
help="Maximum number of epochs")
parser.add_argument("-f", "--dataset", type=str, default="train",
choices=['train', 'dev', 'test'], help="Dataset")
parser.add_argument("-l", "--learning-rate", type=float, default=1e-2,
help="Learning rate")
parser.add_argument("-b", "--batch-size", type=int, default=30,
help="Batch size")
parser.add_argument("-r", "--reg", type=float, default=1e-6,
help="Regularization parameter")
parser.add_argument("-t", "--test", action="store_true",
help="Test a model")
parser.add_argument("-m", "--model", type=str, default='models/RNTN.pickle',
help="Model file")
parser.add_argument("-o", "--optimizer", type=str, default='adagrad',
help="Optimizer", choices=['sgd', 'adagrad'])
args = parser.parse_args()
# Test
if args.test:
print("Testing...")
model = rntn.RNTN.load(args.model)
test_trees = tr.load_trees(args.dataset)
cost, result = model.test(test_trees)
accuracy = 100.0 * result.trace() / result.sum()
print("Cost = {:.2f}, Correct = {:.0f} / {:.0f}, Accuracy = {:.2f} %".format(
cost, result.trace(), result.sum(), accuracy))
else:
# Initialize the model
model = rntn.RNTN(
dim=args.dim, output_dim=args.output_dim, batch_size=args.batch_size,
reg=args.reg, learning_rate=args.learning_rate, max_epochs=args.epochs,
optimizer=args.optimizer)
# Train
train_trees = tr.load_trees(args.dataset)
model.fit(train_trees, export_filename=args.model)
if __name__ == '__main__':
main()