forked from Deelvin/lm-evaluation-harness
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
105 lines (87 loc) · 3.43 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import json
import logging
import os
from lm_eval import tasks, evaluator, utils
logging.getLogger("openai").setLevel(logging.WARNING)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--model_args", default="")
parser.add_argument(
"--tasks", default=None, choices=utils.MultiChoice(tasks.ALL_TASKS)
)
parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument("--batch_size", type=str, default=None)
parser.add_argument(
"--max_batch_size",
type=int,
default=None,
help="Maximal batch size to try with --batch_size auto",
)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--output_path", default=None)
parser.add_argument(
"--limit",
type=float,
default=None,
help="Limit the number of examples per task. "
"If <1, limit is a percentage of the total number of examples.",
)
parser.add_argument("--data_sampling", type=float, default=None)
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False)
parser.add_argument("--output_base_path", type=str, default=None)
return parser.parse_args()
def main():
args = parse_args()
assert not args.provide_description # not implemented
if args.limit:
print(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.tasks is None:
task_names = tasks.ALL_TASKS
else:
task_names = utils.pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
print(f"Selected Tasks: {task_names}")
description_dict = {}
if args.description_dict_path:
with open(args.description_dict_path, "r") as f:
description_dict = json.load(f)
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
tasks=task_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
max_batch_size=args.max_batch_size,
device=args.device,
no_cache=args.no_cache,
limit=args.limit,
description_dict=description_dict,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
write_out=args.write_out,
output_base_path=args.output_base_path,
)
dumped = json.dumps(results, indent=2)
print(dumped)
if args.output_path:
dirname = os.path.dirname(args.output_path)
if dirname:
os.makedirs(dirname, exist_ok=True)
with open(args.output_path, "w") as f:
f.write(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
print(
f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
)
print(evaluator.make_table(results))
if __name__ == "__main__":
main()