Skip to content

Commit affaf11

Browse files
committed
Add XTTS training unit test
1 parent 1f92741 commit affaf11

File tree

5 files changed

+12858
-17
lines changed

5 files changed

+12858
-17
lines changed

TTS/tts/layers/xtts/trainer/gpt_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def format_batch_on_device(self, batch):
268268
dvae_wav = batch["wav"]
269269
dvae_mel_spec = self.torch_mel_spectrogram_dvae(dvae_wav)
270270
codes = self.dvae.get_codebook_indices(dvae_mel_spec)
271+
271272
batch["audio_codes"] = codes
272273
# delete useless batch tensors
273274
del batch["padded_text"]
@@ -454,7 +455,9 @@ def load_checkpoint(
454455
target_options={"anon": True},
455456
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
456457
"""Load the model checkpoint and setup for training or inference"""
457-
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))["model"]
458+
459+
state, _ = self.xtts.get_compatible_checkpoint_state(checkpoint_path)
460+
458461
# load the model weights
459462
self.xtts.load_state_dict(state, strict=strict)
460463

TTS/tts/models/xtts.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ def inference(
643643
expected_output_len = torch.tensor(
644644
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
645645
)
646+
646647
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
647648
gpt_latents = self.gpt(
648649
text_tokens,
@@ -788,6 +789,25 @@ def eval(self): # pylint: disable=redefined-builtin
788789
self.gpt.init_gpt_for_inference()
789790
super().eval()
790791

792+
def get_compatible_checkpoint_state_dict(self, model_path):
793+
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
794+
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
795+
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
796+
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
797+
for key in list(checkpoint.keys()):
798+
# check if it is from the coqui Trainer if so convert it
799+
if key.startswith("xtts."):
800+
new_key = key.replace("xtts.", "")
801+
checkpoint[new_key] = checkpoint[key]
802+
del checkpoint[key]
803+
key = new_key
804+
805+
# remove unused keys
806+
if key.split(".")[0] in ignore_keys:
807+
del checkpoint[key]
808+
809+
return checkpoint
810+
791811
def load_checkpoint(
792812
self,
793813
config,
@@ -821,22 +841,7 @@ def load_checkpoint(
821841

822842
self.init_models()
823843

824-
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
825-
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
826-
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
827-
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
828-
for key in list(checkpoint.keys()):
829-
# check if it is from the coqui Trainer if so convert it
830-
if key.startswith("xtts."):
831-
coqui_trainer_checkpoint = True
832-
new_key = key.replace("xtts.", "")
833-
checkpoint[new_key] = checkpoint[key]
834-
del checkpoint[key]
835-
key = new_key
836-
837-
# remove unused keys
838-
if key.split(".")[0] in ignore_keys:
839-
del checkpoint[key]
844+
checkpoint = self.get_compatible_checkpoint_state_dict(model_path)
840845

841846
# deal with v1 and v1.1. V1 has the init_gpt_for_inference keys, v1.1 do not
842847
try:

recipes/ljspeech/xtts_v1/train_gpt_xtts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
)
5353
LANGUAGE = config_dataset.language
5454

55+
5556
def main():
5657
# init args and config
5758
model_args = GPTArgs(

0 commit comments

Comments
 (0)