You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardexpand all lines: README.md
+15-68
Original file line number
Diff line number
Diff line change
@@ -12,7 +12,7 @@ The project is built on [zhuzilin/ring-flash-attention](https://github.com/zhuzi
12
12
13
13
14
14
15
-
## What's wrong with Ulysses and Ring?
15
+
## Why not apply Ulysses and Ring Attention Individually?
16
16
17
17
- Ulysses is sensitive to the number of attention heads.
18
18
The parallelism degree in Ulysses cannot exceed the number of heads.
@@ -25,89 +25,37 @@ Even with the communication and computation processes fully overlapped, the tota
25
25
Furthermore, Ring-Attention utilizes asynchronous peer-to-peer communication, which not only has a lower bandwidth utilization compared to collective communication methods but also poses the risk of potential communication deadlocks in large-scale deployments.
26
26
27
27
28
-
## LongContextAttention, a.k.a Unified Sequence Parallelism and Hybrid Sequence Parallelism
28
+
## LongContextAttention, also known as Unified Sequence Parallelism and Hybrid Sequence Parallelism
29
29
30
30
`LongContextAttention` is a **unified sequence parallel** , also known as **hybrid sequence parallel** ,that hybrid DeepSpeed-Ulysses-Attention and Ring-Attention therefore addressing the limitations of both methods.
31
31
32
32
<palign="center">
33
-
<img src="./media/hybrid_seqparallel.png">
33
+
<img src="./media/usp.png">
34
34
</p>
35
35
36
+
### Usage
37
+
38
+
Please refer to [test/test_hybrid_qkvpacked_attn.py](./test/test_hybrid_qkvpacked_attn.py) and [test/test_hybrid_attn.py](./test/test_hybrid_attn.py) for usage.
39
+
40
+
In short, we take the `zigzag` ring attention implementation as an example:
41
+
42
+
1. apply `set_seq_parallel_pg` to set the process group
43
+
2. extract local tensors with `zigzag_extract_local`. We need reorder the input tokens or input tensors for load balance ring attention.
44
+
3. then apply `LongContextAttention(ring_impl_type="zigzag")` as a drop-in replacement for Attention implementation.
36
45
37
46
### Install
38
47
39
48
Option 1: pip install from pypi.
40
49
41
-
`pip install yunchang==0.3` (flash_attn >= 2.6.0)
50
+
`pip install yunchang` (flash_attn >= 2.6.0)
42
51
43
52
`pip install yunchang==0.2` (flash_attn < 2.6.0)
44
53
45
54
Option 2: build from local.
46
55
47
56
`pip install .`
48
57
49
-
### Install for AMD GPU
50
-
51
-
Supported GPU : MI300X, MI308X
52
-
53
-
GPU arch : gfx942
54
-
55
-
Step 1: prepare docker envrionment
56
-
57
-
Tow recommended docker container to start with
58
-
59
-
- rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0 : hosted in dockerhub, no conda
60
-
-[dockerhub repo](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/Dockerfile.rocm62.ubuntu-22.04) : Customerized Dockerfile with conda virtual env and develop kit support
Update ROCM SDK using this [script](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/update_sdk.sh):
92
-
93
-
```bash
94
-
# e.g.:
95
-
ROCM_VERSION=6.3 bash rocm/update_sdk.sh
96
-
```
97
-
98
-
Step 2 : build from local.
99
-
100
-
> MAX_JOBS=$(nproc) pip install .[amd] --verbose
101
-
102
-
**Features:**
103
-
104
-
1. No Limitation on the Number of Heads: Our approach does not impose a restriction on the number of heads, providing greater flexibility for various attention mechanisms.
105
-
106
-
2. Cover the Capability of either Ulysses and Ring: By setting the ulysses_degree to the sequence parallel degree, the system operates identically to Ulysses. Conversely, setting the ulysses_degree to 1 mirrors the functionality of Ring.
107
-
108
-
3. Enhanced Performance: We achieve superior performance benchmarks over both Ulysses and Ring, offering a more efficient solution for attention mechanism computations.
109
-
110
-
4. Compatibility with Advanced Parallel Strategies: LongContextAttention is fully compatible with other sophisticated parallelization techniques, including Tensor Parallelism, ZeRO, and Pipeline Parallelism, ensuring seamless integration with the latest advancements in parallel computing.
58
+
Install for AMD GPU: [install_amd.md](./docs/install_amd.md)
111
59
112
60
### Verified in Megatron-LM
113
61
The loss curves for Data Parallel (DP) and Unified Sequence Parallel (ulysses=2+ring=2) are closely aligned, as illustrated in the figure. This alignment confirms the accuracy of the unified sequence parallel.
@@ -121,7 +69,6 @@ In the Megatron-LM, you can reorder the input tokens before feed them into the m
121
69
122
70
## Best Practice for 4D Parallelism
123
71
124
-
125
72
We analyze the impact of introducing Sequnce Parallelism to Data/ZeRO/Tensor/Pipeline Parallelism in a technique report, which can be found at [here](https://arxiv.org/abs/2405.07719).
126
73
127
74
Some best practices are listed here:
@@ -183,7 +130,7 @@ I am honored that this repository has contributed to the following projects:
- rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0 : hosted in dockerhub, no conda
12
+
-[dockerhub repo](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/Dockerfile.rocm62.ubuntu-22.04) : Customerized Dockerfile with conda virtual env and develop kit support
Update ROCM SDK using this [script](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/update_sdk.sh):
44
+
45
+
```bash
46
+
# e.g.:
47
+
ROCM_VERSION=6.3 bash rocm/update_sdk.sh
48
+
```
49
+
50
+
Step 2 : build from local.
51
+
52
+
> MAX_JOBS=$(nproc) pip install .[amd] --verbose
53
+
54
+
**Features:**
55
+
56
+
1. No Limitation on the Number of Heads: Our approach does not impose a restriction on the number of heads, providing greater flexibility for various attention mechanisms.
57
+
58
+
2. Cover the Capability of either Ulysses and Ring: By setting the ulysses_degree to the sequence parallel degree, the system operates identically to Ulysses. Conversely, setting the ulysses_degree to 1 mirrors the functionality of Ring.
59
+
60
+
3. Enhanced Performance: We achieve superior performance benchmarks over both Ulysses and Ring, offering a more efficient solution for attention mechanism computations.
61
+
62
+
4. Compatibility with Advanced Parallel Strategies: LongContextAttention is fully compatible with other sophisticated parallelization techniques, including Tensor Parallelism, ZeRO, and Pipeline Parallelism, ensuring seamless integration with the latest advancements in parallel computing.
0 commit comments