-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval_csqa.py
132 lines (100 loc) · 4.75 KB
/
eval_csqa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
#% python3 evaluator.py -qa datasets/csqa/questions.jsonl -p predictions.csv -o metrics.json
# python eval_csqa.py -qa datasets/csqa/dev_rand_split.jsonl -p outputs/csqa/dev.csv -o metrics.json
import csv
from typing import *
import logging
import os
import sys
import json
import argparse
EXIT_STATUS_ANSWERS_MALFORMED = 1
EXIT_STATUS_PREDICTIONS_MALFORMED = 2
EXIT_STATUS_PREDICTIONS_EXTRA = 3
EXIT_STATUS_PREDICTION_MISSING = 4
def calculate_accuracy(question_answers: Dict[str, str], predictions: Dict[str, List[str]]) -> float:
score = 0.0
for question_id, answer in question_answers.items():
try:
predictions_for_q = predictions[question_id]
except KeyError:
logging.error("Missing prediction for question '%s'.", question_id)
sys.exit(EXIT_STATUS_PREDICTION_MISSING)
if answer in predictions_for_q:
score += 1.0 / len(predictions_for_q)
del predictions[question_id]
if len(predictions) > 0:
logging.error("Found %d extra predictions, for example: %s", len(predictions),
", ".join(list(predictions.keys())[:3]))
sys.exit(EXIT_STATUS_PREDICTIONS_EXTRA)
return score / len(question_answers)
def read_answers(filename: str) -> Dict[str, str]:
answers = {}
with open(filename, "rt", encoding="UTF-8", errors="replace") as f:
for line in f:
line = line.strip()
try:
record = json.loads(line)
except ValueError as e:
logging.error("Error while reading file %s: %s", filename, e)
sys.exit(EXIT_STATUS_ANSWERS_MALFORMED)
question_id = record["id"]
answer = record["answerKey"]
if question_id in answers:
logging.error("Key %s repeated in %s", question_id, filename)
sys.exit(EXIT_STATUS_ANSWERS_MALFORMED)
answers[question_id] = answer
if len(answers) == 0:
logging.error("No answers found in file %s", filename)
sys.exit(EXIT_STATUS_ANSWERS_MALFORMED)
return answers
def read_predictions(filename: str) -> Dict[str, List[str]]:
predictions = {}
with open(filename, "rt", encoding="UTF-8", errors="replace") as f:
reader = csv.reader(f)
try:
for row in reader:
try:
question_id = row[0]
prediction_raw = row[1]
except IndexError as e:
logging.error("Error reading value from CSV file %s on line %d: %s", filename, reader.line_num, e)
sys.exit(EXIT_STATUS_PREDICTIONS_MALFORMED)
if question_id in predictions:
logging.error("Key %s repeated in file %s on line %d", question_id, filename, reader.line_num)
sys.exit(EXIT_STATUS_PREDICTIONS_MALFORMED)
if question_id == "":
logging.error("Key is empty in file %s on line %d", filename, reader.line_num)
sys.exit(EXIT_STATUS_PREDICTIONS_MALFORMED)
prediction = prediction_raw.split(";")
# prediction labels cannot be empty strings
for p in prediction:
if p == "":
logging.error("Key %s has empty labels for prediction in file %s on line %d",
question_id, filename, reader.line_num)
sys.exit(EXIT_STATUS_PREDICTIONS_MALFORMED)
predictions[question_id] = prediction
except csv.Error as e:
logging.error('file %s, line %d: %s', filename, reader.line_num, e)
sys.exit(EXIT_STATUS_PREDICTIONS_MALFORMED)
return predictions
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--ground_truth_labels_dir", type=str, default="datasets/csqa")
parser.add_argument("--predicted_labels_dir", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
args = parser.parse_known_args()[0]
# Create a folder if output_dir doesn't exists:
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
ground_truth_labels_file = os.path.join(args.ground_truth_labels_dir, "dev_rand_split.jsonl")
predicted_labels_file = os.path.join(args.predicted_labels_dir, "dev.csv")
output_file = os.path.join(args.output_dir, "metrics_output.txt")
question_answers = read_answers(ground_truth_labels_file)
predictions = read_predictions(predicted_labels_file)
result_out = "Accuracy score = " + str(calculate_accuracy(question_answers, predictions)) + "\n"
print(result_out)
with open(output_file, "w") as f:
f.write(result_out)
if __name__ == '__main__':
main()