-
Notifications
You must be signed in to change notification settings - Fork 83
Open
Labels
Milestone
Description
latest helion will generate code like this:
from __future__ import annotations
import torch
import helion
import triton
import triton.language as tl
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher
# src[rms_norm.py:33]: def rms_norm_fwd(
# src[rms_norm.py:34]: x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5
# src[rms_norm.py:35]: ) -> tuple[torch.Tensor, torch.Tensor]:
# src[rms_norm.py:33-70]: ...
helion.runtime.set_triton_allocator()
@triton.jit
def _helion_rms_norm_fwd(x, inv_rms, weight, out, eps, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _REDUCTION_BLOCK_1: tl.constexpr):
# src[rms_norm.py:58]: x_tile = x[tile_m, :].to(torch.float32)
x_desc = tl.make_tensor_descriptor(x, [2048, 1024], [1024, 1], [_BLOCK_SIZE_0, _REDUCTION_BLOCK_1])
# src[rms_norm.py:67]: out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype)
out_desc = tl.make_tensor_descriptor(out, [2048, 1024], [1024, 1], [_BLOCK_SIZE_0, _REDUCTION_BLOCK_1])
# src[rms_norm.py:57]: for tile_m in hl.tile(m):
# src[rms_norm.py:58]: x_tile = x[tile_m, :].to(torch.float32)
# src[rms_norm.py:57-68]: ...
total_pids = tl.cdiv(2048, _BLOCK_SIZE_0)
for virtual_pid in tl.range(tl.program_id(0), total_pids, _NUM_SM * 2, loop_unroll_factor=3, disallow_acc_multi_buffer=True, flatten=True):
# src[rms_norm.py:57]: for tile_m in hl.tile(m):
pid_0 = virtual_pid
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
# src[rms_norm.py:62]: mean_x_squared = torch.mean(x_squared, dim=-1)
mean_x_squared_extra_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
# src[rms_norm.py:58]: x_tile = x[tile_m, :].to(torch.float32)
for roffset_1 in tl.range(0, 1024, _REDUCTION_BLOCK_1):
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
x_tile = x_desc.load([offset_0, roffset_1])
# src[rms_norm.py:61]: x_squared = x_tile * x_tile
v_0 = x_tile * x_tile
# src[rms_norm.py:62]: mean_x_squared = torch.mean(x_squared, dim=-1)
v_1 = mean_x_squared_extra_acc + v_0
mean_x_squared_extra_acc = v_1
mean_x_squared_extra = tl.cast(tl.sum(mean_x_squared_extra_acc, 1), tl.float32)
v_2 = 1024
v_3 = mean_x_squared_extra / v_2.to(tl.float32)
# src[rms_norm.py:63]: inv_rms_tile = torch.rsqrt(mean_x_squared + eps)
v_4 = v_3 + eps
v_5 = libdevice.rsqrt(v_4)
# src[rms_norm.py:66]: normalized = x_tile * inv_rms_tile[:, None]
subscript = v_5[:, None]
# src[rms_norm.py:68]: inv_rms[tile_m] = inv_rms_tile.to(out.dtype)
tl.store(inv_rms + indices_0 * 1, v_5, None)
# src[rms_norm.py:58]: x_tile = x[tile_m, :].to(torch.float32)
for roffset_1 in tl.range(0, 1024, _REDUCTION_BLOCK_1):
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
subscript_copy = subscript
x_tile_1 = tl.load(x + (indices_0[:, None] * 1024 + rindex_1[None, :] * 1), None)
# src[rms_norm.py:66]: normalized = x_tile * inv_rms_tile[:, None]
v_6 = x_tile_1 * subscript_copy
# src[rms_norm.py:67]: out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype)
load_1 = tl.load(weight + rindex_1 * 1, None)
v_7 = load_1[None, :]
v_8 = v_6 * v_7
out_desc.store([offset_0, roffset_1], v_8)
def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
"""
Performs Root Mean Square (RMS) normalization on the input tensor.
RMS normalization normalizes by the root mean square of the elements:
output = x / sqrt(mean(x^2) + eps) * weight
Args:
x: Input tensor of shape [M, N]
weight: Scale parameter of shape [N]
eps: Small constant for numerical stability
Returns:
Output tensor of shape [M, N] with RMS normalization applied
RMS tensor of shape [M, 1] with RMS values for each element
"""
# src[rms_norm.py:51]: m, n = x.size()
m, n = x.size()
# src[rms_norm.py:52]: assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}"
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}'
# src[rms_norm.py:54]: out = torch.empty_like(x)
out = torch.empty_like(x)
# src[rms_norm.py:55]: inv_rms = torch.empty([m], dtype=x.dtype, device=x.device)
inv_rms = torch.empty([m], dtype=x.dtype, device=x.device)
# src[rms_norm.py:57]: for tile_m in hl.tile(m):
_NUM_SM = helion.runtime.get_num_sm(x.device)
_BLOCK_SIZE_0 = 64
# src[rms_norm.py:58]: x_tile = x[tile_m, :].to(torch.float32)
_REDUCTION_BLOCK_1 = 256
# src[rms_norm.py:57]: for tile_m in hl.tile(m):
# src[rms_norm.py:58]: x_tile = x[tile_m, :].to(torch.float32)
# src[rms_norm.py:57-68]: ...
_launcher(_helion_rms_norm_fwd, (_NUM_SM * 2,), x, inv_rms, weight, out, eps, _NUM_SM, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=16, num_stages=2, maxnreg=128)
# src[rms_norm.py:70]: return out, inv_rms.reshape(-1, 1)
return (out, inv_rms.reshape(-1, 1))It will return KeyError: 'Keyword argument maxnreg was specified but unrecognised' on triton-xpu 3.6.0
Do we plan to support it? I will workaround like amd for now https://github.com/pytorch/helion/blob/8b7bc74256e63264e964b4971b2ce24464dff19b/helion/_compiler/device_function.py#L698