Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for qianfan api in gen_dataset.py #28

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion configs/api_cfg.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
ali_qwen_api_key: {your_ali_qwen_api_key}
baidu_ernie_api_key: {your_baidu_ernie_api_key}
kimi_api_key: {kimi_api_key}
kimi_api_key: {kimi_api_key}
#千帆api填下面两项
baidu_qianfan_api_key: {your_baidu_qianfan_api_key}
baidu_qianfan_secret_key: {your_baidu_qianfan_secret_key}
24 changes: 20 additions & 4 deletions dataset/gen_dataset/gen_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
import yaml
from tqdm import tqdm

def get_access_token(api_key,secret_key):
"""
使用 AK,SK 生成鉴权签名(Access Token)
:return: access_token,或是None(如果错误)
"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
return str(requests.post(url, params=params).json().get("access_token"))

def set_api_key(api_type, api_yaml_path):
"""设置 api key
Expand All @@ -22,17 +30,21 @@ def set_api_key(api_type, api_yaml_path):
# 读取 yaml 文件
with open(api_yaml_path, "r", encoding="utf-8") as f:
api_yaml = yaml.safe_load(f)

secret_key = None

# 设置 api key
if api_type == "qwen":
api_key = api_yaml["ali_qwen_api_key"]
dashscope.api_key = api_key
elif api_type == "ernie":
api_key = api_yaml["baidu_ernie_api_key"]
elif api_type == "qianfan":
api_key = api_yaml["baidu_qianfan_api_key"]
secret_key = api_yaml["baidu_qianfan_secret_key"]
else:
raise ValueError("api_type must be qwen or ernie")

return api_key
return api_key,secret_key


def call_qwen_message(content_str, model_type=dashscope.Generation.Models.qwen_turbo):
Expand Down Expand Up @@ -212,7 +224,7 @@ def gen_dataset(dastset_yaml_path: str, api_yaml_path: str, save_json_root: Path
), f"{specific_name} not in dataset_yaml['role_type'] ({dataset_yaml['role_type']}), pls check dataset yaml!"

# 设置 api key
api_key = set_api_key(model_name, api_yaml_path)
api_key,secret_key = set_api_key(model_name, api_yaml_path)

data_gen_setting = dataset_yaml["data_generation_setting"]
gen_num = data_gen_setting["each_product_gen"]
Expand Down Expand Up @@ -312,6 +324,10 @@ def gen_dataset(dastset_yaml_path: str, api_yaml_path: str, save_json_root: Path
format_json = process_request(call_qwen_message, content_str, qwen_model_type[idx], model_name)
elif model_name == "ernie":
format_json = process_request(call_ernie_message, content_str, api_key, model_name)
elif model_name == "qianfan":
api_key = get_access_token(api_key,secret_key)
model_name = "ernie"
format_json = process_request(call_ernie_message, content_str, api_key, model_name)
else:
raise ValueError(f"model_name {model_name} not support")

Expand Down Expand Up @@ -359,7 +375,7 @@ def gen_dataset(dastset_yaml_path: str, api_yaml_path: str, save_json_root: Path

# 命令行输入参数
parser = argparse.ArgumentParser(description="Gen Dataset")
parser.add_argument("model_name", type=str, choices=["qwen", "ernie"], help="Model name for data generation")
parser.add_argument("model_name", type=str, choices=["qwen", "ernie", "qianfan"], help="Model name for data generation")
parser.add_argument("--data_yaml", type=str, default="../../configs/conversation_cfg.yaml", help="data setting file path")
parser.add_argument("--api_yaml", type=str, default="../../configs/api_cfg.yaml", help="api setting file path")
parser.add_argument("--output_dir", type=str, default="./train_dataset/response", help="generation json output dir")
Expand Down