You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def main(args):
all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
utt2speech_token = {}
for future in tqdm(as_completed(all_task)):
utt, speech_token = future.result()
utt2speech_token[utt] = ' '.join(str(i) for i in speech_token)
torch.save(utt2speech_token, '/mnt/workspace/haonan/code/llama_audio_character/cosyvoice/examples/libritts/cosyvoice/pretrain_tokenizer/{}.pt'.format(args.dir))
if name == "main":
parser = argparse.ArgumentParser()
parser.add_argument("--dir", type=str)
parser.add_argument("--onnx_path", type=str, default="/mnt/workspace/haonan/code/llama_audio_character/cosyvoice/pretrained_models/CosyVoice2-0.5B/speech_tokenizer_v2.onnx")
parser.add_argument("--num_thread", type=int, default=8)
args = parser.parse_args()
utt2wav = {}
# with open('{}/wav.scp'.format(args.dir)) as f:
# for l in f:
# l = l.replace('\n', '').split()
# utt2wav[l[0]] = ' '.join(i for i in l[1:])
data_path = "/mnt/workspace/lr/workspace/LLaVA_Her/emotion/llava_pretrain_speech_tokenizer_0.json"
args.dir = "llava_pretrain_speech_tokenizer_0"
print(args.dir)
a = json.load(open(data_path, "r"))
for i in a:
utt2wav[i['data_idx']] = i['audio']
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ["CUDAExecutionProvider"]
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
executor = ThreadPoolExecutor(max_workers=args.num_thread)
main(args)
The text was updated successfully, but these errors were encountered:
speech_tokenizer_v1.onnx CUDA 推理很快,但 v2 的推理很卡,不知道是什么情况
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
import onnxruntime
import torch
from tqdm import tqdm
import numpy as np
import torchaudio
from tqdm import tqdm
import whisper
import json
def single_job(utt):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
if audio.shape[1] / 16000 > 30:
# logging.warning('do not support extract speech token for audio longer than 30s, using sliding window')
# speech_token = []
time_step = 0
audios = []
while time_step < audio.shape[1] / 16000:
audio_segment = audio[:, time_step * 16000: (time_step + 30) * 16000]
audios.append(audio_segment)
time_step += 30
speech_token = []
for a in audios:
feat = whisper.log_mel_spectrogram(a, n_mels=128)
# speech_token.extend(ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
# ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist())
ttt = ort_session.run(None, {
ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)
})[0].flatten().tolist()
speech_token.extend(ttt)
else:
feat = whisper.log_mel_spectrogram(audio, n_mels=128)
# speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
# ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
speech_token = ort_session.run(None, {
ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)
})[0].flatten().tolist()
def main(args):
all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
if name == "main":
parser = argparse.ArgumentParser()
parser.add_argument("--dir", type=str)
parser.add_argument("--onnx_path", type=str, default="/mnt/workspace/haonan/code/llama_audio_character/cosyvoice/pretrained_models/CosyVoice2-0.5B/speech_tokenizer_v2.onnx")
parser.add_argument("--num_thread", type=int, default=8)
args = parser.parse_args()
The text was updated successfully, but these errors were encountered: