Skip to content

Commit

Permalink
Improve sm90 mixed dtype kernel (#1883)
Browse files Browse the repository at this point in the history
  • Loading branch information
sklevtsov-nvidia authored Oct 18, 2024
1 parent 755194a commit 08101d9
Show file tree
Hide file tree
Showing 11 changed files with 992 additions and 78 deletions.
657 changes: 657 additions & 0 deletions examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu

Large diffs are not rendered by default.

61 changes: 52 additions & 9 deletions examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@
Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest.
As an additional optimization, we can reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory.
This promotes vectorization of shared memory loads and removes additional instructions on the critical path. For example, when MMA is performed in FP8 data type, each thread reads
4 groups of 4 elements that are logically contiguous in the same row (refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n32-a for thread-value layout).
If the narrow type is INT4 and tensor is major in K dim, only 16 bits can be read at a time, leading to extra load instructions and suboptimal utilization of shared memory throughput.
If we reorder the data offline to place all 16 elements read by a thread contiguously in memory, a single 64-bit load is sufficient. This reordering is often feasible when the quantized
tensor is static (e.g. weight tensor of a NN layer at inference time). This example demonstrates how such a reordering can be performed and communicated to the kernel when the macro
OPTIMIZE_WEIGHT_LAYOUT is set to 1.
It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size).
Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled.
Expand Down Expand Up @@ -104,9 +112,12 @@
#include "helper.h"
#include "unfused_weight_dequantize.hpp"
#include "packed_scale.hpp"
#include "reorder_utils.hpp"

using namespace cute;

#define OPTIMIZE_WEIGHT_LAYOUT 1

#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -130,6 +141,17 @@ constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // M
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;

using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;

#if OPTIMIZE_WEIGHT_LAYOUT
// Define the CuTe layout for reoredered quantized tensor B
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
// It specifies the reordering within a single warp's fragment
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<MmaType>());
using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
#endif

using ElementScale = MmaType;
using ElementZero = ElementScale; // only for verify
using LayoutScale = cutlass::layout::RowMajor;
Expand Down Expand Up @@ -172,7 +194,11 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementB, cutlass::Array<ElementScale, 8> >, LayoutB_Transpose, AlignmentB,
#if OPTIMIZE_WEIGHT_LAYOUT
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, LayoutB_Reordered, AlignmentB,
#else
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, LayoutB_Transpose, AlignmentB,
#endif
ElementA, LayoutA_Transpose, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
Expand All @@ -190,8 +216,6 @@ using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<

using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;

using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
using StrideC = typename GemmKernelScaleOnly::StrideC;
using StrideD = typename GemmKernelScaleOnly::StrideD;

Expand All @@ -211,6 +235,10 @@ StrideD stride_D;
StrideD_ref stride_D_ref;
uint64_t seed;

#if OPTIMIZE_WEIGHT_LAYOUT
LayoutB_Reordered layout_B_reordered;
#endif

using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
StrideS stride_S;
Expand Down Expand Up @@ -399,7 +427,7 @@ bool unify_quant_encoding(
d = out;
}

cutlass::device_memory::copy_to_device((uint8_t*)block_out.get(), data.data(), block_out.size() / 2);
cutlass::device_memory::copy_to_device((StorageType*)block_out.get(), data.data(), block_out.size() / pack);
return true;
}

Expand Down Expand Up @@ -461,17 +489,19 @@ bool initialize_zero(
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(Options const& options) {

auto shape_b = cute::make_shape(options.n, options.k, options.l);
auto shape_B = cute::make_shape(options.n, options.k, options.l);
int const scale_k = (options.k + options.g - 1) / options.g;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
// Reverse stride here due to swap and transpose
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l));
stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l));
// Reverse stride here due to swap and transpose
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l));
stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l));

auto layout_B = make_layout(shape_B, stride_B);

auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
Expand All @@ -496,14 +526,22 @@ void initialize(Options const& options) {
initialize_packed_scale(block_scale, block_scale_packed);
initialize_zero(block_zero, options);

auto layout_B = make_layout(shape_b, stride_B);

auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);

dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);

#if OPTIMIZE_WEIGHT_LAYOUT
// Repeat the reorder layout atom to tile the whole tensor shape
layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
reorder_tensor(block_B_modified.get(), layout_B, layout_B_reordered);

print("Quantized tensor layout: ");
print(layout_B_reordered);
print("\n");
#endif
}

/// Populates a Gemm::Arguments structure from the given commandline options
Expand All @@ -515,7 +553,11 @@ Args args_from_options(Options const& options)
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{block_B_modified.get(), stride_B, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g},
#if OPTIMIZE_WEIGHT_LAYOUT
{block_B_modified.get(), layout_B_reordered, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g},
#else
{block_B_modified.get(), stride_B, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g},
#endif
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
}
Expand Down Expand Up @@ -581,6 +623,7 @@ bool verify(Options const& options) {
ElementD const epsilon(1e-2f);
ElementD const non_zero_floor(1e-4f);
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);

return passed;
}

Expand Down
11 changes: 11 additions & 0 deletions examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,14 @@ cutlass_example_add_executable(
TEST_SCALE_RESIDUE
# TEST_ALPHA_BETA
)

cutlass_example_add_executable(
55_hopper_int4_bf16_gemm
55_hopper_int4_bf16_gemm.cu
TEST_COMMAND_OPTIONS
TEST_DIRECT_BATCHED
TEST_SCALE_PERCOL
TEST_SCALE_GROUP
TEST_SCALE_RESIDUE
# TEST_ALPHA_BETA
)
1 change: 0 additions & 1 deletion examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

#pragma once

#include <iostream>
#include <cstdint>

#include "cutlass/float8.h"
Expand Down
122 changes: 122 additions & 0 deletions examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/

#include "cute/layout.hpp"
#include "cute/tensor.hpp"

#include "cutlass/util/device_memory.h"

// Given a type of MMA instruction, compute a memory reordering atom that places all values
// owned by each thread in contiguous memory locations. This improves smem load vectorization,
// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order
// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses.
template<class MmaType>
auto compute_memory_reordering_atom()
{
using namespace cute;

// 1. Choose an MMA atom to access TV layout and MN shape
// Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary
using MmaAtom = decltype(SM90::GMMA::rs_op_selector<MmaType, MmaType, float, Shape<_64,_16,_32>>());
using MmaTraits = MMA_Traits<MmaAtom>;
auto shape_MK = select<0,2>(typename MmaTraits::Shape_MNK{});
auto tv_layout_mma = typename MmaTraits::ALayout{};

// 2. Create a single warp's TV layout from that of the whole MMA
// Note: this assumes A is partitioned between warps along M mode
auto tile_TV_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma));
auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tile_TV_warp));

// 3. Invert warp's TV layout to get MK layout (m,k -> thr,val)
auto shape_MK_warp = shape_div(shape_MK, size(typename MmaTraits::ThrID{}) / Int<32>{});
auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(shape_MK_warp);

// 4. Compose with a contiguous layout of values in each thread (required for smem vectorization)
auto tv_to_offset = make_ordered_layout(shape(tv_layout_mma_warp), Step<_1,_0>{});
auto layout_atom = composition(tv_to_offset, mk_layout_mma_warp);

return layout_atom;
}

template<class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
__global__ void reorder_tensor_kernel(
cute::Tensor<EngineSrc, LayoutSrc> src,
cute::Tensor<EngineDst, LayoutDst> dst)
{
auto i = blockIdx.x;
auto k = blockIdx.y;
for (int j = threadIdx.x; j < cute::size<1>(src); j += blockDim.x) {
dst(i,j,k) = src(i,j,k);
}
}

template<class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
void reorder_tensor(
cute::Tensor<EngineSrc, LayoutSrc> t_src,
cute::Tensor<EngineDst, LayoutDst> t_dst)
{
using T = typename EngineDst::value_type;
static_assert(cute::is_same_v<cute::remove_const_t<typename EngineSrc::value_type>, T>, "Type mismatch");
using V = cute::uint_bit_t<cute::max(8, cute::sizeof_bits_v<T>)>;

cute::Tensor v_src = cute::recast<V>(t_src);
cute::Tensor v_dst = cute::recast<V>(t_dst);

int threads = 256;
dim3 blocks{unsigned(cute::size<0>(v_src)), unsigned(cute::size<2>(v_src)), 1u};

reorder_tensor_kernel<<<blocks, threads>>>(v_src, v_dst);
CUDA_CHECK(cudaDeviceSynchronize());
}

// In-place version
template<class T, class LayoutSrc, class LayoutDst>
void reorder_tensor(
T const* src,
LayoutSrc const& layout_src,
T * dst,
LayoutDst const& layout_dst)
{
reorder_tensor(make_tensor(src, layout_src),
make_tensor(dst, layout_dst));
}

// In-place version
template<class T, class LayoutSrc, class LayoutDst>
void reorder_tensor(
T * data,
LayoutSrc const& layout_src,
LayoutDst const& layout_dst)
{
cutlass::DeviceAllocation<T> temp(cute::size(layout_src));
reorder_tensor(data, layout_src, temp.get(), layout_dst);
cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(cute::size(layout_src)));
}
2 changes: 1 addition & 1 deletion include/cute/algorithm/tuple_algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ auto
all_of(T const& t, F&& f)
{
if constexpr (is_tuple<T>::value) {
return detail::apply(t, [&] (auto const&... a) { return (true_type{} && ... && f(a)); }, tuple_seq<T>{});
return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (true_type{} && ... && a); }, tuple_seq<T>{});
} else {
return f(t);
}
Expand Down
Loading

0 comments on commit 08101d9

Please sign in to comment.