From 9b2fed514ea419141146f843ab2c84b22b86bfd7 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Fri, 23 Feb 2024 02:09:00 +0800 Subject: [PATCH] [JAX] Refine MHA API and add DPA API (#653) * Refine MHA API Signed-off-by: Reese Wang * Reuse func from the flax Signed-off-by: Reese Wang * DPA draft Signed-off-by: Reese Wang * qkv packed draft Signed-off-by: Reese Wang * Fix test_layer with fused attn Signed-off-by: Reese Wang * Add attn_bias_type and enhance a few code flow Signed-off-by: Reese Wang * Move scale_factor from __call__ to init Signed-off-by: Reese Wang * Enhance the docs Signed-off-by: Reese Wang * Add DPA public API and tests Signed-off-by: Reese Wang * Refine docs Signed-off-by: Reese Wang * Refine docs Signed-off-by: Reese Wang * Fix conflict Signed-off-by: Reese Wang * Add qkv separate fused attn Signed-off-by: Reese Wang * Apply BSHD_BSHD_BSHD format Signed-off-by: Reese Wang * Remove debug log Signed-off-by: Reese Wang * Add fused attention layer tests Signed-off-by: Reese Wang * Add NVTE_FUSED_ATTN docs Signed-off-by: Reese Wang * Fine-grained fused attn settings Signed-off-by: Reese Wang * Remove the default value of num_attetnion_head and head_dim Signed-off-by: Reese Wang * Add teardown for fused attn env Signed-off-by: Reese Wang * Unify the Optional notation Signed-off-by: Reese Wang * Fix Pre/Post scale bias comments Signed-off-by: Reese Wang * Add no_mask tests Signed-off-by: Reese Wang * Add checkpoint_name for fused attn Signed-off-by: Reese Wang * Fix the fused attn batcher Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang --- docs/api/jax.rst | 3 + tests/jax/test_fused_attn.py | 6 +- tests/jax/test_layer.py | 47 +- tests/jax/test_praxis_layers.py | 168 ++- tests/jax/utils.py | 4 +- .../common/fused_attn/fused_attn.cpp | 4 +- transformer_engine/jax/cpp_extensions.py | 445 ++++++- transformer_engine/jax/csrc/extensions.cpp | 7 +- transformer_engine/jax/csrc/modules.cpp | 261 +++- transformer_engine/jax/csrc/modules.h | 14 + transformer_engine/jax/flax/__init__.py | 16 +- transformer_engine/jax/flax/transformer.py | 1053 +++++++++++------ transformer_engine/jax/fused_attn.py | 144 ++- transformer_engine/jax/praxis/__init__.py | 3 +- transformer_engine/jax/praxis/transformer.py | 122 +- 15 files changed, 1820 insertions(+), 477 deletions(-) diff --git a/docs/api/jax.rst b/docs/api/jax.rst index ae19bfa2bc..7d39b15929 100644 --- a/docs/api/jax.rst +++ b/docs/api/jax.rst @@ -45,6 +45,9 @@ Modules .. autoapiclass:: transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs) :members: __call__ +.. autoapiclass:: transformer_engine.jax.flax.DotProductAttention(head_dim, num_heads, **kwargs) + :members: __call__ + .. autoapiclass:: transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs) :members: __call__ diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 56bb73d4d6..e709ecc8ec 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -20,7 +20,7 @@ from jax.typing import ArrayLike, DTypeLike from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout -from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn +from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available @@ -144,6 +144,9 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng kv = jnp.concatenate((key, value), axis=-3) return cross_fused_attn(query, kv, bias, mask, dropout_rng, **kwargs).astype(query.dtype) + case QKVLayout.BSHD_BSHD_BSHD: + return fused_attn(query, key, value, bias, mask, dropout_rng, + **kwargs).astype(query.dtype) @dataclass @@ -337,6 +340,7 @@ def check_dqkv(primitive, reference, valid_len): @pytest.mark.parametrize('qkv_layout', [ pytest.param(QKVLayout.BS3HD, id='qkvpacked'), pytest.param(QKVLayout.BSHD_BS2HD, id='kvpacked'), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='separate'), ]) @pytest.mark.parametrize('dropout_prob', [0., 0.1]) @pytest.mark.parametrize('is_training', diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index ebc478d4a5..6eb659ed4e 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +import os from functools import partial import flax @@ -20,6 +21,16 @@ is_fp8_supported, reason = is_fp8_available() +@pytest.fixture(autouse=True, scope='module') +def enable_fused_attn(): + """ + Enable fused attention + """ + os.environ["NVTE_FUSED_ATTN"] = "1" + yield + del os.environ["NVTE_FUSED_ATTN"] + + @pytest.fixture(autouse=True, scope='function') def clear_live_arrays(): """ @@ -93,6 +104,7 @@ def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): BASE_ATTRS = { _KEY_OF_TRANSPOSE_BS: True, _KEY_OF_NUM_HEADS: 8, + _KEY_OF_DROPOUT_RATE: 0, } ATTRS = [{ @@ -221,7 +233,8 @@ def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng) test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng) - assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) + if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout + assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) del data_rng, init_rng, apply_rng @@ -282,9 +295,6 @@ def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e- test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng) - assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) - assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad - def reorganize_test_wgrad(test_wgrad, attrs): num_heads = attrs.get(_KEY_OF_NUM_HEADS) num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads) @@ -328,10 +338,14 @@ def reorganize_test_wgrad(test_wgrad, attrs): del unfreeze_test_wgrad['mlp']['wo_kernel'] return unfreeze_test_wgrad - compare_dict(ref_grads[1], - reorganize_test_wgrad(test_grads[1], attrs), - rtol=rtol, - atol=atol) # wgrad + if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout + assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) + assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad + + compare_dict(ref_grads[1], + reorganize_test_wgrad(test_grads[1], attrs), + rtol=rtol, + atol=atol) # wgrad del data_rng, init_rng, apply_rng @@ -430,7 +444,8 @@ def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng) test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng) - assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) + if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout + assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) del data_rng, init_rng, apply_rng @@ -492,9 +507,6 @@ def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e- test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng) - assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) - assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad - def reorganize_test_wgrad(test_wgrad, attrs): num_heads = attrs.get(_KEY_OF_NUM_HEADS) num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads) @@ -547,10 +559,13 @@ def reorganize_test_wgrad(test_wgrad, attrs): del unfreeze_test_wgrad['mlp']['wo_kernel'] return unfreeze_test_wgrad - compare_dict(ref_grads[1], - reorganize_test_wgrad(test_grads[1], attrs), - rtol=rtol, - atol=atol) # wgrad + if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout + assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) + assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad + compare_dict(ref_grads[1], + reorganize_test_wgrad(test_grads[1], attrs), + rtol=rtol, + atol=atol) # wgrad del data_rng, init_rng, apply_rng diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index 7a42d1f8c3..848f0b06c3 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +import os from functools import partial from typing import Dict @@ -14,12 +15,14 @@ from utils import assert_allclose +from transformer_engine_jax import get_device_compute_capability from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral from transformer_engine.jax.flax import LayerNorm as flax_LayerNorm from transformer_engine.jax.flax import LayerNormMLP as flax_LayerNormMLP from transformer_engine.jax.flax import MultiHeadAttention as flax_MultiHeadAttention +from transformer_engine.jax.flax import DotProductAttention as flax_DotProductAttention from transformer_engine.jax.flax import RelativePositionBiases as flax_RelativePositionBiases from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLayer from transformer_engine.jax.flax.module import Softmax @@ -27,8 +30,8 @@ from transformer_engine.jax.praxis import LayerNorm from transformer_engine.jax.praxis import FusedSoftmax from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear -from transformer_engine.jax.praxis import MultiHeadAttention, RelativePositionBiases -from transformer_engine.jax.praxis import TransformerEngineBaseLayer +from transformer_engine.jax.praxis import DotProductAttention, MultiHeadAttention +from transformer_engine.jax.praxis import RelativePositionBiases, TransformerEngineBaseLayer from transformer_engine.jax.praxis import TransformerLayer, TransformerLayerType from transformer_engine.jax.softmax import SoftmaxType @@ -40,6 +43,19 @@ FP8_FORMATS = [Format.E4M3, Format.HYBRID] +@pytest.fixture(autouse=True, scope='module') +def enable_fused_attn(): + """ + Enable fused attn for hopper+ arch. + Fused attn kernels on pre-hopper arch are not deterministic. + """ + if get_device_compute_capability(0) >= 90: + os.environ["NVTE_FUSED_ATTN"] = "1" + yield + if "NVTE_FUSED_ATTN" in os.environ: + del os.environ["NVTE_FUSED_ATTN"] + + @pytest.fixture(autouse=True, scope='function') def clear_live_arrays(): """ @@ -101,8 +117,9 @@ def sync_variables(self, praxis_variables, flax_variables): lyr_name = self.get_layer_name() - synced_praxis_variables['params'][lyr_name]['cld'] = \ - flax.core.unfreeze(flax_variables['params']) + if 'params' in flax_variables: + synced_praxis_variables['params'][lyr_name]['cld'] = \ + flax.core.unfreeze(flax_variables['params']) return synced_praxis_variables, flax_variables @@ -111,8 +128,9 @@ def sync_wgrads(self, praxis_wgrads, flax_wgrads): lyr_name = self.get_layer_name() - synced_praxis_grads['params'] = \ - synced_praxis_grads['params'][lyr_name]['cld'] + if 'params' in synced_praxis_grads: + synced_praxis_grads['params'] = \ + synced_praxis_grads['params'][lyr_name]['cld'] if FP8Helper.is_fp8_enabled(): synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = \ @@ -671,6 +689,86 @@ def test_forward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol) +class DotProductAttnAttr: + ATTN_MASK_TYPE = 'attn_mask_type' + NUM_GQA_GROUPS = 'num_gqa_groups' + TRANSPOSE_BS = 'transpose_batch_sequence' + SCALE_FACTOR = 'scale_factor' + ATTRS = [{ + ATTN_MASK_TYPE: 'padding', + TRANSPOSE_BS: True, + SCALE_FACTOR: 0.125, + }, { + ATTN_MASK_TYPE: 'padding_causal', + TRANSPOSE_BS: True, + SCALE_FACTOR: 0.125, + }, { + ATTN_MASK_TYPE: 'causal', + TRANSPOSE_BS: True, + SCALE_FACTOR: 0.125, + }, { + ATTN_MASK_TYPE: 'padding', + TRANSPOSE_BS: False, + SCALE_FACTOR: 0.125, + }, { + ATTN_MASK_TYPE: 'padding_causal', + TRANSPOSE_BS: False, + SCALE_FACTOR: 2., + }, { + ATTN_MASK_TYPE: 'causal', + TRANSPOSE_BS: False, + SCALE_FACTOR: 1., + }, { + ATTN_MASK_TYPE: 'no_mask', + TRANSPOSE_BS: False, + SCALE_FACTOR: 1., + }] + + +class TestDotProductAttn(TestLayer): + + def input_getter(self, shape, dtype): + key = jax.random.PRNGKey(seed=1234) + q_key, k_key, v_key = jax.random.split(key, 3) + return list(map(partial(jax.random.normal, shape=shape, dtype=dtype), + [q_key, k_key, v_key])) + + def get_layer_name(self): + return 'dot_product_attn' + + def generate_praxis_p_and_flax_cls(self, dtype, attrs): + head_dim = 64 + num_attention_heads = 16 + num_gqa_groups = num_attention_heads + attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE] + transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS] + + praxis_p = pax_fiddle.Config(DotProductAttention, + name='mha', + dtype=dtype, + head_dim=head_dim, + num_attention_heads=num_attention_heads, + num_gqa_groups=num_gqa_groups, + attn_mask_type=attn_mask_type, + transpose_batch_sequence=transpose_batch_sequence) + flax_cls = partial(flax_DotProductAttention, + dtype=dtype, + head_dim=head_dim, + num_attention_heads=num_attention_heads, + num_gqa_groups=num_gqa_groups, + attn_mask_type=attn_mask_type, + transpose_batch_sequence=transpose_batch_sequence) + + return praxis_p, flax_cls + + @pytest.mark.parametrize('data_shape', [(32, 128, 16, 64)]) + @pytest.mark.parametrize('dtype', DTYPE) + @pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS) + def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): + praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) + self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) + + class MultiHeadAttnAttr: USE_BIAS = 'use_bias' LN_TYPE = 'layernorm_type' @@ -730,54 +828,57 @@ def get_layer_name(self): def generate_praxis_p_and_flax_cls(self, dtype, attrs): head_dim = 64 - num_heads = 16 + num_attention_heads = 16 + num_gqa_groups = attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS] \ + if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs else None layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE] zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN] kernel_init = WeightInit.Gaussian(1.0) use_bias = attrs[MultiHeadAttnAttr.USE_BIAS] bias_init = WeightInit.Constant(0.0) - apply_residual_connection_post_layernorm = False - output_layernorm = False + input_layernorm = False + return_layernorm_output = False attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE] - fuse_qkv: bool = True + fuse_qkv_params = True transpose_batch_sequence = True scale_attn_logits = False scaled_query_init = True float32_logits = False - praxis_p = pax_fiddle.Config( - MultiHeadAttention, - name='mha', - dtype=dtype, - head_dim=head_dim, - num_heads=num_heads, - layernorm_type=layernorm_type, - zero_centered_gamma=zero_centered_gamma, - params_init=kernel_init, - use_bias=use_bias, - bias_init=bias_init, - apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm, - output_layernorm=output_layernorm, - attn_mask_type=attn_mask_type, - fuse_qkv=fuse_qkv, - transpose_batch_sequence=transpose_batch_sequence, - scale_attn_logits=scale_attn_logits, - scaled_query_init=scaled_query_init, - float32_logits=float32_logits) + praxis_p = pax_fiddle.Config(MultiHeadAttention, + name='mha', + dtype=dtype, + head_dim=head_dim, + num_attention_heads=num_attention_heads, + num_gqa_groups=num_gqa_groups, + layernorm_type=layernorm_type, + zero_centered_gamma=zero_centered_gamma, + params_init=kernel_init, + use_bias=use_bias, + bias_init=bias_init, + return_layernorm_output=return_layernorm_output, + input_layernorm=input_layernorm, + attn_mask_type=attn_mask_type, + fuse_qkv_params=fuse_qkv_params, + transpose_batch_sequence=transpose_batch_sequence, + scale_attn_logits=scale_attn_logits, + scaled_query_init=scaled_query_init, + float32_logits=float32_logits) flax_cls = partial( flax_MultiHeadAttention, dtype=dtype, head_dim=head_dim, - num_heads=num_heads, + num_attention_heads=num_attention_heads, + num_gqa_groups=num_gqa_groups, layernorm_type=layernorm_type, zero_centered_gamma=zero_centered_gamma, kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init), use_bias=use_bias, bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init), - apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm, - output_layernorm=output_layernorm, + return_layernorm_output=return_layernorm_output, + input_layernorm=input_layernorm, attn_mask_type=attn_mask_type, - fuse_qkv=fuse_qkv, + fuse_qkv_params=fuse_qkv_params, transpose_batch_sequence=transpose_batch_sequence, scale_attn_logits=scale_attn_logits, scaled_query_init=scaled_query_init, @@ -1024,6 +1125,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE] enable_relative_embedding = True relative_embedding = pax_fiddle.Config(RelativePositionBiases, + dtype=dtype, num_attention_heads=num_attention_heads) drop_path = 0.0 transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS] diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 74cd23bf9e..098eb2aeac 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -934,7 +934,7 @@ def __call__(self, inputs, encoder_mask=None, deterministic=False): y = LayerNorm(layernorm_type=self.layernorm_type, zero_centered_gamma=self.zero_centered_gamma, dtype=self.dtype, - name="output_layer_norm")(y) + name="output_layernorm")(y) return y @@ -1090,7 +1090,7 @@ def __call__(self, z = LayerNorm(layernorm_type=self.layernorm_type, zero_centered_gamma=self.zero_centered_gamma, dtype=self.dtype, - name="output_layer_norm")(z) + name="output_layernorm")(z) return z diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 5b6ae5d505..43e7d17350 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -105,8 +105,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool flag_m512 = false; bool flag_arb = false; if ((sm_arch_ == 80 || sm_arch_ == 90) - && (max_seqlen_q <= 512) - && (max_seqlen_kv <= 512) + && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) + && (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim == 64) && (num_attn_heads == num_gqa_groups) && ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index d2cdc4b432..a8315bbc50 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -1885,6 +1885,7 @@ def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_b backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type, attn_mask_type, dropout_probability, num_heads, num_heads, max_seqlen, max_seqlen, head_dim).get_fused_attn_backend() + if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: softmax_shape = (*batch_shape, num_heads, max_seqlen, max_seqlen) softmax_dtype = qkv_dtype @@ -2029,7 +2030,7 @@ def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray scaling_factor: float, dropout_probability: float, is_training: bool): """ Wrapper for TE self fused attention fwd - Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 + Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2 """ checker = _FusedAttnRNGStateChecker() seed = checker.check_seed(seed, dropout_probability, is_training) @@ -2273,6 +2274,7 @@ def abstract(q_aval, kv_aval, bias_aval, q_seqlen_or_cu_seqlen_aval, attn_bias_type, attn_mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend() + if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen) softmax_dtype = q_dtype @@ -2426,7 +2428,7 @@ def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, q_s dropout_probability: float, is_training: bool): """ Wrapper for TE cross fused attention fwd - Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 + Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2 """ checker = _FusedAttnRNGStateChecker() seed = checker.check_seed(seed, dropout_probability, is_training) @@ -2662,6 +2664,445 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, is_training=is_training) +class FusedAttnFwdPrimitive(BasePrimitive): + """ + Fused Attention Forward Primitive + Query, key, value are seperated tensors + """ + name = "te_fused_attn_forward" + multiple_results = True + impl_static_args = (7, 8, 9, 10, 11) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(q_aval, k_aval, v_aval, bias_aval, q_seqlen_or_cu_seqlen_aval, + kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type, + scaling_factor, dropout_probability, is_training): + """ + Fused attention fwd abstract + """ + q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) + k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) + v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) + bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + assert q_dtype == k_dtype == v_dtype == bias_dtype + assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + + *q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape + *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape + assert q_batch_shape == kv_batch_shape + assert q_head_dim == kv_head_dim + assert k_aval.shape == v_aval.shape + out_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) + + # backend determines the softmax buffer shape/dtype + backend = FusedAttnHelper(q_dtype, k_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, attn_bias_type, + attn_mask_type, dropout_probability, num_heads, num_gqa_groups, + q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend() + + if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: + softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen) + softmax_dtype = q_dtype + elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: + softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, 1) + softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) + else: + raise ValueError(f'Unsupported {backend=}') + softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) + + # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with + # 32-bit unsigned int to get the buffer size we need in the C++ kernel + checker = _FusedAttnRNGStateChecker() + seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) + assert seed_dtype == checker.rng_state_dtype + rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) + rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) + + # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to + # prepare for the active fused-attn backend + batch_size = reduce(operator.mul, q_batch_shape) + wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes( + batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim, + scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, + jax_dtype_to_te_dtype(q_aval.dtype), is_training) + wkspace_aval = q_aval.update(shape=wkspace_info[0], + dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + + return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + Fused attention fwd outer primitive abstract + """ + out_aval, softmax_aux_aval, rng_state_aval, _ = \ + FusedAttnFwdPrimitive.abstract(*args, **kwargs) + return out_aval, softmax_aux_aval, rng_state_aval + + @staticmethod + def lowering(ctx, q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, + attn_mask_type, scaling_factor, dropout_probability, is_training): + """ + Fused attention fwd lowering rules + """ + operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed] + operand_shapes = map(lambda x: x.type.shape, operands) + out_types = [ + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) + for output in ctx.avals_out + ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + q_aval, k_aval, v_aval, *_ = ctx.avals_in + *batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape + *_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape + assert k_aval.shape == v_aval.shape + batch_size = reduce(operator.mul, batch_shape) + + wkspace_aval = ctx.avals_out[-1] + + opaque = transformer_engine_jax.pack_fused_attn_descriptor( + batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim, + wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, + jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), + is_training) + + out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) + + return out + + @staticmethod + def impl(q, k, v, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, + scaling_factor, dropout_probability, is_training): + assert FusedAttnFwdPrimitive.inner_primitive is not None + + q_cu_seqlen = generate_cu_seqlen(q_seqlen) + kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) + + output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind( + q, + k, + v, + bias, + q_cu_seqlen, + kv_cu_seqlen, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + return output, softmax_aux, rng_state + + @staticmethod + def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): + _check_valid_batch_dims(batch_dims) + assert FusedAttnFwdPrimitive.outer_primitive is not None + q_bdim, *_, seed_bdim = batch_dims + + out_bdims = q_bdim, q_bdim, seed_bdim + return FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training), out_bdims + + @staticmethod + def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training, mesh, arg_infos, + result_infos): + del attn_bias_type, attn_mask_type, scaling_factor + del dropout_probability, is_training, result_infos + q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden) + k_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, head, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3])) + rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) + return (out_sharding, softmax_aux_sharding, rng_state_sharding) + + @staticmethod + def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, + mesh, arg_infos, result_infos): + del result_infos + q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden) + k_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, head, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3])) + rng_state_sharding = seed_sharding = NamedSharding(mesh, + PartitionSpec(get_all_mesh_axes(), None)) + arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + impl = partial(FusedAttnFwdPrimitive.impl, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(FusedAttnFwdPrimitive) + + +def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, + q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, seed: jnp.ndarray, + attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, + scaling_factor: float, dropout_probability: float, is_training: bool): + """ + Wrapper for TE fused attention fwd, where query, key, value are seperated tensors + Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2 + """ + checker = _FusedAttnRNGStateChecker() + seed = checker.check_seed(seed, dropout_probability, is_training) + + if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + assert bias is None + bias = jnp.zeros(0, dtype=q.dtype) + + return FusedAttnFwdPrimitive.outer_primitive.bind(q, + k, + v, + bias, + q_seqlen, + kv_seqlen, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + + +class FusedAttnBwdPrimitive(BasePrimitive): + """ + Fused Attention Backward Primitive + """ + name = "te_fused_attn_backward" + multiple_results = True + impl_static_args = (10, 11, 12, 13, 14) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(q_aval, k_aval, v_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, + doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type, + attn_mask_type, scaling_factor, dropout_probability, is_training): + """ + Fused attention bwd abstract + """ + del softmax_aux_aval, rng_state_aval, output_aval + + q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) + k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) + v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) + bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) + assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype + assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype + + *q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape + *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape + assert q_batch_shape == kv_batch_shape + assert q_head_dim == kv_head_dim + assert k_aval.shape == v_aval.shape + + batch_size = reduce(operator.mul, q_batch_shape) + wkspace_shape, wkspace_dtype = \ + transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( + batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim, + scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, + jax_dtype_to_te_dtype(q_aval.dtype), is_training + ) + + dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) + dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype) + dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype) + dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) + wkspace_aval = q_aval.update(shape=wkspace_shape, + dtype=te_dtype_to_jax_dtype(wkspace_dtype)) + + return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + Fused attention fwd outer primitive abstract + """ + dq_aval, dk_aval, dv_aval, dbias_aval, _ = \ + FusedAttnBwdPrimitive.abstract(*args, **kwargs) + return dq_aval, dk_aval, dv_aval, dbias_aval + + @staticmethod + def lowering(ctx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, + kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): + """ + Fused attention bwd lowering rules + """ + operands = [ + q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen + ] + operand_shapes = map(lambda x: x.type.shape, operands) + out_types = [ + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) + for output in ctx.avals_out + ] + + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + q_aval, k_aval, v_aval, *_ = ctx.avals_in + *batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape + *_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape + assert k_aval.shape == v_aval.shape + batch_size = reduce(operator.mul, batch_shape) + + wkspace_aval = ctx.avals_out[-1] + + opaque = transformer_engine_jax.pack_fused_attn_descriptor( + batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim, + wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, + jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), + is_training) + + out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) + + return out + + @staticmethod + def impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen, + attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): + assert FusedAttnBwdPrimitive.inner_primitive is not None + + q_cu_seqlen = generate_cu_seqlen(q_seqlen) + kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) + + dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + return dq, dk, dv, dbias + + @staticmethod + def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): + _check_valid_batch_dims(batch_dims) + assert FusedAttnBwdPrimitive.outer_primitive is not None + q_bdim, k_bdim, v_bdim, *_ = batch_dims + + out_bdims = q_bdim, k_bdim, v_bdim, q_bdim + return FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training), out_bdims + + @staticmethod + def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training, mesh, arg_infos, + result_infos): + del attn_bias_type, attn_mask_type, scaling_factor + del dropout_probability, is_training, result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) + + @staticmethod + def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, + mesh, arg_infos, result_infos): + del result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) + + def sharded_impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, + kv_cu_seqlen): + local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + global_dbias = local_dbias + if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) + return local_dq, local_dk, local_dv, global_dbias + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(FusedAttnBwdPrimitive) + + +def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, + softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray, + doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, + attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, + scaling_factor: float, dropout_probability: float, is_training: bool): + """ + Wrapper for TE fused attention bwd + Return the gradients of fused attention with seperated query, key, value tensors + """ + if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + assert bias is None + bias = jnp.zeros(0, dtype=q.dtype) + return FusedAttnBwdPrimitive.outer_primitive.bind(q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + + class GeluPrimitive(BasePrimitive): """ Gelu Froward Primitive diff --git a/transformer_engine/jax/csrc/extensions.cpp b/transformer_engine/jax/csrc/extensions.cpp index 3011b43f6f..5faec6fd10 100644 --- a/transformer_engine/jax/csrc/extensions.cpp +++ b/transformer_engine/jax/csrc/extensions.cpp @@ -53,6 +53,8 @@ pybind11::dict Registrations() { dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward); dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward); dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward); + dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); + dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); return dict; } @@ -74,6 +76,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes); m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes); m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes); + m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); + m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) @@ -98,7 +102,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD); + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) diff --git a/transformer_engine/jax/csrc/modules.cpp b/transformer_engine/jax/csrc/modules.cpp index 9e8d5c3d3c..e5e4c72437 100644 --- a/transformer_engine/jax/csrc/modules.cpp +++ b/transformer_engine/jax/csrc/modules.cpp @@ -1253,7 +1253,6 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - // TODO(rewang): add bias for cross attn? auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); // FP16/BF16 doesn't use this tensor @@ -1488,5 +1487,265 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa nvte_tensor_pack_destroy(&aux_input_tensors); } +pybind11::tuple GetFusedAttnForwardWorkspaceSizes( + size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads, + size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD; + + auto q_shape = std::vector{batch_size * q_max_seqlen, num_heads, head_dim}; + auto k_shape = std::vector{batch_size * kv_max_seqlen, num_gqa_groups, head_dim}; + auto v_shape = k_shape; + auto bias_shape = std::vector{1, num_heads, q_max_seqlen, kv_max_seqlen}; + + auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); + auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); + + auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); + + // F16 doesn't use this tensor + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); + + auto q_cu_seqlens_tensor = + TensorWrapper(nullptr, std::vector{batch_size + 1}, DType::kInt32); + auto kv_cu_seqlens_tensor = + TensorWrapper(nullptr, std::vector{batch_size + 1}, DType::kInt32); + + auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); + + NVTETensorPack aux_output_tensors; + nvte_tensor_pack_create(&aux_output_tensors); + + TensorWrapper query_workspace_tensor; + nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + + auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); + return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); +} + +void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + const CustomCallFusedAttnDescriptor &descriptor = + *UnpackOpaque(opaque, opaque_len); + + // input buffers from XLA + void *q = buffers[0]; + void *k = buffers[1]; + void *v = buffers[2]; + void *bias = buffers[3]; + void *q_cu_seqlens = buffers[4]; + void *kv_cu_seqlens = buffers[5]; + void *seed = buffers[6]; + + // output buffers from XLA + void *output = buffers[7]; + void *softmax_aux = buffers[8]; + void *rng_state = buffers[9]; + void *workspace = buffers[10]; + + // tensor sizes + auto batch_size = descriptor.batch_size; + auto q_max_seqlen = descriptor.q_max_seqlen; + auto kv_max_seqlen = descriptor.kv_max_seqlen; + auto num_heads = descriptor.num_heads; + auto num_gqa_groups = descriptor.num_gqa_groups; + auto head_dim = descriptor.head_dim; + auto scaling_factor = descriptor.scaling_factor; + auto dropout_probability = descriptor.dropout_probability; + auto bias_type = descriptor.bias_type; + auto mask_type = descriptor.mask_type; + + auto q_shape = std::vector{batch_size * q_max_seqlen, num_heads, head_dim}; + auto k_shape = std::vector{batch_size * kv_max_seqlen, num_gqa_groups, head_dim}; + auto v_shape = k_shape; + auto bias_shape = std::vector{1, num_heads, q_max_seqlen, kv_max_seqlen}; + + // input tensors + auto dtype = descriptor.dtype; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k_tensor = TensorWrapper(k, k_shape, dtype); + auto v_tensor = TensorWrapper(v, v_shape, dtype); + auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); + + // output tensors + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 + auto o_tensor = TensorWrapper(output, q_shape, dtype); + auto q_cu_seqlens_tensor = + TensorWrapper(q_cu_seqlens, std::vector{batch_size + 1}, DType::kInt32); + auto kv_cu_seqlens_tensor = + TensorWrapper(kv_cu_seqlens, std::vector{batch_size + 1}, DType::kInt32); + + // prep RNG state + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD; + auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); + auto backend = nvte_get_fused_attn_backend( + static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, + mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + head_dim); + PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); + + // auxiliary tensors (to be propagated to the backward pass later) + NVTETensorPack aux_output_tensors; + nvte_tensor_pack_create(&aux_output_tensors); + PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, + softmax_aux); + + // cuDNN workspace + auto workspace_tensor = TensorWrapper(workspace, std::vector{descriptor.wkspace_size}, + descriptor.wkspace_dtype); + + nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, + descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, workspace_tensor.data(), stream); + + nvte_tensor_pack_destroy(&aux_output_tensors); +} + +pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( + size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads, + size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD; + + auto q_shape = std::vector{batch_size * q_max_seqlen, num_heads, head_dim}; + auto k_shape = std::vector{batch_size * kv_max_seqlen, num_gqa_groups, head_dim}; + auto v_shape = k_shape; + auto output_shape = std::vector{batch_size * q_max_seqlen, num_heads, head_dim}; + auto bias_shape = std::vector{1, num_heads, q_max_seqlen, kv_max_seqlen}; + + auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); + auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); + auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); + auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); + // F16 doesn't use this tensor + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + + auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); + auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype); + auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); + + auto q_cu_seqlens_tensor = + TensorWrapper(nullptr, std::vector{batch_size + 1}, DType::kInt32); + auto kv_cu_seqlens_tensor = + TensorWrapper(nullptr, std::vector{batch_size + 1}, DType::kInt32); + + NVTETensorPack aux_input_tensors; + nvte_tensor_pack_create(&aux_input_tensors); + + TensorWrapper query_workspace_tensor; + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + + auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); + return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); +} + +void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + const CustomCallFusedAttnDescriptor &descriptor = + *UnpackOpaque(opaque, opaque_len); + + // input buffers from XLA + void *q = buffers[0]; + void *k = buffers[1]; + void *v = buffers[2]; + void *bias = buffers[3]; + void *softmax_aux = buffers[4]; + void *rng_state = buffers[5]; + void *output = buffers[6]; + void *doutput = buffers[7]; + void *q_cu_seqlens = buffers[8]; + void *kv_cu_seqlens = buffers[9]; + + // output buffers from XLA + void *dq = buffers[10]; + void *dk = buffers[11]; + void *dv = buffers[12]; + void *dbias = buffers[13]; + void *workspace = buffers[14]; + + // tensor sizes + auto batch_size = descriptor.batch_size; + auto q_max_seqlen = descriptor.q_max_seqlen; + auto kv_max_seqlen = descriptor.kv_max_seqlen; + auto num_heads = descriptor.num_heads; + auto num_gqa_groups = descriptor.num_gqa_groups; + auto head_dim = descriptor.head_dim; + auto scaling_factor = descriptor.scaling_factor; + auto dropout_probability = descriptor.dropout_probability; + auto bias_type = descriptor.bias_type; + auto mask_type = descriptor.mask_type; + + auto q_shape = std::vector{batch_size * q_max_seqlen, num_heads, head_dim}; + auto k_shape = std::vector{batch_size * kv_max_seqlen, num_gqa_groups, head_dim}; + auto v_shape = k_shape; + auto output_shape = std::vector{batch_size * q_max_seqlen, num_heads, head_dim}; + auto bias_shape = std::vector{1, num_heads, q_max_seqlen, kv_max_seqlen}; + + // input tensors + auto dtype = descriptor.dtype; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k_tensor = TensorWrapper(k, k_shape, dtype); + auto v_tensor = TensorWrapper(v, v_shape, dtype); + auto output_tensor = TensorWrapper(output, output_shape, dtype); + auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); + + // output tensors + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 + auto dq_tensor = TensorWrapper(dq, q_shape, dtype); + auto dk_tensor = TensorWrapper(dk, k_shape, dtype); + auto dv_tensor = TensorWrapper(dv, v_shape, dtype); + auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); + auto q_cu_seqlens_tensor = + TensorWrapper(q_cu_seqlens, std::vector{batch_size + 1}, DType::kInt32); + auto kv_cu_seqlens_tensor = + TensorWrapper(kv_cu_seqlens, std::vector{batch_size + 1}, DType::kInt32); + + // auxiliary tensors (propagated from the forward pass) + NVTETensorPack aux_input_tensors; + nvte_tensor_pack_create(&aux_input_tensors); + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD; + auto backend = nvte_get_fused_attn_backend( + static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, + mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + head_dim); + PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, + rng_state, bias); + + // cuDNN workspace + auto wkspace_size = std::vector{descriptor.wkspace_size}; + auto wkspace_dtype = descriptor.wkspace_dtype; + auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); + + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, + workspace_tensor.data(), stream); + + nvte_tensor_pack_destroy(&aux_input_tensors); +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/modules.h b/transformer_engine/jax/csrc/modules.h index 2a572c5784..109ea45f9c 100644 --- a/transformer_engine/jax/csrc/modules.h +++ b/transformer_engine/jax/csrc/modules.h @@ -236,6 +236,20 @@ pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +pybind11::tuple GetFusedAttnForwardWorkspaceSizes( + size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads, + size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); + +void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); + +pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( + size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads, + size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); + +void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index edb87447e0..914bce00b3 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -5,11 +5,19 @@ from .module import DenseGeneral, LayerNorm from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase from .transformer import extend_logical_axis_rules -from .transformer import MultiHeadAttention, RelativePositionBiases +from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType __all__ = [ - 'DenseGeneral', 'LayerNorm', 'LayerNormDenseGeneral', 'LayerNormMLP', - 'TransformerEngineBase', 'extend_logical_axis_rules', 'MultiHeadAttention', - 'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType', + 'DenseGeneral', + 'LayerNorm', + 'LayerNormDenseGeneral', + 'LayerNormMLP', + 'TransformerEngineBase', + 'extend_logical_axis_rules', + 'DotProductAttention', + 'MultiHeadAttention', + 'RelativePositionBiases', + 'TransformerLayer', + 'TransformerLayerType', ] diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 907a758d61..8d68604cc4 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -16,6 +16,7 @@ import numpy as np from flax import linen as nn from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import combine_masks from jax import nn as jax_nn from jax import random as jax_random from jax import lax, vmap @@ -24,8 +25,8 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import LayerNorm, Softmax from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout -from ..fused_attn import is_fused_attn_kernel_available -from ..fused_attn import self_fused_attn, cross_fused_attn +from ..fused_attn import is_fused_attn_kernel_available, canonicalize_attn_mask_type +from ..fused_attn import self_fused_attn, cross_fused_attn, fused_attn from ..softmax import SoftmaxType from ..sharding import num_of_devices from ..sharding import get_sharding_map_logic_axis_to_mesh_axis @@ -71,12 +72,12 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: Parameters ---------- - rules : Sequence[Tuple[str, Union[str, None]]] + rules: Sequence[Tuple[str, Union[str, None]]] the base Flax logical axis rules to extend. Returns ------- - extended_rules : Sequence[Tuple[str, Union[str, None]]] + extended_rules: Sequence[Tuple[str, Union[str, None]]] the extended Flax logical axis rules. """ rules_map = {} @@ -108,122 +109,436 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: return tuple(extended_rules) -def _merge_mask(func, *masks: Optional[Array]): - masks = [m for m in masks if m is not None] - if not masks: - return None - assert all(map(lambda x: x.ndim == masks[0].ndim, - masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') - mask, *other_masks = masks - for other_mask in other_masks: - mask = func(mask, other_mask) - return mask - - -def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32): - """Combine attention masks.""" - func = jnp.logical_and - return _merge_mask(func, *masks).astype(dtype) - - -def combine_biases(*masks: Optional[Array]): - """Combine attention biases.""" - - def func(a, b): - return a + b - - return _merge_mask(func, *masks) - - -def core_attention(query: Array, - key: Array, - value: Array, - scale_factor: float, - transpose_batch_sequence: bool, - softmax_type: SoftmaxType = SoftmaxType.SCALED, - mask: Optional[Array] = None, - bias: Optional[Array] = None, - dropout_rng: Optional[PRNGKey] = None, - dropout_rate: float = 0., - deterministic: bool = False, - dtype: DType = jnp.float32, - float32_logits: bool = False): - """Core attention""" - assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' - batch_dim = 1 if transpose_batch_sequence else 0 - assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], ( - 'q, k, v batch dims must match.') - sequence_dim = 0 if transpose_batch_sequence else 1 - assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.' - assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.' - assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.' - - if float32_logits: - query = query.astype(jnp.float32) - key = key.astype(jnp.float32) - - h_q, h_kv = query.shape[-2], key.shape[-2] - # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv. - # Therefore, we have to maintain two code paths. - is_gqa = (h_q != h_kv) - - if is_gqa: - assert (h_q % h_kv == 0) and (h_q >= h_kv) - group_size = h_q // h_kv - grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) - - if transpose_batch_sequence: +class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-methods + attention_dropout: float = 0. + attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK + attn_bias_type: Optional[AttnBiasType] = None + dtype: DType = jnp.float32 + float32_logits: bool = False + scale_factor: Optional[float] = None + transpose_batch_sequence: bool = True + + @nn.compact + def __call__(self, + query: Array, + key: Array, + value: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + *, + dropout_rng: Optional[PRNGKey] = None, + deterministic: bool = False) -> Array: + assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' + batch_dim = 1 if self.transpose_batch_sequence else 0 + assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], ( + 'q, k, v batch dims must match.') + sequence_dim = 0 if self.transpose_batch_sequence else 1 + assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.' + assert key.shape[-2] == value.shape[-2], 'k, v num_attention_heads must match.' + assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.' + + if self.scale_factor is None: + scale_factor = 1.0 / sqrt(query.shape[-1]) + else: + scale_factor = self.scale_factor + del self.scale_factor + + if self.float32_logits: + query = query.astype(jnp.float32) + key = key.astype(jnp.float32) + h_q, h_kv = query.shape[-2], key.shape[-2] + # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv. + # Therefore, we have to maintain two code paths. + is_gqa = (h_q != h_kv) + if is_gqa: - attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key) + assert (h_q % h_kv == 0) and (h_q >= h_kv) + group_size = h_q // h_kv + grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) + + if self.transpose_batch_sequence: + if is_gqa: + attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key) + else: + attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key) else: - attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key) - else: + if is_gqa: + attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key) + else: + attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) + + attn_weights = checkpoint_name(attn_weights, 'logits') + if is_gqa: - attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key) + b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape + attn_weights_without_groups_shape = (b, h * g, q, k) + attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) + + attn_weights = with_sharding_constraint_by_logical_axes( + attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)) + + # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias) + # In this case, the scale can not fused into the Softmax module. + if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS: + attn_weights = attn_weights * scale_factor + fused_scale_factor = 1. else: - attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) - - attn_weights = checkpoint_name(attn_weights, 'logits') - - if is_gqa: - b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape - attn_weights_without_groups_shape = (b, h * g, q, k) - attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) - - attn_weights = with_sharding_constraint_by_logical_axes( - attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)) - - # When a bias is present, the computation is performed as Softmax(attn_weights * scale + bias). - # In this case, the scale can not fused into the Softmax module. - if bias is not None: - attn_weights = attn_weights * scale_factor - fused_scale_factor = 1. - else: - # If no bias, the scale can be fused into Softmax module - fused_scale_factor = scale_factor - - attn_weights = Softmax(softmax_type=softmax_type, - scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype) - - if is_gqa: - attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) - - if not deterministic and dropout_rate > 0.: - keep_prob = 1.0 - dropout_rate - dropout_shape = list(attn_weights.shape) - # TODO(rewang): add attention dropout broadcast dimension arguments for users - keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) - multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) - attn_weights = attn_weights * multiplier - - if transpose_batch_sequence: + # If not post_scale_bias, the scale can be fused into Softmax module + fused_scale_factor = scale_factor + if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: + attn_weights += bias + + def convert_to_softmax_type(attn_mask_type, mask): + """Convert the attn_mask_type to SoftmaxType""" + if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: + return SoftmaxType.SCALED_UPPER_TRIANG_MASKED + if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]: + if mask is not None: + return SoftmaxType.SCALED_MASKED + return SoftmaxType.SCALED + raise ValueError(f"Unsupported {attn_mask_type=}, " + "supported attn_mask_type = {'causal', 'padding'}") + + softmax_type = convert_to_softmax_type(self.attn_mask_type, mask) + + attn_weights = Softmax(softmax_type=softmax_type, + scale_factor=fused_scale_factor)(attn_weights, mask, + bias).astype(self.dtype) + if is_gqa: - return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape) - return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value) + attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) + + if not deterministic and self.attention_dropout > 0.: + keep_prob = 1.0 - self.attention_dropout + dropout_shape = list(attn_weights.shape) + # TODO(rewang): add attention dropout broadcast dimension arguments for users + keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) + multiplier = (keep.astype(attn_weights.dtype) / + jnp.asarray(keep_prob, dtype=self.dtype)) + attn_weights = attn_weights * multiplier + + if self.transpose_batch_sequence: + if is_gqa: + return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape) + return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value) + + if is_gqa: + return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape) + return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) + + +class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-methods + attention_dropout: float = 0. + attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK + attn_bias_type: Optional[AttnBiasType] = None + dtype: DType = jnp.float32 + qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD + scale_factor: Optional[float] = None + transpose_batch_sequence: bool = False + + @nn.compact + def __call__(self, + query: Array, + key: Array, + value: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + *, + dropout_rng: Optional[PRNGKey] = None, + deterministic: bool = False) -> Array: + + seed = None + if dropout_rng is not None: + seed = jax.random.split(dropout_rng, num_of_devices()) + + if self.scale_factor is None: + scale_factor = 1.0 / sqrt(query.shape[-1]) + else: + scale_factor = self.scale_factor + del self.scale_factor + + if self.qkv_layout == QKVLayout.BS3HD: + """qkvpacked format, treat + query: qkvpacked tensor, shape = [..., 3, h, d] + key: ignore + value: ignore + """ + qkv_packed = query + if self.transpose_batch_sequence: + qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4]) + x = self_fused_attn(qkv_packed, + bias, + mask, + seed, + attn_mask_type=self.attn_mask_type, + attn_bias_type=self.attn_bias_type, + scaling_factor=scale_factor, + dropout_probability=self.attention_dropout, + is_training=not deterministic) + elif self.qkv_layout == QKVLayout.BSHD_BS2HD: + """kvpacked format, treat + query: query tensor, shape = [..., h, d] + key: kvpacked tensor, shape = [..., 2, h, d] + value: ignore + """ + kv_packed = key + if self.transpose_batch_sequence: + query = query.transpose([1, 0, 2, 3]) + kv_packed = kv_packed.transpose([1, 0, 2, 3, 4]) + x = cross_fused_attn(query, + kv_packed, + bias, + mask, + seed, + attn_mask_type=self.attn_mask_type, + attn_bias_type=self.attn_bias_type, + scaling_factor=scale_factor, + dropout_probability=self.attention_dropout, + is_training=not deterministic) + elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD: + if self.transpose_batch_sequence: + query = query.transpose([1, 0, 2, 3]) + key = key.transpose([1, 0, 2, 3]) + value = value.transpose([1, 0, 2, 3]) + x = fused_attn(query, + key, + value, + bias, + mask, + seed, + attn_mask_type=self.attn_mask_type, + attn_bias_type=self.attn_bias_type, + scaling_factor=scale_factor, + dropout_probability=self.attention_dropout, + is_training=not deterministic) + else: + raise ValueError(f"Unsupported {self.qkv_layout=}.") + + if self.transpose_batch_sequence: + x = x.transpose([1, 0, 2, 3]) + + return x + + +class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods + r""" + Dot Product Attention (DPA). Allows the model to jointly attend to information from different + representation subspaces as described in the paper: + `Attention Is All You Need `_. + + .. note:: + The DotProductAttention module supports two backends: the unfused and the fused attention + mechanisms. The unfused attention is implemented using JAX native operations, providing + broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused + attention + `_ for + higher performance and lower memory usage on the supported hardwares. + Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment + variable: + + * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default). + * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention + kernel is not available on the system, a warning will be issued, and the module will + automatically fall back to the unfused backend. + + Parameters + ---------- + head_dim: int + The hidden dimension of each attention head. + num_attention_heads: int + The number of attention heads. + num_gqa_groups: int, default = `None` + Number of GQA groups. When `None` is present, it is equal to num_attention_heads. + Grouped Query Attention is described in + `this paper `_. + This only affects the keys and values, not the querys. + GQA-1 is equivalent to Multi-Query Attention + (`MQA `_), while GQA-H + is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. + attention_dropout: float, default = 0.0 + Dropout probability for the dropout op after the softmax. + attn_mask_type: str, default = 'causal' + Type of the attention mask passed into softmax operation in the self attention. + Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} + Introduced in v0.10.0. + attn_bias_type: Optional[str], default = None + Type of the attention bias passed in the self attention. + Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. + When default is present, the type is automatically decided by the MHA's bias parameter. + Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used. + dropout_rng_name: str, default = 'dropout' + The key in given RNGs via flax.linen.Module.apply that is used + to generate Dropout masks in the core attention. + float32_logits: bool, default = False + Whether to compute attention logits in float32 for the unfused attention backend. + For fused attention backend, the accumulation is always float32 without the perf overhead. + qkv_layout: str, default = 'bshd_bshd_bshd' + Specifies the dimensional layout format for the query, key, and value tensors in __call__(). + It indicates how the inputs are processed. + Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd'}. Where + + * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d]. + key and value arguments in :attr:`__call__()` are ignored in this layout. + * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked + tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored. + * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d]. + + Explanation of denotations: + + * b: batch size + * s: seqeuence length + * h: num_attention_heads or num_gqa_groups + * d: head dimension + + scale_factor: Optional[float], default = None + Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal + to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't + need to apply scale on query, which is to set :attr:`scale_factor=1.`. + transpose_batch_sequence: bool, default = True + Indicate whether the input tensors were switched axis of batch + and sequence length dimension. if set to True, the input tensors + should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). + + Optimization parameters + ----------------------- + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used to allocate the initial parameters. + """ + head_dim: int + num_attention_heads: int + num_gqa_groups: Optional[int] = None + attention_dropout: float = 0. + attn_mask_type: AttnMaskType = 'causal' + attn_bias_type: AttnBiasType = None + dtype: DType = jnp.float32 + dropout_rng_name: str = 'dropout' + float32_logits: bool = False + qkv_layout: str = 'bshd_bshd_bshd' + scale_factor: Optional[float] = None + transpose_batch_sequence: bool = True + + @nn.compact + def __call__(self, + query: Array, + key: Array, + value: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + *, + deterministic: bool = False) -> Array: + """ + Parameters + ---------- + query: jax.numpy.ndarray + The details of query tensor representation is described in :attr:`qkv_layout`. + key: jax.numpy.ndarrary + The details of kery tensor representation is described in :attr:`qkv_layout`. + value: jax.numpy.ndarrary + The details of value tensor representation is described in :attr:`qkv_layout`. + mask: jax.numpy.ndarray, default = None + Boolean tensor used to mask out the attention softmax input. + :attr:`True` means to mask out the corresponding values. + bias: jax.numpy.ndarray, default = None + A tensor used to shift attention softmax input. + *: + Below parameters are keyword only + deterministic: bool, default = False + Disable dropout layers if set to True. + + Returns + ------- + outputs: jax.numpy.ndarray + Output tensors. + """ + + # For internal API, we use enum to maintain + if self.attn_bias_type is None: + attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS + else: + attn_bias_type = AttnBiasType[self.attn_bias_type.upper()] + attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type) + qkv_layout = QKVLayout[self.qkv_layout.upper()] + del self.attn_bias_type, self.attn_mask_type, self.qkv_layout + + if attn_bias_type == AttnBiasType.NO_BIAS: + assert bias is None + else: + assert bias is not None + + enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0")) + + sequence_dim = 0 if self.transpose_batch_sequence else 1 + seqlen_q = query.shape[sequence_dim] + if qkv_layout == QKVLayout.BS3HD: + seqlen_kv = seqlen_q + else: + seqlen_kv = key.shape[sequence_dim] + + has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout, + attn_bias_type, attn_mask_type, + self.attention_dropout, + self.num_attention_heads, + self.num_gqa_groups, seqlen_q, + seqlen_kv, self.head_dim) + + use_fused_attn = (enable_fused_attn and has_fused_attn_kernel) + + if enable_fused_attn and not has_fused_attn_kernel: + warnings.warn("Fused attention is not enabled because there is no available kernel.\n" + "Fall back to the unfused attention.\n" + "Please try to update the cuDNN and TE to the latest version.\n" + f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n" + f"{self.attention_dropout=}\n{self.num_attention_heads=}\n" + f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n") + + dropout_rng = None + if not deterministic and self.attention_dropout > 0.: + dropout_rng = self.make_rng(self.dropout_rng_name) + + if self.scale_factor is None: + scale_factor = 1.0 / sqrt(self.head_dim) + else: + scale_factor = self.scale_factor + del self.scale_factor + + if not use_fused_attn: + # unfused attention only supports splitted query, key, value + if qkv_layout == QKVLayout.BS3HD: + query, key, value = jnp.split(query, [1, 2], axis=-3) + query, key, value = map(functools.partial(jnp.squeeze, axis=-3), + [query, key, value]) + elif qkv_layout == QKVLayout.BSHD_BS2HD: + key, value = jnp.split(key, [1], axis=-3) + key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value]) + else: + assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD + + x = _UnfusedDotProductAttention(attention_dropout=self.attention_dropout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + dtype=self.dtype, + float32_logits=self.float32_logits, + scale_factor=scale_factor, + transpose_batch_sequence=self.transpose_batch_sequence)( + query, + key, + value, + mask, + bias, + dropout_rng=dropout_rng, + deterministic=deterministic) + else: + x = _FusedDotProductAttention( + attention_dropout=self.attention_dropout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + dtype=self.dtype, + scale_factor=scale_factor, + transpose_batch_sequence=self.transpose_batch_sequence, + qkv_layout=qkv_layout, + )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) - if is_gqa: - return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape) - return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) + return x def rotary_pos_emb(x: Array, windows: Tuple[int, int], transpose_batch_sequence: bool): @@ -259,43 +574,44 @@ def rotary_pos_emb(x: Array, windows: Tuple[int, int], transpose_batch_sequence: return jnp.concatenate([part_1, part_2], axis=-1) -dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) - - class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods r""" Multi-head Attention (MHA), including Query, Key, Value and Output projection. - .. note:: - - Argument :attr:`mask` will be ignored when - :attr:`attn_mask_type` is set to `"causal"`. - Parameters ---------- - head_dim : int + head_dim: int The hidden dimension of each attention head. - num_heads : int - The number of attention heads - num_gqa_groups : int, default = `None` - Number of GQA groups. When `None` is present, it is equal to num_heads. + num_attention_heads: int + The number of attention heads. + num_gqa_groups: int, default = `None` + Number of GQA groups. When `None` is present, it is equal to num_attention_heads. Grouped Query Attention is described in `this paper `_. This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention (`MQA `_), while GQA-H is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. - dropout_rate : float, default = 0.0 - Dropout probability for the dropout op during multi-head attention. + attention_dropout: float, default = 0.0 + Dropout probability for the dropout op after the softmax. + attn_mask_type: str, default = 'causal' + Type of the attention mask passed into softmax operation in the attention. + Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} + Introduced in v0.10.0. + attn_bias_type: Optional[str], default = None + Type of the attention bias passed in the attention. + Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. + When default is present, the type is automatically decided by the MHA's bias parameter. + Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used. dropout_rng_name: str, default = 'dropout' The key in given RNGs via flax.linen.Module.apply that is used to generate Dropout masks in the core attention. - layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' + layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm' Indicate the type of layer normalization. layernorm_epsilon: float, default = 1e-6 A value added to the denominator of layer normalization for numerical stability. - zero_centered_gamma : bool, default = False + zero_centered_gamma: bool, default = False If set to `True`, the LayerNorm formula changes to .. math:: @@ -305,21 +621,20 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods This parameter is only applicable for 'layernorm'. kernel_init: Initializer, default = flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal') - Used for initializing the QKV and Output projection weights. + Used for initializing the QKV and output projection weights. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). use_bias: bool, default = False - Indicate whether or not to enable bias shifting for QKVO projections. + Indicate whether or not to enable bias shifting for QKV and output projections. If set to False, the layer will not learn additive biases. bias_init: Initializer, default = flax.linen.initializers.zeros Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). - apply_residual_connection_post_layernorm : bool, default = False - Indicate if apply residual connection with the output of layer normalization. - output_layernorm : bool, default = False - Indicate if apply a layer normalization at the end of MHA. - attn_mask_type: {'causal', 'padding'}, default = 'causal' - Type of attention mask passed into softmax operation. - Introduced in v0.10.0. + input_layernorm: bool, default = True + If set to False, layer normalization to the input is not applied. + return_layernorm_output: bool, default = False + If set to True, output of layernorm is returned from the forward together with the output + of the linear transformation. + Example use case: residual connection for transformer module is taken post layernorm. enable_rotary_pos_emb: bool, default = False Whether to enable rotary position embedding to projected query and key. rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000) @@ -327,58 +642,101 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods only used when :attr:`enable_rotary_pos_emb=True` enable_sequence_parallel: bool, default = False Whether to enable sequence parallelism to operations except dot. + num_heads: int, default = None + Deprecated. Please refer `num_attention_heads`. + dropout_rate: float, default = None + Deprecated. Please refer `attention_dropout`. + output_layernorm: bool, default = None + Deprecated. Please refer `input_layernorm` + apply_residual_connection_post_layernorm: bool, default = None + Deprecated. Please refer `return_layernorm_output`. Optimization parameters ----------------------- - dtype :jax.numpy.dtype, default = jax.numpy.float32 + dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - fuse_qkv: bool, default = True + fuse_qkv_params: bool, default = True If set to True, this module exposes a single fused parameter for query-key-value for self-attention and key-value for cross-attention. - transpose_batch_sequence : bool, default = True + transpose_batch_sequence: bool, default = True Indicate whether the input tensors were switched axis of batch and sequence length dimension. if set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). scale_attn_logits: bool, default = False Indicate whether to scale attention logits. - If set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`, + If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`, else :math:`Q*K` - scaled_query_init: bool, default = `True` - Whether to scale WQ on initialization by :math:`\sqrt{head_dim}` - float32_logits : bool, default = False - Whether to compute attention logits in float32. + scaled_query_init: bool, default = True + Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}` + float32_logits: bool, default = False + Whether to compute attention logits in float32 for the unfused attention backend. + For fused attention backend, the accumulation is always float32 without the perf overhead. + fuse_qkv: bool, default = None + Deprecated. Please refer `fuse_qkv_params` """ head_dim: int - num_heads: int - num_gqa_groups: int | None = None - dropout_rate: float = 0. + num_attention_heads: int + num_gqa_groups: Optional[int] = None + attention_dropout: float = 0. dropout_rng_name: str = 'dropout' + input_layernorm: bool = True layernorm_type: str = "layernorm" layernorm_epsilon: float = 1e-6 + return_layernorm_output: bool = False zero_centered_gamma: bool = False kernel_init: Initializer = None use_bias: bool = False bias_init: Initializer = nn.initializers.zeros - apply_residual_connection_post_layernorm: bool = False - output_layernorm: bool = False attn_mask_type: str = 'causal' + attn_bias_type: Optional[str] = None enable_rotary_pos_emb: bool = False rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) dtype: DType = jnp.float32 - fuse_qkv: bool = True + fuse_qkv_params: bool = True transpose_batch_sequence: bool = True enable_sequence_parallel: bool = False scale_attn_logits: bool = False scaled_query_init: bool = True - float32_logits: bool = False # computes logits in float32 for stability. + float32_logits: bool = False + + # Deprecated parameters + num_heads: Optional[int] = None + dropout_rate: Optional[float] = None + output_layernorm: Optional[bool] = None + apply_residual_connection_post_layernorm: Optional[bool] = None + fuse_qkv: Optional[bool] = None def __post_init__(self): + # Deal with the deprecated parameters + if self.num_heads is not None: + self.num_attention_heads = self.num_heads + warnings.warn( + f"{__class__}.num_heads is deprecated. It will be removed recently. " + f"Please uses {__class__}.num_attention_heads as the new API.", DeprecationWarning) + if self.dropout_rate is not None: + self.attention_dropout = self.dropout_rate + warnings.warn( + f"{__class__}.dropout_rate is deprecated. It will be removed recently. " + f"Please use {__class__}.attention_dropout as the new API.", DeprecationWarning) + if self.apply_residual_connection_post_layernorm is not None: + warnings.warn( + f"{__class__}.apply_residual_connection_post_layernorm is deprecated. " + f"It will be removed recently, please use {__class__}.return_layernorm_output.", + DeprecationWarning) + if self.fuse_qkv is not None: + warnings.warn( + f"{__class__}.fuse_qkv is deprecated. It will be removed recently. " + f"Please use {__class__}.fuse_qkv_params as the new API.", DeprecationWarning) + assert self.output_layernorm is None, ( + f"{__class__}.output_layernorm is deprecated. It will be removed recently. " + f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm.") + if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') if self.num_gqa_groups is None: - self.num_gqa_groups = self.num_heads + self.num_gqa_groups = self.num_attention_heads super().__post_init__() @nn.compact @@ -396,23 +754,24 @@ def __call__(self, Parameters ---------- - inputs_q : jax.numpy.ndarray + inputs_q: jax.numpy.ndarray Input tensor for query projection. - inputs_kv : jax.numpy.ndarray + inputs_kv: jax.numpy.ndarray Input tensor for key/value projection. - mask : jax.numpy.ndarray, default = None - Boolean tensor used to mask out self-attention softmax input. - bias : jax.numpy.ndarray, default = None - A tensor used to shift self-attention softmax input. + mask: jax.numpy.ndarray, default = None + Boolean tensor used to mask out the attention softmax input. + :attr:`True` means mask out the corresponding values. + bias: jax.numpy.ndarray, default = None + A tensor used to shift the attention softmax input. * - decode : bool,default = False + decode: bool, default = False Indicate whether to prepare and use an autoregressive cache. - deterministic : bool,default = False + deterministic: bool, default = False Disable dropout layers if set to True. Returns ------- - outputs : jax.numpy.ndarray + outputs: jax.numpy.ndarray Output tensors. """ @@ -450,56 +809,6 @@ def kv_init(key, shape, dtype): return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype) - # TODO(rewang): make it configurable for pre_scale_bias - attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS - - def canonicalize_attn_mask_type(attn_mask_type): - """ - Convert the string to AttnMaskType - """ - if attn_mask_type == 'causal': - return AttnMaskType.PADDING_CAUSAL_MASK - if attn_mask_type == 'padding': - return AttnMaskType.PADDING_MASK - raise ValueError(f"Unsupported {attn_mask_type=}, " - "supported attn_mask_type = {'causal', 'padding'}") - - is_self_attn = (inputs_q is inputs_kv) - is_gqa = (self.num_heads != self.num_gqa_groups) - is_qkvpack = (is_self_attn and not is_gqa) - qkv_layout = QKVLayout.BS3HD if is_self_attn else QKVLayout.BSHD_BS2HD - attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type) - - q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1] - kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1] - enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0")) - - has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout, - attn_bias_type, attn_mask_type, - self.dropout_rate, self.num_heads, - self.num_gqa_groups, q_seqlen, - kv_seqlen, self.head_dim) - - use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \ - has_fused_attn_kernel and \ - enable_fused_attn - - if enable_fused_attn and not use_fused_attn: - reason = "" - if decode: - reason += f"decode=False is required but got {decode}, " - if self.transpose_batch_sequence: - reason += f"transpose_batch_sequence=False is required " \ - f"but got {self.transpose_batch_sequence}, " - if not self.fuse_qkv: - reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, " - if not has_fused_attn_kernel: - reason += "no fused attention kernel is available, " - - warnings.warn( - f"Fused attention is not enabled. Because " \ - f"{reason}fall back to unfused attention.") - def generate_batch_seqlen_logical_axes(is_sharded_seq): sequence_dim = 0 if self.transpose_batch_sequence else 1 batch_dim = 1 - sequence_dim @@ -510,24 +819,27 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES return tuple(axes) + is_self_attn = (inputs_q is inputs_kv) + is_gqa = (self.num_attention_heads != self.num_gqa_groups) + is_qkvpack = (is_self_attn and not is_gqa) + inputs_logical_axes_maybe_sp = (*generate_batch_seqlen_logical_axes( self.enable_sequence_parallel), HIDDEN_AXES) inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES) inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp) - residual = inputs_q - if self.fuse_qkv: + if self.fuse_qkv_params: if is_qkvpack: qkv_proj, ln_out = LayerNormDenseGeneral( - enable_layernorm=not self.output_layernorm, + enable_layernorm=self.input_layernorm, layernorm_type=self.layernorm_type, zero_centered_gamma=self.zero_centered_gamma, epsilon=self.layernorm_epsilon, axis=-1, - features=(3, self.num_heads * self.head_dim), + features=(3, self.num_attention_heads * self.head_dim), transpose_batch_sequence=self.transpose_batch_sequence, - return_layernorm_output=self.apply_residual_connection_post_layernorm, + return_layernorm_output=self.return_layernorm_output, scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), @@ -540,19 +852,17 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): name='qkv', dtype=self.dtype)(inputs_q) qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj') - if not use_fused_attn: - query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2) + qkv_layout = QKVLayout.BS3HD else: query, ln_out = LayerNormDenseGeneral( - enable_layernorm=not self.output_layernorm, + enable_layernorm=self.input_layernorm, layernorm_type=self.layernorm_type, zero_centered_gamma=self.zero_centered_gamma, epsilon=self.layernorm_epsilon, axis=-1, - features=self.num_heads * self.head_dim, + features=self.num_attention_heads * self.head_dim, transpose_batch_sequence=self.transpose_batch_sequence, - return_layernorm_output=(self.apply_residual_connection_post_layernorm - or is_self_attn), + return_layernorm_output=(self.return_layernorm_output or is_self_attn), scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), kernel_axes=(W_FSDP_AXES, W_TP_AXES), @@ -580,8 +890,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): name='kv', dtype=self.dtype)(inputs_kv) kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj') - if not use_fused_attn: - key, value = jnp.split(kv_proj, [1], axis=-2) + qkv_layout = QKVLayout.BSHD_BS2HD else: kv_projection = functools.partial( DenseGeneral, @@ -594,12 +903,12 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): bias_axes=(W_TP_AXES,), dtype=self.dtype) query, ln_out = LayerNormDenseGeneral( - enable_layernorm=not self.output_layernorm, + enable_layernorm=self.input_layernorm, layernorm_type=self.layernorm_type, zero_centered_gamma=self.zero_centered_gamma, epsilon=self.layernorm_epsilon, axis=-1, - features=self.num_heads * self.head_dim, + features=self.num_attention_heads * self.head_dim, transpose_batch_sequence=self.transpose_batch_sequence, return_layernorm_output=True, scale_axes=(W_NO_SHARD_AXES,), @@ -620,44 +929,31 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv) value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv) - - if self.apply_residual_connection_post_layernorm: - assert ln_out is not None - residual = ln_out + query = checkpoint_name(query, 'query_proj') + key = checkpoint_name(key, 'key_proj') + value = checkpoint_name(value, 'value_proj') + qkv_layout = QKVLayout.BSHD_BSHD_BSHD if self.enable_rotary_pos_emb: - if self.fuse_qkv and use_fused_attn: - if is_qkvpack: - query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2) - else: - key, value = jnp.split(kv_proj, [1], axis=-2) + if qkv_layout == QKVLayout.BS3HD: + query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2) + elif qkv_layout == QKVLayout.BSHD_BS2HD: + key, value = jnp.split(kv_proj, [1], axis=-2) + else: + assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD query = rotary_pos_emb(query, self.rotary_pos_emb_windows, self.transpose_batch_sequence) key = rotary_pos_emb(key, self.rotary_pos_emb_windows, self.transpose_batch_sequence) + qkv_layout = QKVLayout.BSHD_BSHD_BSHD - if use_fused_attn: - if is_qkvpack: - qkv_proj = jnp.concatenate([query, key, value], axis=-2) - else: - kv_proj = jnp.concatenate([key, value], axis=-2) - - if not use_fused_attn: - query = checkpoint_name(query, 'query_proj') - key = checkpoint_name(key, 'key_proj') - value = checkpoint_name(value, 'value_proj') - query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim)) + if qkv_layout == QKVLayout.BSHD_BSHD_BSHD: + query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) - qkv_sharding_constraint = \ - (SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \ - if self.transpose_batch_sequence \ - else (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES) - query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint) - key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint) - value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint) if decode: + assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD is_initialized = self.has_variable('cache', 'cached_key') cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) @@ -667,12 +963,12 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: if self.transpose_batch_sequence: - length, batch, num_heads, head_dim = cached_key.value.shape - expected_shape = (1, batch, num_heads, head_dim) + length, batch, num_attention_heads, head_dim = cached_key.value.shape + expected_shape = (1, batch, num_attention_heads, head_dim) one_hot_indices_shape = (length, 1, 1, 1) else: - batch, length, num_heads, head_dim = cached_key.value.shape - expected_shape = (batch, 1, num_heads, head_dim) + batch, length, num_attention_heads, head_dim = cached_key.value.shape + expected_shape = (batch, 1, num_attention_heads, head_dim) one_hot_indices_shape = (1, length, 1, 1) # Sanity shape check of cached key against input query. @@ -694,100 +990,58 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))) if bias is not None: + dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, + in_axes=(None, 0, None, None)) bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2) scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0 - dropout_rng = None - if not deterministic and self.dropout_rate > 0.: - dropout_rng = self.make_rng(self.dropout_rng_name) - - if use_fused_attn: - assert mask is not None and mask.ndim == 4 # (b, 1, s_q, s_kv) - assert not self.transpose_batch_sequence - - seed = None - if dropout_rng is not None: - seed = jax.random.split(dropout_rng, num_of_devices()) - # ensure the old key never used - del dropout_rng - - if is_qkvpack: - qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim)) - qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES, - HIDDEN_AXES) - qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, - qkv_sharding_constraint) - - x = self_fused_attn(qkv_proj, - bias, - mask, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scale_factor, - dropout_probability=self.dropout_rate, - is_training=not deterministic) - else: - assert bias is None - query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim)) - kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_gqa_groups, self.head_dim)) - q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES) - kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES, - HIDDEN_AXES) - query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint) - kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint) - - x = cross_fused_attn(query, - kv_proj, - bias, - mask, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scale_factor, - dropout_probability=self.dropout_rate, - is_training=not deterministic) + LEADING_AXES = (BATCH_AXES, SEQLEN_AXES) + if self.transpose_batch_sequence: + LEADING_AXES = (SEQLEN_AXES, BATCH_AXES) + + if qkv_layout == QKVLayout.BS3HD: + qkv_proj = qkv_proj.reshape(*qkv_proj.shape[:2], 3, self.num_attention_heads, + self.head_dim) + qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES) + qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint) + dpa_args = [qkv_proj, None, None] + elif qkv_layout == QKVLayout.BSHD_BS2HD: + query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim) + kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim) + q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES) + kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES) + query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint) + kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint) + dpa_args = [query, kv_proj, None] else: - - def convert_to_softmax_type(attn_mask_type, mask): - """ - Convert the string to SoftmaxType - """ - if attn_mask_type == 'causal': - return SoftmaxType.SCALED_UPPER_TRIANG_MASKED - if attn_mask_type == 'padding': - if mask is not None: - return SoftmaxType.SCALED_MASKED - return SoftmaxType.SCALED - raise ValueError(f"Unsupported {attn_mask_type=}, " - "supported attn_mask_type = {'causal', 'padding'}") - - softmax_type = convert_to_softmax_type(self.attn_mask_type, mask) - - x = core_attention(query, - key, - value, - scale_factor=scale_factor, - transpose_batch_sequence=self.transpose_batch_sequence, - softmax_type=softmax_type, - mask=mask, - bias=bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout_rate, - deterministic=deterministic, - dtype=self.dtype, - float32_logits=self.float32_logits) - - x = checkpoint_name(x, 'context') - + assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD + query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) + key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) + value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) + qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES) + query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint) + key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint) + value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint) + dpa_args = [query, key, value] + + x = DotProductAttention(head_dim=self.head_dim, + num_attention_heads=self.num_attention_heads, + num_gqa_groups=self.num_gqa_groups, + attn_mask_type=self.attn_mask_type, + attn_bias_type=self.attn_bias_type, + attention_dropout=self.attention_dropout, + dtype=self.dtype, + dropout_rng_name=self.dropout_rng_name, + float32_logits=self.float32_logits, + qkv_layout=qkv_layout.name, + scale_factor=scale_factor, + transpose_batch_sequence=self.transpose_batch_sequence)( + *dpa_args, mask, bias, deterministic=deterministic) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) - attn_context_sharding_constraint = \ - (SEQLEN_AXES, BATCH_AXES, HIDDEN_TP_AXES) \ - if self.transpose_batch_sequence \ - else (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES) + attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES) x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint) out = DenseGeneral(features=inputs_q.shape[-1], @@ -801,7 +1055,8 @@ def convert_to_softmax_type(attn_mask_type, mask): dtype=self.dtype, name='out')(x) out = checkpoint_name(out, 'out_proj') - return out, residual + + return out, ln_out class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-methods @@ -810,21 +1065,21 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-met Parameters ---------- - num_buckets : int + num_buckets: int The number of buckets to bucket distances between key and query positions into. - max_distance : int + max_distance: int The maximum distance before everything is lumped into the last distance bucket. - num_attention_heads : int + num_attention_heads: int Number of attention heads in the transformer layer. - embedding_init : Initializer, default = flax.linen.linear.default_embed_init + embedding_init: Initializer, default = flax.linen.linear.default_embed_init Used for initializing relative embedding tables. - embedding_axes : Tuple[str, ...], default = ('heads', 'relpos_buckets') + embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets') The name of axes used to shard embedding attention bias with a corresponding mesh. Optimization parameters ----------------------- - dtype : jax.numpy.dtype, default = jax.numpy.float32 + dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. """ num_buckets: int @@ -841,11 +1096,11 @@ def __call__(self, q_seqlen, k_seqlen, bidirectional=True): Parameters ---------- - q_seqlen : int + q_seqlen: int The sequence length of query. - k_seqlen : int + k_seqlen: int The sequence length of key. - bidirectional : bool, default = True + bidirectional: bool, default = True Indicate whether to allow positive memory-query relative position embeddings. @@ -917,11 +1172,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods an attention block and a feedforward network (MLP). This standard layer is based on the paper “Attention Is All You Need”. - .. note:: - - Argument :attr:`attention_mask` will be ignored when - :attr:`self_attn_mask_type` is set to `"causal"`. - Parameters ---------- hidden_size: int, default = 512 @@ -930,7 +1180,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods Intermediate size to which input samples are projected. num_attention_heads: int, default = 8 Number of attention heads in the transformer layer. - num_gqa_groups : int, default = `None` + num_gqa_groups: int, default = `None` Number of GQA groups. When `None` is present, it is equal to num_attention_heads. Grouped Query Attention is described in `this paper `_. @@ -938,11 +1188,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods GQA-1 is equivalent to Multi-Query Attention (`MQA `_), while GQA-H is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. - layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' + layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm' Indicate the type of layer normalization. layernorm_epsilon: float, default = 1e-6 A value added to the denominator of layer normalization for numerical stability. - zero_centered_gamma : bool, default = False + zero_centered_gamma: bool, default = False If set to `True`, the LayerNorm formula changes to .. math:: @@ -989,14 +1239,21 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods after the final dropout-add. default behavior is to apply layer normalization on the input side, before the QKV transformation. float32_attention_logits: bool, default = False - If set to True, attention logits are executed in jax.numpy.float32. + Whether to compute attention logits in float32 for the unfused attention backend. + For fused attention backend, the accumulation is always float32 without the perf overhead. layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER If set to TransformerLayerType.DECODER, an additional cross-attention block is added after self-attention.this can be used for structures like `T5` Transformer in conjunction with the TransformerLayerType.ENCODER option. - self_attn_mask_type: {'causal', 'padding'}, default = 'causal' - Type of attention mask passed into softmax operation. + self_attn_mask_type: str, default = 'causal' + Type of the attention mask passed into softmax operation in the self attention. + Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} Introduced in v0.10.0. + self_attn_bias_type: Optional[str], default = None + Type of the attention bias passed into the self attention. + Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. + When default is present, the type is automatically decided by the MHA's bias parameter. + Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used. enable_relative_embedding: bool, default = True Whether to enable relative embedding as shifting of attention logits. relative_embedding: flax.linen.Module, default = None @@ -1017,7 +1274,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- - dtype :jax.numpy.dtype, default = jax.numpy.float32 + dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. drop_path: float, default = 0.0 When > 0.0, applies stochastic depth per sample in the main @@ -1026,7 +1283,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods If set to True, `TransformerLayer` module exposes a single fused parameter for query-key-value for self-attention and key-value for cross-attention. - transpose_batch_sequence : bool, default = False + transpose_batch_sequence: bool, default = False Indicate whether the input tensors were switched axis of batch and sequence length dimension. if set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). @@ -1041,7 +1298,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods hidden_size: int = 512 mlp_hidden_size: int = 2048 num_attention_heads: int = 8 - num_gqa_groups: int | None = None + num_gqa_groups: Optional[int] = None layernorm_type: str = 'layernorm' layernorm_epsilon: float = 1e-6 zero_centered_gamma: bool = False @@ -1061,6 +1318,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods float32_attention_logits: bool = False layer_type: TransformerLayerType = TransformerLayerType.ENCODER self_attn_mask_type: str = 'causal' + self_attn_bias_type: Optional[str] = None enable_relative_embedding: bool = True relative_embedding: nn.Module = None enable_rotary_pos_emb: bool = False @@ -1097,29 +1355,29 @@ def __call__(self, Parameters ---------- - inputs : jax.numpy.ndarray + inputs: jax.numpy.ndarray Input tensor. - encoded : jax.numpy.ndarray, default = None + encoded: jax.numpy.ndarray, default = None Output tensors of the encoder block to be fed into the decoder block if using :attr:`layer_type=TransformerLayerType.DECODER`. attention_mask : jax.numpy.ndarray, default = None Boolean tensor used to mask out self-attention softmax input. - encoder_decoder_mask : jax.numpy.ndarray, default = None + encoder_decoder_mask: jax.numpy.ndarray, default = None Boolean tensor used to mask out cross-attention softmax input when :attr:`layer_type=TransformerLayerType.DECODER`. deterministic: bool, default = False Disable dropout layers if set to True. - decode: bool,default = False + decode: bool, default = False Indicate whether to prepare and use an autoregressive cache in Multi-head attention (MHA). - max_decode_length : bool, default = None + max_decode_length: bool, default = None The maximum length to generate relative embedding biases when :attr:`layer_type=TransformerLayerType.DECODER` and :attr:`enable_relative_embedding=True`. Returns ------- - outputs : jax.numpy.ndarray + outputs: jax.numpy.ndarray Output tensors. """ assert self.layer_type in TransformerLayerType, \ @@ -1184,14 +1442,15 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) # [batch, length, emb_dim] -> [batch, length, emb_dim] - x, residual = MultiHeadAttention( - num_heads=self.num_attention_heads, + residual = inputs + x, ln_out = MultiHeadAttention( + num_attention_heads=self.num_attention_heads, dtype=self.dtype, head_dim=head_dim, num_gqa_groups=self.num_gqa_groups, transpose_batch_sequence=self.transpose_batch_sequence, enable_sequence_parallel=self.enable_sequence_parallel, - dropout_rate=self.attention_dropout, + attention_dropout=self.attention_dropout, dropout_rng_name=self.dropout_rng_name, float32_logits=self.float32_attention_logits, scale_attn_logits=self.scale_attn_logits, @@ -1199,12 +1458,13 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): layernorm_type=self.layernorm_type, layernorm_epsilon=self.layernorm_epsilon, zero_centered_gamma=self.zero_centered_gamma, - apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, - output_layernorm=self.output_layernorm, + return_layernorm_output=self.apply_residual_connection_post_layernorm, + input_layernorm=not self.output_layernorm, attn_mask_type=self.self_attn_mask_type, + attn_bias_type=self.self_attn_bias_type, enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_windows=self.rotary_pos_emb_windows, - fuse_qkv=self.fuse_qkv_params, + fuse_qkv_params=self.fuse_qkv_params, kernel_init=self.mha_kernel_init, use_bias=self.use_bias, bias_init=self.bias_init, @@ -1236,6 +1496,11 @@ def hidden_dropout(x, deterministic): x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape, rng_collection=self.dropout_rng_name)(x, deterministic=deterministic) + + if self.apply_residual_connection_post_layernorm: + assert ln_out is not None + residual = ln_out + x = x + residual mlp_input = x @@ -1246,28 +1511,29 @@ def hidden_dropout(x, deterministic): x = with_sharding_constraint_by_logical_axes( x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) - y, residual = MultiHeadAttention( - num_heads=self.num_attention_heads, + residual = x + y, ln_out = MultiHeadAttention( + num_attention_heads=self.num_attention_heads, dtype=self.dtype, head_dim=head_dim, num_gqa_groups=self.num_gqa_groups, transpose_batch_sequence=self.transpose_batch_sequence, enable_sequence_parallel=self.enable_sequence_parallel, - dropout_rate=self.attention_dropout, + attention_dropout=self.attention_dropout, dropout_rng_name=self.dropout_rng_name, layernorm_type=self.layernorm_type, layernorm_epsilon=self.layernorm_epsilon, zero_centered_gamma=self.zero_centered_gamma, - apply_residual_connection_post_layernorm=self. - apply_residual_connection_post_layernorm, - output_layernorm=False, # Must do LayerNorm before MHA. + return_layernorm_output=self.apply_residual_connection_post_layernorm, + input_layernorm=True, # Must do LayerNorm before MHA. attn_mask_type='padding', + attn_bias_type='no_bias', enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_windows=self.rotary_pos_emb_windows, float32_logits=self.float32_attention_logits, scale_attn_logits=self.scale_attn_logits, scaled_query_init=self.scaled_query_init, - fuse_qkv=self.fuse_qkv_params, + fuse_qkv_params=self.fuse_qkv_params, kernel_init=self.mha_kernel_init, use_bias=self.use_bias, bias_init=self.bias_init, @@ -1282,6 +1548,11 @@ def hidden_dropout(x, deterministic): residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) y = hidden_dropout(y, deterministic) + + if self.apply_residual_connection_post_layernorm: + assert ln_out is not None + residual = ln_out + mlp_input = y + residual mlp_input = with_sharding_constraint_by_logical_axes( @@ -1342,6 +1613,6 @@ def hidden_dropout(x, deterministic): bias_axes=(W_NO_SHARD_AXES,), transpose_batch_sequence=self.transpose_batch_sequence, dtype=self.dtype, - name="output_layer_norm")(z) + name="output_layernorm")(z) return z diff --git a/transformer_engine/jax/fused_attn.py b/transformer_engine/jax/fused_attn.py index 7e6160f9e0..de8df768d1 100644 --- a/transformer_engine/jax/fused_attn.py +++ b/transformer_engine/jax/fused_attn.py @@ -16,6 +16,7 @@ from .cpp_extensions import FusedAttnHelper from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd +from .cpp_extensions import fused_attn_fwd, fused_attn_bwd class AttnBiasType(Enum): @@ -37,6 +38,21 @@ class QKVLayout(Enum): """QKV layout""" BS3HD = NVTE_QKV_Layout.NVTE_BS3HD BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD + BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD + + +def canonicalize_attn_mask_type(attn_mask_type: str): + """Convert string attn_mask_type to AttnMaskType + TE-JAX currently fall back to the padding version kernels for the libraries integration. + The overhead between padding and non-padding version should be small. + However, we will lease this limitation in the near feature. + """ + if attn_mask_type in ['causal', 'padding_causal']: + return AttnMaskType.PADDING_CAUSAL_MASK + if attn_mask_type in ['no_mask', 'padding']: + return AttnMaskType.PADDING_MASK + raise ValueError(f"Unsupported {attn_mask_type=}, " + "supported attn_mask_type={'no_mask', 'padding', 'causal', 'padding_causal'}") def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type, @@ -83,8 +99,12 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, scaling_factor: float, dropout_probability: float, is_training: bool): - mask = jnp.logical_not(mask) - actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) + if mask is None: + batch, seqlen, *_ = qkv.shape + actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32) + else: + mask = jnp.logical_not(mask) + actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) output, softmax_aux, rng_state = self_fused_attn_fwd(qkv, bias, actual_seqlen, @@ -159,14 +179,19 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): - - mask = jnp.logical_not(mask) - q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) - if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: - kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) + if mask is None: + batch, s_q, *_ = q.shape + s_kv = kv.shape[1] + q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32) + kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32) else: - # When mask is padding + causal, the actual seqlen is not the last row, use max to find it - kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2)) + mask = jnp.logical_not(mask) + q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) + if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: + kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) + else: + # When mask is causal, the actual seqlen is not the last row, use max to find it + kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2)) output, softmax_aux, rng_state = cross_fused_attn_fwd(q, kv, @@ -179,7 +204,9 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training) - + output = checkpoint_name(output, 'context') + softmax_aux = checkpoint_name(softmax_aux, 'context') + rng_state = checkpoint_name(rng_state, 'context') return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen) @@ -209,3 +236,100 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d _cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule) + + +def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, + seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, + scaling_factor: float, dropout_probability: float, is_training: bool): + """ + Dot product attention with the seperated query, key, value + """ + + output = _fused_attn(q, + k, + v, + bias, + mask, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + + return output + + +@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10)) +def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, + mask: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, scaling_factor: float, dropout_probability: float, + is_training: bool): + + output, _ = _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, + scaling_factor, dropout_probability, is_training) + return output + + +def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training): + if mask is None: + batch, s_q, *_ = q.shape + s_kv = k.shape[1] + q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32) + kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32) + else: + mask = jnp.logical_not(mask) + q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) + if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: + kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) + else: + # When mask is causal, the actual seqlen is not the last row, use max to find it + kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2)) + + output, softmax_aux, rng_state = fused_attn_fwd(q, + k, + v, + bias, + q_actual_seqlen, + kv_actual_seqlen, + seed, + attn_bias_type=attn_bias_type.value, + attn_mask_type=attn_mask_type.value, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + output = checkpoint_name(output, 'context') + softmax_aux = checkpoint_name(softmax_aux, 'context') + rng_state = checkpoint_name(rng_state, 'context') + return output, (q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, + kv_actual_seqlen) + + +def _fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, + is_training, ctx, dz): + q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx + + grad_q, grad_k, grad_v, grad_bias = fused_attn_bwd(q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + dz, + q_actual_seqlen, + kv_actual_seqlen, + attn_bias_type=attn_bias_type.value, + attn_mask_type=attn_mask_type.value, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + + if attn_bias_type == AttnBiasType.NO_BIAS: + grad_bias = None + + return grad_q, grad_k, grad_v, grad_bias, None, None + + +_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule) diff --git a/transformer_engine/jax/praxis/__init__.py b/transformer_engine/jax/praxis/__init__.py index 6da6ca5f4d..5be51a6d71 100644 --- a/transformer_engine/jax/praxis/__init__.py +++ b/transformer_engine/jax/praxis/__init__.py @@ -4,5 +4,6 @@ """Praxis related Modules""" from .module import FusedSoftmax, LayerNorm from .module import LayerNormLinear, LayerNormMLP, Linear, TransformerEngineBaseLayer -from .transformer import MultiHeadAttention, RelativePositionBiases, TransformerLayer +from .transformer import DotProductAttention, MultiHeadAttention +from .transformer import RelativePositionBiases, TransformerLayer from ..flax.transformer import TransformerLayerType diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index e98d2c422c..053b7768e4 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -6,6 +6,7 @@ """ from functools import partial from typing import Optional, Sequence, Tuple +import warnings from praxis import pax_fiddle from praxis.base_layer import WeightInit @@ -13,9 +14,11 @@ from .module import TransformerEngineBaseLayer from ..flax.transformer import TransformerLayerType +from ..flax.transformer import DotProductAttention as flax_DotProductAttention from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases from ..flax.transformer import TransformerLayer as flax_TransformerLayer +from ..fused_attn import AttnBiasType, AttnMaskType class RelativePositionBiases(TransformerEngineBaseLayer): @@ -59,30 +62,117 @@ def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = T return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional) +class DotProductAttention(TransformerEngineBaseLayer): + """DotProductAttention""" + + head_dim: int = 0 + num_attention_heads: int = 0 + num_gqa_groups: Optional[int] = None + attention_dropout: float = 0. + attn_mask_type: AttnMaskType = 'causal' + attn_bias_type: AttnBiasType = None + dropout_rng_name: str = 'dropout' + float32_logits: bool = False + qkv_layout: str = 'bshd_bshd_bshd' + scale_factor: Optional[float] = None + transpose_batch_sequence: bool = True + + def setup(self) -> None: + """setup""" + super().setup() + + assert self.head_dim > 0, f'{self.head_dim=}' + assert self.num_attention_heads > 0, f'{self.num_attention_heads=}' + + dpa_cls = partial(flax_DotProductAttention, + head_dim=self.head_dim, + num_attention_heads=self.num_attention_heads, + num_gqa_groups=self.num_gqa_groups, + attn_mask_type=self.attn_mask_type, + attn_bias_type=self.attn_bias_type, + attention_dropout=self.attention_dropout, + dtype=self.dtype, + dropout_rng_name=self.dropout_rng_name, + float32_logits=self.float32_logits, + qkv_layout=self.qkv_layout, + scale_factor=self.scale_factor, + transpose_batch_sequence=self.transpose_batch_sequence) + + self.create_layer("dot_product_attention", dpa_cls) + + def __call__(self, + query: JTensor, + key: JTensor, + value: JTensor, + mask: Optional[JTensor] = None, + bias: Optional[JTensor] = None, + *, + deterministic: bool = False) -> JTensor: + """__call__""" + return self.dot_product_attention(query, + key, + value, + mask, + bias, + deterministic=deterministic) + + class MultiHeadAttention(TransformerEngineBaseLayer): """MultiHeadAttention""" - head_dim: int = 64 - num_heads: int = 16 - num_gqa_groups: int | None = None - dropout_rate: float = 0. + head_dim: int = 0 + num_attention_heads: int = 0 + num_gqa_groups: Optional[int] = None + attention_dropout: float = 0. dropout_rng_name: str = 'dropout' + input_layernorm: bool = True layernorm_type: str = "layernorm" layernorm_epsilon: float = 1e-6 zero_centered_gamma: bool = False + return_layernorm_output: bool = False use_bias: bool = False bias_init: WeightInit = WeightInit.Constant(0.0) - apply_residual_connection_post_layernorm: bool = False - output_layernorm: bool = False attn_mask_type: str = 'causal' - fuse_qkv: bool = True + attn_bias_type: Optional[str] = None + fuse_qkv_params: bool = True transpose_batch_sequence: bool = True enable_sequence_parallel: bool = False scale_attn_logits: bool = False scaled_query_init: bool = True float32_logits: bool = False + # Deprecated parameters + num_heads: Optional[int] = None + dropout_rate: Optional[float] = None + output_layernorm: Optional[bool] = None + apply_residual_connection_post_layernorm: Optional[bool] = None + fuse_qkv: Optional[bool] = None + def __post_init__(self): + # Deal with the deprecated parameters + if self.num_heads is not None: + self.num_attention_heads = self.num_heads + warnings.warn( + f"{__class__}.num_heads is deprecated. It will be removed recently. " + f"Please uses {__class__}.num_attention_heads as the new API.", DeprecationWarning) + if self.dropout_rate is not None: + self.attention_dropout = self.dropout_rate + warnings.warn( + f"{__class__}.dropout_rate is deprecated. It will be removed recently. " + f"Please use {__class__}.attention_dropout as the new API.", DeprecationWarning) + if self.apply_residual_connection_post_layernorm is not None: + warnings.warn( + f"{__class__}.apply_residual_connection_post_layernorm is deprecated. " + f"It will be removed recently, please use {__class__}.return_layernorm_output.", + DeprecationWarning) + if self.fuse_qkv is not None: + warnings.warn( + f"{__class__}.fuse_qkv is deprecated. It will be removed recently. " + f"Please use {__class__}.fuse_qkv_params as the new API.", DeprecationWarning) + assert self.output_layernorm is None, ( + f"{__class__}.output_layernorm is deprecated. It will be removed recently. " + f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm.") + if self.num_gqa_groups is None: self.num_gqa_groups = self.num_heads super().__post_init__() @@ -91,24 +181,28 @@ def setup(self) -> None: """setup""" super().setup() + assert self.head_dim > 0, f'{self.head_dim=}' + assert self.num_attention_heads > 0, f'{self.num_attention_heads=}' + mha_cls = partial( flax_MultiHeadAttention, dtype=self.dtype, head_dim=self.head_dim, - num_heads=self.num_heads, + num_attention_heads=self.num_attention_heads, num_gqa_groups=self.num_gqa_groups, - dropout_rate=self.dropout_rate, + attention_dropout=self.attention_dropout, dropout_rng_name=self.dropout_rng_name, + input_layernorm=self.input_layernorm, layernorm_type=self.layernorm_type, layernorm_epsilon=self.layernorm_epsilon, zero_centered_gamma=self.zero_centered_gamma, + return_layernorm_output=self.return_layernorm_output, kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init), use_bias=self.use_bias, bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), - apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, - output_layernorm=self.output_layernorm, attn_mask_type=self.attn_mask_type, - fuse_qkv=self.fuse_qkv, + attn_bias_type=self.attn_bias_type, + fuse_qkv_params=self.fuse_qkv_params, transpose_batch_sequence=self.transpose_batch_sequence, enable_sequence_parallel=self.enable_sequence_parallel, scale_attn_logits=self.scale_attn_logits, @@ -140,7 +234,7 @@ class TransformerLayer(TransformerEngineBaseLayer): hidden_size: int = 512 mlp_hidden_size: int = 2048 num_attention_heads: int = 8 - num_gqa_groups: int | None = None + num_gqa_groups: Optional[int] = None layernorm_type: str = 'layernorm' layernorm_epsilon: float = 1e-6 zero_centered_gamma: bool = False @@ -158,6 +252,7 @@ class TransformerLayer(TransformerEngineBaseLayer): float32_attention_logits: bool = False layer_type: TransformerLayerType = TransformerLayerType.ENCODER self_attn_mask_type: str = 'causal' + self_attn_bias_type: Optional[str] = None enable_rotary_pos_emb: bool = False rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) enable_relative_embedding: bool = True @@ -226,6 +321,7 @@ def setup(self) -> None: float32_attention_logits=self.float32_attention_logits, layer_type=self.layer_type, self_attn_mask_type=self.self_attn_mask_type, + self_attn_bias_type=self.self_attn_bias_type, enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_windows=self.rotary_pos_emb_windows, enable_relative_embedding=self.enable_relative_embedding,