Skip to content

Commit ed7d3b1

Browse files
authored
bump version (#23)
1 parent c638ce4 commit ed7d3b1

File tree

3 files changed

+110
-2
lines changed

3 files changed

+110
-2
lines changed

oat/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Version."""
15-
__version__ = "0.0.4"
15+
__version__ = "0.0.5"

oat/interface.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def get_program(
6262
"gpu_memory_utilization": args.vllm_gpu_ratio,
6363
"dtype": "bfloat16",
6464
"enable_prefix_caching": False,
65-
"max_model_len": args.max_model_len,
6665
}
6766

6867
actors = []

scripts/inference.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2024 Garena Online Private Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Model inference for using vllm."""
16+
17+
import argparse
18+
import json
19+
import os
20+
import time
21+
22+
from datasets import load_dataset
23+
from vllm import LLM, SamplingParams
24+
25+
parser = argparse.ArgumentParser()
26+
parser.add_argument(
27+
"--model",
28+
type=str,
29+
default="meta-llama/Llama-3.2-1B",
30+
help="Path to the LLM model",
31+
)
32+
parser.add_argument(
33+
"--temperature", type=float, default=0.9, help="Temperature for sampling"
34+
)
35+
parser.add_argument(
36+
"--top_p", type=float, default=1, help="Top-p probability for sampling"
37+
)
38+
parser.add_argument(
39+
"--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate"
40+
)
41+
parser.add_argument(
42+
"--output_dir", type=str, default="inference_outputs", help="output_dir"
43+
)
44+
args = parser.parse_args()
45+
args.seed = int(time.time_ns() // 2 * 20) # Less bias to a fixed random seed.
46+
47+
print(args)
48+
49+
llm = LLM(model=args.model, dtype="bfloat16")
50+
51+
52+
tokenizer = llm.get_tokenizer()
53+
54+
eval_set = load_dataset("lkevinzc/alpaca_eval2")["eval"]
55+
56+
prompts = eval_set["instruction"]
57+
58+
conversations = [
59+
tokenizer.apply_chat_template(
60+
[{"role": "user", "content": prompt}],
61+
tokenize=False,
62+
add_generation_prompt=True,
63+
)
64+
for prompt in prompts
65+
]
66+
67+
sampling_params = SamplingParams(
68+
temperature=args.temperature,
69+
top_p=args.top_p,
70+
max_tokens=args.max_tokens,
71+
seed=args.seed,
72+
)
73+
74+
if tokenizer.bos_token:
75+
# lstrip bos_token because vllm will add it.
76+
print(conversations[0].startswith(tokenizer.bos_token))
77+
conversations = [p.lstrip(tokenizer.bos_token) for p in conversations]
78+
79+
outputs = llm.generate(conversations[:1], sampling_params)
80+
81+
if tokenizer.bos_token:
82+
# make sure vllm added bos_token.
83+
assert tokenizer.bos_token_id in outputs[0].prompt_token_ids
84+
85+
outputs = llm.generate(conversations, sampling_params)
86+
87+
# Save the outputs as a JSON file.
88+
output_data = []
89+
model_name = args.model.replace("/", "_")
90+
for i, output in enumerate(outputs):
91+
prompt = output.prompt
92+
generated_text = output.outputs[0].text
93+
output_data.append(
94+
{
95+
"instruction": prompts[i],
96+
"format_instruction": prompt,
97+
"output": generated_text,
98+
"generator": model_name,
99+
}
100+
)
101+
102+
output_file = f"{model_name}_{args.seed}.json"
103+
if not os.path.exists(args.output_dir):
104+
os.makedirs(args.output_dir)
105+
106+
with open(os.path.join(args.output_dir, output_file), "w") as f:
107+
json.dump(output_data, f, indent=4)
108+
109+
print(f"Outputs saved to {os.path.join(args.output_dir, output_file)}")

0 commit comments

Comments
 (0)