-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QST] Global variable inside conv2d kernel #1987
Comments
I expect that a better strategy would be to add a new variable to the Arguments struct in that file. This allows you to pass the value of this parameter in on the host. You'll also need to add something similar to the corresponding Params struct so that it is available on the device side. If you simply want a compile-time constant, you should be able to just use a |
@jackkosaian, I prefer the first approach, modifying the Arguments struct and params struct and adding a new tensor based on what I want to achieve. This way, I could assign and save the tensor to the GPU memory. Theoretically, in the next call to a convolution, should it be saved already, as I understood? And another matter is, if I need to perform an operation, for example, throughout all convolutional filters, it would be easier to do so with GPU threads. Could it be possible, for example, that each thread multiplies a specific filter for an element of the new tensor saved in GPU memory? And, am I right in thinking that the best part of the code where to do so is in operator() function??
|
I don't fully understand what you're thinking of here, unfortunately.
Yes, this is possible.
The place that you'd want to perform this transformation is once operands have been loaded from shared memory into registers (e.g., here). |
I am thinking that if I have two or more calls in a program to execute a convolution operation, for example, if I am reproducing a neural network structure like resnet50 or vgg16 (which have multiple convolutions) and with different inputs and iterations over the structure, the first iteration to each convolution would transform (using the new argument "multiply tensor") its filters and save them in GPU shared memory, and then, the following iterations would found the filters already modified and they should not do anything else. Do you think it could be done?
What is this code exactly doing? It seems to select parts from frag A and B, but I am unfamiliar with the pipe_state variable. I believe that according to what I explained in the previous answer, it seems more natural to transform the filters before splitting the matrixes. |
Thanks for the additional details.
Any modifications done in shared memory will not persist through multiple kernel calls (e.g., for different layers). If you're updating the values in global memory, then what you're describing should be possible. That said, have you considered just multiplying each input for the different "iterations" rather than trying to perform the multiplication once and saving the results back out for later reuse? Multiplication should be much cheaper in this case than storing results back to global memory.
|
@jackkosaian, sorry for the delay in answering you.
My concern is saving performance time when running all convolution layers of a model. It is not the input that I intend to modify, but only the filters of the convolution, which are always the same and do not change. That is why I only desire to perform the transformation once and store the result in global memory. A user case would be a Neural Network model, like Resnet50, with 50 convolutions, and, for example, to perform inference on this model with a dataset with 5000 images. Each Image is an iteration across 50 convolutions, so it means 5000 iterations (if I use different batch sizes, there would be fewer iterations). If I multiply and modify the filters of each convolution (which are always the same for that convolution) before the convolution process in the layer every time it is performed, It would add a huge amount of extra time. I believe it is better to do the first iteration for each of those 50 convolutions, store the modified filters, and re-use them for the following 4999 iterations. I hope this helps clear up my intentions. Could you share some feedback about it? |
Thanks for the additional details. I think I understand your intention a bit better. I agree that it makes sense to modify the filters once and reuse them across remaining iterations. However, I don't know that it will be most beneficial to do this by fusing the multiplication within the convolution kernel itself. Doing so will require a fair amount of modification to the target CUTLASS convolution kernel, and it's unclear whether doing so would be any more performant than simply running a separate kernel to modify the filters before running all iterations. Before endeavoring on modifying the CUTLASS kernel, have you benchmarked running a separate kernel to modify the filters before running the iterations? Is the performance unsatisfactory? This approach will be much easier to implement. |
@jackkosaian thanks for the feedback.
I desire to modify the least amount of source code possible; I assume it would have been optimized already, and I could undo this. However, I agree that the best solution possible is to fuse the multiplication (it would be more like performing bit-operations: OR and AND) to the filters. That is what I asked in the previous comments, to perform this in "operator()" functions, because with "Params", I assume you can manipulate the filters and save them before the convolutions kernel. Therefore, the rest of the code would be equal.
Yes, I perform a similar approach using example59, here I modified the ampere_conv_kernel.h file, in the operator() function and the TFlops performance was almost identical. I decided to modify the filters here to take advantage of the threads. Here is the code. My idea is to replicate this but instead of with and example, in the convolution kernel:
|
The approach that you prototyped is slightly different than what I had in mind. I was suggesting that you write one kernel that does: for (int i = start_idx; i < end_idx; ++i) {
uint32_t original_bits = reinterpret_cast<uint32_t&>(gA(i));
original_bits = (original_bits & and_mask) | or_mask;
gA(i) = reinterpret_cast<tfloat32_t&>(original_bits);
} And then just call into the CUTLASS kernel for however many iterations you'd like. manipulate_filters<<<nBlocks, nThreads>>>(args);
conv_op.run();
conv_op.run();
conv_op.run();
... This would require no modification of CUTLASS source. |
@jackkosaian I get your point, but I think that in the operator() function you have access to all threads, and the functionality you are describing would be basically the same:
I say this because, my final goal is to call this convolution kernel from an outside program. I'm actually planning to export it to Python, so it'll be easier for an external user to just call the convolution, and the kernel will perform the manipulations of the filters and the convolution.
I don't think I could achieve this with this structure. I'd need to run two kernels: one to manipulate the filters and another to run the convolutions. And, as I said, I use it from Python, so I would only call to the convolution like in this example . |
Ok. I personally think it might still be easier to have two kernels, but, if you want to fuse this to a CUTLASS kernel, you'll want to perform the scaling at the location I mention here: #1987 (comment) |
Yes @jackkosaian, it is imperative for me to merge the two and make the call transparent to the end user, who will only need to call the convolution kernel with one extra parameter (i.e. the masks). So I am thinking of changing the definition of the cutlass convolution. Is there any particular part of the repository where the kernel is declared for changing it? and the hierarchy of files I need to change?
Yes, you commented about accessing global memory, but I have some doubts: How can I access global memory (and what variables) and in what part of the code? |
@jackkosaian, Could you give me a hand on this? As a first step, I intend to print the filters' tensor in the operator() function, the first call to a convolution kernel. My code is the following, but apparently, neither the filters (all elements are printed as 0s) nor the bool flag is ok:
Moreover, Could you indicate to me if I am accessing the filters pointer (ptr_b) in the right way, and how you would implement a bool flag only to execute this code the first time a convolution kernel is done (this is a preliminary step to in the future instead of only print values, modify the filters as I commented in previous comments)? |
Why not instead just have a flag that you pass in as part of the
I'm not sure what error you're getting in accessing |
@jackkosaian Thank you for the info. I will try to implement what you suggested, and I will let you know about it.
However, I am more concerned about this. Accessing I find no information about how to use Iterators or how to access the elements of the tensor inside ptr_B. Besides that, I find there is a bigger issue. If I have seen that IteratorB serves as a parameter to Mma, but I don't know exactly what this function does or its source code. Actually the extra code I want to add is very simple as I have shown you in other posts, a few lines, and for them I only need a part of the code accessible by the gpu threads (like the |
Yes, Regarding what the iterators do, I would suggest looking through the source code for the iterator being used for You can determine what datatype is used for You can then look for the definition of this class. It will likely be in one of the You'll see that these iterators define how each thread iterates in the |
@jackkosaian I have implemented what you suggested to me to find out what data type is used for IteratorB. Of course, in testing example 16 I already know the data type so that I can check it, but the point is finding out from the device. These are the parameters from the host CPU:
However, there are a lot of errors printed, and I don't understand exactly what is happening. The iterator class for the device is Conv2dFpropFilterTileAccessIteratorOptimized, but there are also errors with Conv2dFpropFilterTileAccessIteratorAnalytic. The same thing applies to the data type. Cutlass::half_t seems to be the one, but there are the following errors with multiple data types:
This is the complete error output from
I have checked if the following code works, but it gives me also some errors:
These are the errors, again seems like cutlass::half appears a lot, but datatypes like int or float also appears, so I don't quit understand how it woks:
|
Sorry for the confusion. The sole point of introducing static assert was to determine which iterator type was being used. Now that you've determined this, you can remove the static assert from your code. It looks like you'll want to familiarize yourself with how the iterator traverses memory here: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h |
Yes @jackkosaian I understood your point. But I am not sure that I have determined anything. I mean, I know what kind of operator and datatype I am using because I directly assign it from the host CPU. But that is useless because from inside the divide code in https://github.com/NVIDIA/cutlass/blob/24f991e87930e1159f1f5a47e329d43bcfbd76b9/include/cutlass/conv/kernel/implicit_gemm_convolution.h I don't know what data type or element type am I using, all I know is I deduce that as the first error and the most repeated one is with cutlass:half, that is the data type bit it could vary, so again, I don't know how to make a casting with a generic data type (i.e with reinterpret_cast<>¿?) to allow manipulate elements one by one. I alse have checked https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h, there are methods like I was thinking of something like (pseudocode):
But again, I cannot access the values if they are const. So, I am very lost on how to do something simple: accessing the tensor and changing its values. Because params in https://github.com/NVIDIA/cutlass/blob/24f991e87930e1159f1f5a47e329d43bcfbd76b9/include/cutlass/conv/kernel/implicit_gemm_convolution.h save prt_B in the form of a tensor.data(), a pointer to all data, but accessing that tensor as |
@jackkosaian I posted a new issue with updates about the Iterator. In case you are able to help me, this is the post: #2067 |
What is your question?
Hello, good day. I am currently researching the
Conv2dFprop
kernel as I intend to modify its implementation in the library, specifically in the file https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/conv/kernel/implicit_gemm_convolution.h.This is because, whether called from a C++ or Python program on the host, this
.h
file is the last in the execution hierarchy and directly implements the convolution operation to be run on the GPU (in my case, NVIDIA A100 and V100).My question is: Can a global variable be implemented within this class? I intend to assign a specific number of elements to this global variable, called
multiply_tensor
, which would then multiply the convolution parameters from then and the following calls to the class while the host code is still running.I aim for this variable to be stored in GPU memory, then initialized and processed in the GPU during the first call and reused in subsequent ones. I am unsure if a global variable is a solution or if a new kernel parameter would be better.
Is this feasible?
The text was updated successfully, but these errors were encountered: