Skip to content

Commit 2fccbdb

Browse files
committed
more tests and reformatting
1 parent 311e086 commit 2fccbdb

File tree

3 files changed

+272
-237
lines changed

3 files changed

+272
-237
lines changed
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
import triton.language as tl
5+
from triton.experimental import gluon
6+
from triton.experimental.gluon import language as gl
7+
from triton.experimental.gluon.language.amd.cdna4 import async_copy as acp
8+
9+
10+
@gluon.jit
11+
def _issue_loads(
12+
copy_idx,
13+
cols_smem,
14+
row_start_ptr,
15+
n_cols,
16+
layout: gl.constexpr,
17+
BLOCK_SIZE: gl.constexpr,
18+
NUM_STAGES: gl.constexpr,
19+
USE_ASYNC_COPY: gl.constexpr = True,
20+
):
21+
col_offsets = copy_idx * BLOCK_SIZE + gl.arange(0, BLOCK_SIZE, layout=layout)
22+
mask = col_offsets < n_cols
23+
24+
if USE_ASYNC_COPY:
25+
# acp.buffer_load_to_shared(
26+
# cols_smem.index(copy_idx % NUM_STAGES),
27+
# row_start_ptr,
28+
# col_offsets,
29+
# mask,
30+
# other=-float("inf"),
31+
# cache_modifier=".cg",
32+
# )
33+
acp.global_load_to_shared(
34+
cols_smem.index(copy_idx % NUM_STAGES),
35+
row_start_ptr + col_offsets,
36+
mask=mask,
37+
other=-float("inf"),
38+
cache_modifier=".cg",
39+
)
40+
acp.commit_group()
41+
else:
42+
cols_smem.index(copy_idx % NUM_STAGES).store(
43+
gl.amd.cdna4.buffer_load(
44+
ptr=row_start_ptr,
45+
offsets=col_offsets,
46+
mask=mask,
47+
other=-float("inf"),
48+
cache=".cg",
49+
)
50+
)
51+
return copy_idx + 1
52+
53+
54+
@gluon.jit
55+
def _perform_loop1(
56+
m, row_sum, read_idx, cols_smem, layout: gl.constexpr, NUM_STAGES: gl.constexpr
57+
):
58+
row_block = cols_smem.index(read_idx % NUM_STAGES).load(layout)
59+
# row_block = acp.load_shared_relaxed(cols_smem.index(read_idx % NUM_STAGES), layout)
60+
61+
# find the max within the block
62+
m_p = gl.max(row_block, axis=0)
63+
64+
# find new max among all blocks
65+
m_p = gl.maximum(m, m_p)
66+
67+
# correct previous row sum
68+
row_sum = row_sum * gl.exp(m - m_p)
69+
70+
# add new exponential to row sum
71+
row_sum += gl.sum(gl.exp(row_block - m_p), axis=0)
72+
73+
# save the new max and update block
74+
m = m_p
75+
76+
return m, row_sum, read_idx + 1
77+
78+
79+
@gluon.jit
80+
def _perform_loop2(
81+
m,
82+
row_sum,
83+
read_idx,
84+
cols_smem,
85+
output_row_start_ptr,
86+
n_cols,
87+
output_dtype,
88+
layout: gl.constexpr,
89+
BLOCK_SIZE: gl.constexpr,
90+
NUM_STAGES: gl.constexpr,
91+
):
92+
col_offsets = read_idx * BLOCK_SIZE + gl.arange(0, BLOCK_SIZE, layout=layout)
93+
mask = col_offsets < n_cols
94+
row_block = cols_smem.index(read_idx % NUM_STAGES).load(layout)
95+
# row_block = acp.load_shared_relaxed(cols_smem.index(read_idx % NUM_STAGES), layout)
96+
97+
# subtract, exponentiate and divide by sum
98+
softmax_output = gl.exp(row_block - m) / row_sum
99+
softmax_output = softmax_output.to(output_dtype)
100+
101+
# store in output array
102+
gl.amd.cdna4.buffer_store(
103+
stored_value=softmax_output,
104+
ptr=output_row_start_ptr,
105+
offsets=col_offsets,
106+
mask=mask,
107+
cache=".cg",
108+
)
109+
110+
return read_idx + 1
111+
112+
113+
@gluon.jit
114+
def _softmax_kernel_online(
115+
output_ptr,
116+
input_ptr,
117+
input_row_stride,
118+
output_row_stride,
119+
n_rows,
120+
n_cols,
121+
SIZE_PER_THREAD: gl.constexpr,
122+
THREADS_PER_WARP: gl.constexpr,
123+
BLOCK_SIZE: gl.constexpr,
124+
NUM_STAGES: gl.constexpr,
125+
USE_ASYNC_COPY: gl.constexpr,
126+
):
127+
row_start = gl.program_id(0)
128+
row_idx = row_start
129+
130+
blocked_cols: gl.constexpr = gl.BlockedLayout(
131+
size_per_thread=[SIZE_PER_THREAD],
132+
threads_per_warp=[THREADS_PER_WARP],
133+
warps_per_cta=[gl.num_warps()],
134+
order=[0],
135+
)
136+
shared_cols: gl.constexpr = gl.SwizzledSharedLayout(
137+
vec=1, per_phase=1, max_phase=1, order=[0]
138+
)
139+
cols_smem = gl.allocate_shared_memory(
140+
input_ptr.type.element_ty, [NUM_STAGES, BLOCK_SIZE], layout=shared_cols
141+
)
142+
copy_idx = 0
143+
read_idx = 0
144+
145+
# loop 1: find the max and sum of each row
146+
m = -float("inf")
147+
row_sum = 0.0
148+
row_start_ptr = input_ptr + row_idx * input_row_stride
149+
150+
# prefill the pipeline
151+
for _ in gl.static_range(NUM_STAGES - 1):
152+
copy_idx = _issue_loads(
153+
copy_idx,
154+
cols_smem,
155+
row_start_ptr,
156+
n_cols,
157+
blocked_cols,
158+
BLOCK_SIZE,
159+
NUM_STAGES,
160+
USE_ASYNC_COPY,
161+
)
162+
163+
# steady state
164+
for _ in range(gl.cdiv(n_cols, BLOCK_SIZE) - (NUM_STAGES - 1)):
165+
# issue the overlapping copy
166+
copy_idx = _issue_loads(
167+
copy_idx,
168+
cols_smem,
169+
row_start_ptr,
170+
n_cols,
171+
blocked_cols,
172+
BLOCK_SIZE,
173+
NUM_STAGES,
174+
USE_ASYNC_COPY,
175+
)
176+
177+
# wait for a copy to finish before doing any computation
178+
acp.wait_group(NUM_STAGES - 1)
179+
m, row_sum, read_idx = _perform_loop1(
180+
m, row_sum, read_idx, cols_smem, blocked_cols, NUM_STAGES
181+
)
182+
183+
# finish the pipeline
184+
for i in gl.static_range(NUM_STAGES - 1):
185+
acp.wait_group(NUM_STAGES - 2 - i)
186+
m, row_sum, read_idx = _perform_loop1(
187+
m, row_sum, read_idx, cols_smem, blocked_cols, NUM_STAGES
188+
)
189+
190+
# loop 2: divide each row by respective norms, and then store
191+
output_row_start_ptr = output_ptr + row_idx * output_row_stride
192+
copy_idx = 0
193+
read_idx = 0
194+
195+
# prefill the pipeline
196+
for _ in gl.static_range(NUM_STAGES - 1):
197+
copy_idx = _issue_loads(
198+
copy_idx,
199+
cols_smem,
200+
row_start_ptr,
201+
n_cols,
202+
blocked_cols,
203+
BLOCK_SIZE,
204+
NUM_STAGES,
205+
USE_ASYNC_COPY,
206+
)
207+
208+
# steady state
209+
for _ in range(gl.cdiv(n_cols, BLOCK_SIZE) - (NUM_STAGES - 1)):
210+
# issue the overlapping copy
211+
copy_idx = _issue_loads(
212+
copy_idx,
213+
cols_smem,
214+
row_start_ptr,
215+
n_cols,
216+
blocked_cols,
217+
BLOCK_SIZE,
218+
NUM_STAGES,
219+
USE_ASYNC_COPY,
220+
)
221+
222+
# wait for a copy to finish before doing any computation
223+
acp.wait_group(NUM_STAGES - 1)
224+
read_idx = _perform_loop2(
225+
m,
226+
row_sum,
227+
read_idx,
228+
cols_smem,
229+
output_row_start_ptr,
230+
n_cols,
231+
output_ptr.type.element_ty,
232+
blocked_cols,
233+
BLOCK_SIZE,
234+
NUM_STAGES,
235+
)
236+
237+
# finish the pipeline
238+
for i in gl.static_range(NUM_STAGES - 1):
239+
acp.wait_group(NUM_STAGES - 2 - i)
240+
read_idx = _perform_loop2(
241+
m,
242+
row_sum,
243+
read_idx,
244+
cols_smem,
245+
output_row_start_ptr,
246+
n_cols,
247+
output_ptr.type.element_ty,
248+
blocked_cols,
249+
BLOCK_SIZE,
250+
NUM_STAGES,
251+
)

0 commit comments

Comments
 (0)