Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions bench_dc_ulysses/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
*.log
*.pyc
profiles
results
slurm_jobs
slurm*
experiments
logs
*.
*.pt
51 changes: 51 additions & 0 deletions bench_dc_ulysses/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Benchmark for DeepCompile

## Setup

This experiment scripts require 1 node that has 2 A100/A40 GPUs
We tested the scripts with Python 3.10.12 and CUDA 12.3.

### Libraries

In addition, you need to install the following:

- PyTorch 2.5.1
- [modified version of DeepSpeed](https://github.com/tohtana/DeepSpeed-internal/tree/neeld2/debug-loss)

Here are an example of installation commands:

```bash
pip3 install torch==2.5.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip3 install datasets==3.1 accelerate

# Install DeepSpeed and DeepCompile
git clone -b neeld2/debug-loss https://github.com/tohtana/DeepSpeed-internal.git
cd DeepSpeed-internal
pip install -e transformers
cd ..
pip install -e DeepSpeed

# Clone this repository
git clone https://github.com/neeldani/bench_dc_ulysses.git
```

## Running the scripts

Test the setup by running the script:
```bash
bash run_ulysses.sh 6 [compile|deepcompile|eager|ringattn]
```

Here, 6 is the sequence length and is hardcoded because the input sequence inside run_acc_lm.py is hardcoded to easily verify the Q, K and V before and after the all-to-all. You may use pass `compile` to run compiled Ulysses (Ulysses with graph breaks) or `deepcompile` to run deepcompiled Ulysses (allwall inserted within the compiler pass)

We save the Q, K and V tensors before and after the all-toa-all:
For deepcompiled Ulysses, the tensors are saved here: https://github.com/tohtana/DeepSpeed-internal/blob/60feb352a6b0e22cf9a781b4e387d3919dc76833/deepspeed/compile/patch_aot_module.py#L243

For compiled Ulysses, the tensors are saved here: https://github.com/tohtana/DeepSpeed-internal/blob/60feb352a6b0e22cf9a781b4e387d3919dc76833/deepspeed/sequence/layer.py#L381

You can then run the script [check_qkv.py](https://github.com/neeldani/bench_dc_ulysses/blob/main/check_qkv.py) to compare the tensors at various stages i.e before all2all, after all2all, attention outputs etc

## Code walkthrough
1. Script: [run_ulyssess.sh](https://github.com/neeldani/bench_dc_ulysses/blob/main/run_ulysses.sh)
2. The script calls: [run_acc_lm.py](https://github.com/neeldani/bench_dc_ulysses/blob/main/run_acc_lm.py). We have added support for another attention backend in HuggingFace called "ulysses" which uses DistributedAttention. The implementation can be found here: https://github.com/tohtana/DeepSpeed-internal/blob/60feb352a6b0e22cf9a781b4e387d3919dc76833/transformers/src/transformers/models/llama/modeling_llama.py#L306
3. If the `deepcompile` arg is passed to the config file, then a compiler pass will add the all2all's directy at the Torch IR level. The code for it can be found here: https://github.com/tohtana/DeepSpeed-internal/blob/neeld2/debug-loss/deepspeed/compile/patch_aot_module.py
16 changes: 16 additions & 0 deletions bench_dc_ulysses/configs/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
deepspeed_config_file: configs/ds_config.json
distributed_type: DEEPSPEED
machine_rank: 1
main_training_function: main
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
19 changes: 19 additions & 0 deletions bench_dc_ulysses/configs/deepcompile_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{

"bf16": {
"enabled": true
},

"zero_optimization":{
"stage": 0
},
"compile": {
"deepcompile": true
},
"gradient_accumulation_steps": 1,
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
16 changes: 16 additions & 0 deletions bench_dc_ulysses/configs/deepcompile_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
deepspeed_config_file: configs/deepcompile_config.json
distributed_type: DEEPSPEED
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
25 changes: 25 additions & 0 deletions bench_dc_ulysses/configs/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{

"bf16": {
"enabled": true
},

"zero_optimization":{
"stage": 0
},
"compile": {
"deepcompile": true,
"offload_activation": false,
"offload_opt_states": false,
"double_buffer": true,
"symmetric_memory": false,
"free_activation": false,
"dump_graphs": false
},
"gradient_accumulation_steps": 1,
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
30 changes: 30 additions & 0 deletions bench_dc_ulysses/configs/ds_config.json.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
{% if fp16 %}
"fp16": {
"enabled": true,
"initial_scale_power": 8
},
{% else %}
"bf16": {
"enabled": true
},
{% endif %}
"zero_optimization":{
"stage": 0
},
"compile": {
"deepcompile": {{ deepcompile }},
"offload_activation": false,
"offload_opt_states": false,
"double_buffer": true,
"symmetric_memory": false,
"free_activation": false,
"dump_graphs": false
},
"gradient_accumulation_steps": {{ gradient_accumulation_steps }},
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
19 changes: 19 additions & 0 deletions bench_dc_ulysses/configs/ds_config.yaml.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
{%- if zero_stage == 3 %}
zero3_init_flag: true
{%- endif %}
deepspeed_config_file: configs/ds_config.json
distributed_type: DEEPSPEED
machine_rank: {{ machine_rank }}
main_training_function: main
num_machines: {{ num_machines }}
num_processes: {{ num_processes }}
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
14 changes: 14 additions & 0 deletions bench_dc_ulysses/configs/torchcompile_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"bf16": {
"enabled": true
},
"zero_optimization":{
"stage": 0
},
"gradient_accumulation_steps": 1,
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
16 changes: 16 additions & 0 deletions bench_dc_ulysses/configs/torchcompile_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
deepspeed_config_file: configs/torchcompile_config.json
distributed_type: DEEPSPEED
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
104 changes: 104 additions & 0 deletions bench_dc_ulysses/distributed_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import torch
import torch.distributed as dist
from deepspeed.sequence.layer import DistributedAttention
from sp_dp_registry import get_group, is_setup, sp_size

#TODO: Hacky, need to fix it
_padding_mask_context = None

def set_padding_mask(mask):
global _padding_mask_context
_padding_mask_context = mask

def get_padding_mask():
global _padding_mask_context
return _padding_mask_context

def clear_padding_mask():
global _padding_mask_context
_padding_mask_context = None

def ulysses_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
scaling=None,
dropout=0.0,
is_causal=True,
**kwargs,
):
assert is_setup(), 'Incorrectly setup SP/DP Groups.'

gid = dist.get_rank() // sp_size()
group = get_group(gid)

# Ulysses expects (batch, seq, heads, dim)
# HF standard provides (batch, heads, seq, dim)
q = query_states.transpose(1, 2).contiguous()
k = key_states.transpose(1, 2).contiguous()
v = value_states.transpose(1, 2).contiguous()

if not hasattr(self, "ulysses_engine"):
self.ulysses_engine = DistributedAttention(
sdpa_wrapper,
group,
scatter_idx=2, # Shard heads
gather_idx=1 # Gather sequences
)

# b, s, n, h
# Note: we don't pass attention_mask here because it's the 4D mask created by HF
# based on sharded dimensions. We'll create the correct mask in sdpa_wrapper
# using the original unsharded padding mask stored in context.
attn_output = self.ulysses_engine(
q, k, v,
batch_dim_idx=0,
attn_mask=None,
dropout_p=dropout,
is_causal=False,
scale=scaling
)

# Return to HF format: (batch, seq, heads, dim) -> (batch, heads, seq, dim)
# Note: Transformers usually expects (B, N, S, H) back,
# but Llama's forward handles the reshape if we are careful.
return attn_output, None

def sdpa_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True, scale=None):
# Permute from [b, s, n, h] to [b, n, s, h] for SDPA
q = query.permute(0, 2, 1, 3).contiguous()
k = key.permute(0, 2, 1, 3).contiguous()
v = value.permute(0, 2, 1, 3).contiguous()

# Create the attention mask from padding mask + causal mask
padding_mask = get_padding_mask()
combined_mask = None

if padding_mask is not None:
B, S = padding_mask.shape # [B, S]
device = padding_mask.device

causal_mask = torch.tril(torch.ones(S, S, device=device, dtype=torch.bool))
padding_mask_bool = (padding_mask != 0).unsqueeze(1) # [B, 1, S]
causal_expanded = causal_mask.unsqueeze(0) # [1, S, S]
combined_mask = causal_expanded & padding_mask_bool # [B, S, S]
combined_mask = combined_mask.unsqueeze(1) # [B, 1, S, S]

elif is_causal:
pass

output = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=combined_mask,
dropout_p=dropout_p,
is_causal=(combined_mask is None and is_causal),
scale=scale,
enable_gqa=False
)

# Permute back from [b, n, s, h] to [b, s, n, h] for all-to-all on output
output = output.permute(0, 2, 1, 3).contiguous()
return output
Loading