Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: 'NoneType' object has no attribute 'attn_bias' #1490

Open
ganzf886 opened this issue Dec 31, 2024 · 8 comments
Open

AttributeError: 'NoneType' object has no attribute 'attn_bias' #1490

ganzf886 opened this issue Dec 31, 2024 · 8 comments

Comments

@ganzf886
Copy link

ganzf886 commented Dec 31, 2024

I got an error of xformers:

Traceback (most recent call last):
  File "/data1/xuanyang/kooshot_cate_pred_llm_sft/sft.py", line 131, in <module>
    trainer_stats = trainer.train()
                    ^^^^^^^^^^^^^^^
  File "<string>", line 157, in train
  File "<string>", line 382, in _fast_inner_training_loop
  File "<string>", line 31, in _unsloth_training_step
  File "/data1/xuanyang/envs/unsloth/lib/python3.11/site-packages/unsloth/models/_utils.py", line 1062, in _unsloth_pre_compute_loss
    return self._old_compute_loss(model, inputs, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data1/xuanyang/envs/unsloth/lib/python3.11/site-packages/transformers/trainer.py", line 3708, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/data1/xuanyang/envs/unsloth/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data1/xuanyang/envs/unsloth/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data1/xuanyang/envs/unsloth/lib/python3.11/site-packages/accelerate/utils/operations.py", line 823, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data1/xuanyang/envs/unsloth/lib/python3.11/site-packages/accelerate/utils/operations.py", line 811, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data1/xuanyang/envs/unsloth/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data1/xuanyang/envs/unsloth/lib/python3.11/site-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data1/xuanyang/envs/unsloth/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data1/xuanyang/envs/unsloth/lib/python3.11/site-packages/unsloth/models/llama.py", line 1118, in PeftModelForCausalLM_fast_forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/data1/xuanyang/envs/unsloth/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data1/xuanyang/envs/unsloth/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data1/xuanyang/envs/unsloth/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 197, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data1/xuanyang/envs/unsloth/lib/python3.11/site-packages/unsloth/models/llama.py", line 970, in _CausalLM_fast_forward
    causal_mask = xformers.attn_bias.LowerTriangularMask()
                  ^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'attn_bias'
  0%|          | 0/60 [00:00<?, ?it/s]

and my xformers is built for py311_cu12.1.0_pyt2.5.1

Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
accelerate                1.2.1                    pypi_0    pypi
aiohappyeyeballs          2.4.4                    pypi_0    pypi
aiohttp                   3.11.11                  pypi_0    pypi
aiosignal                 1.3.2                    pypi_0    pypi
attrs                     24.3.0                   pypi_0    pypi
bitsandbytes              0.42.0                   pypi_0    pypi
blas                      2.16                        mkl    conda-forge
bzip2                     1.0.8                h4bc722e_7    conda-forge
ca-certificates           2024.12.14           hbcca054_0    conda-forge
certifi                   2024.12.14               pypi_0    pypi
charset-normalizer        3.4.1                    pypi_0    pypi
cpython                   3.11.11         py311hd8ed1ab_1    conda-forge
cuda-cudart               12.1.105                      0    nvidia
cuda-cupti                12.1.105                      0    nvidia
cuda-libraries            12.1.0                        0    nvidia
cuda-nvrtc                12.1.105                      0    nvidia
cuda-nvtx                 12.1.105                      0    nvidia
cuda-opencl               12.4.127                      0    nvidia
cuda-runtime              12.1.0                        0    nvidia
cudatoolkit               11.7.0              hd8887f6_10    nvidia
cut-cross-entropy         24.12.3                  pypi_0    pypi
datasets                  3.2.0                    pypi_0    pypi
dill                      0.3.8                    pypi_0    pypi
docstring-parser          0.16                     pypi_0    pypi
filelock                  3.16.1             pyhd8ed1ab_1    conda-forge
frozenlist                1.5.0                    pypi_0    pypi
fsspec                    2024.9.0                 pypi_0    pypi
gmp                       6.3.0                hac33072_2    conda-forge
gmpy2                     2.1.5           py311h0f6cedb_3    conda-forge
hf-transfer               0.1.8                    pypi_0    pypi
huggingface-hub           0.27.0                   pypi_0    pypi
idna                      3.10                     pypi_0    pypi
intel-openmp              2022.0.1          h06a4308_3633    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
jinja2                    3.1.5              pyhd8ed1ab_0    conda-forge
ld_impl_linux-64          2.43                 h712a8e2_2    conda-forge
libblas                   3.8.0                    16_mkl    conda-forge
libcblas                  3.8.0                    16_mkl    conda-forge
libcublas                 12.1.0.26                     0    nvidia
libcufft                  11.0.2.4                      0    nvidia
libcufile                 1.9.1.3                       0    nvidia
libcurand                 10.3.5.147                    0    nvidia
libcusolver               11.4.4.55                     0    nvidia
libcusparse               12.0.2.55                     0    nvidia
libexpat                  2.6.4                h5888daf_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc                    14.2.0               h77fa898_1    conda-forge
libgcc-ng                 14.2.0               h69a702a_1    conda-forge
libgfortran-ng            7.5.0               h14aa051_20    conda-forge
libgfortran4              7.5.0               h14aa051_20    conda-forge
libgomp                   14.2.0               h77fa898_1    conda-forge
liblapack                 3.8.0                    16_mkl    conda-forge
liblapacke                3.8.0                    16_mkl    conda-forge
liblzma                   5.6.3                hb9d3cd8_1    conda-forge
libnpp                    12.0.2.50                     0    nvidia
libnsl                    2.0.1                hd590300_0    conda-forge
libnvjitlink              12.1.105                      0    nvidia
libnvjpeg                 12.1.1.14                     0    nvidia
libsqlite                 3.47.2               hee588c1_0    conda-forge
libstdcxx                 14.2.0               hc0a3c3a_1    conda-forge
libstdcxx-ng              14.2.0               h4852527_1    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libzlib                   1.3.1                hb9d3cd8_2    conda-forge
llvm-openmp               15.0.7               h0cdce71_0    conda-forge
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                3.0.2           py311h2dc5d0c_1    conda-forge
mdurl                     0.1.2                    pypi_0    pypi
mkl                       2020.2                      256    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
mpc                       1.3.1                h24ddda3_1    conda-forge
mpfr                      4.2.1                h90cbb55_3    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_1    conda-forge
multidict                 6.1.0                    pypi_0    pypi
multiprocess              0.70.16                  pypi_0    pypi
ncurses                   6.5                  he02047a_1    conda-forge
networkx                  3.4.2              pyh267e887_2    conda-forge
numpy                     2.2.1                    pypi_0    pypi
openssl                   3.4.0                hb9d3cd8_0    conda-forge
packaging                 24.2                     pypi_0    pypi
pandas                    2.2.3                    pypi_0    pypi
peft                      0.14.0                   pypi_0    pypi
pillow                    11.0.0                   pypi_0    pypi
pip                       24.3.1             pyh8b19718_2    conda-forge
propcache                 0.2.1                    pypi_0    pypi
protobuf                  3.20.3                   pypi_0    pypi
psutil                    6.1.1                    pypi_0    pypi
pyarrow                   18.1.0                   pypi_0    pypi
pygments                  2.18.0                   pypi_0    pypi
python                    3.11.11         h9e4cc4f_1_cpython    conda-forge
python-dateutil           2.9.0.post0              pypi_0    pypi
python_abi                3.11                    5_cp311    conda-forge
pytorch                   2.5.1           py3.11_cuda12.1_cudnn9.1.0_0    pytorch
pytorch-cuda              12.1                 ha16c6d3_6    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2024.2                   pypi_0    pypi
pyyaml                    6.0.2           py311h9ecbd09_1    conda-forge
readline                  8.2                  h8228510_1    conda-forge
regex                     2024.11.6                pypi_0    pypi
requests                  2.32.3                   pypi_0    pypi
rich                      13.9.4                   pypi_0    pypi
safetensors               0.4.5                    pypi_0    pypi
scipy                     1.14.1                   pypi_0    pypi
sentencepiece             0.2.0                    pypi_0    pypi
setuptools                75.6.0             pyhff2d567_1    conda-forge
shtab                     1.7.1                    pypi_0    pypi
six                       1.17.0                   pypi_0    pypi
sympy                     1.13.1                   pypi_0    pypi
tk                        8.6.13          noxft_h4845f30_101    conda-forge
tokenizers                0.21.0                   pypi_0    pypi
torchtriton               3.1.0                     py311    pytorch
tqdm                      4.67.1                   pypi_0    pypi
transformers              4.47.1                   pypi_0    pypi
trl                       0.13.0                   pypi_0    pypi
typeguard                 4.4.1                    pypi_0    pypi
typing_extensions         4.12.2             pyha770c72_1    conda-forge
tyro                      0.9.5                    pypi_0    pypi
tzdata                    2024.2                   pypi_0    pypi
unsloth                   2024.12.12               pypi_0    pypi
unsloth-zoo               2024.12.7                pypi_0    pypi
urllib3                   2.3.0                    pypi_0    pypi
wheel                     0.45.1             pyhd8ed1ab_1    conda-forge
xformers                  0.0.28.post3    py311_cu12.1.0_pyt2.5.1    xformers
xxhash                    3.5.0                    pypi_0    pypi
yaml                      0.2.5                h7f98852_2    conda-forge
yarl                      1.18.3                   pypi_0    pypi

why do I get this error?

@ItzAmirreza
Copy link

ItzAmirreza commented Dec 31, 2024

Same, happy new year btw

Continued pretraining qwen2.5-coder-7b-bnb4bit

@mosama1994
Copy link

Yeah this issue is happening. The extra embed and lm head layers are causing this. I dont think they are even need for training if we are not using a different tokenizer or changed the tokenizer.

@wkrahulkanojia
Copy link

Is there any solution for this?
xformers.__version__ : '0.0.28.post3'

@mosama1994
Copy link

it works on my local and these are the versions i have:

Package Version


absl-py 2.1.0
accelerate 1.2.1
aiofiles 24.1.0
aiohappyeyeballs 2.4.4
aiohttp 3.11.11
aiosignal 1.3.2
annotated-types 0.7.0
anyio 4.7.0
asttokens 3.0.0
attrs 24.3.0
azure-common 1.1.28
azure-core 1.32.0
azure-search-documents 11.6.0b2
bitsandbytes 0.45.0
certifi 2024.12.14
charset-normalizer 3.4.0
comm 0.2.2
cut-cross-entropy 24.12.2
datasets 3.2.0
debugpy 1.8.11
decorator 5.1.1
dill 0.3.8
distro 1.9.0
docstring_parser 0.16
einops 0.8.0
executing 2.1.0
filelock 3.16.1
frozenlist 1.5.0
fsspec 2024.9.0
groq 0.13.1
grpcio 1.68.1
h11 0.14.0
hf_transfer 0.1.8
httpcore 1.0.7
httpx 0.28.1
huggingface-hub 0.27.0
idna 3.10
ipykernel 6.29.5
ipython 8.31.0
isodate 0.7.2
jedi 0.19.2
Jinja2 3.1.4
jupyter_client 8.6.3
jupyter_core 5.7.2
Markdown 3.7
markdown-it-py 3.0.0
MarkupSafe 3.0.2
matplotlib-inline 0.1.7
mdurl 0.1.2
mpmath 1.3.0
multidict 6.1.0
multiprocess 0.70.16
nest-asyncio 1.6.0
networkx 3.4.2
numpy 2.2.0
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
packaging 24.2
pandas 2.2.3
parso 0.8.4
peft 0.14.0
pexpect 4.9.0
pillow 11.0.0
pip 24.3.1
platformdirs 4.3.6
prompt_toolkit 3.0.48
propcache 0.2.1
protobuf 3.20.3
psutil 6.1.1
ptyprocess 0.7.0
pure_eval 0.2.3
pyarrow 18.1.0
pydantic 2.10.4
pydantic_core 2.27.2
Pygments 2.18.0
python-dateutil 2.9.0.post0
pytz 2024.2
PyYAML 6.0.2
pyzmq 26.2.0
regex 2024.11.6
requests 2.32.3
rich 13.9.4
safetensors 0.4.5
sentencepiece 0.2.0
setuptools 65.5.0
shtab 1.7.1
six 1.17.0
sniffio 1.3.1
stack-data 0.6.3
sympy 1.13.1
tensorboard 2.18.0
tensorboard-data-server 0.7.2
tokenizers 0.20.3
torch 2.5.0
torchaudio 2.5.1
torchvision 0.20.1
tornado 6.4.2
tqdm 4.67.1
traitlets 5.14.3
transformers 4.46.3
triton 3.1.0
trl 0.13.0
typeguard 4.4.1
typing_extensions 4.12.2
tyro 0.9.4
tzdata 2024.2
unsloth 2024.12.8
unsloth_zoo 2024.12.3
urllib3 2.2.3
wcwidth 0.2.13
Werkzeug 3.1.3
wheel 0.45.1
xformers 0.0.28.post2
xxhash 3.5.0
yarl 1.18.3
Note: you may need to restart the kernel to use updated packages.

On local I am using bf16 precision. On colab I was trying to do the pretraining following the same script but was getting error when i added the lm_head and embed layers. On there, I keep getting the error of data type mismatch float 16 and float 32. Something to do with these 2 layers for the Qwen Models.

==((====))== Unsloth - 2x faster free finetuning | Num GPUs = 1
\ /| Num examples = 6,478 | Num Epochs = 1
O^O/ _/ \ Batch size per device = 2 | Gradient Accumulation steps = 8
\ / Total batch size = 16 | Total steps = 120
"-____-" Number of trainable parameters = 1,412,956,160


AssertionError Traceback (most recent call last)

/usr/local/lib/python3.10/dist-packages/triton/language/core.py in wrapper(*args, **kwargs)
34 "(_builder argument must be provided outside of JIT functions.)")
---> 35 return fn(*args, **kwargs)
36

33 frames

/usr/local/lib/python3.10/dist-packages/triton/language/core.py in dot(input, other, acc, input_precision, allow_tf32, max_num_imprecise_acc, out_dtype, _builder)
1533 max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc)
-> 1534 return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder)
1535

/usr/local/lib/python3.10/dist-packages/triton/language/semantic.py in dot(lhs, rhs, acc, input_precision, max_num_imprecise_acc, out_dtype, builder)
1354 assert lhs.type.is_block() and rhs.type.is_block()
-> 1355 assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options)
1356 if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():

/usr/local/lib/python3.10/dist-packages/triton/language/semantic.py in assert_dtypes_valid(lhs_dtype, rhs_dtype, options)
1327 return
-> 1328 assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
1329 else:

AssertionError: First input (fp16) and second input (fp32) must have the same dtype!

The above exception was the direct cause of the following exception:

CompilationError Traceback (most recent call last)

in <cell line: 1>()
----> 1 trainer_stats = trainer.train()

/usr/local/lib/python3.10/dist-packages/unsloth/tokenizer_utils.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)

/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

/usr/local/lib/python3.10/dist-packages/unsloth/models/_utils.py in _unsloth_training_step(self, model, inputs, num_items_in_batch)

/usr/local/lib/python3.10/dist-packages/unsloth/models/_utils.py in _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs)
1060 )
1061 pass
-> 1062 return self._old_compute_loss(model, inputs, *args, **kwargs)
1063 pass
1064

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
3706 loss_kwargs["num_items_in_batch"] = num_items_in_batch
3707 inputs = {**inputs, **loss_kwargs}
-> 3708 outputs = model(**inputs)
3709 # Save past state if it exists
3710 # TODO: this needs to be fixed and made cleaner later.

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
1737
1738 # torchrec tests the code consistency with the following code

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None

/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py in forward(*args, **kwargs)
821
822 def forward(*args, **kwargs):
--> 823 return model_forward(*args, **kwargs)
824
825 # To act like a decorator so that it can be popped when doing extract_model_from_parallel

/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py in call(self, *args, **kwargs)
809
810 def call(self, *args, **kwargs):
--> 811 return convert_to_fp32(self.model_forward(*args, **kwargs))
812
813 def getstate(self):

/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py in decorate_autocast(*args, **kwargs)
42 def decorate_autocast(*args, **kwargs):
43 with autocast_instance:
---> 44 return func(*args, **kwargs)
45
46 decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode" # type: ignore[attr-defined]

/usr/local/lib/python3.10/dist-packages/torch/_compile.py in inner(*args, **kwargs)
30 fn.__dynamo_disable = disable_fn
31
---> 32 return disable_fn(*args, **kwargs)
33
34 return inner

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py in _fn(*args, **kwargs)
630 prior = _maybe_set_eval_frame(callback)
631 try:
--> 632 return fn(*args, **kwargs)
633 finally:
634 _maybe_set_eval_frame(prior)

/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py in PeftModelForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, num_logits_to_keep, **kwargs)
1116 **kwargs,
1117 ):
-> 1118 return self.base_model(
1119 input_ids=input_ids,
1120 causal_mask=causal_mask,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
1737
1738 # torchrec tests the code consistency with the following code

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None

/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py in forward(self, *args, **kwargs)
195
196 def forward(self, *args: Any, **kwargs: Any):
--> 197 return self.model.forward(*args, **kwargs)
198
199 def _pre_injection_hook(self, model: nn.Module, config: PeftConfig, adapter_name: str) -> None:

/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py in _CausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep, *args, **kwargs)
1010 if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None:
1011 n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)
-> 1012 loss = fused_linear_cross_entropy(
1013 hidden_states = hidden_states,
1014 lm_weight = lm_head,

/usr/local/lib/python3.10/dist-packages/unsloth_zoo/loss_utils.py in fused_linear_cross_entropy(hidden_states, lm_weight, labels, num_items_in_batch, ignore_index, reduction, logit_softcapping, accuracy_threshold)
150 reduction = "sum" if num_items_in_batch is not None else "mean"
151 if logit_softcapping == 0: logit_softcapping = None
--> 152 loss = linear_cross_entropy(
153 hidden_states,
154 lm_weight,

/usr/local/lib/python3.10/dist-packages/cut_cross_entropy/linear_cross_entropy.py in linear_cross_entropy(e, c, targets, ignore_index, softcap, reduction, shift, filter_eps, impl)
56
57 assert cce_linear_cross_entropy is not None
---> 58 return cce_linear_cross_entropy(
59 e, c, targets, ignore_index, softcap, reduction, shift, filter_eps
60 )

/usr/local/lib/python3.10/dist-packages/cut_cross_entropy/cce.py in cce_linear_cross_entropy(e, c, targets, ignore_index, softcap, reduction, shift, filter_eps)
166 targets = targets.flatten()
167
--> 168 return linear_cross_entropy_apply(
169 e,
170 c,

/usr/local/lib/python3.10/dist-packages/cut_cross_entropy/cce.py in linear_cross_entropy_apply(e, c, params)
123 params: CCEParams,
124 ) -> torch.Tensor:
--> 125 loss = LinearCrossEntropyFunction.apply(e, c, params)
126 assert isinstance(loss, torch.Tensor)
127

/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py in apply(cls, *args, **kwargs)
573 # See NOTE: [functorch vjp and autograd interaction]
574 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575 return super().apply(*args, **kwargs) # type: ignore[misc]
576
577 if not is_setup_ctx_defined:

/usr/local/lib/python3.10/dist-packages/cut_cross_entropy/cce.py in forward(ctx, e, c, params)
44 return_logit_avg = needs_grad and params.filter_eps is not None
45
---> 46 ret = cce_lse_forward_kernel(
47 e,
48 c,

/usr/local/lib/python3.10/dist-packages/cut_cross_entropy/cce_lse_forward.py in cce_lse_forward_kernel(e, c, valids, softcap, return_logit_avg)
182 return (triton.cdiv(B, META["BLOCK_B"]) * triton.cdiv(V, META["BLOCK_V"]),)
183
--> 184 _cce_lse_forward_kernel[grid](
185 e,
186 c,

/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py in (*args, **kwargs)
343 memorizes the grid.
344 """
--> 345 return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
346 # return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
347

/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py in run(self, *args, **kwargs)
336 for v, heur in self.values.items():
337 kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 338 return self.fn.run(*args, **kwargs)
339
340

/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py in run(self, *args, **kwargs)
336 for v, heur in self.values.items():
337 kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 338 return self.fn.run(*args, **kwargs)
339
340

/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py in run(self, grid, warmup, *args, **kwargs)
660 # compile the kernel
661 src = self.ASTSource(self, signature, constants, configs[0])
--> 662 kernel = self.compile(
663 src,
664 target=target,

/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py in compile(src, target, options)
274 codegen_fns = backend.get_codegen_implementation()
275 try:
--> 276 module = src.make_ir(options, codegen_fns, context)
277 except Exception as e:
278 filter_traceback(e)

/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py in make_ir(self, options, codegen_fns, context)
111
112 def make_ir(self, options, codegen_fns, context):
--> 113 return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
114
115 def parse_options(self):

CompilationError: at 60:16:

accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32)
for d in range(0, tl.cdiv(D, BLOCK_D)):
    # Load the next block of A and B, generate a mask by checking the K dimension.
    # If it is out of bounds, set it to 0.
    if EVEN_D:
        e = tl.load(e_ptrs)
        c = tl.load(c_ptrs)
    else:
        e = tl.load(e_ptrs, mask=offs_d[None, :] < D - d * BLOCK_D, other=0.0)
        c = tl.load(c_ptrs, mask=offs_d[:, None] < D - d * BLOCK_D, other=0.0)
    accum = tl.dot(e, c, accum, input_precision=DOT_PRECISION)
            ^

@mosama1994
Copy link

Same error i get for llama 3.2 models as well. The default mistral (unsloth/mistral-7b-v0.3) model in the pretraining notebook that one is working fine.

@mosama1994
Copy link

On colab have to use float 16 that is the only change so maybe error rises due to that. On local, I am using bfloat 16 and did not run into any errors. My notebook link: https://github.com/mosama1994/Unsloth-Pretraining/blob/main/Pretraining.ipynb

@ganzf886
Copy link
Author

ganzf886 commented Jan 2, 2025

It's OK when I downgrade the xformers to 0.0.27.post2 (the lowest version of xformers that unsloth requires)...

@FartyPants
Copy link

FartyPants commented Jan 5, 2025

What is the point of the example notebooks if they don't work?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants