Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encountered errors when reproducing lightning training example #271

Open
ReginaZh opened this issue Sep 26, 2024 · 3 comments
Open

Encountered errors when reproducing lightning training example #271

ReginaZh opened this issue Sep 26, 2024 · 3 comments

Comments

@ReginaZh
Copy link

🐛 Describe the bug

Tried to reproduce the liger kernel optimization on lighting trainer with deepspeed zero3 but encountered several errors.

Reproduce

script:

cd /examples/lightning/
python training.py --model Qwen/Qwen2-0.5B-Instruct --num_gpu 1 --max_length 1024 --strategy deepspeed

output:

[INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
/opt/conda/envs/ptca/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_validation.py:118: UserWarning: onnxruntime training package info: package_name: onnxruntime-training
  warnings.warn("onnxruntime training package info: package_name: %s" % package_name)
/opt/conda/envs/ptca/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_validation.py:119: UserWarning: onnxruntime training package info: __version__: 1.18.0
  warnings.warn("onnxruntime training package info: __version__: %s" % version)
/opt/conda/envs/ptca/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_validation.py:120: UserWarning: onnxruntime training package info: cuda_version: 12.2
  warnings.warn("onnxruntime training package info: cuda_version: %s" % cuda_version)
/opt/conda/envs/ptca/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_validation.py:121: UserWarning: onnxruntime build info: cudart_version: 12020
  warnings.warn("onnxruntime build info: cudart_version: %s" % cudart_version)
/opt/conda/envs/ptca/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_validation.py:129: UserWarning: WARNING: failed to find cudart version that matches onnxruntime build info
  warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info")
/opt/conda/envs/ptca/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_validation.py:130: UserWarning: WARNING: found cudart versions: [12010]
  warnings.warn("WARNING: found cudart versions: %s" % local_cudart_versions)
2024-09-26 03:11:07.596978: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-26 03:11:07.611316: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-26 03:11:07.615979: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-26 03:11:07.627834: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-26 03:11:08.472073: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Seed set to 42
2024-09-26 03:11:09,359 root [WARNING] - Cannot import JIT optimized kernels. CUDA extension will be disabled.
Traceback (most recent call last):
  File "/Liger-Kernel/examples/lightning/training.py", line 289, in <module>
    train()
  File "/Liger-Kernel/examples/lightning/training.py", line 257, in train
    strategy = DeepSpeedStrategy(stage=3)
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/lightning/pytorch/strategies/deepspeed.py", line 305, in __init__
    deepspeed.utils.logging.logger.setLevel(logging_level)
AttributeError: module 'deepspeed.utils' has no attribute 'logging'

I fixed above error by adding "import deepspeed" in training.py, but after that another error raised:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/Liger-Kernel/examples/lightning/training.py", line 289, in <module>
[rank0]:     train()
[rank0]:   File "/Liger-Kernel/examples/lightning/training.py", line 285, in train
[rank0]:     trainer.fit(model, datamodule=data_module)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
[rank0]:     call._call_and_handle_interrupt(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
[rank0]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank0]:     return function(*args, **kwargs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
[rank0]:     self._run(model, ckpt_path=ckpt_path)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 945, in _run
[rank0]:     call._call_configure_model(self)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 119, in _call_configure_model
[rank0]:     _call_lightning_module_hook(trainer, "configure_model")
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:   File "/Liger-Kernel/examples/lightning/training.py", line 76, in configure_model
[rank0]:     self.model = AutoLigerKernelForCausalLM.from_pretrained(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/liger_kernel/transformers/auto_model.py", line 31, in from_pretrained
[rank0]:     return super().from_pretrained(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
[rank0]:     return model_class.from_pretrained(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3838, in from_pretrained
[rank0]:     ) = cls._load_pretrained_model(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4349, in _load_pretrained_model
[rank0]:     raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
[rank0]: RuntimeError: Error(s) in loading state_dict for Qwen2ForCausalLM:
[rank0]:        size mismatch for model.embed_tokens.weight: copying a param with shape torch.Size([151936, 896]) from checkpoint, the shape in current model is torch.Size([0]).
[rank0]:        size mismatch for model.layers.0.self_attn.q_proj.weight: copying a param with shape torch.Size([896, 896]) from checkpoint, the shape in current model is torch.Size([0]).
[rank0]:        size mismatch for model.layers.0.self_attn.q_proj.bias: copying a param with shape torch.Size([896]) from checkpoint, the shape in current model is torch.Size([0]).

Versions

Environment Report:

Operating System: Linux-6.5.0-1025-azure-x86_64-with-glibc2.31
Python version: 3.10.14
PyTorch version: 2.4.1+cu121
CUDA version: 12.1
Triton version: 3.0.0
Transformers version: 4.42.4
deepspeed version: 0.15.0
liger_kernel version 0.3.0

@yundai424
Copy link
Collaborator

i think it's related to the deepspeed model init method. When using deepspeed the model should be initialized in a context where all new tensor creation will have 0 shape and it's inside of deepspeed source to implement the sharding & broadcast. There could be something falling off either throughout liger diffs or deepspeed/HF new version release. Will TAL and get back to this issue asap.

@yundai424
Copy link
Collaborator

So it was ignore_mismatch_shapes=True occasionally dropped and it has been fixed very recently in #263 😄 @ReginaZh you can try to install liger-kernel-lightly and it should fix your issue. @shimizust do you think we can make a quick patch release for it 🤔 ?

@ReginaZh
Copy link
Author

ReginaZh commented Oct 15, 2024

Thanks @yundai424, above issue has been solved by install liger-kernel-lightly.
However, I found another strange phenomenon:
when I reproduce the lightning examples, it took 2h59m to finish.

python training.py --model Qwen/Qwen2-0.5B-Instruct --num_gpu 1 --max_length 1024 --strategy deepspeed

image
But after modifying AutoLigerKernelForCausalLM to AutoModelForCausalLM in training.py, it took 2h42m to finish, which means AutoModelForCausalLM are even faster than AutoLigerKernelForCausalLM.

self.model = AutoLigerKernelForCausalLM.from_pretrained(

image

I wonder is it expected? And what should be the baseline of lightning trainer optimization?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants