|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | +from aiter.ops.triton._triton_kernels.rope import _rope_fwd_3d_kernel |
| 5 | + |
| 6 | +def rope_params(max_seq_len, dim, theta=10000): |
| 7 | + assert dim % 2 == 0 |
| 8 | + freqs = torch.outer( |
| 9 | + torch.arange(max_seq_len), |
| 10 | + 1.0 / torch.pow(theta, |
| 11 | + torch.arange(0, dim, 2).to(torch.float32).div(dim)) |
| 12 | + ) |
| 13 | + freqs = torch.polar(torch.ones_like(freqs), freqs) # complex |
| 14 | + return freqs |
| 15 | + |
| 16 | +def pad_freqs(original_tensor, target_len): |
| 17 | + seq_len, s1, s2 = original_tensor.shape |
| 18 | + pad_size = target_len - seq_len |
| 19 | + padding_tensor = torch.ones( |
| 20 | + pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device) |
| 21 | + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) |
| 22 | + return padded_tensor |
| 23 | + |
| 24 | + |
| 25 | +def rope_apply_triton(x, grid_sizes: tl.constexpr, freqs: tl.constexpr, sp_size: tl.constexpr, sp_rank: tl.constexpr): |
| 26 | + B, s, n_heads, C = x.shape |
| 27 | + c_total = C // 2 # 64 |
| 28 | + c1 = c_total - 2 * (c_total // 3) # 22 |
| 29 | + c2 = c_total // 3 # 21 |
| 30 | + c3 = c_total // 3 # 21 |
| 31 | + device = x.device |
| 32 | + |
| 33 | + grid_sizes = grid_sizes.to(device=device, dtype=torch.int32).contiguous() |
| 34 | + |
| 35 | + freqs_real = freqs.real.to(dtype=torch.float32, device=device).contiguous() |
| 36 | + freqs_imag = freqs.imag.to(dtype=torch.float32, device=device).contiguous() |
| 37 | + out = torch.empty_like(x, dtype=torch.float32, device=device) |
| 38 | + |
| 39 | + BLOCK_L, BLOCK_N, BLOCK_C = 32, 4, 64 |
| 40 | + |
| 41 | + grid = ( |
| 42 | + B, |
| 43 | + n_heads, |
| 44 | + triton.cdiv(s, BLOCK_L) |
| 45 | + ) |
| 46 | + |
| 47 | + num_warps = 4 |
| 48 | + waves_per_eu = 1 |
| 49 | + |
| 50 | + _rope_fwd_3d_kernel[grid]( |
| 51 | + x, freqs_real, freqs_imag, grid_sizes, out, |
| 52 | + *x.stride(), |
| 53 | + freqs_real.stride(0), freqs_real.stride(1), |
| 54 | + *grid_sizes.stride(), |
| 55 | + *out.stride(), |
| 56 | + s, n_heads, C, c_total, |
| 57 | + sp_size, sp_rank, |
| 58 | + freqs.shape[0], s, |
| 59 | + 1.0, 0.0, |
| 60 | + BLOCK_L=BLOCK_L, BLOCK_N=BLOCK_N, BLOCK_C=BLOCK_C, |
| 61 | + C1=c1, C2=c2, |
| 62 | + num_warps=num_warps, |
| 63 | + waves_per_eu=waves_per_eu, |
| 64 | + ) |
| 65 | + |
| 66 | + return out |
| 67 | + |
| 68 | +def rope_apply_original(x, grid_sizes, freqs, sp_size, sp_rank): |
| 69 | + B = x.size(0) |
| 70 | + s = x.size(1) |
| 71 | + n = x.size(2) |
| 72 | + c = x.size(3) // 2 |
| 73 | + |
| 74 | + c1 = c - 2 * (c // 3) |
| 75 | + c2 = (c // 3) |
| 76 | + c3 = (c // 3) |
| 77 | + freqs = freqs.split([c1, c2, c3], dim=1) |
| 78 | + |
| 79 | + output = [] |
| 80 | + for i, (f, h, w) in enumerate(grid_sizes.tolist()): |
| 81 | + seq_len = f * h * w |
| 82 | + |
| 83 | + x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2)) |
| 84 | + |
| 85 | + freqs_i = torch.cat([ |
| 86 | + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), |
| 87 | + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), |
| 88 | + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) |
| 89 | + ], dim=-1).reshape(seq_len, 1, -1) |
| 90 | + merged_real_sum = freqs_i.real.sum() |
| 91 | + freqs_i = pad_freqs(freqs_i, s * sp_size) |
| 92 | + s_per_rank = s |
| 93 | + freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] |
| 94 | + |
| 95 | + x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) |
| 96 | + x_i = torch.cat([x_i, x[i, s:]]) |
| 97 | + output.append(x_i) |
| 98 | + |
| 99 | + out = torch.stack(output).float() |
| 100 | + return out |
| 101 | + |
| 102 | +def test_rope_consistency(): |
| 103 | + B, s, n, C = 1, 9450, 40, 128 |
| 104 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 105 | + sp_size = 8 |
| 106 | + max_seq_len = 1024 |
| 107 | + |
| 108 | + x = torch.arange(B*s*n*C, dtype=torch.float32, device=device).reshape(B, s, n, C) |
| 109 | + x = x / (B*s*n*C) |
| 110 | + |
| 111 | + grid_sizes = torch.tensor([[21, 45, 80]], dtype=torch.int32, device=device) |
| 112 | + |
| 113 | + d_total = 128 |
| 114 | + d1 = d_total - 4 * (d_total // 6) |
| 115 | + d2 = 2 * (d_total // 6) |
| 116 | + d3 = 2 * (d_total // 6) |
| 117 | + |
| 118 | + freqs_f = rope_params(max_seq_len, d1) |
| 119 | + freqs_h = rope_params(max_seq_len, d2) |
| 120 | + freqs_w = rope_params(max_seq_len, d3) |
| 121 | + freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=1).to(device) |
| 122 | + |
| 123 | + sp_rank = 0 |
| 124 | + out_orig = rope_apply_original(x.clone(), grid_sizes.clone(), freqs.clone(), sp_size, sp_rank) |
| 125 | + |
| 126 | + |
| 127 | + out_triton = rope_apply_triton(x.clone(), grid_sizes.clone(), freqs.clone(), sp_size, sp_rank) |
| 128 | + |
| 129 | + print(f"the result compare: sp_rank={sp_rank}") |
| 130 | + print("="*50) |
| 131 | + shape_ok = (out_orig.shape == out_triton.shape) |
| 132 | + sum_orig = out_orig.sum().item() |
| 133 | + sum_triton = out_triton.sum().item() |
| 134 | + sum_diff = abs(sum_orig - sum_triton) / abs(sum_orig) |
| 135 | + sum_ok = sum_diff < 1e-2 |
| 136 | + feat_orig = out_orig[0,0,0,:4] |
| 137 | + feat_triton = out_triton[0,0,0,:4] |
| 138 | + feat_diff = torch.abs(feat_orig - feat_triton).max().item() |
| 139 | + feat_ok = feat_diff < 1e-3 |
| 140 | + |
| 141 | + print(f"shape same {'yes' if shape_ok else 'no'}") |
| 142 | + print(f"(sum diff<1%): {'yes' if sum_ok else 'no'}") |
| 143 | + print(f" - Original sum: {sum_orig:.6f}") |
| 144 | + print(f" - Triton sum: {sum_triton:.6f}") |
| 145 | + print(f" - corellation diff %: {sum_diff*100:.2f}%") |
| 146 | + print(f"fisrt 4 tensor same {'yes' if feat_ok else 'no'}") |
| 147 | + print(f" - Original: {feat_orig.cpu().numpy()}") |
| 148 | + print(f" - Triton: {feat_triton.cpu().numpy()}") |
| 149 | + print(f" - max diff: {feat_diff:.6f}") |
| 150 | + |
| 151 | + |
| 152 | + if shape_ok and sum_ok and feat_ok: |
| 153 | + print(f"\n sp_rank={sp_rank} test success") |
| 154 | + else: |
| 155 | + print(f"\n sp_rank={sp_rank} test failed") |
| 156 | + print("="*60) |
| 157 | + |
| 158 | + |
| 159 | +if __name__ == "__main__": |
| 160 | + test_rope_consistency() |
| 161 | + |
0 commit comments