-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain.py
83 lines (64 loc) · 2.92 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import logging
import os.path
import torch
import concreteness
log = logging.getLogger(__name__)
DEFAULT_K = 50
def _setup_logging(verbose):
logging_level = logging.DEBUG if verbose else logging.INFO
logging_format = '%(asctime)s [%(levelname)s] %(message)s'
logging.basicConfig(level=logging_level, format=logging_format)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataset-dir", type=str, required=True,
help="Path to the directory of the dataset.")
parser.add_argument("-c", "--cache-dir", type=str, required=False,
help="Path to a directory to use for cache.")
parser.add_argument("-v", "--verbose", help="Increase output verbosity.",
action="store_true")
parser.add_argument("-k", help="Number of neighbos to search for.", type=int,
default=DEFAULT_K, required=False)
parser.add_argument("-t", help="Type of dataset: mirflickr | mscoco", type=str,
default="mirflickr")
args = parser.parse_args()
_setup_logging(args.verbose)
vectors_file = None
annoy_index_file = None
if args.cache_dir is not None:
vectors_file = os.path.join(args.cache_dir, "vectors.pt")
annoy_index_file = os.path.join(args.cache_dir, "index.ann")
if args.t == "mirflickr":
from mirflickr import MirflickrImagesDataset as Dataset
images_directory = os.path.join(args.dataset_dir, "images")
tags_directory = os.path.join(args.dataset_dir, "tags")
elif args.t == "mscoco":
from mscoco import MSCOCODataset as Dataset
images_directory = os.path.join(args.dataset_dir, "train2014/")
tags_directory = os.path.join(args.dataset_dir, "annotations/captions_train2014.json")
else:
raise Exception("Data type {:s} not supported.".format(args.t))
log.info("Loading dataset.")
dataset = Dataset(
images_directory, tags_directory, transform=concreteness.get_tensor_for_image
)
log.info("Dataset is loaded.")
if vectors_file is not None and os.path.isfile(vectors_file):
img_vectors = torch.load(vectors_file)
else:
log.info("Building image vectors.")
img_vectors = concreteness.build_image_vectors(dataset)
log.info("Built image vectors.")
if vectors_file is not None:
torch.save(img_vectors, vectors_file)
log.info("Saved image vectors to %s", vectors_file)
log.info("Computing concreteness.")
nns = concreteness.build_nns(img_vectors, args.k, annoy_index_file=annoy_index_file)
concreteness_dict = concreteness.get_concreteness(dataset, nns, args.k)
log.info("Done!")
sorted_concreteness = sorted(concreteness_dict.items(), key=lambda x: x[1], reverse=True)
len(sorted_concreteness)
from IPython import embed
embed(using=False)
if __name__ == "__main__":
main()