Skip to content

Commit

Permalink
Merge branch 'NVIDIA:main' into fused_out_correction
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyao0115 authored Dec 24, 2024
2 parents 89bbeb7 + 838345e commit d346d9c
Show file tree
Hide file tree
Showing 30 changed files with 950 additions and 385 deletions.
1 change: 1 addition & 0 deletions .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
|| github.actor == 'kocchop'
|| github.actor == 'youngeunkwon0405'
|| github.actor == 'KshitijLakhani'
|| github.actor == 'jberchtold-nvidia'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 43 files
+1 −1 CMakeLists.txt
+10 −0 docs/operations/Attention.md
+3 −2 include/cudnn_backend_base.h
+1 −0 include/cudnn_frontend.h
+24 −2 include/cudnn_frontend/graph_helpers.h
+28 −0 include/cudnn_frontend/graph_interface.h
+32 −1 include/cudnn_frontend/graph_properties.h
+6 −0 include/cudnn_frontend/node/paged_cache_load.h
+3 −0 include/cudnn_frontend/node/resample.h
+372 −481 include/cudnn_frontend/node/scaled_dot_product_flash_attention.h
+4 −1 include/cudnn_frontend/node/sdpa_fp8.h
+5 −1 include/cudnn_frontend/node/sdpa_fp8_bwd.h
+7 −3 include/cudnn_frontend/plans.h
+387 −0 include/cudnn_frontend/utils/attn_score_modifiers.h
+3 −3 include/cudnn_frontend_EngineFallbackList.h
+3 −3 include/cudnn_frontend_ExecutionPlan.h
+3 −4 include/cudnn_frontend_Operation.h
+1 −1 include/cudnn_frontend_OperationGraph.h
+3 −4 include/cudnn_frontend_get_plan.h
+2 −0 include/cudnn_frontend_shim.h
+1 −1 include/cudnn_frontend_utils.h
+1 −1 include/cudnn_frontend_version.h
+2 −2 pyproject.toml
+1 −1 python/cudnn/__init__.py
+16 −0 python/pygraph/pygraph.cpp
+3 −0 python/pygraph/pygraph.h
+2 −2 python/pygraph/sdpa.cpp
+3 −0 samples/cpp/CMakeLists.txt
+205 −0 samples/cpp/convolution/conv_dynamic_shape_benchmark.cpp
+2 −1 samples/cpp/convolution/fp8_fprop.cpp
+4 −0 samples/cpp/convolution/fprop.cpp
+5 −1 samples/cpp/convolution/wgrads.cpp
+144 −0 samples/cpp/norm/layernorm.cpp
+207 −0 samples/cpp/sdpa/fp16_bwd_with_flexible_graphs.cpp
+198 −0 samples/cpp/sdpa/fp16_fwd_with_flexible_graphs.cpp
+1 −1 samples/cpp/utils/helpers.h
+5 −3 samples/legacy_samples/fp16_emu.cpp
+1 −1 samples/legacy_samples/helpers.cpp
+5 −0 samples/legacy_samples/test_list.cpp
+3 −1 samples/python/50_scaled_dot_product_attention.ipynb
+5 −3 samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb
+7 −0 test/python/test_conv_bias.py
+112 −60 test/python/test_mhas.py
4 changes: 2 additions & 2 deletions examples/pytorch/comm_gemm_overlap/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Forward and backward passes with layer weights distributed over all GPUs in a single node.

```bash
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_with_overlap.py
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py

# Sample output on 8x H100s:
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3, 4, 5, 6, 7]
Expand Down Expand Up @@ -70,7 +70,7 @@ Uses `torch.nn.parallel.DistributedDataParallel` for replicatin the model across
groups in a single node.

```bash
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_overlap.py --num-replicas 2
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py --num-replicas 2

# Sample output on 8x H100s:
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3]
Expand Down
15 changes: 15 additions & 0 deletions qa/L0_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

set -xe

: ${TE_PATH:=/opt/transformerengine}

pip install -r $TE_PATH/examples/jax/encoder/requirements.txt

# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
3 changes: 1 addition & 2 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,4 @@ pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist

# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py
2 changes: 1 addition & 1 deletion qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
Expand All @@ -22,3 +21,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
1 change: 1 addition & 0 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
2 changes: 1 addition & 1 deletion tests/jax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def clear_live_arrays():


@pytest.fixture(autouse=True, scope="module")
def enable_fused_attn():
def enable_fused_attn_after_hopper():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
Expand Down
6 changes: 2 additions & 4 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from utils import (
make_causal_mask,
make_self_mask,
assert_tree_like_allclose,
assert_allclose,
print_debug_tensor_stats,
)
Expand All @@ -32,7 +31,6 @@
AttnMaskType,
QKVLayout,
QKVFormat,
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
Expand Down Expand Up @@ -421,7 +419,7 @@ def impl_test_contex_parallel_attn(
dropout_prob = 0.0
is_training = True
dp_size, cp_size, tp_size = mesh_shape
qkv_format = get_qkv_format(qkv_layout)
qkv_format = qkv_layout.get_qkv_format()

batch, seqlen, num_head, hidden = data_shape

Expand Down Expand Up @@ -503,7 +501,7 @@ def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
_, max_seq_len, num_heads, _ = data_shape
gradient_multiplier = max_seq_len * num_heads
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]:
if attn_mask_type.is_causal():
gradient_multiplier /= 10
ret_valid = func(*args, **kwargs)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype)
Expand Down
Loading

0 comments on commit d346d9c

Please sign in to comment.