Skip to content

Commit

Permalink
[JAX] Refine MHA API and add DPA API (#653)
Browse files Browse the repository at this point in the history
* Refine MHA API

Signed-off-by: Reese Wang <[email protected]>

* Reuse func from the flax

Signed-off-by: Reese Wang <[email protected]>

* DPA draft

Signed-off-by: Reese Wang <[email protected]>

* qkv packed draft

Signed-off-by: Reese Wang <[email protected]>

* Fix test_layer with fused attn

Signed-off-by: Reese Wang <[email protected]>

* Add attn_bias_type and enhance a few code flow

Signed-off-by: Reese Wang <[email protected]>

* Move scale_factor from __call__ to init

Signed-off-by: Reese Wang <[email protected]>

* Enhance the docs

Signed-off-by: Reese Wang <[email protected]>

* Add DPA public API and tests

Signed-off-by: Reese Wang <[email protected]>

* Refine docs

Signed-off-by: Reese Wang <[email protected]>

* Refine docs

Signed-off-by: Reese Wang <[email protected]>

* Fix conflict

Signed-off-by: Reese Wang <[email protected]>

* Add qkv separate fused attn

Signed-off-by: Reese Wang <[email protected]>

* Apply BSHD_BSHD_BSHD format

Signed-off-by: Reese Wang <[email protected]>

* Remove debug log

Signed-off-by: Reese Wang <[email protected]>

* Add fused attention layer tests

Signed-off-by: Reese Wang <[email protected]>

* Add NVTE_FUSED_ATTN docs

Signed-off-by: Reese Wang <[email protected]>

* Fine-grained fused attn settings

Signed-off-by: Reese Wang <[email protected]>

* Remove the default value of num_attetnion_head and head_dim

Signed-off-by: Reese Wang <[email protected]>

* Add teardown for fused attn env

Signed-off-by: Reese Wang <[email protected]>

* Unify the Optional notation

Signed-off-by: Reese Wang <[email protected]>

* Fix Pre/Post scale bias comments

Signed-off-by: Reese Wang <[email protected]>

* Add no_mask tests

Signed-off-by: Reese Wang <[email protected]>

* Add checkpoint_name for fused attn

Signed-off-by: Reese Wang <[email protected]>

* Fix the fused attn batcher

Signed-off-by: Reese Wang <[email protected]>

---------

Signed-off-by: Reese Wang <[email protected]>
  • Loading branch information
zlsh80826 authored Feb 22, 2024
1 parent fb2f952 commit 9b2fed5
Show file tree
Hide file tree
Showing 15 changed files with 1,820 additions and 477 deletions.
3 changes: 3 additions & 0 deletions docs/api/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ Modules
.. autoapiclass:: transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)
:members: __call__

.. autoapiclass:: transformer_engine.jax.flax.DotProductAttention(head_dim, num_heads, **kwargs)
:members: __call__

.. autoapiclass:: transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)
:members: __call__

Expand Down
6 changes: 5 additions & 1 deletion tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.typing import ArrayLike, DTypeLike

from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available


Expand Down Expand Up @@ -144,6 +144,9 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
kv = jnp.concatenate((key, value), axis=-3)
return cross_fused_attn(query, kv, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)
case QKVLayout.BSHD_BSHD_BSHD:
return fused_attn(query, key, value, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)


@dataclass
Expand Down Expand Up @@ -337,6 +340,7 @@ def check_dqkv(primitive, reference, valid_len):
@pytest.mark.parametrize('qkv_layout', [
pytest.param(QKVLayout.BS3HD, id='qkvpacked'),
pytest.param(QKVLayout.BSHD_BS2HD, id='kvpacked'),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='separate'),
])
@pytest.mark.parametrize('dropout_prob', [0., 0.1])
@pytest.mark.parametrize('is_training',
Expand Down
47 changes: 31 additions & 16 deletions tests/jax/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.

import os
from functools import partial

import flax
Expand All @@ -20,6 +21,16 @@
is_fp8_supported, reason = is_fp8_available()


@pytest.fixture(autouse=True, scope='module')
def enable_fused_attn():
"""
Enable fused attention
"""
os.environ["NVTE_FUSED_ATTN"] = "1"
yield
del os.environ["NVTE_FUSED_ATTN"]


@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Expand Down Expand Up @@ -93,6 +104,7 @@ def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_DROPOUT_RATE: 0,
}

ATTRS = [{
Expand Down Expand Up @@ -221,7 +233,8 @@ def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)

assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)

del data_rng, init_rng, apply_rng

Expand Down Expand Up @@ -282,9 +295,6 @@ def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-
test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
apply_rng)

assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad

def reorganize_test_wgrad(test_wgrad, attrs):
num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
Expand Down Expand Up @@ -328,10 +338,14 @@ def reorganize_test_wgrad(test_wgrad, attrs):
del unfreeze_test_wgrad['mlp']['wo_kernel']
return unfreeze_test_wgrad

compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
atol=atol) # wgrad
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad

compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
atol=atol) # wgrad

del data_rng, init_rng, apply_rng

Expand Down Expand Up @@ -430,7 +444,8 @@ def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)

assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)

del data_rng, init_rng, apply_rng

Expand Down Expand Up @@ -492,9 +507,6 @@ def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-
test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
apply_rng)

assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad

def reorganize_test_wgrad(test_wgrad, attrs):
num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
Expand Down Expand Up @@ -547,10 +559,13 @@ def reorganize_test_wgrad(test_wgrad, attrs):
del unfreeze_test_wgrad['mlp']['wo_kernel']
return unfreeze_test_wgrad

compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
atol=atol) # wgrad
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
atol=atol) # wgrad

del data_rng, init_rng, apply_rng

Expand Down
Loading

0 comments on commit 9b2fed5

Please sign in to comment.