Skip to content

Commit

Permalink
add shadow fusion with am and lm
Browse files Browse the repository at this point in the history
  • Loading branch information
lovemefan committed Oct 17, 2023
1 parent 5b35e24 commit c828e2b
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 17 deletions.
94 changes: 88 additions & 6 deletions paraformer/runtime/python/model/asr/paraformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import numpy as np

from paraformer.runtime.python.model.lm.transformer_lm import TransformerLM
from paraformer.runtime.python.utils.asrOrtInferRuntimeSession import (
AsrOfflineOrtInferRuntimeSession,
AsrOnlineDecoderOrtInferRuntimeSession,
Expand Down Expand Up @@ -262,7 +263,7 @@ def extract_feat(
)
return feats.astype(np.float32), feats_len.astype(np.int32)

def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
def decode(self, am_scores: np.ndarray, token_nums: int):
return [
self.decode_one(am_score, token_num)
for am_score, token_num in zip(am_scores, token_nums)
Expand Down Expand Up @@ -363,11 +364,14 @@ def cif_search(self, hidden, alphas, cache=None):

@singleton
class ParaformerOfflineModel:
def __init__(self, model_dir: str = None, intra_op_num_threads=4) -> None:
def __init__(
self, model_dir: str = None, use_lm=False, intra_op_num_threads=4
) -> None:
config_path = os.path.join(model_dir, "config.pkl")
with open(config_path, "rb") as file:
config = pickle.load(file)

self.use_lm = use_lm
self.converter = TokenIDConverter(config["token_list"])
self.tokenizer = CharTokenizer(**config["CharTokenizer"])
self.frontend = WavFrontend(
Expand All @@ -379,6 +383,11 @@ def __init__(self, model_dir: str = None, intra_op_num_threads=4) -> None:
model_file = glob.glob(os.path.join(model_dir, "model_quant_*.onnx"))

contextual_model = os.path.join(model_dir, "model_eb.onnx")

if use_lm:
lm_model_path = os.path.join(model_dir, "lm")
self.lm = TransformerLM(lm_model_path, intra_op_num_threads)

self.ort_infer = AsrOfflineOrtInferRuntimeSession(
model_file,
contextual_model=contextual_model,
Expand Down Expand Up @@ -412,10 +421,78 @@ def decoder_with_greedy_search(self, am_score):
texts = sentence_postprocess(token)
return texts

def decoder_with_beam_search(self, am_score):
pass
def search(self, beams, am_score: np.ndarray, beam_size=5, lm_weight=0.25):
"""Search new tokens for running hypotheses and encoded speech x.
Args:
beams (List[Hypothesis]): Running hypotheses on beam
am_score (torch.Tensor): decoded output (L, vocab_size)
beam_size: beam size
lm_weight: the weight of lm
"""
best_hyps = []
n_vocab = len(self.converter.token_list)
part_ids = np.arange(n_vocab) # no pre-beam
for hyp in beams:
# scoring
weighted_scores = np.zeros(n_vocab)
weighted_scores += am_score

if self.use_lm:
lm_score = self.lm.lm(hyp.yseq[:, -20:])
weighted_scores += lm_weight * lm_score[0][0]

# add previous hyp score
weighted_scores += hyp.score

# update hyps
for j in np.argpartition(weighted_scores, -beam_size)[-beam_size:]:
# will be (2 x beam at most)
best_hyps.append(
Hypothesis(
score=weighted_scores[j],
yseq=np.concatenate(
(hyp.yseq[0], np.array([j], dtype=np.int64))
)[None, ...],
)
)

# sort and prune 2 x beam -> beam
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
: min(len(best_hyps), beam_size)
]
return best_hyps

def decoder_with_beam_search(self, am_scores, beam_size=5, lm_weight=0.15):
# set length bounds
# main loop of prefix search
beams = [
Hypothesis(
score=0,
yseq=np.array([[1]], dtype=np.int64),
)
]
for score in am_scores:
beams = self.search(beams, score, beam_size=beam_size, lm_weight=lm_weight)

# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x not in (0, 2), beams[0].yseq.tolist()[0]))

def infer(self, audio: Union[str, np.ndarray, bytes], hot_words: str = None):
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
texts = sentence_postprocess(token)

return texts

def infer(
self,
audio: Union[str, np.ndarray, bytes],
hot_words: str = None,
beam_search=False,
beam_size=5,
lm_weight=0.15,
):
if isinstance(audio, str):
audio, _ = AudioReader.read_wav_file(audio)
elif isinstance(audio, bytes):
Expand Down Expand Up @@ -448,7 +525,12 @@ def infer(self, audio: Union[str, np.ndarray, bytes], hot_words: str = None):

results = []
for am_score in am_scores:
pred_res = self.decoder_with_greedy_search(am_score)
if beam_search:
pred_res = self.decoder_with_beam_search(
am_score, beam_size=beam_size, lm_weight=lm_weight
)
else:
pred_res = self.decoder_with_greedy_search(am_score)
results.append(pred_res)
return results if len(results) != 0 else [[""]]

Expand Down
20 changes: 14 additions & 6 deletions paraformer/runtime/python/model/lm/transformer_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
from paraformer.runtime.python.utils.lmOrtInderRuntimeSession import (
LMOrtInferRuntimeSession,
)
from paraformer.runtime.python.utils.singleton import singleton


@singleton
class TransformerLM:
def __init__(self, model_dir: str = None, intra_op_num_threads=4):
tokens_list_path = os.path.join(model_dir, "tokens.txt")
Expand Down Expand Up @@ -59,12 +61,11 @@ def seg_tokenize_wo_pattern(self, txt, seg_dict):
out_txt += "<unk>" + " "
return out_txt.strip().split()

def nll_and_ppl(self, text: str):
tokens = text.strip().split(" ")
if self.segment_dict is not None:
tokens = self.seg_tokenize_wo_pattern(tokens, self.segment_dict)
text_ints = np.array(self.converter.tokens2ids(tokens), dtype=np.int64)

def get_nll_and_ppl(self, text_ints):
"""
Args:
text_ints
"""
# 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
# text: (Batch, Length) -> x, y: (Batch, Length + 1)
x = np.pad(text_ints, [1, 0], "constant", constant_values=(1,))[None, ...]
Expand All @@ -87,3 +88,10 @@ def nll_and_ppl(self, text: str):
ppl = np.exp(negative_log_likelihood.mean())

return nll, ppl

def get_nll_and_ppl_from_text(self, text: str):
tokens = text.strip().split(" ")
if self.segment_dict is not None:
tokens = self.seg_tokenize_wo_pattern(tokens, self.segment_dict)
text_ints = np.array(self.converter.tokens2ids(tokens), dtype=np.int64)
return self.get_nll_and_ppl(text_ints)
27 changes: 22 additions & 5 deletions paraformer/runtime/python/paraformerInfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import numpy as np

from paraformer.runtime.python.model.asr.paraformer import (
ParaformerOfflineModel, ParaformerOnlineModel)
ParaformerOfflineModel,
ParaformerOnlineModel,
)
from paraformer.runtime.python.utils.logger import logger


Expand Down Expand Up @@ -43,24 +45,39 @@ def infer_online(self, chunk: np.ndarray, is_final=False):


class ParaformerOffline:
def __init__(self, model_dir=None, *, intra_op_num_threads=4):
def __init__(self, model_dir=None, *, use_lm=False, intra_op_num_threads=4):
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
model_dir = model_dir or os.path.join(project_dir, "onnx", "asr_offline")
logger.info(f"Load onnx model dir at {model_dir}")
self.model = ParaformerOfflineModel(
model_dir, intra_op_num_threads=intra_op_num_threads
model_dir, intra_op_num_threads=intra_op_num_threads, use_lm=use_lm
)
self.param_dict = {"cache": dict()}

def infer_offline(self, audio: np.ndarray, hot_words: str = ""):
def infer_offline(
self,
audio: np.ndarray,
hot_words: str = "",
beam_search=False,
beam_size=5,
lm_weight=0.15,
):
"""
Args:
audio: 600ms is best
hot_words: hot words split by space . eg `a b cc`
beam_search
beam_size
Return:
transcript of audio
"""
result = self.model.infer(audio, hot_words)
result = self.model.infer(
audio,
hot_words,
beam_search=beam_search,
beam_size=beam_size,
lm_weight=lm_weight,
)

return result[0][0]

0 comments on commit c828e2b

Please sign in to comment.