Skip to content

Commit

Permalink
code format
Browse files Browse the repository at this point in the history
  • Loading branch information
lovemefan committed Sep 26, 2023
1 parent 2667794 commit 2d6da90
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 20 deletions.
31 changes: 15 additions & 16 deletions paraformerOnline/runtime/python/asr_all_in_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def extract_endpoint_from_vad_result(self, segments_result):
return segments

def one_sentence_asr(self, audio: np.ndarray):
"""asr offline + punc
"""
"""asr offline + punc"""
result = self.asr_offline.infer_offline(audio, hot_words=self.hot_words)
result = self.punc.punctuate(result)[0]
return result
Expand All @@ -126,14 +125,14 @@ def file_transcript(self, audio: np.ndarray, step=9600):
speech_length = len(audio)
sample_offset = 0
for sample_offset in range(
0, speech_length, min(step, speech_length - sample_offset)
0, speech_length, min(step, speech_length - sample_offset)
):
if sample_offset + step >= speech_length - 1:
step = speech_length - sample_offset
is_final = True
else:
is_final = False
chunk = audio[sample_offset: sample_offset + step]
chunk = audio[sample_offset : sample_offset + step]
vad_pre_idx += len(chunk)
segments_result = self.vad.segments_online(chunk, is_final=is_final)
start_frame = 0
Expand All @@ -147,10 +146,12 @@ def file_transcript(self, audio: np.ndarray, step=9600):
if end != -1:
end_frame = end * 16
end_ms = end
data = np.array(audio[start_ms * 16: end_frame])
data = np.array(audio[start_ms * 16 : end_frame])
time_start = time.time()
asr_offline_final = self.asr_offline.infer_offline(data)
logger.debug(f"asr offline inference use {time.time() - time_start} s")
logger.debug(
f"asr offline inference use {time.time() - time_start} s"
)
if self.speaker_verification:
time_start = time.time()
speaker_id = self.sv.recognize(data)
Expand All @@ -161,13 +162,12 @@ def file_transcript(self, audio: np.ndarray, step=9600):
self.speech_start = False
time_start = time.time()
_final = self.punc.punctuate(asr_offline_final)[0]
logger.debug(f"punc online inference use {time.time() - time_start} s")
logger.debug(
f"punc online inference use {time.time() - time_start} s"
)

result["text"] = _final
result['time_stamp'] = {
'start': start_ms,
'end': end_ms
}
result["time_stamp"] = {"start": start_ms, "end": end_ms}

if is_final:
self.reset_asr()
Expand Down Expand Up @@ -215,7 +215,9 @@ def two_pass_asr(self, chunk: np.ndarray, is_final: bool = False, hot_words=None
end = self.end_frame + len(self.frames) - self.vad_pre_idx
data = np.array(self.frames[:end])
self.frames = self.frames[end:]
asr_offline_final = self.asr_offline.infer_offline(data, hot_words=(hot_words or self.hot_words))
asr_offline_final = self.asr_offline.infer_offline(
data, hot_words=(hot_words or self.hot_words)
)
logger.debug(f"asr offline inference use {time.time() - time_start} s")
if self.speaker_verification:
time_start = time.time()
Expand All @@ -235,10 +237,7 @@ def two_pass_asr(self, chunk: np.ndarray, is_final: bool = False, hot_words=None
if final is not None:
result["final"] = final
result["partial"] = ""
result['time_stamp'] = {
'start': time_stamp_start,
'end': time_stamp_end
}
result["time_stamp"] = {"start": time_stamp_start, "end": time_stamp_end}
if self.speaker_verification:
result["speaker_id"] = speaker_id
self.text_cache = ""
Expand Down
13 changes: 12 additions & 1 deletion paraformerOnline/runtime/python/model/sv/campplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,17 @@ def __init__(self, onnx_path=None, threshold=0.5):
),
"onnx/sv/campplus.onnx",
)
self.sess = onnxruntime.InferenceSession(self.onnx)
cpu_ep = "CPUExecutionProvider"
cpu_provider_options = {
"arena_extend_strategy": "kSameAsRequested",
}

self.sess = onnxruntime.InferenceSession(
self.onnx,
providers=[
(cpu_ep, cpu_provider_options),
],
)
self.output_name = [nd.name for nd in self.sess.get_outputs()]
self.threshhold = threshold
self.memory: np.ndarray = None
Expand Down Expand Up @@ -95,6 +105,7 @@ def recognize(self, waveform: Union[str, Path, bytes], threshold=0.65):
self.memory = emb / np.linalg.norm(emb)
return 0
sim = self.compute_cos_similarity(emb)[0]
print(threshold, sim)
max_sim_index = np.argmax(sim)

if sim[max_sim_index] <= threshold:
Expand Down
6 changes: 3 additions & 3 deletions paraformerOnline/runtime/python/svInfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ def __init__(self, model_path=None, model_name="cam++", threshold=0.5):
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
model_dir = os.path.join(project_dir, "onnx", "sv")
model_path = model_path or os.path.join(model_dir, model_names[model_name][1])

self.threshold = threshold
self.model = model_names[model_name][0](model_path, threshold)

def register_speaker(self, emb: np.ndarray):
self.model.recognize(emb)
self.model.register_speaker(emb)

def recognize(self, waveform: Union[str, Path, bytes]):
return self.model.recognize(waveform)
return self.model.recognize(waveform, self.threshold)

0 comments on commit 2d6da90

Please sign in to comment.