Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/allenai/WildBench into main
Browse files Browse the repository at this point in the history
  • Loading branch information
yuchenlin committed Jun 13, 2024
2 parents 2dbd6fb + a2612f4 commit 1afa425
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 43 deletions.
25 changes: 25 additions & 0 deletions scripts/yi-large.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# export YI_API_KEY=your_yi_api_key
model_name="yi/yi-large"
model_pretty_name="yi-large"
output_dir="result_dirs/wild_bench_v2/"
TEMP=0; TOP_P=1.0; MAX_TOKENS=4096;

# shard_size should be 1024 // n_shards
n_shards=8
shard_size=128
start_gpu=0
shards_dir="${output_dir}/tmp_${model_pretty_name}"
for ((start = 0, end = (($shard_size)), gpu = $start_gpu; gpu < $n_shards+$start_gpu; start += $shard_size, end += $shard_size, gpu++)); do
python src/unified_infer.py \
--data_name wild_bench \
--start_index $start --end_index $end \
--engine yi \
--model_name $model_name \
--top_p $TOP_P --temperature $TEMP \
--max_tokens $MAX_TOKENS \
--output_folder $shards_dir/ \
&
done
wait
python src/merge_results.py $shards_dir/ $model_pretty_name
cp $shards_dir/${model_pretty_name}.json $output_dir/${model_pretty_name}.json
32 changes: 31 additions & 1 deletion src/unified_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
from unified_utils import load_eval_data, save_outputs
from global_configs import HF_TEMPLATED_MODELS, IM_END_MODELS
from unified_utils import openai_chat_request, retry_handler, google_chat_request, cohere_chat_request, mistral_chat_request, anthropic_chat_request, together_chat_request, reka_chat_request
from unified_utils import openai_chat_request, retry_handler, google_chat_request, cohere_chat_request, mistral_chat_request, anthropic_chat_request, together_chat_request, reka_chat_request, yi_chat_request
from hf_models import DecoderOnlyModelManager
from transformers import AutoTokenizer

Expand Down Expand Up @@ -83,6 +83,8 @@ def sanitize_args(args):
pass
elif args.engine == "reka":
pass
elif args.engine == "yi":
pass

print("loading dataset!")

Expand Down Expand Up @@ -387,3 +389,31 @@ def api(**kwargs):
outputs.append(result)
save_outputs(args, id_strs, outputs, chat_history, metadata, model_inputs, filepath)

elif args.engine == "yi":
todo_chats = chat_history[num_skipped:]
@retry_handler(retry_limit=10)
def api(**kwargs):
result = yi_chat_request(**kwargs)
return result

for cur_id in tqdm(range(0, len(todo_inputs)), desc=f"Generating {args.model_name} from {args.start_index} to {args.end_index}"):
# input_text = todo_inputs[cur_id]
chat = todo_chats[cur_id]
yi_msg = [{"role":"system", "content":"You are a helpful AI assistant."}]
for i, chat_item in enumerate(chat):
if i % 2 == 0:
yi_msg.append({"role":"user","content": chat_item})
else:
yi_msg.append({"role":"assistant","content": chat_item})
yi_args = {
"model": args.model_pretty_name,
"prompt": None,
"messages": yi_msg,
"top_p": args.top_p,
"temperature": args.temperature,
"max_tokens": args.max_tokens,
"stop": stop_words,
}
result = api(**yi_args)
outputs.append(result)
save_outputs(args, id_strs, outputs, chat_history, metadata, model_inputs, filepath)
Loading

0 comments on commit 1afa425

Please sign in to comment.