Skip to content

Commit

Permalink
Support flickr30k.
Browse files Browse the repository at this point in the history
  • Loading branch information
ruotianluo committed Jul 18, 2019
1 parent 8118670 commit c4d5d13
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 53 deletions.
53 changes: 4 additions & 49 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,56 +23,11 @@ Pretrained models are provided [here](https://drive.google.com/open?id=0B7fNdx_j

If you want to do evaluation only, you can then follow [this section](#generate-image-captions) after downloading the pretrained models (and also the pretrained resnet101).

## Train your own network on COCO
## Train your own network on COCO/Flickr30k

### Download COCO captions and preprocess them
### Prepare data.

Download preprocessed coco captions from [link](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip) from Karpathy's homepage. Extract `dataset_coco.json` from the zip file and copy it in to `data/`. This file provides preprocessed captions and also standard train-val-test splits.

Then do:

```bash
$ python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk
```

`prepro_labels.py` will map all words that occur <= 5 times to a special `UNK` token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into `data/cocotalk.json` and discretized caption data are dumped into `data/cocotalk_label.h5`.

### Download COCO dataset and pre-extract the image features (Skip if you are using bottom-up feature)

Download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`.

Then:

```
$ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT
```


`prepro_feats.py` extract the resnet101 features (both fc feature and last conv feature) of each image. The features are saved in `data/cocotalk_fc` and `data/cocotalk_att`, and resulting files are about 200GB.

(Check the prepro scripts for more options, like other resnet models or other attention sizes.)

**Warning**: the prepro script will fail with the default MSCOCO data because one of their images is corrupted. See [this issue](https://github.com/karpathy/neuraltalk2/issues/4) for the fix, it involves manually replacing one image in the dataset.

### Download Bottom-up features (Skip if you are using resnet features)

Download pre-extracted feature from [link](https://github.com/peteanderson80/bottom-up-attention). You can either download adaptive one or fixed one.

For example:
```
mkdir data/bu_data; cd data/bu_data
wget https://storage.googleapis.com/bottom-up-attention/trainval.zip
unzip trainval.zip
```

Then:

```bash
python script/make_bu_data.py --output_dir data/cocobu
```

This will create `data/cocobu_fc`, `data/cocobu_att` and `data/cocobu_box`. If you want to use bottom-up feature, you can just follow the following steps and replace all cocotalk with cocobu.
We now support both flickr30k and COCO. See details in `data/README.md`. (Note: the later sections assume COCO dataset; it should be trivial to use flickr30k.)

### Start training

Expand Down Expand Up @@ -108,7 +63,7 @@ $ bash scripts/copy_model.sh fc fc_rl

Then
```bash
$ python train.py --id fc_rl --caption_model fc --input_json data/cocotalk.json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --batch_size 10 --learning_rate 5e-5 --start_from log_fc_rl --checkpoint_path log_fc_rl --save_checkpoint_every 6000 --language_eval 1 --val_images_use 5000 --self_critical_after 30
$ python train.py --id fc_rl --caption_model fc --input_json data/cocotalk.json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --batch_size 10 --learning_rate 5e-5 --start_from log_fc_rl --checkpoint_path log_fc_rl --save_checkpoint_every 6000 --language_eval 1 --val_images_use 5000 --self_critical_after 30 --cached_tokens coco-train-idxs
```

You will see a huge boost on Cider score, : ).
Expand Down
100 changes: 100 additions & 0 deletions data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Prepare data

## COCO

### Download COCO captions and preprocess them

Download preprocessed coco captions from [link](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip) from Karpathy's homepage. Extract `dataset_coco.json` from the zip file and copy it in to `data/`. This file provides preprocessed captions and also standard train-val-test splits.

Then do:

```bash
$ python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk
```

`prepro_labels.py` will map all words that occur <= 5 times to a special `UNK` token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into `data/cocotalk.json` and discretized caption data are dumped into `data/cocotalk_label.h5`.

### Download COCO dataset and pre-extract the image features (Skip if you are using bottom-up feature)

Download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`.

Then:

```
$ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT
```


`prepro_feats.py` extract the resnet101 features (both fc feature and last conv feature) of each image. The features are saved in `data/cocotalk_fc` and `data/cocotalk_att`, and resulting files are about 200GB.

(Check the prepro scripts for more options, like other resnet models or other attention sizes.)

**Warning**: the prepro script will fail with the default MSCOCO data because one of their images is corrupted. See [this issue](https://github.com/karpathy/neuraltalk2/issues/4) for the fix, it involves manually replacing one image in the dataset.

### Download Bottom-up features (Skip if you are using resnet features)

Download pre-extracted feature from [link](https://github.com/peteanderson80/bottom-up-attention). You can either download adaptive one or fixed one.

For example:
```
mkdir data/bu_data; cd data/bu_data
wget https://storage.googleapis.com/bottom-up-attention/trainval.zip
unzip trainval.zip
```

Then:

```bash
python script/make_bu_data.py --output_dir data/cocobu
```

This will create `data/cocobu_fc`, `data/cocobu_att` and `data/cocobu_box`. If you want to use bottom-up feature, you can just follow the following steps and replace all cocotalk with cocobu.

## Flickr30k.

It's similar.

```
python scripts/prepro_labels.py --input_json data/dataset_flickr30k.json --output_json data/f30ktalk.json --output_h5 data/f30ktalk
python scripts/prepro_ngrams.py --input_json data/dataset_flickr30k.json --dict_json data/f30ktalk.json --output_pkl data/f30k-train --split train
```

This is to generate the coco-like annotation file for evaluation using coco-caption.

```
python scripts/prepro_reference_json.py --input_json data/dataset_flickr30k.json --output_json data/f30k_captions4eval.json
```

### Feature extraction

For resnet feature, you can do the same thing as COCO.

For bottom-up feature, you can download from [link](https://github.com/kuanghuei/SCAN)

`wget https://scanproject.blob.core.windows.net/scan-data/data.zip`

and then convert to a pth file using the following script:

```
import numpy as np
import os
import torch
from tqdm import tqdm
out = {}
def transform(id_file, feat_file):
ids = open(id_file, 'r').readlines()
ids = [_.strip('\n') for _ in ids]
feats = np.load(feat_file)
assert feats.shape[0] == len(ids)
for _id, _feat in tqdm(zip(ids, feats)):
out[str(_id)] = _feat
transform('dev_ids.txt', 'dev_ims.npy')
transform('train_ids.txt', 'train_ims.npy')
transform('test_ids.txt', 'test_ims.npy')
torch.save(out, 'f30kbu_att.pth')
```
7 changes: 7 additions & 0 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def __init__(self, db_path, ext):
self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
elif db_path.endswith('.pth'): # Assume a key,value dictionary
self.db_type = 'pth'
self.feat_file = torch.load(db_path)
self.loader = lambda x: x
print('HybridLoader: ext is ignored')
else:
self.db_type = 'dir'

Expand All @@ -43,6 +48,8 @@ def get(self, key):
with env.begin(write=False) as txn:
byteflow = txn.get(key)
f_input = six.BytesIO(byteflow)
elif self.db_type == 'pth':
f_input = self.feat_file[key]
else:
f_input = os.path.join(self.db_path, key + self.ext)

Expand Down
1 change: 1 addition & 0 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@


# Set sample options
opt.datset = opt.input_json
loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader,
vars(opt))

Expand Down
5 changes: 4 additions & 1 deletion eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def count_bad(sen):
def language_eval(dataset, preds, model_id, split):
import sys
sys.path.append("coco-caption")
annFile = 'coco-caption/annotations/captions_val2014.json'
if 'coco' in dataset:
annFile = 'coco-caption/annotations/captions_val2014.json'
elif 'flickr30k' in dataset or 'f30k' in dataset:
annFile = 'coco-caption/f30k_captions4eval.json'
from pycocotools.coco import COCO
from pycocoevalcap.eval import COCOEvalCap

Expand Down
9 changes: 6 additions & 3 deletions scripts/prepro_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,12 @@ def main(params):

jimg = {}
jimg['split'] = img['split']
if 'filename' in img: jimg['file_path'] = os.path.join(img['filepath'], img['filename']) # copy it over, might need
if 'cocoid' in img: jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful)

if 'filename' in img: jimg['file_path'] = os.path.join(img.get('filepath', ''), img['filename']) # copy it over, might need
if 'cocoid' in img:
jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful)
elif 'imgid' in img:
jimg['id'] = img['imgid']

if params['images_root'] != '':
with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img:
jimg['width'], jimg['height'] = _img.size
Expand Down
89 changes: 89 additions & 0 deletions scripts/prepro_reference_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# coding: utf-8
"""
Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua
Input: json file that has the form
[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
example element in this list would look like
{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
This script reads this json, does some basic preprocessing on the captions
(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
Output: a json file and an hdf5 file
The hdf5 file contains several fields:
/images is (N,3,256,256) uint8 array of raw image data in RGB format
/labels is (M,max_length) uint32 array of encoded labels, zero padded
/label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
first and last indices (in range 1..M) of labels for each image
/label_length stores the length of the sequence for each of the M sequences
The json file has a dict that contains:
- an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
- an 'images' field that is a list holding auxiliary information for each image,
such as in particular the 'split' it was assigned to.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import json
import argparse
import sys
import hashlib
from random import shuffle, seed


def main(params):

imgs = json.load(open(params['input_json'][0], 'r'))['images']
# tmp = []
# for k in imgs.keys():
# for img in imgs[k]:
# img['filename'] = img['image_id'] # k+'/'+img['image_id']
# img['image_id'] = int(
# int(hashlib.sha256(img['image_id']).hexdigest(), 16) % sys.maxint)
# tmp.append(img)
# imgs = tmp

# create output json file
out = {u'info': {u'description': u'This is stable 1.0 version of the 2014 MS COCO dataset.', u'url': u'http://mscoco.org', u'version': u'1.0', u'year': 2014, u'contributor': u'Microsoft COCO group', u'date_created': u'2015-01-27 09:11:52.357475'}, u'licenses': [{u'url': u'http://creativecommons.org/licenses/by-nc-sa/2.0/', u'id': 1, u'name': u'Attribution-NonCommercial-ShareAlike License'}, {u'url': u'http://creativecommons.org/licenses/by-nc/2.0/', u'id': 2, u'name': u'Attribution-NonCommercial License'}, {u'url': u'http://creativecommons.org/licenses/by-nc-nd/2.0/', u'id': 3, u'name': u'Attribution-NonCommercial-NoDerivs License'}, {u'url': u'http://creativecommons.org/licenses/by/2.0/', u'id': 4, u'name': u'Attribution License'}, {u'url': u'http://creativecommons.org/licenses/by-sa/2.0/', u'id': 5, u'name': u'Attribution-ShareAlike License'}, {u'url': u'http://creativecommons.org/licenses/by-nd/2.0/', u'id': 6, u'name': u'Attribution-NoDerivs License'}, {u'url': u'http://flickr.com/commons/usage/', u'id': 7, u'name': u'No known copyright restrictions'}, {u'url': u'http://www.usa.gov/copyright.shtml', u'id': 8, u'name': u'United States Government Work'}], u'type': u'captions'}
out.update({'images': [], 'annotations': []})

cnt = 0
empty_cnt = 0
for i, img in enumerate(imgs):
if img['split'] == 'train':
continue
out['images'].append(
{u'id': img.get('cocoid', img['imgid'])})
for j, s in enumerate(img['sentences']):
if len(s) == 0:
continue
s = ' '.join(s['tokens'])
out['annotations'].append(
{'image_id': out['images'][-1]['id'], 'caption': s, 'id': cnt})
cnt += 1

json.dump(out, open(params['output_json'], 'w'))
print('wrote ', params['output_json'])


if __name__ == "__main__":

parser = argparse.ArgumentParser()

# input json
parser.add_argument('--input_json', nargs='+', required=True,
help='input json file to process into hdf5')
parser.add_argument('--output_json', default='data.json',
help='output json file')

args = parser.parse_args()
params = vars(args) # convert to ordinary dict
print('parsed input parameters:')
print(json.dumps(params, indent=2))
main(params)

0 comments on commit c4d5d13

Please sign in to comment.