Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added :增加python版本离线版本 Punc和VAD ,VAD是16k的。 #17

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a>
</p>

- 模型出自阿里达摩院[Paraformer语音识别-中文-通用-16k-离线-large-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
- ASR模型出自阿里达摩院[Paraformer语音识别-中文-通用-16k-离线-large-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
- VAD模型FSMN-VAD出自阿里达摩院[FSMN语音端点检测-中文-通用-16k](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary)
- Punc模型CT-Transformer出自阿里达摩院[CT-Transformer标点-中文-通用-pytorch](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary)
- 🎉该项目核心代码已经并入[FunASR](https://github.com/alibaba-damo-academy/FunASR)
- 本仓库仅对模型做了转换,只采用ONNXRuntime推理引擎

Expand Down Expand Up @@ -52,17 +54,18 @@
3. 运行demo
```python
from rapid_paraformer import RapidParaformer
```


config_path = 'resources/config.yaml'
paraformer = RapidParaformer(config_path)

# 输入:支持Union[str, np.ndarray, List[str]] 三种方式传入
# 输出: List[asr_res]
wav_path = [
'test_wavs/0478_00017.wav',
]

result = paraformer(wav_path)
print(result)
```
Expand All @@ -71,3 +74,14 @@
['呃说不配合就不配合的好以上的话呢我们摘取八九十三条因为这三条的话呢比较典型啊一些数字比较明确尤其是时间那么我们要投资者就是了解这一点啊不要轻信这个市场可以快速回来啊这些配市公司啊后期又利好了可
以快速快速攻能包括像前一段时间啊有些媒体在二三月份的时候']
```

更新内容:

1、更新了VAD和Punc

更新内容主要代码都来源于[FunASR](https://github.com/alibaba-damo-academy/FunASR)

模型导出参考[这里](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export) ,把导出来的model.onnx放到对应的文件夹就可以了。

demo里面组合了使用方式 ,目前来看VAD的效果不太好,所以我这里直接是把音频手动按固定的30s切了,然后再去识别组合。

179 changes: 163 additions & 16 deletions python/demo.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,170 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]

from rapid_paraformer import RapidParaformer
from rapid_paraformer.rapid_punc import PuncParaformer
from rapid_paraformer.rapid_vad import RapidVad
import moviepy.editor as mp
import time
from concurrent.futures import ThreadPoolExecutor
vad_model = RapidVad()
paraformer = RapidParaformer()
punc = PuncParaformer()

#统计时间的装饰器
def timeit(func):
def wrapper(*args, **kwargs):
start = time.time()
print(f"function name: {func.__name__}")
result = func(*args, **kwargs)
end = time.time() - start
print(f"cost time: {end}")
return result
return wrapper
# 音频时长
@timeit
def get_audio_duration(wav_path):
import wave
f = wave.open(wav_path, 'rb')
params = f.getparams()
nchannels, sampwidth, framerate, nframes = params[:4]
duration = nframes / framerate
#转成00:00:00格式
m, s = divmod(duration, 60)
h, m = divmod(m, 60)
duration = "%02d:%02d:%02d" % (h, m, s)
return duration

#读取音频
from joblib import Parallel, delayed
@timeit
def load_audiro(wav_path):
'''
加载音频
:param wav_path: 音频路径
:return:
'''
#如果wav是mp4格式,需要先转换为wav格式,然后再加载

# print('加载音频')
y, sr = librosa.load(wav_path, sr=16000) # 加载音频文件并解码为音频信号数组
# wav_list = [y[i:i + 16000 * 30] for i in range(0, len(y), 16000 * 30)]
wav_list = (y[i:i + 16000 * 30] for i in range(0, len(y), 16000 * 30))
# wav_list = Parallel(n_jobs=-1)(
# delayed(lambda i: y[i:i + 16000 * 30])(i)
# for i in range(0, len(y), 16000 * 30))
# print('切割音频完成')
return wav_list

def split_string(text, length):
"""
将字符串按照指定长度拆分并返回拆分后的列表
"""
result = []
start = 0
while start < len(text):
end = start + length
result.append(text[start:end])
start = end
return result

def text_process(text):
#句号,问号后面自动换行
text = text.replace('。', '。\n')
text = text.replace('?', '?\n')
return text


import librosa
import threading

@timeit
def load_and_cut_audio(wav_path, num_threads=4, chunk_size=30):
y, sr = librosa.load(wav_path, sr=16000)
n_samples = len(y)
chunk_size_samples = 16000 * chunk_size
num_chunks = n_samples // chunk_size_samples + (n_samples % chunk_size_samples > 0)
chunk_indices = [(i * chunk_size_samples, min(n_samples, (i + 1) * chunk_size_samples)) for i in range(num_chunks)]

results = [None] * num_chunks

def load_and_cut_thread(start_index, end_index, result_list, index):
result_list[index] = y[start_index:end_index]

threads = []
for i, (start_index, end_index) in enumerate(chunk_indices):
t = threading.Thread(target=load_and_cut_thread, args=(start_index, end_index, results, i))
threads.append(t)
t.start()

if i % num_threads == num_threads - 1:
for thread in threads[i - num_threads + 1:i + 1]:
thread.join()

for thread in threads[(num_chunks - 1) // num_threads * num_threads:]:
thread.join()

return results



@timeit
def vad(vad_model, wav_path):
return vad_model(wav_path)

from multiprocessing import Pool
from functools import partial

def asr_single(wav):
try:
result_text = paraformer(wav)[0][0]
except:
result_text = ''
return result_text
@timeit
def asr(wav_path):
wav_list = load_audiro(wav_path)
# wav_list = load_and_cut_audio(wav_path)
# pool = Pool()
# result = pool.map(partial(asr_single), wav_list)
# pool.close()
# pool.join()
with ThreadPoolExecutor() as executor:
result = executor.map(partial(asr_single), wav_list)

return result

if __name__ == '__main__':
wave_path = r'C:\Users\ADMINI~1\AppData\Local\Temp\gradio\d5e738ea910657f76c96e6fbfb74f7de8c6fdb11\11.mp3'
# wave_path = r'E:\10分钟.wav'
if wave_path.endswith('.mp4'):
wav_path = wave_path.replace('.mp4', '.wav')
clip = mp.VideoFileClip(wave_path)
clip.audio.write_audiofile(wav_path,fps = 22050,bitrate='64k') # 将剪辑对象的音频部分写入音频文件
print('mp4转wav完成')
print(wav_path)
print(clip.duration)

#音频时长
# duration = get_audio_duration(wave_path)
# print(f"音频时长:{duration}")
# vad
# vad_result = vad(vad_model, row_path)
# print(f"vad结果:{vad_result}")

#asr
asr_result = asr(wave_path)
print('asr完成')
print(asr_result)

# print(f"asr结果:{asr_result}")
#标点
new_text = punc(''.join(asr_result))
prossed_text = text_process(new_text[0])
print(f"标点结果:{prossed_text}")
#将识别结果写入txt,名称为音频名称
# with open(f'{wave_path.replace(".mp4", "")}.txt', 'w') as f:
# f.write(prossed_text)

config_path = 'resources/config.yaml'

paraformer = RapidParaformer(config_path)

wav_path = [
'test_wavs/0478_00017.wav',
'test_wavs/asr_example_zh.wav',
'test_wavs/0478_00017.wav',
'test_wavs/asr_example_zh.wav',
'test_wavs/0478_00017.wav',
'test_wavs/asr_example_zh.wav',
]

print(wav_path)
# wav_path = 'test_wavs/0478_00017.wav'
result = paraformer(wav_path)
print(result)
19 changes: 19 additions & 0 deletions python/rapid_paraformer/punc_model/punc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
punc_list:
- <unk>
- _
- ','
- 。
- '?'
- 、

TokenIDConverter:
token_path: punc_model/punc_token_list.pkl
unk_symbol: <unk>
Model:
model_path: punc_model/model.onnx
use_cuda: false
CUDAExecutionProvider:
device_id: 0
arena_extend_strategy: kNextPowerOfTwo
cudnn_conv_algo_search: EXHAUSTIVE
do_copy_in_default_stream: true
Binary file not shown.
Loading