Skip to content

[helion][feature] support maxnreg #5912

@jianyizh

Description

@jianyizh

latest helion will generate code like this:

from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

# src[rms_norm.py:33]: def rms_norm_fwd(
# src[rms_norm.py:34]:     x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5
# src[rms_norm.py:35]: ) -> tuple[torch.Tensor, torch.Tensor]:
# src[rms_norm.py:33-70]: ...
helion.runtime.set_triton_allocator()

@triton.jit
def _helion_rms_norm_fwd(x, inv_rms, weight, out, eps, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _REDUCTION_BLOCK_1: tl.constexpr):
    # src[rms_norm.py:58]: x_tile = x[tile_m, :].to(torch.float32)
    x_desc = tl.make_tensor_descriptor(x, [2048, 1024], [1024, 1], [_BLOCK_SIZE_0, _REDUCTION_BLOCK_1])
    # src[rms_norm.py:67]: out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype)
    out_desc = tl.make_tensor_descriptor(out, [2048, 1024], [1024, 1], [_BLOCK_SIZE_0, _REDUCTION_BLOCK_1])
    # src[rms_norm.py:57]: for tile_m in hl.tile(m):
    # src[rms_norm.py:58]:     x_tile = x[tile_m, :].to(torch.float32)
    # src[rms_norm.py:57-68]: ...
    total_pids = tl.cdiv(2048, _BLOCK_SIZE_0)
    for virtual_pid in tl.range(tl.program_id(0), total_pids, _NUM_SM * 2, loop_unroll_factor=3, disallow_acc_multi_buffer=True, flatten=True):
        # src[rms_norm.py:57]: for tile_m in hl.tile(m):
        pid_0 = virtual_pid
        offset_0 = pid_0 * _BLOCK_SIZE_0
        indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
        # src[rms_norm.py:62]: mean_x_squared = torch.mean(x_squared, dim=-1)
        mean_x_squared_extra_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
        # src[rms_norm.py:58]: x_tile = x[tile_m, :].to(torch.float32)
        for roffset_1 in tl.range(0, 1024, _REDUCTION_BLOCK_1):
            rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
            x_tile = x_desc.load([offset_0, roffset_1])
            # src[rms_norm.py:61]: x_squared = x_tile * x_tile
            v_0 = x_tile * x_tile
            # src[rms_norm.py:62]: mean_x_squared = torch.mean(x_squared, dim=-1)
            v_1 = mean_x_squared_extra_acc + v_0
            mean_x_squared_extra_acc = v_1
        mean_x_squared_extra = tl.cast(tl.sum(mean_x_squared_extra_acc, 1), tl.float32)
        v_2 = 1024
        v_3 = mean_x_squared_extra / v_2.to(tl.float32)
        # src[rms_norm.py:63]: inv_rms_tile = torch.rsqrt(mean_x_squared + eps)
        v_4 = v_3 + eps
        v_5 = libdevice.rsqrt(v_4)
        # src[rms_norm.py:66]: normalized = x_tile * inv_rms_tile[:, None]
        subscript = v_5[:, None]
        # src[rms_norm.py:68]: inv_rms[tile_m] = inv_rms_tile.to(out.dtype)
        tl.store(inv_rms + indices_0 * 1, v_5, None)
        # src[rms_norm.py:58]: x_tile = x[tile_m, :].to(torch.float32)
        for roffset_1 in tl.range(0, 1024, _REDUCTION_BLOCK_1):
            rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
            subscript_copy = subscript
            x_tile_1 = tl.load(x + (indices_0[:, None] * 1024 + rindex_1[None, :] * 1), None)
            # src[rms_norm.py:66]: normalized = x_tile * inv_rms_tile[:, None]
            v_6 = x_tile_1 * subscript_copy
            # src[rms_norm.py:67]: out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype)
            load_1 = tl.load(weight + rindex_1 * 1, None)
            v_7 = load_1[None, :]
            v_8 = v_6 * v_7
            out_desc.store([offset_0, roffset_1], v_8)

def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
    """
    Performs Root Mean Square (RMS) normalization on the input tensor.

    RMS normalization normalizes by the root mean square of the elements:
    output = x / sqrt(mean(x^2) + eps) * weight

    Args:
        x: Input tensor of shape [M, N]
        weight: Scale parameter of shape [N]
        eps: Small constant for numerical stability

    Returns:
        Output tensor of shape [M, N] with RMS normalization applied
        RMS tensor of shape [M, 1] with RMS values for each element
    """
    # src[rms_norm.py:51]: m, n = x.size()
    m, n = x.size()
    # src[rms_norm.py:52]: assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}"
    assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}'
    # src[rms_norm.py:54]: out = torch.empty_like(x)
    out = torch.empty_like(x)
    # src[rms_norm.py:55]: inv_rms = torch.empty([m], dtype=x.dtype, device=x.device)
    inv_rms = torch.empty([m], dtype=x.dtype, device=x.device)
    # src[rms_norm.py:57]: for tile_m in hl.tile(m):
    _NUM_SM = helion.runtime.get_num_sm(x.device)
    _BLOCK_SIZE_0 = 64
    # src[rms_norm.py:58]: x_tile = x[tile_m, :].to(torch.float32)
    _REDUCTION_BLOCK_1 = 256
    # src[rms_norm.py:57]: for tile_m in hl.tile(m):
    # src[rms_norm.py:58]:     x_tile = x[tile_m, :].to(torch.float32)
    # src[rms_norm.py:57-68]: ...
    _launcher(_helion_rms_norm_fwd, (_NUM_SM * 2,), x, inv_rms, weight, out, eps, _NUM_SM, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=16, num_stages=2, maxnreg=128)
    # src[rms_norm.py:70]: return out, inv_rms.reshape(-1, 1)
    return (out, inv_rms.reshape(-1, 1))

It will return KeyError: 'Keyword argument maxnreg was specified but unrecognised' on triton-xpu 3.6.0

Do we plan to support it? I will workaround like amd for now https://github.com/pytorch/helion/blob/8b7bc74256e63264e964b4971b2ce24464dff19b/helion/_compiler/device_function.py#L698

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions