Skip to content
This repository has been archived by the owner on Feb 4, 2024. It is now read-only.

Commit

Permalink
load inlien now works
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Jan 12, 2024
1 parent db23aa6 commit 154ea8a
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 13 deletions.
30 changes: 17 additions & 13 deletions load_inline.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Look at this test for inspiration
# https://github.com/pytorch/pytorch/blob/main/test/test_cpp_extensions_jit.py

import torch
from torch.utils.cpp_extension import load_inline

# Define the CUDA kernel and C++ wrapper
cuda_source = '''
#include <torch/extension.h>
__global__ void square_matrix_kernel(const float* matrix, float* result, int width, int height) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -29,23 +30,26 @@
matrix.data_ptr<float>(), result.data_ptr<float>(), width, height);
return result;
}
}
'''

cpp_source = "torch::Tensor square_matrix(torch::Tensor matrix);"

# Load the CUDA kernel as a PyTorch extension
square_matrix_extension = load_inline(
name='square_matrix_extension',
cpp_sources=cuda_source,
cpp_sources=cpp_source,
cuda_sources=cuda_source,
functions=['square_matrix'],
verbose=True,
extra_cuda_cflags=['--expt-relaxed-constexpr']
# verbose=True,
with_cuda=True,
build_directory='./load_inline_cuda',
# extra_cuda_cflags=['--expt-relaxed-constexpr']
)

# Create a sample PyTorch tensor
matrix = torch.randn(10, 10, dtype=torch.float32)

# Call the CUDA kernel through the loaded extension
result = square_matrix_extension.square_matrix(matrix)
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]], device='cuda')
print(square_matrix_extension.square_matrix(a))

print("Original Matrix:\n", matrix)
print("Squared Matrix:\n", result)
# (cudamode) ubuntu@ip-172-31-9-217:~/cudamode/cudamodelecture1$ python load_inline.py
# tensor([[ 1., 4., 9.],
# [16., 25., 36.]], device='cuda:0')
Binary file added load_inline_cuda/.ninja_deps
Binary file not shown.
4 changes: 4 additions & 0 deletions load_inline_cuda/.ninja_log
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# ninja log v5
0 14129 1705021139520500980 main.o 4279c12607ecc039
0 22396 1705021147784612815 cuda.cuda.o a37a24931cf259f6
22396 22710 1705021148100617092 square_matrix_extension.so ff0a0db8e20d3c15
34 changes: 34 additions & 0 deletions load_inline_cuda/build.ninja
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
ninja_required_version = 1.3
cxx = c++
nvcc = /usr/local/cuda/bin/nvcc

cflags = -DTORCH_EXTENSION_NAME=square_matrix_extension -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include/TH -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/envs/cudamode/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17
post_cflags =
cuda_cflags = -DTORCH_EXTENSION_NAME=square_matrix_extension -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include/TH -isystem /opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/envs/cudamode/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -std=c++17
cuda_post_cflags =
cuda_dlink_post_cflags =
ldflags = -shared -L/opt/conda/envs/cudamode/lib/python3.10/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart

rule compile
command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
depfile = $out.d
deps = gcc

rule cuda_compile
depfile = $out.d
deps = gcc
command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags



rule link
command = $cxx $in $ldflags -o $out

build main.o: compile /home/ubuntu/cudamode/cudamodelecture1/load_inline_cuda/main.cpp
build cuda.cuda.o: cuda_compile /home/ubuntu/cudamode/cudamodelecture1/load_inline_cuda/cuda.cu



build square_matrix_extension.so: link main.o cuda.cuda.o

default square_matrix_extension.so
29 changes: 29 additions & 0 deletions load_inline_cuda/cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include <torch/types.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__ void square_matrix_kernel(const float* matrix, float* result, int width, int height) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;

if (row < height && col < width) {
int idx = row * width + col;
result[idx] = matrix[idx] * matrix[idx];
}
}

torch::Tensor square_matrix(torch::Tensor matrix) {
const auto height = matrix.size(0);
const auto width = matrix.size(1);

auto result = torch::empty_like(matrix);

dim3 threads_per_block(16, 16);
dim3 number_of_blocks((width + threads_per_block.x - 1) / threads_per_block.x,
(height + threads_per_block.y - 1) / threads_per_block.y);

square_matrix_kernel<<<number_of_blocks, threads_per_block>>>(
matrix.data_ptr<float>(), result.data_ptr<float>(), width, height);

return result;
}
Binary file added load_inline_cuda/cuda.cuda.o
Binary file not shown.
5 changes: 5 additions & 0 deletions load_inline_cuda/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <torch/extension.h>
torch::Tensor square_matrix(torch::Tensor matrix);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("square_matrix", torch::wrap_pybind_function(square_matrix), "square_matrix");
}
Binary file added load_inline_cuda/main.o
Binary file not shown.

0 comments on commit 154ea8a

Please sign in to comment.