Skip to content

Commit 5ac27d8

Browse files
authored
use extract_local for test_hybrid_attn.py (#74)
1 parent 567ca32 commit 5ac27d8

11 files changed

+179
-136
lines changed

README.md

+15-68
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The project is built on [zhuzilin/ring-flash-attention](https://github.com/zhuzi
1212

1313

1414

15-
## What's wrong with Ulysses and Ring?
15+
## Why not apply Ulysses and Ring Attention Individually?
1616

1717
- Ulysses is sensitive to the number of attention heads.
1818
The parallelism degree in Ulysses cannot exceed the number of heads.
@@ -25,89 +25,37 @@ Even with the communication and computation processes fully overlapped, the tota
2525
Furthermore, Ring-Attention utilizes asynchronous peer-to-peer communication, which not only has a lower bandwidth utilization compared to collective communication methods but also poses the risk of potential communication deadlocks in large-scale deployments.
2626

2727

28-
## LongContextAttention, a.k.a Unified Sequence Parallelism and Hybrid Sequence Parallelism
28+
## LongContextAttention, also known as Unified Sequence Parallelism and Hybrid Sequence Parallelism
2929

3030
`LongContextAttention` is a **unified sequence parallel** , also known as **hybrid sequence parallel** ,that hybrid DeepSpeed-Ulysses-Attention and Ring-Attention therefore addressing the limitations of both methods.
3131

3232
<p align="center">
33-
<img src="./media/hybrid_seqparallel.png">
33+
<img src="./media/usp.png">
3434
</p>
3535

36+
### Usage
37+
38+
Please refer to [test/test_hybrid_qkvpacked_attn.py](./test/test_hybrid_qkvpacked_attn.py) and [test/test_hybrid_attn.py](./test/test_hybrid_attn.py) for usage.
39+
40+
In short, we take the `zigzag` ring attention implementation as an example:
41+
42+
1. apply `set_seq_parallel_pg` to set the process group
43+
2. extract local tensors with `zigzag_extract_local`. We need reorder the input tokens or input tensors for load balance ring attention.
44+
3. then apply `LongContextAttention(ring_impl_type="zigzag")` as a drop-in replacement for Attention implementation.
3645

3746
### Install
3847

3948
Option 1: pip install from pypi.
4049

41-
`pip install yunchang==0.3` (flash_attn >= 2.6.0)
50+
`pip install yunchang` (flash_attn >= 2.6.0)
4251

4352
`pip install yunchang==0.2` (flash_attn < 2.6.0)
4453

4554
Option 2: build from local.
4655

4756
`pip install .`
4857

49-
### Install for AMD GPU
50-
51-
Supported GPU : MI300X, MI308X
52-
53-
GPU arch : gfx942
54-
55-
Step 1: prepare docker envrionment
56-
57-
Tow recommended docker container to start with
58-
59-
- rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0 : hosted in dockerhub, no conda
60-
- [dockerhub repo](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/Dockerfile.rocm62.ubuntu-22.04) : Customerized Dockerfile with conda virtual env and develop kit support
61-
62-
An example to create an docker container :
63-
64-
```bash
65-
# create docker container
66-
IMG=rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0
67-
tag=py310-rocm6.2-distattn-dev
68-
69-
docker_args=$(echo -it --privileged \
70-
--name $tag \
71-
--ulimit memlock=-1:-1 --net=host --cap-add=IPC_LOCK \
72-
--device=/dev/kfd --device=/dev/dri \
73-
--ipc=host \
74-
--security-opt seccomp=unconfined \
75-
--shm-size 16G \
76-
--group-add video \
77-
-v $(readlink -f `pwd`):/workspace \
78-
--workdir /workspace \
79-
--cpus=$((`nproc` / 2 - 1)) \
80-
$IMG
81-
)
82-
83-
docker_args=($docker_args)
84-
85-
docker container create "${docker_args[@]}"
86-
87-
# start it
88-
docker start -a -i $tag
89-
```
90-
91-
Update ROCM SDK using this [script](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/update_sdk.sh):
92-
93-
```bash
94-
# e.g.:
95-
ROCM_VERSION=6.3 bash rocm/update_sdk.sh
96-
```
97-
98-
Step 2 : build from local.
99-
100-
> MAX_JOBS=$(nproc) pip install .[amd] --verbose
101-
102-
**Features:**
103-
104-
1. No Limitation on the Number of Heads: Our approach does not impose a restriction on the number of heads, providing greater flexibility for various attention mechanisms.
105-
106-
2. Cover the Capability of either Ulysses and Ring: By setting the ulysses_degree to the sequence parallel degree, the system operates identically to Ulysses. Conversely, setting the ulysses_degree to 1 mirrors the functionality of Ring.
107-
108-
3. Enhanced Performance: We achieve superior performance benchmarks over both Ulysses and Ring, offering a more efficient solution for attention mechanism computations.
109-
110-
4. Compatibility with Advanced Parallel Strategies: LongContextAttention is fully compatible with other sophisticated parallelization techniques, including Tensor Parallelism, ZeRO, and Pipeline Parallelism, ensuring seamless integration with the latest advancements in parallel computing.
58+
Install for AMD GPU: [install_amd.md](./docs/install_amd.md)
11159

11260
### Verified in Megatron-LM
11361
The loss curves for Data Parallel (DP) and Unified Sequence Parallel (ulysses=2+ring=2) are closely aligned, as illustrated in the figure. This alignment confirms the accuracy of the unified sequence parallel.
@@ -121,7 +69,6 @@ In the Megatron-LM, you can reorder the input tokens before feed them into the m
12169

12270
## Best Practice for 4D Parallelism
12371

124-
12572
We analyze the impact of introducing Sequnce Parallelism to Data/ZeRO/Tensor/Pipeline Parallelism in a technique report, which can be found at [here](https://arxiv.org/abs/2405.07719).
12673

12774
Some best practices are listed here:
@@ -183,7 +130,7 @@ I am honored that this repository has contributed to the following projects:
183130
6. [FlagOpen/FlagScale](https://github.com/FlagOpen/FlagScale/commit/f98ee1e293bd906cc77f512f7a884b2030c10a12)
184131
7. [zhiyuanhubj/LongRecipe](https://github.com/zhiyuanhubj/LongRecipe)
185132

186-
## Citation
133+
## Cite Us
187134
```
188135
@article{fang2024unified,
189136
title={USP: A Unified Sequence Parallelism Approach for Long Context Generative AI},

docs/install_amd.md

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
## Install for AMD GPU
2+
3+
Supported GPU : MI300X, MI308X
4+
5+
GPU arch : gfx942
6+
7+
Step 1: prepare docker envrionment
8+
9+
Tow recommended docker container to start with
10+
11+
- rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0 : hosted in dockerhub, no conda
12+
- [dockerhub repo](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/Dockerfile.rocm62.ubuntu-22.04) : Customerized Dockerfile with conda virtual env and develop kit support
13+
14+
An example to create an docker container :
15+
16+
```bash
17+
# create docker container
18+
IMG=rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0
19+
tag=py310-rocm6.2-distattn-dev
20+
21+
docker_args=$(echo -it --privileged \
22+
--name $tag \
23+
--ulimit memlock=-1:-1 --net=host --cap-add=IPC_LOCK \
24+
--device=/dev/kfd --device=/dev/dri \
25+
--ipc=host \
26+
--security-opt seccomp=unconfined \
27+
--shm-size 16G \
28+
--group-add video \
29+
-v $(readlink -f `pwd`):/workspace \
30+
--workdir /workspace \
31+
--cpus=$((`nproc` / 2 - 1)) \
32+
$IMG
33+
)
34+
35+
docker_args=($docker_args)
36+
37+
docker container create "${docker_args[@]}"
38+
39+
# start it
40+
docker start -a -i $tag
41+
```
42+
43+
Update ROCM SDK using this [script](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/update_sdk.sh):
44+
45+
```bash
46+
# e.g.:
47+
ROCM_VERSION=6.3 bash rocm/update_sdk.sh
48+
```
49+
50+
Step 2 : build from local.
51+
52+
> MAX_JOBS=$(nproc) pip install .[amd] --verbose
53+
54+
**Features:**
55+
56+
1. No Limitation on the Number of Heads: Our approach does not impose a restriction on the number of heads, providing greater flexibility for various attention mechanisms.
57+
58+
2. Cover the Capability of either Ulysses and Ring: By setting the ulysses_degree to the sequence parallel degree, the system operates identically to Ulysses. Conversely, setting the ulysses_degree to 1 mirrors the functionality of Ring.
59+
60+
3. Enhanced Performance: We achieve superior performance benchmarks over both Ulysses and Ring, offering a more efficient solution for attention mechanism computations.
61+
62+
4. Compatibility with Advanced Parallel Strategies: LongContextAttention is fully compatible with other sophisticated parallelization techniques, including Tensor Parallelism, ZeRO, and Pipeline Parallelism, ensuring seamless integration with the latest advancements in parallel computing.

media/hybrid_seqparallel.png

-66.8 KB
Binary file not shown.

media/usp.png

150 KB
Loading

setup.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
from setuptools import setup, find_packages
2+
import os
3+
4+
# 读取版本信息
5+
version_file = os.path.join(os.path.dirname(__file__), 'yunchang', '__version__.py')
6+
with open(version_file, 'r') as f:
7+
exec(f.read())
28

39
setup(
410
name="yunchang",
5-
version="0.3",
6-
author="Jiarui Fang, Zilin Zhu, Yang Yu",
11+
version=__version__,
12+
713
url="https://github.com/feifeibear/long-context-attention",
814
packages=find_packages(exclude=['test', 'benchmark']),
915
install_requires=[

test/test_hybrid_attn.py

+62-31
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
AsyncLongContextAttention,
33
LongContextAttention,
44
set_seq_parallel_pg,
5+
EXTRACT_FUNC_DICT
56
)
67
import torch
78
import torch.distributed as dist
@@ -34,31 +35,35 @@ def log(msg, a, rank0_only=False):
3435
)
3536
dist.barrier()
3637

37-
38+
# test it with:
39+
# torchrun --nproc_per_node=4 test/test_hybrid_attn_v2.py
3840
if __name__ == "__main__":
3941
torch.random.manual_seed(0)
4042

41-
use_bwd = False
43+
use_bwd = True
4244
dist.init_process_group("nccl")
4345

4446
rank = dist.get_rank()
4547
world_size = dist.get_world_size()
48+
49+
assert world_size == 4, f"torchrun --nproc_per_node=4 test/test_hybrid_attn_v2.py"
4650
# Inference mainly uses fp16; ROCM flash attention with bf16 precision is slightly larger, will be fixed soon
47-
dtype = torch.float16
51+
dtype = torch.bfloat16
4852
device = torch.device(f"cuda:{rank}")
4953

5054
batch_size = 2
51-
seqlen = 3816
52-
nheads = 2
55+
seqlen = 1024
56+
nheads = 4
5357
d = 128
5458
dropout_p = 0
5559
causal = True
60+
5661
deterministic = False
5762

58-
use_async_all_to_all = True
5963
assert seqlen % world_size == 0
6064
assert d % 8 == 0
61-
# assert batch_size == 1
65+
66+
ring_impl_type = "zigzag" # You can change this to "basic" or "zigzag" if needed
6267

6368
# Prepare inputs
6469
q = torch.randn(
@@ -77,30 +82,34 @@ def log(msg, a, rank0_only=False):
7782
dist.broadcast(v, src=0)
7883
dist.broadcast(dout, src=0)
7984

80-
local_q = q.chunk(world_size, dim=1)[rank].detach().clone()
81-
local_q.requires_grad = True
82-
local_k = k.chunk(world_size, dim=1)[rank].detach().clone()
83-
local_k.requires_grad = True
84-
local_v = v.chunk(world_size, dim=1)[rank].detach().clone()
85-
local_v.requires_grad = True
86-
87-
local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone()
88-
8985
# prepare process group for hybrid sequence parallelism
9086
use_ring_low_dim = True
9187

92-
sp_ulysses_degree = min(nheads, world_size)
88+
sp_ulysses_degree = 2
9389
sp_ring_degree = world_size // sp_ulysses_degree
90+
9491
print(
9592
f"rank {rank}, sp_ulysses_degree: {sp_ulysses_degree}, sp_ring_degree: {sp_ring_degree}"
9693
)
9794

9895
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
9996

100-
if use_async_all_to_all:
101-
hybrid_seq_parallel_attn = AsyncLongContextAttention()
102-
else:
103-
hybrid_seq_parallel_attn = LongContextAttention()
97+
# Use EXTRACT_FUNC_DICT to shard the tensors
98+
local_q = EXTRACT_FUNC_DICT[ring_impl_type](
99+
q, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
100+
).detach().clone()
101+
local_q.requires_grad = True
102+
103+
local_k = EXTRACT_FUNC_DICT[ring_impl_type](
104+
k, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
105+
).detach().clone()
106+
local_k.requires_grad = True
107+
108+
local_v = EXTRACT_FUNC_DICT[ring_impl_type](
109+
v, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
110+
).detach().clone()
111+
local_v.requires_grad = True
112+
usp_attn = LongContextAttention(ring_impl_type=ring_impl_type)
104113

105114
if rank == 0:
106115
print("#" * 30)
@@ -112,7 +121,9 @@ def log(msg, a, rank0_only=False):
112121
alibi_slopes, attn_bias = None, None
113122
dropout_mask = None
114123

115-
local_out = hybrid_seq_parallel_attn(
124+
print(f"before usp attn forward: {local_q.shape} {local_k.shape} {local_v.shape}")
125+
# usp attn forward
126+
local_out = usp_attn(
116127
local_q,
117128
local_k,
118129
local_v,
@@ -125,11 +136,17 @@ def log(msg, a, rank0_only=False):
125136
return_attn_probs=True,
126137
)
127138

139+
# extract local dout
140+
local_dout = EXTRACT_FUNC_DICT[ring_impl_type](
141+
dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
142+
).detach().clone()
143+
128144
if rank == 0:
129145
print("#" * 30)
130146
print("# ds-ulysses backward:")
131147
print("#" * 30)
132148

149+
# usp attn backward
133150
if use_bwd:
134151
local_out.backward(local_dout)
135152

@@ -177,26 +194,40 @@ def log(msg, a, rank0_only=False):
177194
dist.barrier()
178195

179196
# check correctness
180-
181-
local_out_ref = out_ref.chunk(world_size, dim=1)[rank]
182-
local_out_pt_ref = out_ref.chunk(world_size, dim=1)[rank]
197+
# When checking correctness, use EXTRACT_FUNC_DICT for reference outputs
198+
local_out_ref = EXTRACT_FUNC_DICT[ring_impl_type](
199+
out_ref, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
200+
)
201+
local_out_pt_ref = EXTRACT_FUNC_DICT[ring_impl_type](
202+
out_pt_ref, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
203+
)
183204

184205
log("local (rank) out", local_out, rank0_only=True)
185206
log("out (distributed) - out_ref (non-distributed) diff", local_out_ref - local_out)
186-
log("out_ref (non-distributed) - out_pt_ref (gpu) diff", local_out_ref - local_out_pt_ref)
207+
208+
# log("out_ref (non-distributed) - out_pt_ref (gpu) diff", local_out_ref - local_out_pt_ref)
187209

188210
torch.testing.assert_close(local_out, local_out_ref, atol=1e-2, rtol=0)
189-
torch.testing.assert_close(out_ref, out_pt_ref, atol=1e-2, rtol=0)
211+
# torch.testing.assert_close(out_ref, out_pt_ref, atol=1e-2, rtol=0)
190212

191213
if use_bwd:
192-
local_dq_ref = q.grad.chunk(world_size, dim=1)[rank]
214+
local_dq_ref = EXTRACT_FUNC_DICT[ring_impl_type](
215+
q.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
216+
)
193217
log("load_dq", local_q.grad)
194218
log("dq diff", local_dq_ref - local_q.grad)
195219

196-
local_dk_ref = k.grad.chunk(world_size, dim=1)[rank]
220+
local_dk_ref = EXTRACT_FUNC_DICT[ring_impl_type](
221+
k.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
222+
)
197223
log("load_dk", local_k.grad)
198224
log("dk diff", local_dk_ref - local_k.grad)
199225

200-
local_dv_ref = v.grad.chunk(world_size, dim=1)[rank]
201-
log("load_dk", local_v.grad)
226+
local_dv_ref = EXTRACT_FUNC_DICT[ring_impl_type](
227+
v.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
228+
)
229+
log("load_dv", local_v.grad)
202230
log("dv diff", local_dv_ref - local_v.grad)
231+
232+
if dist.is_initialized():
233+
dist.destroy_process_group()

0 commit comments

Comments
 (0)