Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Adding a flag in Tensor Ref Class #2080

Open
IzanCatalan opened this issue Feb 5, 2025 · 4 comments
Open

[QST] Adding a flag in Tensor Ref Class #2080

IzanCatalan opened this issue Feb 5, 2025 · 4 comments

Comments

@IzanCatalan
Copy link

IzanCatalan commented Feb 5, 2025

What is your question?
Hi, I want to define an extra parameter in Tensor_ref class. In my case a flag, in the form of an integer pointer to be accesses when the convolution is performed in https://github.com/NVIDIA/cutlass/blob/24f991e87930e1159f1f5a47e329d43bcfbd76b9/include/cutlass/conv/kernel/implicit_gemm_convolution.h:

  /// Pointer
  Element* ptr_;
  int* check;

The flag is to check if the GPU uses the tensor. So I have also modified the constructors of the class (in my case, as it is only intended to be used in GPU, I use cudaMalloc:

  /// Constructs a TensorRef with a pointer and layout object.
  CUTLASS_HOST_DEVICE
  TensorRef(
    Element *ptr,                   ///< pointer to start of tensor
    Layout const &layout            ///< layout object containing stride and mapping function
  ):
    ptr_(ptr), layout_(layout), check(nullptr){ 
      cudaMalloc((void**)&check, sizeof(int)); 
      cudaMemset(check, 0, sizeof(int)); // Inicializar en 0
    }


  /// Converting constructor from TensorRef to non-constant data.
  template<typename _Magic = int>
  CUTLASS_HOST_DEVICE
  TensorRef(
    NonConstTensorRef const &ref,              ///< TensorRef to non-const data
    ///SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const
    _Magic magic = (typename platform::enable_if< ! platform::is_same<NonConstTensorRef, TensorRef<Element_, Layout_> >::value, _Magic>::type)0
  ):
    ptr_(ref.data()), layout_(ref.layout()), check(nullptr){ 
      cudaMalloc((void**)&check, sizeof(int)); 
      cudaMemset(check, 0, sizeof(int)); // Inicializar en 0
    } 

Finally, I have added a new function similar to data() to pass the pointer to the parameters of the convolution in a similar way as is done in implicit_gemmem_convolution.h:

    /// Returns the check object
  CUTLASS_HOST_DEVICE
  int * isChecked() const {return check;}


As a result, I have also modified different parts of the convolution kernel assigning in Params the pointer to a new value first_call: https://github.com/NVIDIA/cutlass/blob/24f991e87930e1159f1f5a47e329d43bcfbd76b9/include/cutlass/conv/kernel/implicit_gemm_convolution.h:

struct Params {
    ...
    int *first_call;
    ...

    //
    // Methods
    //

    CUTLASS_HOST_DEVICE
    Params(): swizzle_log_tile(0), gemm_k_iterations(0) { }

    /// 
    CUTLASS_HOST_DEVICE
    Params(
      Arguments const &args,
      int *semaphore = nullptr
    ):
     ...
      first_call(args.ref_B.isChecked()),
     ...
    {

Later, in operator() function, as the both *ptr and *check are not const pointers in Tensor_Ref class, they can be accessed.

Only ptr_B works fine, but the program is suddenly stacked when I access and modify first_call.

I am executing example 16 to check the implementation.

This is the code when I modify the flag using first call parameter:

  void operator()(Params const &params, SharedStorage &shared_storage) {
    int threadId = threadIdx.x + blockIdx.x * blockDim.x;
    if (threadId == 0) {
      if (*params.first_call == 0){
      *params.first_call = 1;
      }
    }

First_call is modified because I have printed after this piece of code, but when It arrives to line 343 the process it gets stacked. It uses the GPU. I have checked that but don't know why it stops there.

I think that is perhaps some kind of memory-free problem due to how I reserve memory in the constructor of the Tensor_Ref Class. Maybe it is not the proper way of doing it because I don't do any free method on the new integer pointer.

Should I modify host tensor and device_memory classes which are the ones used to define tensors from the host as is described in example [16]? :

cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);

Any help would be appreciated.

Izan.

@hwu36
Copy link
Collaborator

hwu36 commented Feb 6, 2025

You just want to add a flag in the device memory to tag if the tensor is first used? just pass a pointer through params, and change the value in the mainloop. No need to hack tensor ref to do simple things like this.

  CUTLASS_HOST_DEVICE
  TensorRef(
    Element *ptr,                   ///< pointer to start of tensor
    Layout const &layout            ///< layout object containing stride and mapping function
  ):
    ptr_(ptr), layout_(layout), check(nullptr){ 
      cudaMalloc((void**)&check, sizeof(int)); 
      cudaMemset(check, 0, sizeof(int)); // Inicializar en 0
    }

this is hacky and ugly. the ctor is CUTLASS_HOST_DEVICE. I don't know what happens when you call cudaMalloc and cudaMemset from there.

@IzanCatalan
Copy link
Author

IzanCatalan commented Feb 6, 2025

Hi @hwu36, I will explain to you my final goal; maybe you will have a better idea of what I am trying to do because I have seen you have replied to me in another issue like #2067, and these two are related.

My user case is to execute multiples times the same convolution, that is something you can do in example 16 launching several iterations over the same convolution kernel.

I explain this to you because my idea is that I can check and do some modifications to the filters of a convolution but only the first time this convolution is launched; that is what the flag check is for, to know if the operator() function is performed the by the first iteration or not.

Therefore, in the several following iterations, the filters would have already been stored in GPU memory (This is related to my issue #1987). So, to know if one iteration is the first, I had thought of adding an extra parameter to the filter tensor so it can be saved in GPU memory (this is the check variable). I had decided to modify Tensor_Ref class because I believe both variables (the pointer to the filter parameters of the tensor and the check flag should be parameters of the same class).

In addition, I plan to add another parameter to the Tensor_Ref Class, which will contain the modified values.

This is my goal. I'm sorry if my previous question didn't clear up this.

If I just pass a pointer through params without doing cudamalloc, it gives me a segmentation fault or cudaEventSynchronize() failed error when accessing the variable in operator() function because it is a CPU variable in a GPU code.

Now that perhaps you have a clear an concise idea of my doubt, perhaps you can give any advise of how I should proceed, because if I need to reserve GPU memory, its likely that I need to add something in host tensor and device_memory classes, but I have no a clear idea of how cutlass manages all of this.

Thank you for your help.

@hwu36
Copy link
Collaborator

hwu36 commented Feb 6, 2025

You can just create another HostTensor for check and pass it to the kernel as a parameter.

@IzanCatalan
Copy link
Author

@hwu36 Yes, that seems like the easiest thing to do. However, I have a couple of questions:

  1. Adding an extra parameter to the kernel would mean changing the definition of the function, the arguments, or anything else to allow cutlass to compile and import the kernel?

In this case, apart from here, do I need to change any other declaration in the kernel hierarchy?

  1. What should I do about the check flag in this case?

  2. Suppose I still wanted to group both the flag and the new tensor with the pointer to the data (meaning the Tensor_Ref class would now have 3 private attributes):

  /// Pointer
  Element* ptr_;
  Element* ptr_modified;
  int* check;

Even if this is a tricky thing to do, how should I proceed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants