Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full Fine tuning Inference not working Fishspeech1.4 #772

Closed
6 tasks done
kdcyberdude opened this issue Dec 20, 2024 · 1 comment
Closed
6 tasks done

Full Fine tuning Inference not working Fishspeech1.4 #772

kdcyberdude opened this issue Dec 20, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@kdcyberdude
Copy link

Self Checks

  • This template is only for bug reports. For questions, please visit Discussions.
  • I have thoroughly reviewed the project documentation (installation, training, inference) but couldn't find information to solve my problem. English 中文 日本語 Portuguese (Brazil)
  • I have searched for existing issues, including closed ones. Search issues
  • I confirm that I am using English to submit this report (我已阅读并同意 Language Policy).
  • [FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)
  • Please do not modify this template and fill in all required fields.

Cloud or Self Hosted

Self Hosted (Source)

Environment Details

Ubuntu 22.04, torch==2.4.1, Gradio 5.9.0

Steps to Reproduce

I am fine tuning fishspeech 1.4 on a new language(Panjabi) without lora.

  1. First I checkout to git checkout tags/v1.4.3
  2. Then I started the training with below config
  3. This created the step*.ckpt in {result} dir.. Which I converted to .pth using tools/extract_model.py
  4. Updated the model.pth from fishspeech1.4 to new checkpoint.
  5. When trying to inference, I am getting the following on gradio The expanded size of the tensor (4096) must match the existing size (4170) at non-singleton dimension 1. Target sizes: [9, 4096]. Tensor sizes: [9, 4170]

image

NOTE: The loss is also started increasing after certain iterations.. not sure why that is the case... You can check the logs on wandb - https://wandb.ai/kdcyberdude/fish-speech/workspace?nw=nwuserkdcyberdude

YAML config to train
defaults:
  - base
  - _self_

paths:
  run_dir: results/${project}
  ckpt_dir: ${paths.run_dir}/ft_checkpoints2

hydra:
  run:
    dir: ${paths.run_dir}


project: text2semantic_finetune_dual_ar
max_length: 4096
pretrained_ckpt_path: checkpoints/fish-speech-1.4

# Lightning Trainer
trainer:
  accumulate_grad_batches: 2
  gradient_clip_val: 1.0
  gradient_clip_algorithm: "norm"
  max_steps: 10000
  precision: bf16-true
  limit_val_batches: 10
  val_check_interval: 500
  benchmark: true


# Dataset Configuration
tokenizer:
  _target_: transformers.AutoTokenizer.from_pretrained
  pretrained_model_name_or_path: ${pretrained_ckpt_path}  

# Dataset Configuration
train_dataset:
  _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
  proto_files:
    - data/protos
  tokenizer: ${tokenizer}
  causal: true
  max_length: ${max_length}
  use_speaker: false
  interactive_prob: 0.7

val_dataset:
  _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
  proto_files:
    - data/protos
  tokenizer: ${tokenizer}
  causal: true
  max_length: ${max_length}
  use_speaker: false
  interactive_prob: 0.7

data:
  _target_: fish_speech.datasets.semantic.SemanticDataModule
  train_dataset: ${train_dataset}
  val_dataset: ${val_dataset}
  num_workers: 8
  batch_size: 20
  tokenizer: ${tokenizer}
  max_length: ${max_length}

# Model Configuration
model:
  _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
  model: 
    _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
    path: ${pretrained_ckpt_path}
    load_weights: true
    max_length: ${max_length}
    lora_config: null

  optimizer:
    _target_: torch.optim.AdamW
    _partial_: true
    lr: 2e-5
    weight_decay: 0
    betas: [0.9, 0.95]
    eps: 1e-5

  lr_scheduler:
    _target_: torch.optim.lr_scheduler.LambdaLR
    _partial_: true
    lr_lambda:
      _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
      _partial_: true
      num_warmup_steps: 2500

# Callbacks
callbacks:
  model_checkpoint:
    every_n_train_steps: ${trainer.val_check_interval}
    dirpath: ${paths.ckpt_dir}
    filename: "step_{step:09d}"
    save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt
    save_top_k: 5 # save 5 latest checkpoints
    monitor: step # use step to monitor checkpoints
    mode: max # save the latest checkpoint with the highest global_step
    every_n_epochs: null # don't save checkpoints by epoch end
    auto_insert_metric_name: false

  model_summary:
    _target_: lightning.pytorch.callbacks.ModelSummary
    max_depth: 2 # the maximum depth of layer nesting that the summary will include

  learning_rate_monitor:
    _target_: lightning.pytorch.callbacks.LearningRateMonitor
    logging_interval: step
    log_momentum: false

  grad_norm_monitor:
    _target_: fish_speech.callbacks.GradNormMonitor
    norm_type: 2
    logging_interval: step

# Logger
logger:
  tensorboard:
    _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
    save_dir: "${paths.run_dir}/tensorboard/"
    name: null
    log_graph: false
    default_hp_metric: true
    prefix: ""

  wandb:
    _target_: lightning.pytorch.loggers.wandb.WandbLogger
    # name: "" # name of the run (normally generated by wandb)
    save_dir: "${paths.run_dir}"
    offline: False
    id: null # pass correct id to resume experiment!
    anonymous: null # enable anonymous logging
    project: "fish-speech"
    log_model: False # upload lightning ckpts
    prefix: "" # a string to put at the beginning of metric keys
    # entity: "" # set to name of your wandb team
    group: ""
    tags: ["vq", "hq", "finetune"]
    job_type: ""
    ```

### ✔️ Expected Behavior

Inference should work on fine tuned check point.
Loss should not diverge.

### ❌ Actual Behavior

Inference error.
@kdcyberdude kdcyberdude added the bug Something isn't working label Dec 20, 2024
@Whale-Dolphin
Copy link
Collaborator

Whale-Dolphin commented Dec 21, 2024

The fine tunint works now in 1.5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants