@@ -255,39 +255,57 @@ def device(self):
255255 return next (self .parameters ()).device
256256
257257 @torch .inference_mode ()
258- def get_gpt_cond_latents (self , audio , sr , length : int = 3 ):
258+ def get_gpt_cond_latents (self , audio , sr , length : int = 30 , chunk_length : int = 6 ):
259259 """Compute the conditioning latents for the GPT model from the given audio.
260260
261261 Args:
262262 audio (tensor): audio tensor.
263263 sr (int): Sample rate of the audio.
264- length (int): Length of the audio in seconds. Defaults to 3.
264+ length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30.
265+ chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio
266+ is being used without chunking. It must be < `length`. Defaults to 6.
265267 """
266268 if sr != 22050 :
267269 audio = torchaudio .functional .resample (audio , sr , 22050 )
268- audio = audio [:, : 22050 * length ]
270+ if length > 0 :
271+ audio = audio [:, : 22050 * length ]
269272 if self .args .gpt_use_perceiver_resampler :
270- n_fft = 2048
271- hop_length = 256
272- win_length = 1024
273+ style_embs = []
274+ for i in range (0 , audio .shape [1 ], 22050 * chunk_length ):
275+ audio_chunk = audio [:, i : i + 22050 * chunk_length ]
276+ mel_chunk = wav_to_mel_cloning (
277+ audio_chunk ,
278+ mel_norms = self .mel_stats .cpu (),
279+ n_fft = 2048 ,
280+ hop_length = 256 ,
281+ win_length = 1024 ,
282+ power = 2 ,
283+ normalized = False ,
284+ sample_rate = 22050 ,
285+ f_min = 0 ,
286+ f_max = 8000 ,
287+ n_mels = 80 ,
288+ )
289+ style_emb = self .gpt .get_style_emb (mel_chunk .to (self .device ), None )
290+ style_embs .append (style_emb )
291+
292+ # mean style embedding
293+ cond_latent = torch .stack (style_embs ).mean (dim = 0 )
273294 else :
274- n_fft = 4096
275- hop_length = 1024
276- win_length = 4096
277- mel = wav_to_mel_cloning (
278- audio ,
279- mel_norms = self .mel_stats .cpu (),
280- n_fft = n_fft ,
281- hop_length = hop_length ,
282- win_length = win_length ,
283- power = 2 ,
284- normalized = False ,
285- sample_rate = 22050 ,
286- f_min = 0 ,
287- f_max = 8000 ,
288- n_mels = 80 ,
289- )
290- cond_latent = self .gpt .get_style_emb (mel .to (self .device ))
295+ mel = wav_to_mel_cloning (
296+ audio ,
297+ mel_norms = self .mel_stats .cpu (),
298+ n_fft = 4096 ,
299+ hop_length = 1024 ,
300+ win_length = 4096 ,
301+ power = 2 ,
302+ normalized = False ,
303+ sample_rate = 22050 ,
304+ f_min = 0 ,
305+ f_max = 8000 ,
306+ n_mels = 80 ,
307+ )
308+ cond_latent = self .gpt .get_style_emb (mel .to (self .device ))
291309 return cond_latent .transpose (1 , 2 )
292310
293311 @torch .inference_mode ()
@@ -323,12 +341,24 @@ def get_speaker_embedding(self, audio, sr):
323341 def get_conditioning_latents (
324342 self ,
325343 audio_path ,
344+ max_ref_length = 30 ,
326345 gpt_cond_len = 6 ,
327- max_ref_length = 10 ,
346+ gpt_cond_chunk_len = 6 ,
328347 librosa_trim_db = None ,
329348 sound_norm_refs = False ,
330- load_sr = 24000 ,
349+ load_sr = 22050 ,
331350 ):
351+ """Get the conditioning latents for the GPT model from the given audio.
352+
353+ Args:
354+ audio_path (str or List[str]): Path to reference audio file(s).
355+ max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30.
356+ gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6.
357+ gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6.
358+ librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None.
359+ sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False.
360+ load_sr (int, optional): Sample rate to load the audio. Defaults to 24000.
361+ """
332362 # deal with multiples references
333363 if not isinstance (audio_path , list ):
334364 audio_paths = [audio_path ]
@@ -349,14 +379,17 @@ def get_conditioning_latents(
349379 if librosa_trim_db is not None :
350380 audio = librosa .effects .trim (audio , top_db = librosa_trim_db )[0 ]
351381
382+ # compute latents for the decoder
352383 speaker_embedding = self .get_speaker_embedding (audio , load_sr )
353384 speaker_embeddings .append (speaker_embedding )
354385
355386 audios .append (audio )
356387
357- # use a merge of all references for gpt cond latents
388+ # merge all the audios and compute the latents for the gpt
358389 full_audio = torch .cat (audios , dim = - 1 )
359- gpt_cond_latents = self .get_gpt_cond_latents (full_audio , load_sr , length = gpt_cond_len ) # [1, 1024, T]
390+ gpt_cond_latents = self .get_gpt_cond_latents (
391+ full_audio , load_sr , length = gpt_cond_len , chunk_length = gpt_cond_chunk_len
392+ ) # [1, 1024, T]
360393
361394 if speaker_embeddings :
362395 speaker_embedding = torch .stack (speaker_embeddings )
@@ -397,6 +430,7 @@ def inference_with_config(self, text, config, ref_audio_path, language, **kwargs
397430 "top_k" : config .top_k ,
398431 "top_p" : config .top_p ,
399432 "gpt_cond_len" : config .gpt_cond_len ,
433+ "gpt_cond_chunk_len" : config .gpt_cond_chunk_len ,
400434 "max_ref_len" : config .max_ref_len ,
401435 "sound_norm_refs" : config .sound_norm_refs ,
402436 }
@@ -417,7 +451,8 @@ def full_inference(
417451 top_p = 0.85 ,
418452 do_sample = True ,
419453 # Cloning
420- gpt_cond_len = 6 ,
454+ gpt_cond_len = 30 ,
455+ gpt_cond_chunk_len = 6 ,
421456 max_ref_len = 10 ,
422457 sound_norm_refs = False ,
423458 ** hf_generate_kwargs ,
@@ -448,7 +483,10 @@ def full_inference(
448483 (aka boring) outputs. Defaults to 0.8.
449484
450485 gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
451- else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.
486+ else the first `gpt_cond_len` secs is used. Defaults to 30 seconds.
487+
488+ gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`.
489+ If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds.
452490
453491 hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
454492 transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
@@ -461,6 +499,7 @@ def full_inference(
461499 (gpt_cond_latent , speaker_embedding ) = self .get_conditioning_latents (
462500 audio_path = ref_audio_path ,
463501 gpt_cond_len = gpt_cond_len ,
502+ gpt_cond_chunk_len = gpt_cond_chunk_len ,
464503 max_ref_length = max_ref_len ,
465504 sound_norm_refs = sound_norm_refs ,
466505 )
@@ -566,7 +605,7 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
566605 if overlap_len > len (wav_chunk ):
567606 # wav_chunk is smaller than overlap_len, pass on last wav_gen
568607 if wav_gen_prev is not None :
569- wav_chunk = wav_gen [(wav_gen_prev .shape [0 ] - overlap_len ):]
608+ wav_chunk = wav_gen [(wav_gen_prev .shape [0 ] - overlap_len ) :]
570609 else :
571610 # not expecting will hit here as problem happens on last chunk
572611 wav_chunk = wav_gen [- overlap_len :]
@@ -576,7 +615,7 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
576615 crossfade_wav = crossfade_wav * torch .linspace (0.0 , 1.0 , overlap_len ).to (crossfade_wav .device )
577616 wav_chunk [:overlap_len ] = wav_overlap * torch .linspace (1.0 , 0.0 , overlap_len ).to (wav_overlap .device )
578617 wav_chunk [:overlap_len ] += crossfade_wav
579-
618+
580619 wav_overlap = wav_gen [- overlap_len :]
581620 wav_gen_prev = wav_gen
582621 return wav_chunk , wav_gen_prev , wav_overlap
0 commit comments