@@ -643,6 +643,7 @@ def inference(
643643 expected_output_len = torch .tensor (
644644 [gpt_codes .shape [- 1 ] * self .gpt .code_stride_len ], device = text_tokens .device
645645 )
646+
646647 text_len = torch .tensor ([text_tokens .shape [- 1 ]], device = self .device )
647648 gpt_latents = self .gpt (
648649 text_tokens ,
@@ -788,6 +789,25 @@ def eval(self): # pylint: disable=redefined-builtin
788789 self .gpt .init_gpt_for_inference ()
789790 super ().eval ()
790791
792+ def get_compatible_checkpoint_state_dict (self , model_path ):
793+ checkpoint = load_fsspec (model_path , map_location = torch .device ("cpu" ))["model" ]
794+ ignore_keys = ["diffusion_decoder" , "vocoder" ] if self .args .use_hifigan or self .args .use_ne_hifigan else []
795+ ignore_keys += [] if self .args .use_hifigan else ["hifigan_decoder" ]
796+ ignore_keys += [] if self .args .use_ne_hifigan else ["ne_hifigan_decoder" ]
797+ for key in list (checkpoint .keys ()):
798+ # check if it is from the coqui Trainer if so convert it
799+ if key .startswith ("xtts." ):
800+ new_key = key .replace ("xtts." , "" )
801+ checkpoint [new_key ] = checkpoint [key ]
802+ del checkpoint [key ]
803+ key = new_key
804+
805+ # remove unused keys
806+ if key .split ("." )[0 ] in ignore_keys :
807+ del checkpoint [key ]
808+
809+ return checkpoint
810+
791811 def load_checkpoint (
792812 self ,
793813 config ,
@@ -821,22 +841,7 @@ def load_checkpoint(
821841
822842 self .init_models ()
823843
824- checkpoint = load_fsspec (model_path , map_location = torch .device ("cpu" ))["model" ]
825- ignore_keys = ["diffusion_decoder" , "vocoder" ] if self .args .use_hifigan or self .args .use_ne_hifigan else []
826- ignore_keys += [] if self .args .use_hifigan else ["hifigan_decoder" ]
827- ignore_keys += [] if self .args .use_ne_hifigan else ["ne_hifigan_decoder" ]
828- for key in list (checkpoint .keys ()):
829- # check if it is from the coqui Trainer if so convert it
830- if key .startswith ("xtts." ):
831- coqui_trainer_checkpoint = True
832- new_key = key .replace ("xtts." , "" )
833- checkpoint [new_key ] = checkpoint [key ]
834- del checkpoint [key ]
835- key = new_key
836-
837- # remove unused keys
838- if key .split ("." )[0 ] in ignore_keys :
839- del checkpoint [key ]
844+ checkpoint = self .get_compatible_checkpoint_state_dict (model_path )
840845
841846 # deal with v1 and v1.1. V1 has the init_gpt_for_inference keys, v1.1 do not
842847 try :
0 commit comments