This repository has been archived by the owner on Feb 4, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
89 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# ninja log v5 | ||
0 14129 1705021139520500980 main.o 4279c12607ecc039 | ||
0 22396 1705021147784612815 cuda.cuda.o a37a24931cf259f6 | ||
22396 22710 1705021148100617092 square_matrix_extension.so ff0a0db8e20d3c15 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
ninja_required_version = 1.3 | ||
cxx = c++ | ||
nvcc = /usr/local/cuda/bin/nvcc | ||
|
||
cflags = -DTORCH_EXTENSION_NAME=square_matrix_extension -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include/TH -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/envs/cudamode/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 | ||
post_cflags = | ||
cuda_cflags = -DTORCH_EXTENSION_NAME=square_matrix_extension -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include/TH -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/envs/cudamode/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -std=c++17 | ||
cuda_post_cflags = | ||
cuda_dlink_post_cflags = | ||
ldflags = -shared -L/opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart | ||
|
||
rule compile | ||
command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags | ||
depfile = $out.d | ||
deps = gcc | ||
|
||
rule cuda_compile | ||
depfile = $out.d | ||
deps = gcc | ||
command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags | ||
|
||
|
||
|
||
rule link | ||
command = $cxx $in $ldflags -o $out | ||
|
||
build main.o: compile /home/ubuntu/cudamode/cudamodelecture1/load_inline_cuda/main.cpp | ||
build cuda.cuda.o: cuda_compile /home/ubuntu/cudamode/cudamodelecture1/load_inline_cuda/cuda.cu | ||
|
||
|
||
|
||
build square_matrix_extension.so: link main.o cuda.cuda.o | ||
|
||
default square_matrix_extension.so |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#include <torch/types.h> | ||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
|
||
__global__ void square_matrix_kernel(const float* matrix, float* result, int width, int height) { | ||
int row = blockIdx.y * blockDim.y + threadIdx.y; | ||
int col = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
if (row < height && col < width) { | ||
int idx = row * width + col; | ||
result[idx] = matrix[idx] * matrix[idx]; | ||
} | ||
} | ||
|
||
torch::Tensor square_matrix(torch::Tensor matrix) { | ||
const auto height = matrix.size(0); | ||
const auto width = matrix.size(1); | ||
|
||
auto result = torch::empty_like(matrix); | ||
|
||
dim3 threads_per_block(16, 16); | ||
dim3 number_of_blocks((width + threads_per_block.x - 1) / threads_per_block.x, | ||
(height + threads_per_block.y - 1) / threads_per_block.y); | ||
|
||
square_matrix_kernel<<<number_of_blocks, threads_per_block>>>( | ||
matrix.data_ptr<float>(), result.data_ptr<float>(), width, height); | ||
|
||
return result; | ||
} |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#include <torch/extension.h> | ||
torch::Tensor square_matrix(torch::Tensor matrix); | ||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("square_matrix", torch::wrap_pybind_function(square_matrix), "square_matrix"); | ||
} |
Binary file not shown.