Skip to content

Commit a16360a

Browse files
committed
Implement chunking gpt_cond
1 parent 6f1cba2 commit a16360a

File tree

2 files changed

+78
-33
lines changed

2 files changed

+78
-33
lines changed

TTS/tts/configs/xtts_config.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@ class XttsConfig(BaseTTSConfig):
4343
Defaults to `16`.
4444
4545
gpt_cond_len (int):
46-
Secs audio to be used as conditioning for the autoregressive model. Defaults to `3`.
46+
Secs audio to be used as conditioning for the autoregressive model. Defaults to `12`.
47+
48+
gpt_cond_chunk_len (int):
49+
Audio chunk size in secs. Audio is split into chunks and latents are extracted for each chunk. Then the
50+
latents are averaged. Chunking improves the stability. It must be <= gpt_cond_len.
51+
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to `4`.
4752
4853
max_ref_len (int):
4954
Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`.
@@ -95,6 +100,7 @@ class XttsConfig(BaseTTSConfig):
95100
num_gpt_outputs: int = 1
96101

97102
# cloning
98-
gpt_cond_len: int = 3
103+
gpt_cond_len: int = 12
104+
gpt_cond_chunk_len: int = 4
99105
max_ref_len: int = 10
100106
sound_norm_refs: bool = False

TTS/tts/models/xtts.py

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -255,39 +255,57 @@ def device(self):
255255
return next(self.parameters()).device
256256

257257
@torch.inference_mode()
258-
def get_gpt_cond_latents(self, audio, sr, length: int = 3):
258+
def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6):
259259
"""Compute the conditioning latents for the GPT model from the given audio.
260260
261261
Args:
262262
audio (tensor): audio tensor.
263263
sr (int): Sample rate of the audio.
264-
length (int): Length of the audio in seconds. Defaults to 3.
264+
length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30.
265+
chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio
266+
is being used without chunking. It must be < `length`. Defaults to 6.
265267
"""
266268
if sr != 22050:
267269
audio = torchaudio.functional.resample(audio, sr, 22050)
268-
audio = audio[:, : 22050 * length]
270+
if length > 0:
271+
audio = audio[:, : 22050 * length]
269272
if self.args.gpt_use_perceiver_resampler:
270-
n_fft = 2048
271-
hop_length = 256
272-
win_length = 1024
273+
style_embs = []
274+
for i in range(0, audio.shape[1], 22050 * chunk_length):
275+
audio_chunk = audio[:, i : i + 22050 * chunk_length]
276+
mel_chunk = wav_to_mel_cloning(
277+
audio_chunk,
278+
mel_norms=self.mel_stats.cpu(),
279+
n_fft=2048,
280+
hop_length=256,
281+
win_length=1024,
282+
power=2,
283+
normalized=False,
284+
sample_rate=22050,
285+
f_min=0,
286+
f_max=8000,
287+
n_mels=80,
288+
)
289+
style_emb = self.gpt.get_style_emb(mel_chunk.to(self.device), None)
290+
style_embs.append(style_emb)
291+
292+
# mean style embedding
293+
cond_latent = torch.stack(style_embs).mean(dim=0)
273294
else:
274-
n_fft = 4096
275-
hop_length = 1024
276-
win_length = 4096
277-
mel = wav_to_mel_cloning(
278-
audio,
279-
mel_norms=self.mel_stats.cpu(),
280-
n_fft=n_fft,
281-
hop_length=hop_length,
282-
win_length=win_length,
283-
power=2,
284-
normalized=False,
285-
sample_rate=22050,
286-
f_min=0,
287-
f_max=8000,
288-
n_mels=80,
289-
)
290-
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
295+
mel = wav_to_mel_cloning(
296+
audio,
297+
mel_norms=self.mel_stats.cpu(),
298+
n_fft=4096,
299+
hop_length=1024,
300+
win_length=4096,
301+
power=2,
302+
normalized=False,
303+
sample_rate=22050,
304+
f_min=0,
305+
f_max=8000,
306+
n_mels=80,
307+
)
308+
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
291309
return cond_latent.transpose(1, 2)
292310

293311
@torch.inference_mode()
@@ -323,12 +341,24 @@ def get_speaker_embedding(self, audio, sr):
323341
def get_conditioning_latents(
324342
self,
325343
audio_path,
344+
max_ref_length=30,
326345
gpt_cond_len=6,
327-
max_ref_length=10,
346+
gpt_cond_chunk_len=6,
328347
librosa_trim_db=None,
329348
sound_norm_refs=False,
330-
load_sr=24000,
349+
load_sr=22050,
331350
):
351+
"""Get the conditioning latents for the GPT model from the given audio.
352+
353+
Args:
354+
audio_path (str or List[str]): Path to reference audio file(s).
355+
max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30.
356+
gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6.
357+
gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6.
358+
librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None.
359+
sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False.
360+
load_sr (int, optional): Sample rate to load the audio. Defaults to 24000.
361+
"""
332362
# deal with multiples references
333363
if not isinstance(audio_path, list):
334364
audio_paths = [audio_path]
@@ -349,14 +379,17 @@ def get_conditioning_latents(
349379
if librosa_trim_db is not None:
350380
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
351381

382+
# compute latents for the decoder
352383
speaker_embedding = self.get_speaker_embedding(audio, load_sr)
353384
speaker_embeddings.append(speaker_embedding)
354385

355386
audios.append(audio)
356387

357-
# use a merge of all references for gpt cond latents
388+
# merge all the audios and compute the latents for the gpt
358389
full_audio = torch.cat(audios, dim=-1)
359-
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T]
390+
gpt_cond_latents = self.get_gpt_cond_latents(
391+
full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len
392+
) # [1, 1024, T]
360393

361394
if speaker_embeddings:
362395
speaker_embedding = torch.stack(speaker_embeddings)
@@ -397,6 +430,7 @@ def inference_with_config(self, text, config, ref_audio_path, language, **kwargs
397430
"top_k": config.top_k,
398431
"top_p": config.top_p,
399432
"gpt_cond_len": config.gpt_cond_len,
433+
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
400434
"max_ref_len": config.max_ref_len,
401435
"sound_norm_refs": config.sound_norm_refs,
402436
}
@@ -417,7 +451,8 @@ def full_inference(
417451
top_p=0.85,
418452
do_sample=True,
419453
# Cloning
420-
gpt_cond_len=6,
454+
gpt_cond_len=30,
455+
gpt_cond_chunk_len=6,
421456
max_ref_len=10,
422457
sound_norm_refs=False,
423458
**hf_generate_kwargs,
@@ -448,7 +483,10 @@ def full_inference(
448483
(aka boring) outputs. Defaults to 0.8.
449484
450485
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
451-
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.
486+
else the first `gpt_cond_len` secs is used. Defaults to 30 seconds.
487+
488+
gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`.
489+
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds.
452490
453491
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
454492
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
@@ -461,6 +499,7 @@ def full_inference(
461499
(gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents(
462500
audio_path=ref_audio_path,
463501
gpt_cond_len=gpt_cond_len,
502+
gpt_cond_chunk_len=gpt_cond_chunk_len,
464503
max_ref_length=max_ref_len,
465504
sound_norm_refs=sound_norm_refs,
466505
)
@@ -566,7 +605,7 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
566605
if overlap_len > len(wav_chunk):
567606
# wav_chunk is smaller than overlap_len, pass on last wav_gen
568607
if wav_gen_prev is not None:
569-
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len):]
608+
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :]
570609
else:
571610
# not expecting will hit here as problem happens on last chunk
572611
wav_chunk = wav_gen[-overlap_len:]
@@ -576,7 +615,7 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
576615
crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
577616
wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
578617
wav_chunk[:overlap_len] += crossfade_wav
579-
618+
580619
wav_overlap = wav_gen[-overlap_len:]
581620
wav_gen_prev = wav_gen
582621
return wav_chunk, wav_gen_prev, wav_overlap

0 commit comments

Comments
 (0)