forked from KdaiP/StableTTS
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpreprocess.py
105 lines (83 loc) · 3.86 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import json
from tqdm import tqdm
from dataclasses import dataclass, asdict
import torch
from torch.multiprocessing import Pool, set_start_method
import torchaudio
from config import MelConfig, TrainConfig
from utils.audio import LogMelSpectrogram, load_and_resample_audio
from text.mandarin import chinese_to_cnm3
from text.english import english_to_ipa2
from text.japanese import japanese_to_ipa2
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
@dataclass
class DataConfig:
input_filelist_path = './filelists/filelist.txt' # a filelist contains 'audiopath | text'
output_filelist_path = './filelists/filelist.json' # path to save filelist
output_feature_path = './stableTTS_datasets' # path to save resampled audios and mel features
language = 'english' # chinese, japanese or english
resample = False # waveform is not used in training, so save resampled results is not necessary.
g2p_mapping = {
'chinese': chinese_to_cnm3,
'japanese': japanese_to_ipa2,
'english': english_to_ipa2,
}
data_config = DataConfig()
train_config = TrainConfig()
mel_config = MelConfig()
input_filelist_path = data_config.input_filelist_path
output_filelist_path = data_config.output_filelist_path
output_feature_path = data_config.output_feature_path
# Ensure output directories exist
output_mel_dir = os.path.join(output_feature_path, 'mels')
os.makedirs(output_mel_dir, exist_ok=True)
os.makedirs(os.path.dirname(output_filelist_path), exist_ok=True)
if data_config.resample:
output_wav_dir = os.path.join(output_feature_path, 'waves')
os.makedirs(output_wav_dir, exist_ok=True)
mel_extractor = LogMelSpectrogram(**asdict(mel_config)).to(device)
g2p = g2p_mapping.get(data_config.language)
def load_filelist(path) -> list:
file_list = []
with open(path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
audio_path, text = line.strip().split('|', maxsplit=1)
file_list.append((str(idx), audio_path, text))
return file_list
@ torch.inference_mode()
def process_filelist(line) -> str:
idx, audio_path, text = line
audio = load_and_resample_audio(audio_path, mel_config.sample_rate, device=device) # shape: [1, time]
if audio is not None:
# get output path
audio_name, _ = os.path.splitext(os.path.basename(audio_path))
try:
phone = g2p(text)
if len(phone) > 0:
mel = mel_extractor(audio.to(device)).cpu().squeeze(0) # shape: [n_mels, time // hop_length]
output_mel_path = os.path.join(output_mel_dir, f'{idx}_{audio_name}.pt')
torch.save(mel, output_mel_path)
if data_config.resample:
audio_path = os.path.join(output_wav_dir, f'{idx}_{audio_name}.wav')
torchaudio.save(audio_path, audio.cpu(), mel_config.sample_rate)
return json.dumps({'mel_path': output_mel_path, 'phone': phone, 'audio_path': audio_path, 'text': text, 'mel_length': mel.size(-1)}, ensure_ascii=False, allow_nan=False)
except Exception as e:
print(f'Error processing {audio_path}: {str(e)}')
def main():
set_start_method('spawn') # CUDA must use spawn method
input_filelist = load_filelist(input_filelist_path)
results = []
with Pool(processes=2) as pool:
for result in tqdm(pool.imap(process_filelist, input_filelist), total=len(input_filelist)):
if result is not None:
results.append(f'{result}\n')
# save filelist
with open(output_filelist_path, 'w', encoding='utf-8') as f:
f.writelines(results)
print(f"filelist file has been saved to {output_filelist_path}")
# faster and use much less CPU
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == '__main__':
main()