Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C/PyTorch/Jax] Add support for more bias shapes #677

Merged
merged 20 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
70ebda6
added support for arbitrary bias shapes for fused_attn
denera Feb 21, 2024
994d7ab
Fix linting
denera Feb 21, 2024
cb19e89
Add b1ss/bhss/11ss bias shapes when not requiring dBias
cyanguwa Feb 21, 2024
bfc22c0
fix lint
cyanguwa Feb 21, 2024
f7ff976
Merge branch 'NVIDIA:main' into fused_attn/add_dbias_shapes_c_pytorch
cyanguwa Feb 21, 2024
a12184f
Merge remote-tracking branch 'denera/jax-fused-attn-arbitrary-bias-si…
denera Feb 21, 2024
fd87a0b
add bias_b/h to plan cache
cyanguwa Feb 22, 2024
ee221cf
Merge branch 'main' into fused_attn/add_dbias_shapes_c_pytorch
denera Feb 22, 2024
f615dfc
fixed compile errors after PR653 merge
denera Feb 22, 2024
9bc00f1
updated JAX unittests for new bias shapes
denera Feb 22, 2024
02a8613
fixed mismatched mask type checking
denera Feb 23, 2024
4049e63
corrected skip condition
denera Feb 23, 2024
6ab94bc
fix selection logic for A100s
cyanguwa Feb 23, 2024
3c08250
corrected skip checks for bias shapes
denera Feb 24, 2024
f9d5c76
resolved test issues but neginf with float16 is still problematic wit…
denera Feb 26, 2024
59b71c6
new bias shapes passing TE JAX CI for seqlen <= 512, seq_q == seq_kv …
denera Feb 27, 2024
554426a
TE/JAX fused attn tests for new bias shapes passing with neg_inf=-2**…
denera Feb 27, 2024
fe74430
code style fixes and test parameter ID cleanup
denera Feb 27, 2024
5dbf923
Merge branch 'main' into fused_attn/add_dbias_shapes_c_pytorch
cyanguwa Feb 27, 2024
d47f6e8
fixed incorrect skip condition for backward fused attn test
denera Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 137 additions & 77 deletions tests/jax/test_fused_attn.py

Large diffs are not rendered by default.

43 changes: 39 additions & 4 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
attn_bias_type: str,
alibi_type: str = "none",
num_layers: int = 1,
bias_shape: str = "1hss",
):
self.batch_size = batch_size
self.num_heads = num_heads
Expand All @@ -100,6 +101,7 @@ def __init__(
self.alibi_type = alibi_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers
self.bias_shape = bias_shape

def _is_fused_attention_supported(
config: ModelConfig,
Expand Down Expand Up @@ -379,6 +381,31 @@ def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)

model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p,
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0,
# mask, bias, bias_shape,
"no_mask", "post_scale_bias", bias_shape='11ss'),
"bias_1_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0,
"no_mask", "post_scale_bias", bias_shape='1hss'),
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,
"no_mask", "post_scale_bias", bias_shape='b1ss'),
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,
"no_mask", "post_scale_bias", bias_shape='bhss'),
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,
"causal", "alibi", bias_shape='1hss', alibi_type='custom'),
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,
"causal", "alibi", bias_shape='bhss', alibi_type='custom'),
}

@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types and shapes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)

model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
Expand Down Expand Up @@ -510,10 +537,13 @@ def _run_dot_product_attention(
window_size, attention_mask = None, None

alibi_slopes = None
if config.attn_bias_type == "alibi":
if config.alibi_type == "custom":
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes = torch.randn(
config.num_heads).abs().to(dtype=torch.float32, device="cuda")
if config.bias_shape == "bhss":
alibi_slopes = torch.randn(
config.batch_size, config.num_heads).abs().to(dtype=torch.float32, device="cuda")

# Create input tensors
dim_to_num = {
Expand All @@ -527,6 +557,7 @@ def _run_dot_product_attention(
'tg' : cu_seqlens_kv[-1],
'3' : 3,
'2' : 2,
'1' : 1,
}
inp = []
for i,layout in enumerate(qkv_layout.split('_')):
Expand Down Expand Up @@ -566,8 +597,12 @@ def _run_dot_product_attention(
if config.attn_bias_type in ['no_bias', 'alibi']:
bias = None
if config.attn_bias_type == 'post_scale_bias':
bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
dtype=dtype, device="cuda")
shape = '_'.join(config.bias_shape)
shape = shape.replace('_s_s', '_sq_skv')
tensor_shape = [dim_to_num[j] for j in shape.split('_')]
bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
if config.bias_shape != '1hss':
bias.requires_grad = False

# Create RNG
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
FADescriptor_v1 descriptor{b, h,
hg, s_q,
s_kv, d,
bias_b, bias_h,
scaling_factor, is_training,
dropout_probability, layout,
bias_type, mask_type,
Expand Down Expand Up @@ -316,6 +317,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
FADescriptor_v1 descriptor{b, h,
hg, s_q,
s_kv, d,
bias_b, bias_h,
scaling_factor, true,
dropout_probability, layout,
bias_type, mask_type,
Expand Down Expand Up @@ -426,7 +428,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
sdpa_backward_options.set_bias(bias);
sdpa_backward_options.set_dbias(dBias);
// shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s]
// are not supported for dbias calculation but they are
// supported for forward bias calculation
if ((bias_b == 1) && (bias_h == h)) {
sdpa_backward_options.set_dbias(dBias);
}
}

if (is_padding) {
Expand Down Expand Up @@ -541,7 +548,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(

if (is_bias) {
variant_pack[bias] = devPtrBias;
variant_pack[dBias] = devPtrdBias;
if ((bias_b == 1) && (bias_h == h)) {
variant_pack[dBias] = devPtrdBias;
} else {
variant_pack[dBias] = nullptr;
}
}

if (is_padding) {
Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/common/fused_attn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ struct FADescriptor_v1 {
std::int64_t s_q;
std::int64_t s_kv;
std::int64_t d;
std::int64_t bias_b;
std::int64_t bias_h;
float attnScale;
bool isTraining;
float dropoutProbability;
Expand All @@ -112,11 +114,12 @@ struct FADescriptor_v1 {
cudnn_frontend::DataType_t tensor_type;

bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d,
return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h,
attnScale, isTraining, dropoutProbability,
layout, mask_type, bias_type, tensor_type)
< std::tie(
rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d,
rhs.bias_b, rhs.bias_h,
rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ enum NVTE_QKV_Layout {
};

/*! \enum NVTE_QKV_Layout_Group
* \brief QKV layout groups
* \brief QKV layout groups
*/
enum NVTE_QKV_Layout_Group {
/*! 3HD QKV layouts, i.e. BS3HD, SB3HD, T3HD */
Expand Down
Loading
Loading