Skip to content

Commit 691e037

Browse files
author
Zhu Jiale
committed
complex number multiplication that supports 3D ROPE triton kernel
1 parent c7245de commit 691e037

File tree

12 files changed

+942
-0
lines changed

12 files changed

+942
-0
lines changed

aiter/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,4 @@ def getLogger():
7474
from .ops.gradlib import *
7575
from .ops.trans_ragged_layout import *
7676
from . import mla
77+
from .ops.groupnorm import *

aiter/install_mode

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
develop

aiter/jit/optCompilerConfig.json

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,5 +913,17 @@
913913
],
914914
"verbose": "False",
915915
"blob_gen_cmd": "''"
916+
},
917+
"module_groupnorm": {
918+
"srcs": [
919+
"f'{AITER_CSRC_DIR}/pybind/groupnorm_pybind.cu'",
920+
"f'{AITER_CSRC_DIR}/kernels/groupnorm.cu'"
921+
],
922+
"flags_extra_cc": [],
923+
"flags_extra_hip": [],
924+
"extra_ldflags": "None",
925+
"extra_include": [],
926+
"verbose": "True",
927+
"blob_gen_cmd": "''"
916928
}
917929
}

aiter/ops/groupnorm.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from ..jit.core import compile_ops
2+
import torch
3+
from typing import Optional
4+
5+
6+
@compile_ops("module_groupnorm")
7+
def _groupnorm_run(
8+
input: torch.Tensor,
9+
num_groups: int,
10+
weight: torch.Tensor,
11+
bias: torch.Tensor,
12+
eps: float,
13+
) -> torch.Tensor:
14+
"""Placeholder function, will be replaced by JIT."""
15+
pass
16+
17+
18+
class GroupNorm(torch.nn.Module):
19+
def __init__(
20+
self,
21+
num_groups: int,
22+
num_channels: int,
23+
eps: float = 1e-5,
24+
affine: bool = True,
25+
device: Optional[torch.device] = None,
26+
dtype: Optional[torch.dtype] = None,
27+
):
28+
super().__init__()
29+
self.num_groups = num_groups
30+
self.eps = eps
31+
self.affine = affine
32+
33+
if affine:
34+
self.weight = torch.nn.Parameter(
35+
torch.ones(num_channels, device=device, dtype=dtype)
36+
)
37+
self.bias = torch.nn.Parameter(
38+
torch.zeros(num_channels, device=device, dtype=dtype)
39+
)
40+
else:
41+
self.register_parameter("weight", None)
42+
self.register_parameter("bias", None)
43+
44+
def forward(self, x: torch.Tensor, use_torch: bool = False) -> torch.Tensor:
45+
if use_torch or not self.affine:
46+
# fallback to PyTorch for non-affine or debug mode
47+
return torch.nn.functional.group_norm(
48+
x,
49+
self.num_groups,
50+
weight=self.weight if self.affine else None,
51+
bias=self.bias if self.affine else None,
52+
eps=self.eps,
53+
)
54+
else:
55+
return _groupnorm_run(x, self.num_groups, self.weight, self.bias, self.eps)

aiter/ops/triton/_triton_kernels/rope.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1934,3 +1934,84 @@ def _rope_fwd_2d_kernel_neox(
19341934

19351935
# store output
19361936
tl.store(out_ptr + offs_x, out)
1937+
1938+
@triton.jit
1939+
def _rope_fwd_3d_kernel(
1940+
x_ptr, freqs_real_ptr, freqs_imag_ptr, grid_sizes_ptr, out_ptr,
1941+
stride_x_b, stride_x_l, stride_x_n, stride_x_c,
1942+
stride_freqs_s, stride_freqs_c,
1943+
stride_grid_b, stride_grid_d,
1944+
stride_out_b, stride_out_l, stride_out_n, stride_out_c,
1945+
L: tl.constexpr, N_HEADS: tl.constexpr, C: tl.constexpr, c_total: tl.constexpr,
1946+
sp_size: tl.constexpr, sp_rank: tl.constexpr,
1947+
max_freq_seq_len: tl.constexpr, s_per_rank: tl.constexpr,
1948+
pad_freq_val_r: tl.constexpr, pad_freq_val_i: tl.constexpr,
1949+
BLOCK_L: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_C: tl.constexpr,
1950+
C1: tl.constexpr, C2: tl.constexpr,
1951+
):
1952+
pid_b = tl.program_id(0)
1953+
pid_n = tl.program_id(1)
1954+
pid_l = tl.program_id(2)
1955+
1956+
l_start = pid_l * BLOCK_L
1957+
l_off = l_start + tl.arange(0, BLOCK_L)
1958+
s_mask = l_off < L
1959+
1960+
c_off = tl.arange(0, BLOCK_C)
1961+
c_mask = c_off < c_total
1962+
1963+
# head mask
1964+
n_mask = pid_n < N_HEADS
1965+
1966+
# broadcast to (BLOCK_L, 1, BLOCK_C)
1967+
l_b = tl.broadcast_to(l_off[:, None], (BLOCK_L, BLOCK_C))
1968+
c_b = tl.broadcast_to(c_off[None, :], (BLOCK_L, BLOCK_C))
1969+
1970+
# read grid_sizes
1971+
f_grid = tl.load(grid_sizes_ptr + pid_b * stride_grid_b + 0 * stride_grid_d,
1972+
mask=n_mask, other=0)
1973+
h_grid = tl.load(grid_sizes_ptr + pid_b * stride_grid_b + 1 * stride_grid_d,
1974+
mask=n_mask, other=0)
1975+
w_grid = tl.load(grid_sizes_ptr + pid_b * stride_grid_b + 2 * stride_grid_d,
1976+
mask=n_mask, other=0)
1977+
h_w = h_grid * w_grid
1978+
1979+
global_tid = sp_rank * s_per_rank + l_b
1980+
valid_global_tid = global_tid < f_grid * h_w
1981+
1982+
# caculate f h w
1983+
f_idx = tl.where(valid_global_tid, global_tid // h_w, 0)
1984+
rem = tl.where(valid_global_tid, global_tid % h_w, 0)
1985+
h_idx = tl.where(valid_global_tid, rem // w_grid, 0)
1986+
w_idx = tl.where(valid_global_tid, rem % w_grid, 0)
1987+
1988+
freq_row = tl.where(c_b < C1, f_idx,
1989+
tl.where(c_b < C1 + C2, h_idx, w_idx))
1990+
freq_row = tl.where(freq_row >= max_freq_seq_len, max_freq_seq_len - 1, freq_row)
1991+
1992+
mask_rope = s_mask[:, None] & c_mask[None, :] & n_mask & valid_global_tid[:, :]
1993+
1994+
# load freqs_real and freqs_imag
1995+
off_freq = freq_row * stride_freqs_s + c_b * stride_freqs_c
1996+
freq_r = tl.load(freqs_real_ptr + off_freq, mask=mask_rope, other=pad_freq_val_r)
1997+
freq_i = tl.load(freqs_imag_ptr + off_freq, mask=mask_rope, other=pad_freq_val_i)
1998+
1999+
off_x_base = pid_b * stride_x_b + pid_n * stride_x_n
2000+
off_x_r = off_x_base + l_b * stride_x_l + (2 * c_b) * stride_x_c
2001+
off_x_i = off_x_base + l_b * stride_x_l + (2 * c_b + 1) * stride_x_c
2002+
2003+
x_r = tl.load(x_ptr + off_x_r, mask=mask_rope, other=0.0)
2004+
x_i = tl.load(x_ptr + off_x_i, mask=mask_rope, other=0.0)
2005+
2006+
# complex number multiplication
2007+
out_r = x_r * freq_r - x_i * freq_i
2008+
out_i = x_r * freq_i + x_i * freq_r
2009+
2010+
# write result
2011+
off_out_base = pid_b * stride_out_b + pid_n * stride_out_n
2012+
off_out_r = off_out_base + l_b * stride_out_l + (2 * c_b) * stride_out_c
2013+
off_out_i = off_out_base + l_b * stride_out_l + (2 * c_b + 1) * stride_out_c
2014+
2015+
tl.store(out_ptr + off_out_r, out_r, mask=mask_rope)
2016+
tl.store(out_ptr + off_out_i, out_i, mask=mask_rope)
2017+

aiter/ops/triton/rope3d.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
from aiter.ops.triton._triton_kernels.rope import _rope_fwd_3d_kernel
5+
6+
def rope_params(max_seq_len, dim, theta=10000):
7+
assert dim % 2 == 0
8+
freqs = torch.outer(
9+
torch.arange(max_seq_len),
10+
1.0 / torch.pow(theta,
11+
torch.arange(0, dim, 2).to(torch.float32).div(dim))
12+
)
13+
freqs = torch.polar(torch.ones_like(freqs), freqs) # complex
14+
return freqs
15+
16+
def pad_freqs(original_tensor, target_len):
17+
seq_len, s1, s2 = original_tensor.shape
18+
pad_size = target_len - seq_len
19+
padding_tensor = torch.ones(
20+
pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device)
21+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
22+
return padded_tensor
23+
24+
25+
def rope_apply_triton(x, grid_sizes: tl.constexpr, freqs: tl.constexpr, sp_size: tl.constexpr, sp_rank: tl.constexpr):
26+
B, s, n_heads, C = x.shape
27+
c_total = C // 2 # 64
28+
c1 = c_total - 2 * (c_total // 3) # 22
29+
c2 = c_total // 3 # 21
30+
c3 = c_total // 3 # 21
31+
device = x.device
32+
33+
grid_sizes = grid_sizes.to(device=device, dtype=torch.int32).contiguous()
34+
35+
freqs_real = freqs.real.to(dtype=torch.float32, device=device).contiguous()
36+
freqs_imag = freqs.imag.to(dtype=torch.float32, device=device).contiguous()
37+
out = torch.empty_like(x, dtype=torch.float32, device=device)
38+
39+
BLOCK_L, BLOCK_N, BLOCK_C = 32, 4, 64
40+
41+
grid = (
42+
B,
43+
n_heads,
44+
triton.cdiv(s, BLOCK_L)
45+
)
46+
47+
num_warps = 4
48+
waves_per_eu = 1
49+
50+
_rope_fwd_3d_kernel[grid](
51+
x, freqs_real, freqs_imag, grid_sizes, out,
52+
*x.stride(),
53+
freqs_real.stride(0), freqs_real.stride(1),
54+
*grid_sizes.stride(),
55+
*out.stride(),
56+
s, n_heads, C, c_total,
57+
sp_size, sp_rank,
58+
freqs.shape[0], s,
59+
1.0, 0.0,
60+
BLOCK_L=BLOCK_L, BLOCK_N=BLOCK_N, BLOCK_C=BLOCK_C,
61+
C1=c1, C2=c2,
62+
num_warps=num_warps,
63+
waves_per_eu=waves_per_eu,
64+
)
65+
66+
return out
67+
68+
def rope_apply_original(x, grid_sizes, freqs, sp_size, sp_rank):
69+
B = x.size(0)
70+
s = x.size(1)
71+
n = x.size(2)
72+
c = x.size(3) // 2
73+
74+
c1 = c - 2 * (c // 3)
75+
c2 = (c // 3)
76+
c3 = (c // 3)
77+
freqs = freqs.split([c1, c2, c3], dim=1)
78+
79+
output = []
80+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
81+
seq_len = f * h * w
82+
83+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
84+
85+
freqs_i = torch.cat([
86+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
87+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
88+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
89+
], dim=-1).reshape(seq_len, 1, -1)
90+
merged_real_sum = freqs_i.real.sum()
91+
freqs_i = pad_freqs(freqs_i, s * sp_size)
92+
s_per_rank = s
93+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
94+
95+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
96+
x_i = torch.cat([x_i, x[i, s:]])
97+
output.append(x_i)
98+
99+
out = torch.stack(output).float()
100+
return out
101+
102+
def test_rope_consistency():
103+
B, s, n, C = 1, 9450, 40, 128
104+
device = "cuda" if torch.cuda.is_available() else "cpu"
105+
sp_size = 8
106+
max_seq_len = 1024
107+
108+
x = torch.arange(B*s*n*C, dtype=torch.float32, device=device).reshape(B, s, n, C)
109+
x = x / (B*s*n*C)
110+
111+
grid_sizes = torch.tensor([[21, 45, 80]], dtype=torch.int32, device=device)
112+
113+
d_total = 128
114+
d1 = d_total - 4 * (d_total // 6)
115+
d2 = 2 * (d_total // 6)
116+
d3 = 2 * (d_total // 6)
117+
118+
freqs_f = rope_params(max_seq_len, d1)
119+
freqs_h = rope_params(max_seq_len, d2)
120+
freqs_w = rope_params(max_seq_len, d3)
121+
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=1).to(device)
122+
123+
sp_rank = 0
124+
out_orig = rope_apply_original(x.clone(), grid_sizes.clone(), freqs.clone(), sp_size, sp_rank)
125+
126+
127+
out_triton = rope_apply_triton(x.clone(), grid_sizes.clone(), freqs.clone(), sp_size, sp_rank)
128+
129+
print(f"the result compare: sp_rank={sp_rank}")
130+
print("="*50)
131+
shape_ok = (out_orig.shape == out_triton.shape)
132+
sum_orig = out_orig.sum().item()
133+
sum_triton = out_triton.sum().item()
134+
sum_diff = abs(sum_orig - sum_triton) / abs(sum_orig)
135+
sum_ok = sum_diff < 1e-2
136+
feat_orig = out_orig[0,0,0,:4]
137+
feat_triton = out_triton[0,0,0,:4]
138+
feat_diff = torch.abs(feat_orig - feat_triton).max().item()
139+
feat_ok = feat_diff < 1e-3
140+
141+
print(f"shape same {'yes' if shape_ok else 'no'}")
142+
print(f"(sum diff<1%): {'yes' if sum_ok else 'no'}")
143+
print(f" - Original sum: {sum_orig:.6f}")
144+
print(f" - Triton sum: {sum_triton:.6f}")
145+
print(f" - corellation diff %: {sum_diff*100:.2f}%")
146+
print(f"fisrt 4 tensor same {'yes' if feat_ok else 'no'}")
147+
print(f" - Original: {feat_orig.cpu().numpy()}")
148+
print(f" - Triton: {feat_triton.cpu().numpy()}")
149+
print(f" - max diff: {feat_diff:.6f}")
150+
151+
152+
if shape_ok and sum_ok and feat_ok:
153+
print(f"\n sp_rank={sp_rank} test success")
154+
else:
155+
print(f"\n sp_rank={sp_rank} test failed")
156+
print("="*60)
157+
158+
159+
if __name__ == "__main__":
160+
test_rope_consistency()
161+

csrc/include/common.hpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include <hip/hip_runtime.h>
2+
3+
#include <cstdint>
4+
#include <iostream>
5+
#include <exception>
6+
7+
#define CHECK_COND(x) \
8+
do { \
9+
if (!(x)) { \
10+
std::cerr << "check failed, file=" \
11+
<< __FILE__ << ", line=" \
12+
<< __LINE__ << std::endl; \
13+
std::terminate(); \
14+
} \
15+
} while(false)
16+
17+
#define CHECK_HIP(x) \
18+
do { \
19+
hipError_t __err_code = (x); \
20+
if( __err_code != hipSuccess ) { \
21+
std::cerr << "call hip api failed, file=" \
22+
<< __FILE__ << ", line=" \
23+
<< __LINE__ << ", name=" \
24+
<< hipGetErrorName(__err_code) \
25+
<< std::endl; \
26+
std::terminate(); \
27+
} \
28+
} while(false)

csrc/include/groupnorm.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include <torch/extension.h>
2+
3+
#include "common.hpp"
4+
5+
#include <optional>
6+
7+
namespace rocm_torch_x {
8+
9+
class __attribute__ ((visibility("hidden"))) GroupNorm final
10+
{
11+
public:
12+
explicit GroupNorm() = default;
13+
~GroupNorm() = default;
14+
public:
15+
// return empty if not supported
16+
std::optional<torch::Tensor> Run(
17+
torch::Tensor x,
18+
int num_groups,
19+
torch::Tensor weights,
20+
torch::Tensor bias,
21+
float epsilon);
22+
private:
23+
template<typename T>
24+
torch::Tensor launchGroupNormKernel(uint32_t num_groups, float epsilon,
25+
const torch::Tensor x, const torch::Tensor weights, const torch::Tensor bias, hipStream_t stream);
26+
27+
void reserveMeanAccumulator(uint32_t nums_to_reserve, torch::Device device);
28+
private:
29+
torch::Tensor mean_accumulator_;
30+
};
31+
32+
} // namespace rocm_torch_x

0 commit comments

Comments
 (0)