-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_onnx.py
107 lines (87 loc) · 4.43 KB
/
main_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import onnxruntime as ort
import numpy as np
import argparse
import os
import soundfile as sf
import librosa
import time
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input_audio", "-i", type=str, required=True, help="Input audio file(.wav)")
parser.add_argument("--output_audio", "-o", type=str, required=False, default="./output.wav", help="Seperated wav path")
parser.add_argument("--encoder", "-e", type=str, required=False, default="./models/encoder.onnx", help="encoder onnx model")
parser.add_argument("--decoder", "-d", type=str, required=False, default="./models/decoder.onnx", help="decoder onnx model")
parser.add_argument("--g_src", type=str, required=False, default="./models/g_src.bin", help="source speaker feature")
parser.add_argument("--g_dst", type=str, required=False, default="./models/g_dst.bin", help="target speaker feature")
return parser.parse_args()
def spectrogram_np(y, n_fft, sampling_rate, hop_size, win_size, center=False):
y = np.pad(y, int((n_fft - hop_size) / 2), mode="reflect")
spec = librosa.stft(y, n_fft=n_fft, hop_length=hop_size, win_length=win_size, window="hann", center=center, pad_mode="reflect")
mag = spec.real ** 2 + spec.imag ** 2
spec = np.sqrt(mag + 1e-6)
return spec[None, ...]
def main():
args = get_args()
assert os.path.exists(args.input_audio), f"Input audio {args.input_audio} not exist"
assert os.path.exists(args.encoder), f"Encoder {args.model} not exist"
assert os.path.exists(args.decoder), f"Decoder {args.model} not exist"
assert os.path.exists(args.g_src), f"{args.g_src} not exist"
assert os.path.exists(args.g_dst), f"{args.g_dst} not exist"
input_audio = args.input_audio
output_audio = args.output_audio
encoder_path = args.encoder
decoder_path = args.decoder
g_src = np.fromfile(args.g_src, dtype=np.float32).reshape((1, 256, 1))
g_dst = np.fromfile(args.g_dst, dtype=np.float32).reshape((1, 256, 1))
sampling_rate = 22050
filter_length = 1024
hop_length = 256
win_length = 1024
tau = 0.0
enc_len = 1024
dec_len = 128
print(f"Input audio: {input_audio}")
print(f"Output audio: {output_audio}")
print(f"Encoder: {encoder_path}")
print(f"Decoder: {decoder_path}")
print("Loading audio...")
audio, origin_sr = librosa.load(input_audio, sr=sampling_rate)
# print(f"audio.shape = {audio.shape}")
print("Loading model...")
start = time.time()
sess_enc = ort.InferenceSession(encoder_path, providers=["CPUExecutionProvider"])
sess_dec = ort.InferenceSession(decoder_path, providers=["CPUExecutionProvider"])
print(f"Load model take {(time.time() - start) * 1000}ms")
print("Preprocessing audio...")
start = time.time()
spec = spectrogram_np(audio, filter_length,
sampling_rate, hop_length, win_length,
center=False)
real_enc_len = spec.shape[-1]
spec = np.concatenate((spec, np.zeros((*spec.shape[:-1], int(np.ceil(spec.shape[-1] / enc_len)) * enc_len - spec.shape[-1]), dtype=np.float32)), axis=-1)
print(f"Preprocess take {(time.time() - start) * 1000}ms")
print("Running model...")
slice_num = spec.shape[-1] // enc_len
z = []
for i in range(slice_num):
spec_slice = spec[..., i * enc_len : (i + 1) * enc_len]
start = time.time()
outputs = sess_enc.run(None, {"y": spec_slice})
z.append(outputs[0])
print(f"Run encoder slice {i + 1}/{slice_num} take {(time.time() - start) * 1000}ms")
z = np.concatenate(z, axis=-1)[..., :real_enc_len]
z = np.concatenate((z, np.zeros((*z.shape[:-1], int(np.ceil(z.shape[-1] / dec_len)) * dec_len - z.shape[-1]), dtype=np.float32)), axis=-1)
slice_num = z.shape[-1] // dec_len
audio_list = []
for i in range(slice_num):
z_slice = z[..., i * dec_len : (i + 1) * dec_len]
start = time.time()
audio = sess_dec.run(None, {"z": z_slice, "g_src": g_src, "g_dst": g_dst})[0]
audio = audio.flatten()
print(f"Run decoder slice {i + 1}/{slice_num} take {(time.time() - start) * 1000}ms")
audio_list.append(audio)
audio = np.concatenate(audio_list, axis=-1)[:256 * real_enc_len]
sf.write(output_audio, audio, sampling_rate)
print(f"Save audio to {output_audio}")
if __name__ == "__main__":
main()