-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
82 lines (61 loc) · 1.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python
if __name__ != '__main__': raise Exception("Do not import me!")
# ruff: noqa: E402
import chainer
import logging
import matplotlib
import numpy as np
matplotlib.use('Agg')
from cvfinetune.parser.utils import populate_args
from pathlib import Path
from moth_classifier import core
from moth_classifier.utils import parser
def main(args, experiment_name="Moth classifier"):
if args.mode in ["evaluate", "extract"]:
populate_args(args,
ignore=[
"mode", "load", "load_path", "gpu",
"mpi", "n_jobs", "batch_size",
"test_fold_id",
"center_crop_on_val",
"only_klass",
],
fc_params=[
"fc/b",
"fc8/b",
"fc6/b",
"wrapped/output/fc/b",
"wrapped/output/fc2/b",
]
)
if args.cross_dataset:
args.dataset = args.cross_dataset
chainer.set_debug(args.debug)
MiB = 1024**2
chainer.backends.cuda.set_max_workspace_size(512 * MiB)
if args.debug:
logging.warning("DEBUG MODE ENABLED!")
args.dtype = np.empty(0, dtype=chainer.get_dtype()).dtype.name
logging.info(f"Default dtype: {args.dtype}")
tuner, comm = core.finetuner.new(args, experiment_name)
tuner.profile_images()
logging.info("Fitting size model, if possible")
tuner.clf.fit_size_model(tuner.train_data)
if args.mode == "train":
tuner.run(opts=args,
trainer_cls=core.Trainer
)
elif args.mode == "evaluate":
dest_folder = Path(args.load).parent
eval_fname = dest_folder / args.eval_output
tuner.evaluate(eval_fname, force=args.force)
elif args.mode == "extract":
dest_folder = Path(args.load).parent
if args.suffix:
feats = dest_folder / f"features.{args.suffix}.npz"
else:
feats = dest_folder / "features.npz"
tuner.extract_to(feats)
else:
raise NotImplementedError(f"mode not implemented: {args.mode}")
main(parser.parse_args())