diff --git a/config/fs2.json b/config/fs2.json index dba17991..eae4105a 100644 --- a/config/fs2.json +++ b/config/fs2.json @@ -15,7 +15,7 @@ "extract_energy": true, "energy_extract_mode": "from_tacotron_stft", "extract_duration": true, - "use_phone": true, + "use_phone": false, "pitch_norm": true, "energy_norm": true, "pitch_remove_outlier": true, @@ -47,6 +47,7 @@ "mert_dir": "mert", "spk2id":"spk2id.json", "utt2spk":"utt2spk", + "valid_file": "test.json", // Features used for model training "use_mel": true, diff --git a/models/tts/fastspeech2/fs2_inference.py b/models/tts/fastspeech2/fs2_inference.py index 64d1d387..5a03afd0 100644 --- a/models/tts/fastspeech2/fs2_inference.py +++ b/models/tts/fastspeech2/fs2_inference.py @@ -100,6 +100,7 @@ def inference_for_batches(self): ) os.remove(os.path.join(self.args.output_dir, f"{uid}.pt")) + @torch.inference_mode() def _inference_each_batch(self, batch_data): device = self.accelerator.device control_values = ( diff --git a/processors/acoustic_extractor.py b/processors/acoustic_extractor.py index 2423be59..ea9d24b1 100644 --- a/processors/acoustic_extractor.py +++ b/processors/acoustic_extractor.py @@ -18,6 +18,11 @@ from utils.data_utils import remove_outlier from preprocessors.metadata import replace_augment_name from scipy.interpolate import interp1d +from utils.mel import ( + extract_mel_features, + extract_linear_features, + extract_mel_features_tts, +) ZERO = 1e-12 @@ -124,16 +129,12 @@ def __extract_utt_acoustic_features(dataset_output, cfg, utt): wav_torch = torch.from_numpy(wav).to(wav_torch.device) if cfg.preprocess.extract_linear_spec: - from utils.mel import extract_linear_features - linear = extract_linear_features(wav_torch.unsqueeze(0), cfg.preprocess) save_feature( dataset_output, cfg.preprocess.linear_dir, uid, linear.cpu().numpy() ) if cfg.preprocess.extract_mel: - from utils.mel import extract_mel_features - if cfg.preprocess.mel_extract_mode == "taco": _stft = TacotronSTFT( sampling_rate=cfg.preprocess.sample_rate, diff --git a/utils/mel.py b/utils/mel.py index d32d3822..e222884b 100644 --- a/utils/mel.py +++ b/utils/mel.py @@ -232,14 +232,12 @@ def extract_mel_features_tts( spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec) spec = spectral_normalize_torch(spec) spec = spec.squeeze(0) + spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) else: audio = torch.clip(y, -1, 1) audio = torch.autograd.Variable(audio, requires_grad=False) spec, energy = _stft.mel_spectrogram(audio) - spec = torch.squeeze(spec, 0) - - spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec) - spec = spectral_normalize_torch(spec) return spec.squeeze(0)