Skip to content

Commit

Permalink
Compile specific kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Dec 18, 2023
1 parent 4513749 commit 61cf4e6
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 131 deletions.
63 changes: 63 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{
"files.associations": {
"*.yml": "yaml",
"*.ke": "Kestrel",
"array": "cpp",
"atomic": "cpp",
"bit": "cpp",
"*.tcc": "cpp",
"bitset": "cpp",
"cctype": "cpp",
"chrono": "cpp",
"clocale": "cpp",
"cmath": "cpp",
"compare": "cpp",
"concepts": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdint": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"deque": "cpp",
"unordered_map": "cpp",
"vector": "cpp",
"exception": "cpp",
"algorithm": "cpp",
"functional": "cpp",
"iterator": "cpp",
"memory": "cpp",
"memory_resource": "cpp",
"numeric": "cpp",
"optional": "cpp",
"random": "cpp",
"ratio": "cpp",
"string": "cpp",
"string_view": "cpp",
"system_error": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"utility": "cpp",
"initializer_list": "cpp",
"iosfwd": "cpp",
"istream": "cpp",
"limits": "cpp",
"mutex": "cpp",
"new": "cpp",
"numbers": "cpp",
"ostream": "cpp",
"ranges": "cpp",
"stdexcept": "cpp",
"stop_token": "cpp",
"streambuf": "cpp",
"thread": "cpp",
"typeinfo": "cpp",
"__nullptr": "cpp",
"__bit_reference": "cpp",
"__functional_base": "cpp",
"__memory": "cpp"
}
}
7 changes: 0 additions & 7 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,3 @@ void reshape_and_cache(
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping);

void gather_cached_kv(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping);
37 changes: 0 additions & 37 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,47 +29,10 @@ void paged_attention_v2(
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);

void rms_norm(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& weight,
float epsilon);

void fused_add_rms_norm(
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
float epsilon);

void rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox);

void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);

void gelu_new(
torch::Tensor& out,
torch::Tensor& input);

void gelu_fast(
torch::Tensor& out,
torch::Tensor& input);

torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);

void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table);
80 changes: 0 additions & 80 deletions csrc/pybind.cpp

This file was deleted.

3 changes: 3 additions & 0 deletions csrc/rustbind.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
17 changes: 10 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,16 @@ def get_torch_arch_list() -> Set[str]:
"csrc/cache_kernels.cu",
"csrc/attention/attention_kernels.cu",
"csrc/pos_encoding_kernels.cu",
"csrc/activation_kernels.cu",
"csrc/layernorm_kernels.cu",
"csrc/quantization/awq/gemm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/cuda_utils_kernels.cu",
"csrc/ops.h",
"csrc/cache.h",
#"csrc/activation_kernels.cu",
#"csrc/layernorm_kernels.cu",
#"csrc/quantization/awq/gemm_kernels.cu",
#"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
#"csrc/cuda_utils_kernels.cu",

#"csrc/ops.h",
#"csrc/cache.h",

"csrc/rustbind.cpp"
],
extra_compile_args={
"cxx": CXX_FLAGS,
Expand Down

0 comments on commit 61cf4e6

Please sign in to comment.