Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Global variable inside conv2d kernel #1987

Open
IzanCatalan opened this issue Dec 15, 2024 · 21 comments
Open

[QST] Global variable inside conv2d kernel #1987

IzanCatalan opened this issue Dec 15, 2024 · 21 comments

Comments

@IzanCatalan
Copy link

IzanCatalan commented Dec 15, 2024

What is your question?
Hello, good day. I am currently researching the Conv2dFprop kernel as I intend to modify its implementation in the library, specifically in the file https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/conv/kernel/implicit_gemm_convolution.h.

This is because, whether called from a C++ or Python program on the host, this .h file is the last in the execution hierarchy and directly implements the convolution operation to be run on the GPU (in my case, NVIDIA A100 and V100).

My question is: Can a global variable be implemented within this class? I intend to assign a specific number of elements to this global variable, called multiply_tensor, which would then multiply the convolution parameters from then and the following calls to the class while the host code is still running.

I aim for this variable to be stored in GPU memory, then initialized and processed in the GPU during the first call and reused in subsequent ones. I am unsure if a global variable is a solution or if a new kernel parameter would be better.

Is this feasible?

@jackkosaian
Copy link
Contributor

I expect that a better strategy would be to add a new variable to the Arguments struct in that file. This allows you to pass the value of this parameter in on the host.

You'll also need to add something similar to the corresponding Params struct so that it is available on the device side.

If you simply want a compile-time constant, you should be able to just use a #define macro.

@IzanCatalan
Copy link
Author

IzanCatalan commented Dec 17, 2024

@jackkosaian, I prefer the first approach, modifying the Arguments struct and params struct and adding a new tensor based on what I want to achieve. This way, I could assign and save the tensor to the GPU memory. Theoretically, in the next call to a convolution, should it be saved already, as I understood?

And another matter is, if I need to perform an operation, for example, throughout all convolutional filters, it would be easier to do so with GPU threads. Could it be possible, for example, that each thread multiplies a specific filter for an element of the new tensor saved in GPU memory? And, am I right in thinking that the best part of the code where to do so is in operator() function??

void operator()(Params const &params, SharedStorage &shared_storage) {

@jackkosaian
Copy link
Contributor

Theoretically, in the next call to a convolution, should it be saved already, as I understood?

I don't fully understand what you're thinking of here, unfortunately.

Could it be possible, for example, that each thread multiplies a specific filter for an element of the new tensor saved in GPU memory?

Yes, this is possible.

And, am I right in thinking that the best part of the code where to do so is in operator() function??

The place that you'd want to perform this transformation is once operands have been loaded from shared memory into registers (e.g., here).

@IzanCatalan
Copy link
Author

IzanCatalan commented Dec 20, 2024

@jackkosaian

I don't fully understand what you're thinking of here, unfortunately.

I am thinking that if I have two or more calls in a program to execute a convolution operation, for example, if I am reproducing a neural network structure like resnet50 or vgg16 (which have multiple convolutions) and with different inputs and iterations over the structure, the first iteration to each convolution would transform (using the new argument "multiply tensor") its filters and save them in GPU shared memory, and then, the following iterations would found the filters already modified and they should not do anything else. Do you think it could be done?

The place that you'd want to perform this transformation is once operands have been loaded from shared memory into registers (e.g., here).

What is this code exactly doing? It seems to select parts from frag A and B, but I am unfamiliar with the pipe_state variable. I believe that according to what I explained in the previous answer, it seems more natural to transform the filters before splitting the matrixes.

@jackkosaian
Copy link
Contributor

Thanks for the additional details.

its filters and save them in GPU shared memory

Any modifications done in shared memory will not persist through multiple kernel calls (e.g., for different layers). If you're updating the values in global memory, then what you're describing should be possible.

That said, have you considered just multiplying each input for the different "iterations" rather than trying to perform the multiplication once and saving the results back out for later reuse? Multiplication should be much cheaper in this case than storing results back to global memory.

What is this code exactly doing?
It is selecting portions previously loaded from smem -> register file for computation via the Tensor Core. By multiplying the values in register file by your multiply_tensor argument, you will have modified the input before performing the Tensor Core instruction.

@IzanCatalan
Copy link
Author

@jackkosaian, sorry for the delay in answering you.

That said, have you considered just multiplying each input for the different "iterations" rather than trying to perform the multiplication once and saving the results back out for later reuse? Multiplication should be much cheaper in this case than storing results back to global memory.

My concern is saving performance time when running all convolution layers of a model. It is not the input that I intend to modify, but only the filters of the convolution, which are always the same and do not change. That is why I only desire to perform the transformation once and store the result in global memory.

A user case would be a Neural Network model, like Resnet50, with 50 convolutions, and, for example, to perform inference on this model with a dataset with 5000 images. Each Image is an iteration across 50 convolutions, so it means 5000 iterations (if I use different batch sizes, there would be fewer iterations).

If I multiply and modify the filters of each convolution (which are always the same for that convolution) before the convolution process in the layer every time it is performed, It would add a huge amount of extra time.

I believe it is better to do the first iteration for each of those 50 convolutions, store the modified filters, and re-use them for the following 4999 iterations.

I hope this helps clear up my intentions. Could you share some feedback about it?

@jackkosaian
Copy link
Contributor

Thanks for the additional details. I think I understand your intention a bit better.

I agree that it makes sense to modify the filters once and reuse them across remaining iterations.

However, I don't know that it will be most beneficial to do this by fusing the multiplication within the convolution kernel itself. Doing so will require a fair amount of modification to the target CUTLASS convolution kernel, and it's unclear whether doing so would be any more performant than simply running a separate kernel to modify the filters before running all iterations.

Before endeavoring on modifying the CUTLASS kernel, have you benchmarked running a separate kernel to modify the filters before running the iterations? Is the performance unsatisfactory? This approach will be much easier to implement.

@IzanCatalan
Copy link
Author

IzanCatalan commented Jan 3, 2025

@jackkosaian thanks for the feedback.

However, I don't know that it will be most beneficial to do this by fusing the multiplication within the convolution kernel itself. Doing so will require a fair amount of modification to the target CUTLASS convolution kernel, and it's unclear whether doing so would be any more performant than simply running a separate kernel to modify the filters before running all iterations.

I desire to modify the least amount of source code possible; I assume it would have been optimized already, and I could undo this. However, I agree that the best solution possible is to fuse the multiplication (it would be more like performing bit-operations: OR and AND) to the filters. That is what I asked in the previous comments, to perform this in "operator()" functions, because with "Params", I assume you can manipulate the filters and save them before the convolutions kernel. Therefore, the rest of the code would be equal.

Before endeavoring on modifying the CUTLASS kernel, have you benchmarked running a separate kernel to modify the filters before running the iterations? Is the performance unsatisfactory? This approach will be much easier to implement.

Yes, I perform a similar approach using example59, here I modified the ampere_conv_kernel.h file, in the operator() function and the TFlops performance was almost identical. I decided to modify the filters here to take advantage of the threads. Here is the code. My idea is to replicate this but instead of with and example, in the convolution kernel:

template <class EngineFlt, class TensorActivation, class TensorOutput>
  void __device__
  operator()(cute::Tensor<EngineFlt, GmemLayoutFlt> mFlt, // ( K,        (C,T,R,S))
             TensorActivation                       mAct, // ((N,Z,P,Q), (C,T,R,S))
             TensorOutput                           mOut, // ( K,        (N,Z,P,Q))
             char* smem_buf) const {
    using namespace cute;
    using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveMma<
        cutlass::gemm::MainloopSm80CpAsyncUnpredicated<PIPE::value>,
        Shape<TileM,TileN,TileK>,
        ElementFlt,
        Underscore, // Ignore the stride, we are passing full cute::Tensor to operator()
        ElementAct,
        Underscore, // Ignore the stride, we are passing full cute::Tensor to operator()
        TiledMma,
        GmemTiledCopyFlt,
        SmemLayoutAtomFlt,
        SmemCopyAtomFlt,
        cute::identity,
        GmemTiledCopyAct,
        SmemLayoutAtomAct,
        SmemCopyAtomAct,
        cute::identity>;

    TiledMma tiled_mma;
    Tensor accum = partition_fragment_C(tiled_mma, TilerOut{});
    clear(accum);


    // Set up tensors
    // NOTE: blockIdx.x projects onto act-NDHW mode, y along the flt-K mode for the sake of higher dynamic range in NDHW
    Tensor gA_mk = local_tile(mFlt, TilerFlt{}, make_coord(_,_));                              // (BLK_M,BLK_K,m',k')
    Tensor gB_nk = local_tile(mAct, TilerAct{}, make_coord(_,_));                              // (BLK_N,BLK_K,n',_1)
    Tensor gC_mn = local_tile(mOut, TilerOut{}, make_coord(_,_));                              // (BLK_M,BLK_N,m',n')

    // Compute m_coord and n_coord with their post-tiled shapes
    auto m_coord = idx2crd(int(blockIdx.y), shape<2>(gA_mk));
    auto n_coord = idx2crd(int(blockIdx.x), shape<2>(gB_nk));
    Tensor gA = gA_mk(_,_,m_coord,_);                                                          // (BLK_M,BLK_K,k')
    Tensor gB = gB_nk(_,_,n_coord,_);                                                          // (BLK_N,BLK_K,_1)
    Tensor gC = gC_mn(_,_,m_coord,n_coord);                                                    // (BLK_M,BLK_N)

    const uint32_t and_mask = 0xAAAAAAA; //AND
    const uint32_t or_mask  = 0x5555555; // OR

    int total_elements = size(gA);
    int total_threads = gridDim.x * blockDim.x; // Total threads availables
    int thread_id = threadIdx.x + blockIdx.x * blockDim.x;

    int elements_per_thread = (total_elements + total_threads - 1) / total_threads; 
    int start_idx = thread_id * elements_per_thread;
    int end_idx = min(start_idx + elements_per_thread, total_elements);

    for (int i = start_idx; i < end_idx; ++i) {
        uint32_t original_bits = reinterpret_cast<uint32_t&>(gA(i));
        original_bits = (original_bits & and_mask) | or_mask;
        gA(i) = reinterpret_cast<tfloat32_t&>(original_bits);
    }

    auto k_tile_iter = cute::make_coord_iterator(size<2>(gA));
    int k_tile_count = size<2>(gA);

    CollectiveMainloop collective_mma;
    collective_mma(
      accum,
      gA,
      gB,
      accum,
      k_tile_iter, k_tile_count,
      Underscore{}, // no residue since we do not support predication
      threadIdx.x,
      smem_buf);

@jackkosaian
Copy link
Contributor

The approach that you prototyped is slightly different than what I had in mind. I was suggesting that you write one kernel that does:

    for (int i = start_idx; i < end_idx; ++i) {
        uint32_t original_bits = reinterpret_cast<uint32_t&>(gA(i));
        original_bits = (original_bits & and_mask) | or_mask;
        gA(i) = reinterpret_cast<tfloat32_t&>(original_bits);
    }

And then just call into the CUTLASS kernel for however many iterations you'd like.
At a high level, this would look something like:

manipulate_filters<<<nBlocks, nThreads>>>(args);
conv_op.run();
conv_op.run();
conv_op.run();
...

This would require no modification of CUTLASS source.

@IzanCatalan
Copy link
Author

IzanCatalan commented Jan 4, 2025

@jackkosaian I get your point, but I think that in the operator() function you have access to all threads, and the functionality you are describing would be basically the same:

void operator()(Params const &params, SharedStorage &shared_storage) {

I say this because, my final goal is to call this convolution kernel from an outside program. I'm actually planning to export it to Python, so it'll be easier for an external user to just call the convolution, and the kernel will perform the manipulations of the filters and the convolution.

manipulate_filters<<<nBlocks, nThreads>>>(args);
conv_op.run();
conv_op.run();
conv_op.run();
...

I don't think I could achieve this with this structure. I'd need to run two kernels: one to manipulate the filters and another to run the convolutions. And, as I said, I use it from Python, so I would only call to the convolution like in this example .

@jackkosaian
Copy link
Contributor

Ok. I personally think it might still be easier to have two kernels, but, if you want to fuse this to a CUTLASS kernel, you'll want to perform the scaling at the location I mention here: #1987 (comment)

@IzanCatalan
Copy link
Author

IzanCatalan commented Jan 8, 2025

Yes @jackkosaian, it is imperative for me to merge the two and make the call transparent to the end user, who will only need to call the convolution kernel with one extra parameter (i.e. the masks). So I am thinking of changing the definition of the cutlass convolution. Is there any particular part of the repository where the kernel is declared for changing it? and the hierarchy of files I need to change?

If you're updating the values in global memory, then what you're describing should be possible.

Yes, you commented about accessing global memory, but I have some doubts: How can I access global memory (and what variables) and in what part of the code?

@jackkosaian
Copy link
Contributor

If you just want to access the global memory, you can do so via the ptr_A and ptr_B members of Params here.

These are available in the operator() method, which is essentially the kernel entry-point, here.

@IzanCatalan
Copy link
Author

IzanCatalan commented Jan 11, 2025

@jackkosaian, Could you give me a hand on this? As a first step, I intend to print the filters' tensor in the operator() function, the first call to a convolution kernel. My code is the following, but apparently, neither the filters (all elements are printed as 0s) nor the bool flag is ok:

struct Params {
  // ...existing code...

  // barrier to only be used in the first call to a convolution
  bool first_call = True;

  // ...existing code...
};

// ...existing code...

CUTLASS_DEVICE
void operator()(Params &params, SharedStorage &shared_storage) {

  int threadId = threadIdx.x + blockIdx.x * blockDim.x;

  if (threadId == 0 && params.first_call) {
      for (int k = 0; k < params.problem_size.K; ++k) {
        for (int c = 0; c < params.problem_size.C; ++c) {
          for (int r = 0; r < params.problem_size.R; ++r) {
            for (int s = 0; s < params.problem_size.S; ++s) {
              printf("A[%d, %d, %d, %d] = %f\n", k, c, r, s, params.ptr_B[bk* params.problem_size.C * params.problem_size.R * params.problem_size.S + c * params.problem_size.R * params.problem_size.S + r * params.problem_size.S + s]);
            }
          }
        }
      }
    }
    // first call is done
    params.first_call = false;
  }

Moreover, params.first_call = false; gives me an error: error: expression must be a modifiable lvalue.

Could you indicate to me if I am accessing the filters pointer (ptr_b) in the right way, and how you would implement a bool flag only to execute this code the first time a convolution kernel is done (this is a preliminary step to in the future instead of only print values, modify the filters as I commented in previous comments)?

@jackkosaian
Copy link
Contributor

Moreover, params.first_call = false; gives me an error: error: expression must be a modifiable lvalue.

params are declared as const here. They cannot be modified.

how you would implement a bool flag only to execute this code the first time a convolution kernel is done (this is a preliminary step to in the future instead of only print values, modify the filters as I commented in previous comments)?

Why not instead just have a flag that you pass in as part of the Arguments structure (here) (which then gets set in Params via the constructor here) indicating whether to perform the scaling? You set it to true the first time the conv is called and false in all remaining times. Modifying Params like you're trying to do is unlikely to work out, since Params can be overwritten on each kernel invocation.

Could you indicate to me if I am accessing the filters pointer (ptr_b) in the right way

I'm not sure what error you're getting in accessing ptr_B. However, you'll note that there's an iterator associated with ptr_B in the Params struct: iterator_B here. Investigating and utilizing that should be helpful in performing what you'd like.

@IzanCatalan
Copy link
Author

IzanCatalan commented Jan 12, 2025

Why not instead just have a flag that you pass in as part of the Arguments structure (here) (which then gets set in Params via the constructor here) indicating whether to perform the scaling? You set it to true the first time the conv is called and false in all remaining times. Modifying Params like you're trying to do is unlikely to work out, since Params can be overwritten on each kernel invocation.

@jackkosaian Thank you for the info. I will try to implement what you suggested, and I will let you know about it.

I'm not sure what error you're getting in accessing ptr_B. However, you'll note that there's an iterator associated with ptr_B in the Params struct: iterator_B here. Investigating and utilizing that should be helpful in performing what you'd like.

However, I am more concerned about this. Accessing ptr_B I get no error, but with the code I posted, all elements printed are 0, therefore I assume I am not accessing the params the right way. You mentioned iterator_B, what is exaxtly its purposes? Or its functions?

I find no information about how to use Iterators or how to access the elements of the tensor inside ptr_B.

Besides that, I find there is a bigger issue. If params are const and they cannot be modified, how then can I performed a modification of the filters (like i described previoulsly ) inside operator() method?

I have seen that IteratorB serves as a parameter to Mma, but I don't know exactly what this function does or its source code.

Actually the extra code I want to add is very simple as I have shown you in other posts, a few lines, and for them I only need a part of the code accessible by the gpu threads (like the operator() function) and iterate and modify the filters in global memory (like ptr_B). I am not sure how to achieve this because I don't have a deep knowledge of the repository, sorry if my questions are basic, but can you help me to clarify these doubts?

@jackkosaian
Copy link
Contributor

Besides that, I find there is a bigger issue. If params are const and they cannot be modified, how then can I performed a modification of the filters (like i described #1987 (comment)) inside operator() method?

Yes, params itself is const. Thus, you won't be able to change the ptr_B TensorRef itself. However, you should be able to write to memory pointed to at ptr_B.data(), since the ptr_ member of the TensorRef does not point to a constant allocation (see here).

Regarding what the iterators do, I would suggest looking through the source code for the iterator being used for ptr_B in your case. Here is an example of one such iterator for an fprop.

You can determine what datatype is used for IteratorB in your case by forcing a compilation error (e.g., adding static_assert(Mma::IteratorB::non_existant_method(), ""); here). This should print out the type of the iterator class as part of the error message.

You can then look for the definition of this class. It will likely be in one of the *iterator*.h files here.

You'll see that these iterators define how each thread iterates in the advance() method (e.g., here).

@IzanCatalan
Copy link
Author

IzanCatalan commented Jan 23, 2025

You can determine what datatype is used for IteratorB in your case by forcing a compilation error (e.g., adding static_assert(Mma::IteratorB::non_existant_method(), ""); here). This should print out the type of the iterator class as part of the error message.

@jackkosaian I have implemented what you suggested to me to find out what data type is used for IteratorB. Of course, in testing example 16 I already know the data type so that I can check it, but the point is finding out from the device.

These are the parameters from the host CPU:

using ElementAccumulator = float;                  // Data type of accumulator
using ElementComputeEpilogue = float;              // Data type of epilogue computation (alpha, beta)
using ElementInputA = cutlass::half_t;             // Data type of elements in input tensor
using ElementInputB = cutlass::half_t;             // Data type of elements in input tensor
using ElementOutput = float;     
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;

However, there are a lot of errors printed, and I don't understand exactly what is happening. The iterator class for the device is Conv2dFpropFilterTileAccessIteratorOptimized, but there are also errors with Conv2dFpropFilterTileAccessIteratorAnalytic.

The same thing applies to the data type. Cutlass::half_t seems to be the one, but there are the following errors with multiple data types:

/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<8, 128>, float, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 128>, 256, 1>, cutlass::AlignedArray<float, 1, 4>, false>>" has no member "non_existant_method"

This is the complete error output from make -j$(nproc) 2>&1 | grep "error":


/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<128, 128>, ElementInputB, LayoutInputB, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<128, 128>, 128, cutlass::PitchLinearShape<4, 8>, 32>, cutlass::AlignedArray<ElementInputA, 32, 16>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 64>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 64>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<8, 128>, float, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 128>, 256, 1>, cutlass::AlignedArray<float, 1, 4>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 64>, int8_t, cutlass::layout::TensorCxRSKx<32>, cutlass::transform::TransposePitchLinearThreadMap<cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<2048, 1>, 64, cutlass::PitchLinearShape<32, 1>, 16>, cutlass::PitchLinearShape<2, 16>>, cutlass::AlignedArray<int8_t, 16, 16>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<8, 128>, float, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 128>, 128, 1>, cutlass::AlignedArray<float, 1, 4>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, int8_t, cutlass::layout::TensorCxRSKx<32>, cutlass::transform::TransposePitchLinearThreadMap<cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<4096, 1>, 128, cutlass::PitchLinearShape<32, 1>, 16>, cutlass::PitchLinearShape<2, 16>>, cutlass::AlignedArray<int8_t, 16, 16>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<8, 128>, float, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 128>, 256, 1>, cutlass::AlignedArray<float, 1, 4>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, ElementInputB, LayoutInputB, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<ElementInputA, 8, 16>, false>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<8, 64>, float, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 64>, 128, 1>, cutlass::AlignedArray<float, 1, 4>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, ElementInputB, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<128, 32>, 128, cutlass::PitchLinearShape<8, 4>, 8>, cutlass::AlignedArray<ElementInputA, 8, 16>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<8, 64>, float, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 64>, 64, 1>, cutlass::AlignedArray<float, 1, 4>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<8, 128>, float, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 128>, 64, 1>, cutlass::AlignedArray<float, 1, 4>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<8, 64>, float, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 64>, 128, 1>, cutlass::AlignedArray<float, 1, 4>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, ElementInputB, LayoutInputB, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<ElementInputA, 8, 16>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<8, 32>, float, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 32>, 64, 1>, cutlass::AlignedArray<float, 1, 4>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<16, 64>, ElementInputB, LayoutInputB, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<16, 64>, 128, cutlass::PitchLinearShape<4, 8>, 4>, cutlass::AlignedArray<ElementInputA, 4, 16>, false>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<16, 64>, ElementInputB, LayoutInputB, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<16, 64>, 128, cutlass::PitchLinearShape<4, 8>, 4>, cutlass::AlignedArray<ElementInputA, 4, 16>, false>" has no member "non_existant_method"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_optimized/cutlass_simt_sfprop_optimized_128x128_8x2_nhwc_align1.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_analytic/cutlass_simt_sfprop_analytic_128x128_8x2_nhwc_align1.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_optimized/cutlass_simt_sfprop_optimized_64x128_8x2_nhwc_align1.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<8, 32>, float, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 32>, 64, 1>, cutlass::AlignedArray<float, 1, 4>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<8, 64>, ElementInputB, LayoutInputB, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 64>, 128, 1>, cutlass::AlignedArray<ElementInputA, 1, 16>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 64>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 64>, 64, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<64, 64>, int8_t, cutlass::layout::TensorCxRSKx<32>, cutlass::transform::TransposePitchLinearThreadMap<cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<2048, 2>, 64, cutlass::PitchLinearShape<32, 1>, 16>, cutlass::PitchLinearShape<2, 16>>, cutlass::AlignedArray<int8_t, 16, 16>, false>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, false>>" has no member "non_existant_method"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<64, 128>, int8_t, cutlass::layout::TensorCxRSKx<32>, cutlass::transform::TransposePitchLinearThreadMap<cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<4096, 2>, 128, cutlass::PitchLinearShape<32, 1>, 16>, cutlass::PitchLinearShape<2, 16>>, cutlass::AlignedArray<int8_t, 16, 16>, false>" has no member "non_existant_method"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_analytic/cutlass_simt_sfprop_analytic_128x64_8x2_nhwc_align1.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 64>, int8_t, cutlass::layout::TensorCxRSKx<32>, cutlass::transform::TransposePitchLinearThreadMap<cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<2048, 1>, 64, cutlass::PitchLinearShape<32, 1>, 16>, cutlass::PitchLinearShape<2, 16>>, cutlass::AlignedArray<int8_t, 16, 16>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<8, 64>, float, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 64>, 64, 1>, cutlass::AlignedArray<float, 1, 4>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<8, 128>, float, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 128>, 64, 1>, cutlass::AlignedArray<float, 1, 4>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, int8_t, cutlass::layout::TensorCxRSKx<32>, cutlass::transform::TransposePitchLinearThreadMap<cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<4096, 1>, 128, cutlass::PitchLinearShape<32, 1>, 16>, cutlass::PitchLinearShape<2, 16>>, cutlass::AlignedArray<int8_t, 16, 16>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<8, 128>, float, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 128>, 128, 1>, cutlass::AlignedArray<float, 1, 4>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_optimized/cutlass_simt_sfprop_optimized_64x64_8x2_nhwc_align1.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_optimized/cutlass_simt_sfprop_optimized_32x128_8x2_nhwc_align1.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<64, 64>, int8_t, cutlass::layout::TensorCxRSKx<32>, cutlass::transform::TransposePitchLinearThreadMap<cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<2048, 2>, 64, cutlass::PitchLinearShape<32, 1>, 16>, cutlass::PitchLinearShape<2, 16>>, cutlass::AlignedArray<int8_t, 16, 16>, false>" has no member "non_existant_method"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_optimized/cutlass_simt_sfprop_optimized_128x64_8x2_nhwc_align1.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_optimized/cutlass_simt_sfprop_optimized_128x32_8x2_nhwc_align1.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<64, 128>, int8_t, cutlass::layout::TensorCxRSKx<32>, cutlass::transform::TransposePitchLinearThreadMap<cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<4096, 2>, 128, cutlass::PitchLinearShape<32, 1>, 16>, cutlass::PitchLinearShape<2, 16>>, cutlass::AlignedArray<int8_t, 16, 16>, false>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 64>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 64>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, false>" has no member "non_existant_method"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 64, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, false>" has no member "non_existant_method"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/30_wgrad_split_k/30_wgrad_split_k.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 64>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 64>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, false>" has no member "non_existant_method"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_analytic/cutlass_simt_sfprop_analytic_128x32_8x2_nhwc_align1.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 64, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, false>" has no member "non_existant_method"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_analytic/cutlass_simt_sfprop_analytic_64x64_8x2_nhwc_align1.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/22_quaternion_conv/quaternion_conv.cu".
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_analytic/cutlass_simt_sfprop_analytic_32x128_8x2_nhwc_align1.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_analytic/cutlass_simt_sfprop_analytic_64x128_8x2_nhwc_align1.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<64, 64>, ElementInputB, LayoutInputB, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, cutlass::PitchLinearShape<8, 4>, 8>, cutlass::AlignedArray<ElementInputA, 8, 16>, cutlass::conv::GroupMode::kMultipleGroup, false>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<64, 64>, ElementInputB, LayoutInputB, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, cutlass::PitchLinearShape<8, 4>, 8>, cutlass::AlignedArray<ElementInputA, 8, 16>, false>" has no member "non_existant_method"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu".
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<64, 64>, ElementInputB, LayoutInputB, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, cutlass::PitchLinearShape<8, 4>, 8>, cutlass::AlignedArray<ElementInputA, 8, 16>, cutlass::conv::GroupMode::kSingleGroup, false>" has no member "non_existant_method"
3 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu".
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu".
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu".
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu".
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<8, 64>, cutlass::complex<float>, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 8>, 128, 1>, cutlass::conv::StrideSupport::kUnity, cutlass::AlignedArray<cutlass::complex<float>, 1, 8>>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<8, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 128>, 128, 1>, cutlass::AlignedArray<cutlass::half_t, 1, 2>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<16, 128>, cutlass::half_t, cutlass::layout::TensorNDHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<16, 128>, 128, cutlass::PitchLinearShape<2, 16>, 8>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<128, 32>, 128, cutlass::PitchLinearShape<8, 4>, 8>, cutlass::conv::StrideSupport::kUnity, cutlass::AlignedArray<cutlass::half_t, 8, 16>>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<8, 64>, cutlass::complex<float>, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 64>, 64, 1>, cutlass::AlignedArray<cutlass::complex<float>, 1, 8>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<128, 32>, 128, cutlass::PitchLinearShape<8, 4>, 8>, cutlass::conv::StrideSupport::kUnity, cutlass::AlignedArray<cutlass::half_t, 8, 16>>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<128, 32>, 128, cutlass::PitchLinearShape<8, 4>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<128, 32>, 128, cutlass::PitchLinearShape<8, 4>, 8>>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic<cutlass::MatrixShape<8, 64>, cutlass::complex<float>, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 8>, 128, 1>, cutlass::AlignedArray<cutlass::complex<float>, 1, 8>>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<cutlass::MatrixShape<8, 64>, cutlass::complex<float>, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 8>, 32, 1>, cutlass::conv::StrideSupport::kUnity, cutlass::AlignedArray<cutlass::complex<float>, 1, 8>>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<8, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 128>, 128, 1>, cutlass::AlignedArray<cutlass::half_t, 1, 2>, cutlass::conv::GroupMode::kDepthwise, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<8, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 128>, 128, 1>, cutlass::AlignedArray<cutlass::half_t, 1, 2>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<128, 32>, 128, cutlass::PitchLinearShape<8, 4>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<8, 128>, cutlass::complex<float>, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 128>, 128, 1>, cutlass::AlignedArray<cutlass::complex<float>, 1, 8>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<8, 128>, float, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 128>, 128, 1>, cutlass::AlignedArray<float, 1, 4>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<128, 32>, 128, cutlass::PitchLinearShape<8, 4>, 8>, cutlass::conv::StrideSupport::kUnity, cutlass::AlignedArray<cutlass::half_t, 8, 16>>>" has no member "non_existant_method"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_fprop_with_reduction_sm75.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<128, 32>, 128, cutlass::PitchLinearShape<8, 4>, 8>, cutlass::conv::StrideSupport::kUnity, cutlass::AlignedArray<cutlass::half_t, 2, 4>>>" has no member "non_existant_method"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<128, 32>, 128, cutlass::PitchLinearShape<8, 4>, 8>, cutlass::conv::StrideSupport::kUnity, cutlass::AlignedArray<cutlass::half_t, 2, 4>>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<128, 32>, 128, cutlass::PitchLinearShape<8, 4>, 8>, cutlass::conv::StrideSupport::kUnity, cutlass::AlignedArray<cutlass::half_t, 8, 16>>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<cutlass::MatrixShape<8, 64>, cutlass::complex<float>, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 8>, 32, 1>, cutlass::AlignedArray<cutlass::complex<float>, 1, 8>>>" has no member "non_existant_method"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu".
4 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, false>>" has no member "non_existant_method"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<8, 64>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<8, 64>, 128, 1>, cutlass::AlignedArray<cutlass::half_t, 1, 2>, cutlass::conv::GroupMode::kDepthwise, false>>" has no member "non_existant_method"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, false>>" has no member "non_existant_method"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu".
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<128, 32>, 128, cutlass::PitchLinearShape<8, 4>, 8>, cutlass::AlignedArray<cutlass::half_t, 4, 8>>>" has no member "non_existant_method"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 2, 4>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<128, 32>, 128, cutlass::PitchLinearShape<8, 4>, 8>, cutlass::AlignedArray<cutlass::half_t, 4, 8>>>" has no member "non_existant_method"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm50.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 4, 8>, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 4, 8>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(76): error: class "cutlass::conv::threadblock::TileIterator<cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<32, 128>, cutlass::half_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<cutlass::half_t, 8, 16>, cutlass::conv::GroupMode::kNone, false>>" has no member "non_existant_method"
3 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu".
5 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_fprop_with_broadcast_sm70.cu".
3 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu".

Yes, params itself is const. Thus, you won't be able to change the ptr_B TensorRef itself. However, you should be able to write to memory pointed to at ptr_B.data(), since the ptr_ member of the TensorRef does not point to a constant allocation (see here).

I have checked if the following code works, but it gives me also some errors:

   for (int n = 0; n < params.problem_size.K; n++) {
        for (int h = 0; h < params.problem_size.R; h++) {
          for (int w = 0; w < params.problem_size.S; w++) {
            for (int c= 0; c < params.problem_size.C; c++) {
              auto value = params.ptr_B.data().at({n, h, w, c});
              float real_value = static_cast<float>(value);
              printf("B[%d, %d, %d, %d] = %f\n", n, h, w, c, real_value);
            }
          }
        }
      }

These are the errors, again seems like cutlass::half appears a lot, but datatypes like int or float also appears, so I don't quit understand how it woks:

/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const ElementInputB *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const ElementInputB *const"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const ElementInputB *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const ElementInputB *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const ElementInputB *const"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const int8_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/30_wgrad_split_k/30_wgrad_split_k.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const int8_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const ElementInputB *const"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/22_quaternion_conv/quaternion_conv.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const int8_t *const"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const int8_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const ElementInputB *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const ElementInputB *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const ElementInputB *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const ElementInputB *const"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const int8_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const int8_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu".
3 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const int8_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const int8_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const float *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const float *const"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_analytic/cutlass_simt_sfprop_analytic_64x64_8x2_nhwc_align1.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_analytic/cutlass_simt_sfprop_analytic_128x128_8x2_nhwc_align1.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const float *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const float *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const float *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const float *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const float *const"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_analytic/cutlass_simt_sfprop_analytic_64x128_8x2_nhwc_align1.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_analytic/cutlass_simt_sfprop_analytic_32x128_8x2_nhwc_align1.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const float *const"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_optimized/cutlass_simt_sfprop_optimized_128x32_8x2_nhwc_align1.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_optimized/cutlass_simt_sfprop_optimized_32x128_8x2_nhwc_align1.cu".
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu".
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const float *const"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu".
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_analytic/cutlass_simt_sfprop_analytic_128x64_8x2_nhwc_align1.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_optimized/cutlass_simt_sfprop_optimized_128x64_8x2_nhwc_align1.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const float *const"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_analytic/cutlass_simt_sfprop_analytic_128x32_8x2_nhwc_align1.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const float *const"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_optimized/cutlass_simt_sfprop_optimized_64x64_8x2_nhwc_align1.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const float *const"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_optimized/cutlass_simt_sfprop_optimized_128x128_8x2_nhwc_align1.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/build/tools/library/generated/conv2d/50/sfprop_optimized/cutlass_simt_sfprop_optimized_64x128_8x2_nhwc_align1.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::complex<float> *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const float *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::complex<float> *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm50.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu".
3 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::complex<float> *const"
4 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::complex<float> *const"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::half_t *const"
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::complex<float> *const"
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu".
/mnt/beegfs/gap/[email protected]/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h(369): error: expression must have class type but it has type "const cutlass::complex<float> *const"
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu".
5 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu".
2 errors detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu".
1 error detected in the compilation of "/mnt/beegfs/gap/[email protected]/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu".

@jackkosaian
Copy link
Contributor

Sorry for the confusion. The sole point of introducing static assert was to determine which iterator type was being used. Now that you've determined this, you can remove the static assert from your code.

It looks like you'll want to familiarize yourself with how the iterator traverses memory here: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h

@IzanCatalan
Copy link
Author

IzanCatalan commented Jan 24, 2025

Yes @jackkosaian I understood your point. But I am not sure that I have determined anything. I mean, I know what kind of operator and datatype I am using because I directly assign it from the host CPU.

But that is useless because from inside the divide code in https://github.com/NVIDIA/cutlass/blob/24f991e87930e1159f1f5a47e329d43bcfbd76b9/include/cutlass/conv/kernel/implicit_gemm_convolution.h I don't know what data type or element type am I using, all I know is Mma::IteratorB::Element but that gives me no information. Furthermore, I don't know what casting I should do if I want to modify the values or which data type to use. Your assertion gives me more or less the type (cutlass:half), but there are the following errors, repeating the same with float and int, as you can check in the output I showed you.

I deduce that as the first error and the most repeated one is with cutlass:half, that is the data type bit it could vary, so again, I don't know how to make a casting with a generic data type (i.e with reinterpret_cast<>¿?) to allow manipulate elements one by one.

I alse have checked https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h, there are methods like AccessType const *get() const, but the variables which returns is constant, so I deduce I cannot modify it.

I was thinking of something like (pseudocode):

var = IteratorB.get()
//manipulations
IteratorB.advance()

But again, I cannot access the values if they are const. Conv2dFpropFilterTileAccessIteratorOptimized &operator++() also accesses memory and advances, but it does not return a concrete element.

So, I am very lost on how to do something simple: accessing the tensor and changing its values. Because params in https://github.com/NVIDIA/cutlass/blob/24f991e87930e1159f1f5a47e329d43bcfbd76b9/include/cutlass/conv/kernel/implicit_gemm_convolution.h save prt_B in the form of a tensor.data(), a pointer to all data, but accessing that tensor as ref_B.data().at({n, h, w, c}); should it work according to https://github.com/NVIDIA/cutlass/blob/375e284e6aef68b81d58b116dff9e0970f64c5cd/media/docs/layout.md#tensorref ? It gives me an error of segmentation fault because I iterate over K,R,S and C size variables.

@IzanCatalan
Copy link
Author

@jackkosaian I posted a new issue with updates about the Iterator. In case you are able to help me, this is the post: #2067

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants