Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LongLoRA + Flash Attention 2 causing illigal memory access #148

Open
ArturNiederfahrenhorst opened this issue Nov 21, 2023 · 7 comments
Open

Comments

@ArturNiederfahrenhorst
Copy link

ArturNiederfahrenhorst commented Nov 21, 2023

Thanks for providing the LongLoRA forward functions.
Your flash-attn/non-flash-attn implementations of SSN show divergent behavior in my case.

For a repro script, please have a look at the issue I opened over at the flash-attention repo: Dao-AILab/flash-attention#670

The one without flash attention works without problems for me. I stepped my way through it and ops and shapes make sense to me.
The shift is implemented by rolling there.

The one with flash attention shows weird behaviour. The shift is not just a roll, but we mess with cu_q_lens. The code, to me, looks like it was written with token sequences longer than half of the group size in mind or something like that. For a batch with 4k context length but only 8 unpadded tokens, I end up with cu_q_lens=[ 0, 8, 520, 16]. For smaller group sizes, the 520 in this tensor "shrinks".

Can you please elaborate the calculations or help me to fix this?

@yukang2017
Copy link
Member

yukang2017 commented Nov 23, 2023

Hi,

Many thanks for your interest in our work.

Let's take a step-by-step example to understand this flash-attention version implementation.

(1) To understand the flash-attention implementation
Taking batch size = 1 and the sequence length = 8192 for example. In this case, group_size=2048

  • The qkv shape is originally (1, 8192, 3, 32, 128) and reshape to be (2, 8192, 3, 16, 128) after L104.

We split the mutli-head dimension to the batch dimension. That is why the batch size is double.

  • We get the cu_q_lens via L109.
    The value for it is torch.tensor([ 0, 8192, 16384]). It means that the first batch is from 0 to 8191, and the second is from 8192 to 16383.

  • After that, we calculate cu_q_len_tmp via L111.
    The value for it is torch.tensor([[ 0, 2048, 4096, 6144], [ 9216, 11264, 13312, 15360]]).

It is to split tokens into groups in each batch (the original attention head dimension).

  • In the end, we get the final cu_q_lens via L112. Its value is
    tensor([ 0, 2048, 4096, 6144, 8192, 9216, 11264, 13312, 15360, 16384]). It contains bath batch and group splitting for attention computation.

(2) Let's debug that why cu_q_lens=[ 0, 8, 520, 16] in your case.

The most weird thing is that, there is a 520 in cu_q_lens. Would you please print out the qkv.shape, x.shape, group_size, cu_q_lens and cu_q_len_tmp from L104 to L112 ? If so, we can easily find out why this is the case.

Based on my guess, because this function is designed for a continued pre-training, the group_size is 1/4 of the sequence length. It might be variable to some values. You can change it to be a fixed number, such as 2048. It might be stable.

@ArturNiederfahrenhorst
Copy link
Author

ArturNiederfahrenhorst commented Nov 27, 2023

Hi @yukang2017 and sorry for the late response. Thanks for taking time to look at this issue.
Here are the values that you requested:

qkv.shape torch.Size([2, 4096, 3, 16, 128])
x.shape torch.Size([2, 4096, 6144])
group_size 1024
cu_q_lens tensor([  0,   8, 520,  16], device='cuda:0', dtype=torch.int32)
cu_q_len_tmp tensor([[  0], [520]], device='cuda:0', dtype=torch.int32)

This is for llama 2 with a context length of 4096, and a max_length (4096) padded input sequence of 8 input tokens.

@ArturNiederfahrenhorst
Copy link
Author

ArturNiederfahrenhorst commented Nov 27, 2023

Here's the repro script I mentioned:

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

from llama_attn_replace import replace_llama_attn

flash_attn = True

replace_llama_attn(flash_attn=flash_attn)

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto", use_flash_attention_2=flash_attn)

model.resize_token_embeddings(len(tokenizer))
optimizer = torch.optim.AdamW(model.parameters())

dataset = load_dataset("databricks/databricks-dolly-15k", split="train", streaming=True)
dataset = dataset.map(lambda x: tokenizer([x["instruction"]], truncation=True, padding='max_length', max_length=4096, return_tensors="pt"))
train_dataloader = DataLoader(dataset)

model.train()

for step, batch in enumerate(train_dataloader):
    output = model(
        input_ids=batch["input_ids"].to(model.device).squeeze(1), attention_mask=batch.get("attention_mask").to(model.device).squeeze(1)
    )
    loss = torch.sum(output.logits - output.logits)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print("step: ", step)import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

from llmforge.llama_attn_replace import replace_llama_attn

flash_attn = True

replace_llama_attn(flash_attn=flash_attn)

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto", use_flash_attention_2=flash_attn)

model.resize_token_embeddings(len(tokenizer))
optimizer = torch.optim.AdamW(model.parameters())

dataset = load_dataset("databricks/databricks-dolly-15k", split="train", streaming=True)
dataset = dataset.map(lambda x: tokenizer([x["instruction"]], truncation=True, padding='max_length', max_length=4096, return_tensors="pt"))
train_dataloader = DataLoader(dataset)

model.train()

for step, batch in enumerate(train_dataloader):
    output = model(
        input_ids=batch["input_ids"].to(model.device).squeeze(1), attention_mask=batch.get("attention_mask").to(model.device).squeeze(1)
    )
    loss = torch.sum(output.logits - output.logits)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print("step: ", step)

@jcao-ai
Copy link

jcao-ai commented Nov 30, 2023

@ArturNiederfahrenhorst Hi, do you fix the issue ? I met the same one.

@ArturNiederfahrenhorst
Copy link
Author

@jcao-ai No, I'm waiting for a response from @yukang2017

@CxsGhost
Copy link

CxsGhost commented Jan 5, 2024

I also encountered this issue, especially when I tried to increase the "per_device_train_batch_size" parameter. It occurs, but after repeated experiments, I confirmed that it is not caused by insufficient GPU/CPU memory.

@agokrani
Copy link

I am also getting similar error when trying to expand the code base to support phi-2. Is anyone interested in jumping on call and trying to solve this. Not sure about the solution yet, but would be good to brainstorm, and fix it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants