Skip to content

Commit

Permalink
refactored the joint q2ar evaluation script
Browse files Browse the repository at this point in the history
  • Loading branch information
Rowan Zellers committed Feb 14, 2019
1 parent 73a2408 commit ae532f6
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 86 deletions.
86 changes: 0 additions & 86 deletions eval_all.py

This file was deleted.

64 changes: 64 additions & 0 deletions models/eval_q2ar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
You can use this script to evaluate prediction files (valpreds.npy). Essentially this is needed if you want to, say,
combine answer and rationale predictions.
"""

import numpy as np
import json
import os
from config import VCR_ANNOTS_DIR
import argparse

parser = argparse.ArgumentParser(description='Evaluate question -> answer and rationale')
parser.add_argument(
'-answer_preds',
dest='answer_preds',
default='saves/flagship_answer/valpreds.npy',
help='Location of question->answer predictions',
type=str,
)
parser.add_argument(
'-rationale_preds',
dest='rationale_preds',
default='saves/flagship_rationale/valpreds.npy',
help='Location of question+answer->rationale predictions',
type=str,
)
parser.add_argument(
'-split',
dest='split',
default='val',
help='Split you\'re using. Probably you want val.',
type=str,
)

args = parser.parse_args()

answer_preds = np.load(args.answer_preds)
rationale_preds = np.load(args.rationale_preds)

rationale_labels = []
answer_labels = []

with open(os.path.join(VCR_ANNOTS_DIR, '{}.jsonl'.format(args.split)), 'r') as f:
for l in f:
item = json.loads(l)
answer_labels.append(item['answer_label'])
rationale_labels.append(item['rationale_label'])

answer_labels = np.array(answer_labels)
rationale_labels = np.array(rationale_labels)

# Sanity checks
assert answer_preds.shape[0] == answer_labels.size
assert rationale_preds.shape[0] == answer_labels.size
assert answer_preds.shape[1] == 4
assert rationale_preds.shape[1] == 4

answer_hits = answer_preds.argmax(1) == answer_labels
rationale_hits = rationale_preds.argmax(1) == rationale_labels
joint_hits = answer_hits & rationale_hits

print("Answer acc: {:.3f}".format(np.mean(answer_hits)), flush=True)
print("Rationale acc: {:.3f}".format(np.mean(rationale_hits)), flush=True)
print("Joint acc: {:.3f}".format(np.mean(answer_hits & rationale_hits)), flush=True)

0 comments on commit ae532f6

Please sign in to comment.