-
Hi everyone, I am learning to use TMA, and doing so, I wrote a very simple program to copy data from a GMEM tensor to an SMEM with the same layout. My program crashed. I don't know how to debug (have tried a few things, details below), so I would really appreciate some help or insights here. Thank you! Now to the details. My program crashed with the error:
and
Here's a self-contained repro (I promise it's really short): /*******************
Self-contained example to study and debug TMA.
To build and run (for nvcc 12.3, but perhaps others work as well):
$ nvcc main.cu -G -g \
--expt-relaxed-constexpr \
--generate-code=arch=compute_90a,code=sm_90a \
-lcuda \
-w \
-Xcompiler=-Wconversion \
-Xcompiler=-fno-strict-aliasing \
-Xcompiler=-Wfatal-errors \
-Xcompiler=-Wno-abi \
-Xcompiler=-Wfatal-errors \
-std=c++17 \
-arch=sm_90 \
-I/usr/local/cuda/include \
-I/local/path/to/cutlass-3.4/include \
-I/local/path/to/cutlass-3.4/tools/util/include
$ a.out
********************/
#include <cstdio>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <cutlass/cutlass.h>
#include "cute/tensor.hpp"
#include "cute/arch/cluster_sm90.hpp"
template <
class T,
class TensorX,
class GmemLayout,
class SmemLayout,
class TmaLoad
>
__global__ static void
tma_kernel(
TensorX tX,
GmemLayout gmem_layout,
SmemLayout smem_layout,
TmaLoad tma_load
) {
using namespace cute;
__shared__ T smem[cosize_v<SmemLayout>];
__shared__ uint64_t tma_load_mbar[1];
auto sX = make_tensor(make_smem_ptr(smem), smem_layout);
auto mX = tma_load.get_tma_tensor(shape(gmem_layout));
auto gX = local_tile(mX, shape(smem_layout), make_coord(0, 0)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
auto cta_tma_load = tma_load.get_slice(0);
auto tXgX = cta_tma_load.partition_S(gX); // (TMA,TMA_M,TMA_N,REST_M,REST_N)
auto tXsX = cta_tma_load.partition_D(sX); // (TMA,TMA_M,TMA_N)
auto warp_idx = cutlass::canonical_warp_idx_sync();
auto lane_predicate = cute::elect_one_sync();
if (warp_idx == 0 && lane_predicate) {
constexpr int k_tma_transaction_bytes = size(sX) * sizeof_bits_v<T> / 8;
tma_load_mbar[0] = 0;
cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/);
cute::set_barrier_transaction_bytes(tma_load_mbar[0], k_tma_transaction_bytes);
cute::copy(tma_load.with(tma_load_mbar[0]), tXgX, tXsX);
}
__syncthreads();
constexpr int k_phase_bit = 0;
cute::wait_barrier(tma_load_mbar[0], k_phase_bit);
}
int main() {
using namespace cute;
using T = float;
constexpr int m = 4;
constexpr int n = 4;
// create data
thrust::host_vector<T> cpu_data(m * n);
for (int i = 0; i < m*n; ++i) {
cpu_data[i] = static_cast<T>(i);
}
thrust::device_vector<T> gpu_data = cpu_data;
cudaDeviceSynchronize();
// create tensors
auto gmem_layout = Layout<Shape<Int<n>, Int<m>>>{};
auto smem_layout = Layout<Shape<Int<n>, Int<m>>>{};
auto pX = reinterpret_cast<const T*>(gpu_data.data().get());
auto gX = make_tensor(make_gmem_ptr(pX), gmem_layout);
// create the TMA object
auto tma_load = make_tma_copy(SM90_TMA_LOAD{}, gX, smem_layout);
// launch the kernel
dim3 blk_dim{1, 1, 1};
dim3 grd_dim{1, 1, 1};
tma_kernel<
T,
decltype(gX),
decltype(gmem_layout),
decltype(smem_layout),
decltype(tma_load)
>
<<< grd_dim, blk_dim >>>
(gX, gmem_layout, smem_layout, tma_load);
CUTE_CHECK_LAST();
return 0;
}
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You need to replace |
Beta Was this translation helpful? Give feedback.
You need to replace
TmaLoad tma_load
byCUTE_GRID_CONSTANT TmaLoad const tma_load
or__grid_constant__ TmaLoad const tma_load
. The program should run then.