Skip to content

Commit 49bac72

Browse files
authored
Implement VitsAudioConfig (coqui-ai#1556)
* Implement VitsAudioConfig * Update VITS LJSpeech recipe * Update VITS VCTK recipe * Make style * Add missing decorator * Add missing param * Make style * Update recipes * Fix test * Bug fix * Exclude tests folder * Make linter * Make style
1 parent 34b80e0 commit 49bac72

File tree

17 files changed

+65
-76
lines changed

17 files changed

+65
-76
lines changed

MANIFEST.in

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ recursive-include TTS *.md
1111
recursive-include TTS *.py
1212
recursive-include TTS *.pyx
1313
recursive-include images *.png
14-
14+
recursive-exclude tests *
15+
prune tests*

TTS/tts/configs/vits_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import List
33

44
from TTS.tts.configs.shared_configs import BaseTTSConfig
5-
from TTS.tts.models.vits import VitsArgs
5+
from TTS.tts.models.vits import VitsArgs, VitsAudioConfig
66

77

88
@dataclass
@@ -16,6 +16,9 @@ class VitsConfig(BaseTTSConfig):
1616
model_args (VitsArgs):
1717
Model architecture arguments. Defaults to `VitsArgs()`.
1818
19+
audio (VitsAudioConfig):
20+
Audio processing configuration. Defaults to `VitsAudioConfig()`.
21+
1922
grad_clip (List):
2023
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
2124
@@ -94,6 +97,7 @@ class VitsConfig(BaseTTSConfig):
9497
model: str = "vits"
9598
# model specific params
9699
model_args: VitsArgs = field(default_factory=VitsArgs)
100+
audio: VitsAudioConfig = VitsAudioConfig()
97101

98102
# optimizer
99103
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])

TTS/tts/layers/losses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def forward(self, y_hat, y, length):
137137

138138
if ssim_loss.item() < 0.0:
139139
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
140-
ssim_loss = torch.tensor([0.0])
140+
ssim_loss = torch.tensor([0.0])
141141

142142
return ssim_loss
143143

TTS/tts/models/vits.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,22 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
200200
return spec
201201

202202

203+
#############################
204+
# CONFIGS
205+
#############################
206+
207+
208+
@dataclass
209+
class VitsAudioConfig(Coqpit):
210+
fft_size: int = 1024
211+
sample_rate: int = 22050
212+
win_length: int = 1024
213+
hop_length: int = 256
214+
num_mels: int = 80
215+
mel_fmin: int = 0
216+
mel_fmax: int = None
217+
218+
203219
##############################
204220
# DATASET
205221
##############################

TTS/tts/utils/ssim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
1616
"""
1717
if reduction == "none":
1818
return x
19-
elif reduction == "mean":
19+
if reduction == "mean":
2020
return x.mean(dim=0)
21-
elif reduction == "sum":
21+
if reduction == "sum":
2222
return x.sum(dim=0)
2323
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
2424

TTS/utils/synthesizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def tts(
307307
waveform = waveform.squeeze()
308308

309309
# trim silence
310-
if self.tts_config.audio["do_trim_silence"] is True:
310+
if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]:
311311
waveform = trim_silence(waveform, self.tts_model.ap)
312312

313313
wavs += list(waveform)

recipes/ljspeech/fast_pitch/train_fast_pitch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
print_step=50,
5555
print_eval=False,
5656
mixed_precision=False,
57-
sort_by_audio_len=True,
5857
max_seq_len=500000,
5958
output_path=output_path,
6059
datasets=[dataset_config],

recipes/ljspeech/fast_speech/train_fast_speech.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
print_step=50,
5454
print_eval=False,
5555
mixed_precision=False,
56-
sort_by_audio_len=True,
5756
max_seq_len=500000,
5857
output_path=output_path,
5958
datasets=[dataset_config],

recipes/ljspeech/speedy_speech/train_speedy_speech.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
print_step=50,
4747
print_eval=False,
4848
mixed_precision=False,
49-
sort_by_audio_len=True,
5049
max_seq_len=500000,
5150
output_path=output_path,
5251
datasets=[dataset_config],

recipes/ljspeech/tacotron2-Capacitron/train_capacitron_t2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
print_step=25,
6969
print_eval=True,
7070
mixed_precision=False,
71-
sort_by_audio_len=True,
7271
seq_len_norm=True,
7372
output_path=output_path,
7473
datasets=[dataset_config],

0 commit comments

Comments
 (0)