You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I'm trying to retrain jukebox with new samples that are only speech. I get a strange mismatch error when attempting to sample:
/usr/local/lib/python3.10/dist-packages/jukebox/make_models.py in make_prior(hps, vqvae, device)
182 prior.apply(_convert_conv_weights_to_fp16)
183 prior = prior.to(device)
--> 184 restore_model(hps, prior, hps.restore_prior)
185 if hps.train:
186 print_all(f"Loading prior in train mode")
/usr/local/lib/python3.10/dist-packages/jukebox/make_models.py in restore_model(hps, model, checkpoint_path)
64 # print(k, "Checkpoint:", checkpoint_hps.get(k, None), "Ours:", hps.get(k, None))
65 checkpoint['model'] = {k[7:] if k[:7] == 'module.' else k: v for k, v in checkpoint['model'].items()}
---> 66 model.load_state_dict(checkpoint['model'])
67 if 'step' in checkpoint: model.step = checkpoint['step']
68
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
2187
2188 if len(error_msgs) > 0:
-> 2189 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2190 self.class.name, "\n\t".join(error_msgs)))
2191 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for SimplePrior:
size mismatch for prior.x_emb.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([2127, 1024]).
size mismatch for prior.x_out.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([2127, 1024]).
I've used small_single_enc_dec_prior and these are the hyperparameters:
Hi, I'm trying to retrain jukebox with new samples that are only speech. I get a strange mismatch error when attempting to sample:
/usr/local/lib/python3.10/dist-packages/jukebox/make_models.py in make_prior(hps, vqvae, device)
182 prior.apply(_convert_conv_weights_to_fp16)
183 prior = prior.to(device)
--> 184 restore_model(hps, prior, hps.restore_prior)
185 if hps.train:
186 print_all(f"Loading prior in train mode")
/usr/local/lib/python3.10/dist-packages/jukebox/make_models.py in restore_model(hps, model, checkpoint_path)
64 # print(k, "Checkpoint:", checkpoint_hps.get(k, None), "Ours:", hps.get(k, None))
65 checkpoint['model'] = {k[7:] if k[:7] == 'module.' else k: v for k, v in checkpoint['model'].items()}
---> 66 model.load_state_dict(checkpoint['model'])
67 if 'step' in checkpoint: model.step = checkpoint['step']
68
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
2187
2188 if len(error_msgs) > 0:
-> 2189 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2190 self.class.name, "\n\t".join(error_msgs)))
2191 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for SimplePrior:
size mismatch for prior.x_emb.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([2127, 1024]).
size mismatch for prior.x_out.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([2127, 1024]).
I've used small_single_enc_dec_prior and these are the hyperparameters:
small_single_enc_dec_prior = Hyperparams(
n_ctx=6144, # original: 6144, ours: 384
prior_width=1024,
prior_depth=48,
heads=2,
attn_order=12,
blocks=64,
init_scale=0.7,
c_res=1,
prime_loss_fraction=0.4,
single_enc_dec=True,
labels=True,
labels_v3=True,
y_bins=(10,100), # Set this to (genres, artists) for your dataset
max_bow_genre_size=1,
min_duration=24.0, # 24
max_duration=600.0, # 600
t_bins=64, # original: 64, ours: 16
use_tokens=True,
n_tokens=384, # original: 384, ours: 24
n_vocab=79,
I then restore my saved checkpoint and update so the keys match. This is how I've formulated the training:
args = {
'hps': 'vqvae,small_single_enc_dec_prior,all_fp16,cpu_ema',
'name': 'pretrained_vqvae_small_single_enc_dec_prior_labels',
'sample_length': 786432, #49152
'bs': 4,
'aug_shift': True,
'aug_blend': True,
'audio_files_dir': '/content/data',
'train': True,
'test': True,
'prior': True,
'min_duration': 24, #24
'max_duration': 600, #600
'levels': 3,
'level': 2,
'weight_decay': 0.01,
'save_iters': 10,
'copy_input': True
}
Train the model:
train.run(**args)
I've tried all kinds of things but still get the same error, does anyone have an idea what could be the problem?
The text was updated successfully, but these errors were encountered: