Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Size mismatch #300

Open
dhurtigkth opened this issue May 31, 2024 · 0 comments
Open

Size mismatch #300

dhurtigkth opened this issue May 31, 2024 · 0 comments

Comments

@dhurtigkth
Copy link

Hi, I'm trying to retrain jukebox with new samples that are only speech. I get a strange mismatch error when attempting to sample:

/usr/local/lib/python3.10/dist-packages/jukebox/make_models.py in make_prior(hps, vqvae, device)
182 prior.apply(_convert_conv_weights_to_fp16)
183 prior = prior.to(device)
--> 184 restore_model(hps, prior, hps.restore_prior)
185 if hps.train:
186 print_all(f"Loading prior in train mode")

/usr/local/lib/python3.10/dist-packages/jukebox/make_models.py in restore_model(hps, model, checkpoint_path)
64 # print(k, "Checkpoint:", checkpoint_hps.get(k, None), "Ours:", hps.get(k, None))
65 checkpoint['model'] = {k[7:] if k[:7] == 'module.' else k: v for k, v in checkpoint['model'].items()}
---> 66 model.load_state_dict(checkpoint['model'])
67 if 'step' in checkpoint: model.step = checkpoint['step']
68

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
2187
2188 if len(error_msgs) > 0:
-> 2189 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2190 self.class.name, "\n\t".join(error_msgs)))
2191 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for SimplePrior:
size mismatch for prior.x_emb.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([2127, 1024]).
size mismatch for prior.x_out.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([2127, 1024]).

I've used small_single_enc_dec_prior and these are the hyperparameters:

small_single_enc_dec_prior = Hyperparams(
n_ctx=6144, # original: 6144, ours: 384
prior_width=1024,
prior_depth=48,
heads=2,
attn_order=12,
blocks=64,
init_scale=0.7,
c_res=1,
prime_loss_fraction=0.4,
single_enc_dec=True,
labels=True,
labels_v3=True,
y_bins=(10,100), # Set this to (genres, artists) for your dataset
max_bow_genre_size=1,
min_duration=24.0, # 24
max_duration=600.0, # 600
t_bins=64, # original: 64, ours: 16
use_tokens=True,
n_tokens=384, # original: 384, ours: 24
n_vocab=79,

I then restore my saved checkpoint and update so the keys match. This is how I've formulated the training:

args = {
'hps': 'vqvae,small_single_enc_dec_prior,all_fp16,cpu_ema',
'name': 'pretrained_vqvae_small_single_enc_dec_prior_labels',
'sample_length': 786432, #49152
'bs': 4,
'aug_shift': True,
'aug_blend': True,
'audio_files_dir': '/content/data',
'train': True,
'test': True,
'prior': True,
'min_duration': 24, #24
'max_duration': 600, #600
'levels': 3,
'level': 2,
'weight_decay': 0.01,
'save_iters': 10,
'copy_input': True
}

Train the model:

train.run(**args)

I've tried all kinds of things but still get the same error, does anyone have an idea what could be the problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant