Skip to content

Commit

Permalink
convert feat from double to float
Browse files Browse the repository at this point in the history
lovemefan committed Sep 6, 2023
1 parent bf2a066 commit 9257436
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions paraformerOnline/runtime/python/utils/preprocess.py
Original file line number Diff line number Diff line change
@@ -91,7 +91,7 @@ def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
feat = self.apply_cmvn(feat)

feat_len = np.array(feat.shape[0]).astype(np.int32)
return feat, feat_len
return feat.astype(np.float32), feat_len

@staticmethod
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
@@ -221,7 +221,7 @@ def compute_frame_num(
)

def fbank(
self, input: np.ndarray, input_lengths: np.ndarray
self, input: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
self.fbank_fn = knf.OnlineFbank(self.opts)
batch_size = input.shape[0]
@@ -277,14 +277,14 @@ def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
return self.fbanks, self.fbanks_lens

def lfr_cmvn(
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
self, input: np.ndarray, is_final: bool = False
) -> Tuple[np.ndarray, np.ndarray, List[int]]:
batch_size = input.shape[0]
feats = []
feats_lens = []
lfr_splice_frame_idxs = []
for i in range(batch_size):
mat = input[i, : input_lengths[i], :]
mat = input[i, : len(input[i]), :]
lfr_splice_frame_idx = -1
if self.lfr_m != 1 or self.lfr_n != 1:
# update self.lfr_splice_cache in self.apply_lfr
@@ -309,9 +309,7 @@ def extract_fbank(
assert (
batch_size == 1
), "we support to extract feature online only when the batch size is equal to 1 now"
waveforms, feats, feats_lengths = self.fbank(
input, input_lengths
) # input shape: B T D
waveforms, feats, feats_lengths = self.fbank(input) # input shape: B T D
if feats.shape[0]:
self.waveforms = (
waveforms
@@ -339,7 +337,7 @@ def extract_fbank(
(self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
)
feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
feats, feats_lengths, is_final
feats, is_final
)
if self.lfr_m == 1:
self.reserve_waveforms = None

0 comments on commit 9257436

Please sign in to comment.