From 2d6da90b0dd378846465e1ea967dab9314f996c7 Mon Sep 17 00:00:00 2001 From: lovemefan Date: Tue, 26 Sep 2023 14:14:12 +0800 Subject: [PATCH] code format --- .../runtime/python/asr_all_in_one.py | 31 +++++++++---------- .../runtime/python/model/sv/campplus.py | 13 +++++++- paraformerOnline/runtime/python/svInfer.py | 6 ++-- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/paraformerOnline/runtime/python/asr_all_in_one.py b/paraformerOnline/runtime/python/asr_all_in_one.py index 71b71e2..db8896f 100644 --- a/paraformerOnline/runtime/python/asr_all_in_one.py +++ b/paraformerOnline/runtime/python/asr_all_in_one.py @@ -106,8 +106,7 @@ def extract_endpoint_from_vad_result(self, segments_result): return segments def one_sentence_asr(self, audio: np.ndarray): - """asr offline + punc - """ + """asr offline + punc""" result = self.asr_offline.infer_offline(audio, hot_words=self.hot_words) result = self.punc.punctuate(result)[0] return result @@ -126,14 +125,14 @@ def file_transcript(self, audio: np.ndarray, step=9600): speech_length = len(audio) sample_offset = 0 for sample_offset in range( - 0, speech_length, min(step, speech_length - sample_offset) + 0, speech_length, min(step, speech_length - sample_offset) ): if sample_offset + step >= speech_length - 1: step = speech_length - sample_offset is_final = True else: is_final = False - chunk = audio[sample_offset: sample_offset + step] + chunk = audio[sample_offset : sample_offset + step] vad_pre_idx += len(chunk) segments_result = self.vad.segments_online(chunk, is_final=is_final) start_frame = 0 @@ -147,10 +146,12 @@ def file_transcript(self, audio: np.ndarray, step=9600): if end != -1: end_frame = end * 16 end_ms = end - data = np.array(audio[start_ms * 16: end_frame]) + data = np.array(audio[start_ms * 16 : end_frame]) time_start = time.time() asr_offline_final = self.asr_offline.infer_offline(data) - logger.debug(f"asr offline inference use {time.time() - time_start} s") + logger.debug( + f"asr offline inference use {time.time() - time_start} s" + ) if self.speaker_verification: time_start = time.time() speaker_id = self.sv.recognize(data) @@ -161,13 +162,12 @@ def file_transcript(self, audio: np.ndarray, step=9600): self.speech_start = False time_start = time.time() _final = self.punc.punctuate(asr_offline_final)[0] - logger.debug(f"punc online inference use {time.time() - time_start} s") + logger.debug( + f"punc online inference use {time.time() - time_start} s" + ) result["text"] = _final - result['time_stamp'] = { - 'start': start_ms, - 'end': end_ms - } + result["time_stamp"] = {"start": start_ms, "end": end_ms} if is_final: self.reset_asr() @@ -215,7 +215,9 @@ def two_pass_asr(self, chunk: np.ndarray, is_final: bool = False, hot_words=None end = self.end_frame + len(self.frames) - self.vad_pre_idx data = np.array(self.frames[:end]) self.frames = self.frames[end:] - asr_offline_final = self.asr_offline.infer_offline(data, hot_words=(hot_words or self.hot_words)) + asr_offline_final = self.asr_offline.infer_offline( + data, hot_words=(hot_words or self.hot_words) + ) logger.debug(f"asr offline inference use {time.time() - time_start} s") if self.speaker_verification: time_start = time.time() @@ -235,10 +237,7 @@ def two_pass_asr(self, chunk: np.ndarray, is_final: bool = False, hot_words=None if final is not None: result["final"] = final result["partial"] = "" - result['time_stamp'] = { - 'start': time_stamp_start, - 'end': time_stamp_end - } + result["time_stamp"] = {"start": time_stamp_start, "end": time_stamp_end} if self.speaker_verification: result["speaker_id"] = speaker_id self.text_cache = "" diff --git a/paraformerOnline/runtime/python/model/sv/campplus.py b/paraformerOnline/runtime/python/model/sv/campplus.py index be31ae5..0c955db 100644 --- a/paraformerOnline/runtime/python/model/sv/campplus.py +++ b/paraformerOnline/runtime/python/model/sv/campplus.py @@ -25,7 +25,17 @@ def __init__(self, onnx_path=None, threshold=0.5): ), "onnx/sv/campplus.onnx", ) - self.sess = onnxruntime.InferenceSession(self.onnx) + cpu_ep = "CPUExecutionProvider" + cpu_provider_options = { + "arena_extend_strategy": "kSameAsRequested", + } + + self.sess = onnxruntime.InferenceSession( + self.onnx, + providers=[ + (cpu_ep, cpu_provider_options), + ], + ) self.output_name = [nd.name for nd in self.sess.get_outputs()] self.threshhold = threshold self.memory: np.ndarray = None @@ -95,6 +105,7 @@ def recognize(self, waveform: Union[str, Path, bytes], threshold=0.65): self.memory = emb / np.linalg.norm(emb) return 0 sim = self.compute_cos_similarity(emb)[0] + print(threshold, sim) max_sim_index = np.argmax(sim) if sim[max_sim_index] <= threshold: diff --git a/paraformerOnline/runtime/python/svInfer.py b/paraformerOnline/runtime/python/svInfer.py index ae7e23d..ab5ae00 100644 --- a/paraformerOnline/runtime/python/svInfer.py +++ b/paraformerOnline/runtime/python/svInfer.py @@ -26,11 +26,11 @@ def __init__(self, model_path=None, model_name="cam++", threshold=0.5): project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) model_dir = os.path.join(project_dir, "onnx", "sv") model_path = model_path or os.path.join(model_dir, model_names[model_name][1]) - + self.threshold = threshold self.model = model_names[model_name][0](model_path, threshold) def register_speaker(self, emb: np.ndarray): - self.model.recognize(emb) + self.model.register_speaker(emb) def recognize(self, waveform: Union[str, Path, bytes]): - return self.model.recognize(waveform) + return self.model.recognize(waveform, self.threshold)