Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add e2e tests for Unsloth qlora and test the builds #2093

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jobs:
run: |
pip3 show torch
pip3 install -U -e .
python scripts/unsloth_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt

- name: Run tests
Expand Down
2 changes: 2 additions & 0 deletions cicd/Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi

RUN python scripts/unsloth_install.py | sh

# So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt

Expand Down
2 changes: 2 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi

RUN python scripts/unsloth_install.py | sh

# So we can test the Docker image
RUN pip install pytest

Expand Down
7 changes: 5 additions & 2 deletions scripts/unsloth_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

v = V(torch.__version__)
cuda = str(torch.version.cuda)
is_ampere = torch.cuda.get_device_capability()[0] >= 8
try:
is_ampere = torch.cuda.get_device_capability()[0] >= 8
except RuntimeError:
is_ampere = False
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V("2.1.0"):
Expand All @@ -29,5 +32,5 @@
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print(
f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"'
f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
)
44 changes: 22 additions & 22 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

# pylint: disable=too-many-lines
import gc
import importlib
import logging
import math
import os
import types
from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401

import addict
Expand Down Expand Up @@ -409,7 +411,7 @@ def apply_patches(self) -> None:
)

if self.cfg.is_llama_derived_model:
self.patch_loss()
self.patch_loss_llama()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

Expand Down Expand Up @@ -451,27 +453,34 @@ def patch_attention(self) -> None:

replace_stablelm_attn_with_flash_attn(self.cfg.base_model)

def patch_loss(self) -> None:
@cached_property
def flash_attn(self) -> bool:
"""Check if flash attention is installed"""
return importlib.util.find_spec("flash_attn") is not None

def patch_loss_llama(self) -> None:
"""
Patch loss functions
"""
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if self.flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)

if self.cfg.flash_attn_cross_entropy:
if self.cfg.flash_attn_cross_entropy and self.flash_attn:
patch_llama_cross_entropy()
if self.cfg.flash_attn_rms_norm:
elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")

if self.cfg.flash_attn_rms_norm and self.flash_attn:
patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm

patch_unsloth_layernorm()
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

Expand All @@ -481,6 +490,7 @@ def patch_llama_derived_model(self) -> None:
"""
Modify all llama derived models in one block
"""
self.patch_loss_llama()

if self.cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
Expand Down Expand Up @@ -528,16 +538,6 @@ def patch_llama_derived_model(self) -> None:
"Shifted-sparse attention not currently implemented without flash attention."
)

if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch(model_type="llama")

if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

patch_self_attn_lora()

def set_auto_model_loader(self) -> None:
"""set self.AutoModelLoader
- default value: AutoModelForCausalLM (set at __init__)
Expand Down
188 changes: 188 additions & 0 deletions tests/e2e/patched/test_unsloth_qlora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""
e2e tests for unsloth qlora
"""
import logging
import os
from pathlib import Path

import pytest
from e2e.utils import most_recent_subdir
from tbparse import SummaryReader

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"


class TestUnslothQLoRA:
"""
Test class for Unsloth QLoRA Llama models
"""

@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": sample_packing,
"flash_attention": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.2,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
}
)

normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"

def test_unsloth_llama_qlora_unpacked_no_fa2(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.2,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
}
)

normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"

@pytest.mark.parametrize(
"sdp_attention",
[True, False],
)
def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.2,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"sdp_attention": sdp_attention,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"fp16": True,
}
)

normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"