diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 3e547f8869..ec5b2a9c61 100644 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -15,8 +15,12 @@ from wave_lang.kernel.wave.compile import wave_compile from wave_lang.kernel.wave.utils.run_utils import set_default_run_config -from wave_lang.kernel.wave.templates import get_tagged_mxfp4_gemm -from wave_lang.kernel.wave.schedules import get_mxfp4_dbuf_schedule +from wave_lang.kernel.wave.templates import get_tagged_mxfp4_gemm, get_preshuffle_kernel +from wave_lang.kernel.wave.schedules import ( + get_mxfp4_dbuf_schedule, + get_mxfp4_dbuf_schedule_shuffle, +) +from wave_lang.kernel.wave.schedules import get_mxfp4_triplebuf_schedule from wave_lang.kernel.wave.utils.mxfp_utils import ( generate_gemm_afp4wfp4_inputs, torchScaledGemmMXFP4, @@ -25,16 +29,69 @@ from utils import parse_args, list_tests, run_test -def _run_mxfp_gemm(gemm, shape): - """Run compiled GEMM kernel and verify against reference.""" +def e8m0_shuffle(scale): + """ + Shuffle the scale tensor for e8m0 format. + + This particular shuffle is taken from + https://github.com/ROCm/rocm-libraries/blob/4348901528fe100a84975b89c247eece553a2a2d/shared/mxdatagenerator/lib/include/mxDataGenerator/PreSwizzle.hpp#L403 + + The e8m0_shuffle operation transforms a matrix with shape (m, n) as follows: + 1. Pads to shape ((m+255)//256*256, (n+7)//8*8) + 2. Reshapes to (sm//32, 2, 16, sn//8, 2, 4) + 3. Permutes dimensions: (0, 3, 5, 2, 4, 1) + 4. Flattens back to (sm, sn) + + Args: + scale: A 2D tensor to be shuffled + + Returns: + Shuffled tensor with the same padded shape + """ + if scale is None: + return scale + if scale.dtype == torch.float32: + return scale + assert scale.ndim == 2, "scale must be a 2D tensor" + m, n = scale.shape + scale_padded = torch.zeros( + (m + 255) // 256 * 256, + (n + 7) // 8 * 8, + dtype=scale.dtype, + device=scale.device, + ) + + scale_padded[:m, :n] = scale + scale = scale_padded + sm, sn = scale.shape + scale = scale.view(sm // 32, 2, 16, sn // 8, 2, 4) + scale = scale.permute(0, 3, 5, 2, 4, 1).contiguous() + scale = scale.view(sm, sn) + return scale + + +def _run_mxfp_gemm(gemm, shape, shuffle_scales=False): + """Run compiled GEMM kernel and verify against reference. + + Args: + gemm: Compiled GEMM kernel function. + shape: (M, N, K) problem dimensions. + shuffle_scales: If True, shuffle the scale tensors using e8m0_shuffle. + """ x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) torch_out = torchScaledGemmMXFP4(x, w, x_scales, w_scales) + if shuffle_scales: + # x_scales = e8m0_shuffle(x_scales) + w_scales = e8m0_shuffle(w_scales) + x, w = x.cuda(), w.cuda() x_scales, w_scales = x_scales.cuda(), w_scales.cuda() out = torch.zeros(x.shape[0], w.shape[1], dtype=torch.float32).cuda() - gemm(x, x_scales, w.T.contiguous(), w_scales, out) + for i in range(100): + gemm(x, x_scales, w.T.contiguous(), w_scales, out) + torch.testing.assert_close( torch_out, out.cpu(), check_dtype=False, check_device=False ) @@ -45,6 +102,7 @@ def test_dbuf_4wave_mxfp_gemm( ): """Double-buffered MXFP4 GEMM, 4 waves, no stagger.""" gemm, options = get_tagged_mxfp4_gemm(shape, block, num_waves=4) + schedule = get_mxfp4_dbuf_schedule(use_stagger=False) options.print_ir_after = "all" if is_debug else [] @@ -58,20 +116,7764 @@ def test_dbuf_4wave_mxfp_gemm( def test_dbuf_8wave_mxfp_gemm( - is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256) + is_debug=False, shape=(4096, 57344, 16384), block=(256, 256, 256) ): """Double-buffered MXFP4 GEMM, 8 waves, with stagger.""" + + mlir = """ + #map = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8) floordiv 256) * 256)> + #map1 = affine_map<()[s0] -> ((s0 floordiv 8) mod 8)> + #map2 = affine_map<()[s0] -> (s0 mod 8)> + #map3 = affine_map<()[s0] -> (s0 * 16)> + #map4 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64) floordiv 32) * 256)> + #map5 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 64) floordiv 256) * 256 + 64)> + #map6 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 8) floordiv 32) * 256 + 64)> + #map7 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 128) floordiv 256) * 256 + 128)> + #map8 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 16) floordiv 32) * 256 + 128)> + #map9 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 192) floordiv 256) * 256 + 192)> + #map10 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 24) floordiv 32) * 256 + 192)> + #map11 = affine_map<()[s0, s1, s2] -> (s1 * 128 + s2 * 256 + s0 floordiv 2 - ((s1 * 128 + s0 floordiv 2) floordiv 256) * 256)> + #map12 = affine_map<()[s0] -> ((s0 floordiv 2) mod 2)> + #map13 = affine_map<()[s0] -> (s0 mod 2)> + #map14 = affine_map<()[s0] -> (s0 * 4)> + #map15 = affine_map<()[s0, s1] -> (s1 * 128 + (s0 floordiv 64) * 32 - ((s1 * 4 + s0 floordiv 64) floordiv 8) * 256)> + #map16 = affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 64)> + #map17 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64)> + #map18 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> + #map19 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 16)> + #map20 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 32)> + #map21 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 48)> + #map22 = affine_map<()[s0] -> (s0 * 4 + (s0 mod 64) floordiv 16 - (s0 floordiv 2) * 8)> + #map23 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16)> + #map24 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 16)> + #map25 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 32)> + #map26 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 48)> + #map27 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 64)> + #map28 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 80)> + #map29 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 96)> + #map30 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 112)> + #map31 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16 + 4)> + #map32 = affine_map<()[s0, s1] -> (s1 * 4 + (s0 mod 64) floordiv 16)> + #map33 = affine_map<()[s0, s1] -> (s0 * 128 + s1 * 16 + 128)> + #map34 = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 4 + 8)> + #map35 = affine_map<()[s0] -> (s0 * 256)> + #map36 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4)> + #map37 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 1)> + #map38 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 2)> + #map39 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 3)> + #map40 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 16)> + #map41 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 17)> + #map42 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 18)> + #map43 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 19)> + #map44 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 32)> + #map45 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 33)> + #map46 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 34)> + #map47 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 35)> + #map48 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 48)> + #map49 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 49)> + #map50 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 50)> + #map51 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 51)> + #translation = #iree_codegen.translation_info + module attributes {transform.with_named_sequence} { + stream.executable private @gemm { + stream.executable.export public @gemm workgroups() -> (index, index, index) { + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + stream.return %c64, %c64, %c1 : index, index, index + } + builtin.module { + func.func @gemm(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding, %arg3: !stream.binding, %arg4: !stream.binding) attributes {translation_info = #translation} { + %c4_i32 = arith.constant 4 : i32 + %c512_i14 = arith.constant 512 : i14 + %c-8192_i14 = arith.constant -8192 : i14 + %c2147483643_i64 = arith.constant 2147483643 : i64 + %c16384 = arith.constant 16384 : index + %c63 = arith.constant 63 : index + %c512 = arith.constant 512 : index + %c2147483646_i64 = arith.constant 2147483646 : i64 + %c8192 = arith.constant 8192 : index + %c1 = arith.constant 1 : index + %cst = arith.constant dense<0.000000e+00> : vector<4xf32> + %c0 = arith.constant 0 : index + %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> memref + %1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> memref + %2 = stream.binding.subspan %arg2[%c0] : !stream.binding -> memref + %3 = stream.binding.subspan %arg3[%c0] : !stream.binding -> memref + %4 = stream.binding.subspan %arg4[%c0] : !stream.binding -> memref + %block_id_x = gpu.block_id x upper_bound 64 + %block_id_y = gpu.block_id y upper_bound 64 + %thread_id_x = gpu.thread_id x upper_bound 256 + %thread_id_y = gpu.thread_id y upper_bound 2 + %alloc = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_0 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_1 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_2 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_3 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_4 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_5 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_6 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %5 = affine.apply #map()[%thread_id_x, %thread_id_y, %block_id_x] + %6 = affine.apply #map1()[%thread_id_x] + %7 = affine.apply #map2()[%thread_id_x] + %8 = arith.xori %7, %6 : index + %9 = affine.apply #map3()[%8] + %10 = affine.apply #map4()[%thread_id_x, %thread_id_y] + %11 = gpu.subgroup_broadcast %10, first_active_lane : index + %12 = gpu.subgroup_broadcast %c0, first_active_lane : index + %13 = arith.muli %5, %c8192 overflow : index + %14 = arith.addi %13, %9 overflow : index + %reinterpret_cast = memref.reinterpret_cast %0 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast = memref.cast %reinterpret_cast : memref<2147483646xi8, strided<[1]>> to memref> + %15 = amdgpu.fat_raw_buffer_cast %cast validBytes(%c2147483646_i64) cacheSwizzleStride(%c-8192_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %15[%14], %alloc_6[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %16 = affine.apply #map5()[%thread_id_x, %thread_id_y, %block_id_x] + %17 = affine.apply #map6()[%thread_id_x, %thread_id_y] + %18 = gpu.subgroup_broadcast %17, first_active_lane : index + %19 = arith.muli %16, %c8192 overflow : index + %20 = arith.addi %19, %9 overflow : index + amdgpu.gather_to_lds %15[%20], %alloc_6[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %21 = affine.apply #map7()[%thread_id_x, %thread_id_y, %block_id_x] + %22 = affine.apply #map8()[%thread_id_x, %thread_id_y] + %23 = gpu.subgroup_broadcast %22, first_active_lane : index + %24 = arith.muli %21, %c8192 overflow : index + %25 = arith.addi %24, %9 overflow : index + amdgpu.gather_to_lds %15[%25], %alloc_6[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %26 = affine.apply #map9()[%thread_id_x, %thread_id_y, %block_id_x] + %27 = affine.apply #map10()[%thread_id_x, %thread_id_y] + %28 = gpu.subgroup_broadcast %27, first_active_lane : index + %29 = arith.muli %26, %c8192 overflow : index + %30 = arith.addi %29, %9 overflow : index + amdgpu.gather_to_lds %15[%30], %alloc_6[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %31 = affine.apply #map11()[%thread_id_x, %thread_id_y, %block_id_x] + %32 = affine.apply #map12()[%thread_id_x] + %33 = affine.apply #map13()[%thread_id_x] + %34 = arith.xori %33, %32 : index + %35 = affine.apply #map14()[%34] + %36 = affine.apply #map15()[%thread_id_x, %thread_id_y] + %37 = gpu.subgroup_broadcast %36, first_active_lane : index + %38 = arith.muli %31, %c512 overflow : index + %39 = arith.addi %38, %35 overflow : index + %reinterpret_cast_7 = memref.reinterpret_cast %1 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_8 = memref.cast %reinterpret_cast_7 : memref<2147483646xi8, strided<[1]>> to memref> + %40 = amdgpu.fat_raw_buffer_cast %cast_8 validBytes(%c2147483646_i64) cacheSwizzleStride(%c512_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %40[%39], %alloc_4[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + %41 = affine.apply #map()[%thread_id_x, %thread_id_y, %block_id_y] + %42 = arith.muli %41, %c8192 overflow : index + %43 = arith.addi %42, %9 overflow : index + %reinterpret_cast_9 = memref.reinterpret_cast %2 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_10 = memref.cast %reinterpret_cast_9 : memref<2147483646xi8, strided<[1]>> to memref> + %44 = amdgpu.fat_raw_buffer_cast %cast_10 validBytes(%c2147483646_i64) cacheSwizzleStride(%c-8192_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %44[%43], %alloc_2[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %45 = affine.apply #map5()[%thread_id_x, %thread_id_y, %block_id_y] + %46 = arith.muli %45, %c8192 overflow : index + %47 = arith.addi %46, %9 overflow : index + amdgpu.gather_to_lds %44[%47], %alloc_2[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %48 = affine.apply #map7()[%thread_id_x, %thread_id_y, %block_id_y] + %49 = arith.muli %48, %c8192 overflow : index + %50 = arith.addi %49, %9 overflow : index + amdgpu.gather_to_lds %44[%50], %alloc_2[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %51 = affine.apply #map9()[%thread_id_x, %thread_id_y, %block_id_y] + %52 = arith.muli %51, %c8192 overflow : index + %53 = arith.addi %52, %9 overflow : index + amdgpu.gather_to_lds %44[%53], %alloc_2[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %54 = affine.apply #map11()[%thread_id_x, %thread_id_y, %block_id_y] + %55 = arith.muli %54, %c512 overflow : index + %56 = arith.addi %55, %35 overflow : index + %reinterpret_cast_11 = memref.reinterpret_cast %3 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_12 = memref.cast %reinterpret_cast_11 : memref<2147483646xi8, strided<[1]>> to memref> + %57 = amdgpu.fat_raw_buffer_cast %cast_12 validBytes(%c2147483646_i64) cacheSwizzleStride(%c512_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %57[%56], %alloc_0[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + rocdl.s.barrier + %58 = affine.apply #map16()[%thread_id_x, %thread_id_y] + %59 = arith.index_cast %58 : index to i32 + %60 = arith.cmpi sge, %59, %c4_i32 : i32 + %61 = arith.cmpi slt, %59, %c4_i32 : i32 + scf.if %60 { + rocdl.s.barrier + } + %62 = affine.apply #map17()[%thread_id_x] + %63 = affine.apply #map18()[%thread_id_x] + %64 = arith.xori %63, %7 : index + %65 = affine.apply #map3()[%64] + %66 = affine.apply #map19()[%thread_id_x] + %67 = affine.apply #map20()[%thread_id_x] + %68 = affine.apply #map21()[%thread_id_x] + %69 = affine.apply #map22()[%thread_id_x] + %70 = affine.apply #map23()[%thread_id_x, %thread_id_y] + %71 = affine.apply #map24()[%thread_id_x, %thread_id_y] + %72 = affine.apply #map25()[%thread_id_x, %thread_id_y] + %73 = affine.apply #map26()[%thread_id_x, %thread_id_y] + %74 = affine.apply #map27()[%thread_id_x, %thread_id_y] + %75 = affine.apply #map28()[%thread_id_x, %thread_id_y] + %76 = affine.apply #map29()[%thread_id_x, %thread_id_y] + %77 = affine.apply #map30()[%thread_id_x, %thread_id_y] + %78 = affine.apply #map31()[%thread_id_x] + %79 = arith.xori %78, %7 : index + %80 = affine.apply #map3()[%79] + %81 = arith.xori %33, %c1 : index + %82 = affine.apply #map32()[%thread_id_x, %81] + %83:40 = scf.for %arg5 = %c0 to %c63 step %c1 iter_args(%arg6 = %cst, %arg7 = %cst, %arg8 = %cst, %arg9 = %cst, %arg10 = %cst, %arg11 = %cst, %arg12 = %cst, %arg13 = %cst, %arg14 = %cst, %arg15 = %cst, %arg16 = %cst, %arg17 = %cst, %arg18 = %cst, %arg19 = %cst, %arg20 = %cst, %arg21 = %cst, %arg22 = %cst, %arg23 = %cst, %arg24 = %cst, %arg25 = %cst, %arg26 = %cst, %arg27 = %cst, %arg28 = %cst, %arg29 = %cst, %arg30 = %cst, %arg31 = %cst, %arg32 = %cst, %arg33 = %cst, %arg34 = %cst, %arg35 = %cst, %arg36 = %cst, %arg37 = %cst, %arg38 = %alloc_6, %arg39 = %alloc_5, %arg40 = %alloc_4, %arg41 = %alloc_3, %arg42 = %alloc_2, %arg43 = %alloc_1, %arg44 = %alloc_0, %arg45 = %alloc) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>) { + rocdl.sched.barrier 0 + amdgpu.memory_counter_wait load(0) + rocdl.s.barrier + %582 = affine.apply #map33()[%arg5, %8] + %583 = arith.addi %13, %582 overflow : index + amdgpu.gather_to_lds %15[%583], %arg39[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %584 = arith.addi %19, %582 overflow : index + amdgpu.gather_to_lds %15[%584], %arg39[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %585 = arith.addi %24, %582 overflow : index + amdgpu.gather_to_lds %15[%585], %arg39[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %586 = arith.addi %29, %582 overflow : index + amdgpu.gather_to_lds %15[%586], %arg39[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %587 = affine.apply #map34()[%arg5, %34] + %588 = arith.addi %38, %587 overflow : index + amdgpu.gather_to_lds %40[%588], %arg41[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + %589 = arith.addi %42, %582 overflow : index + amdgpu.gather_to_lds %44[%589], %arg43[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %590 = arith.addi %46, %582 overflow : index + amdgpu.gather_to_lds %44[%590], %arg43[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %591 = arith.addi %49, %582 overflow : index + amdgpu.gather_to_lds %44[%591], %arg43[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %592 = arith.addi %52, %582 overflow : index + amdgpu.gather_to_lds %44[%592], %arg43[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %593 = arith.addi %55, %587 overflow : index + amdgpu.gather_to_lds %57[%593], %arg45[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + rocdl.sched.barrier 0 + %594 = vector.load %arg38[%62, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %595 = vector.load %arg38[%66, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %596 = vector.load %arg38[%67, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %597 = vector.load %arg38[%68, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %598 = vector.load %arg40[%62, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %599 = vector.load %arg40[%66, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %600 = vector.load %arg40[%67, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %601 = vector.load %arg40[%68, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %602 = vector.load %arg42[%70, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %603 = vector.load %arg42[%71, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %604 = vector.load %arg42[%72, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %605 = vector.load %arg42[%73, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %606 = vector.load %arg42[%74, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %607 = vector.load %arg42[%75, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %608 = vector.load %arg42[%76, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %609 = vector.load %arg42[%77, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %610 = vector.load %arg44[%70, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %611 = vector.load %arg44[%71, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %612 = vector.load %arg44[%72, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %613 = vector.load %arg44[%73, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %614 = vector.load %arg44[%74, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %615 = vector.load %arg44[%75, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %616 = vector.load %arg44[%76, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %617 = vector.load %arg44[%77, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %618 = vector.bitcast %594 : vector<16xi8> to vector<32xf4E2M1FN> + %619 = vector.bitcast %595 : vector<16xi8> to vector<32xf4E2M1FN> + %620 = vector.bitcast %596 : vector<16xi8> to vector<32xf4E2M1FN> + %621 = vector.bitcast %597 : vector<16xi8> to vector<32xf4E2M1FN> + %622 = vector.bitcast %598 : vector<1xi8> to vector<1xf8E8M0FNU> + %623 = vector.bitcast %599 : vector<1xi8> to vector<1xf8E8M0FNU> + %624 = vector.bitcast %600 : vector<1xi8> to vector<1xf8E8M0FNU> + %625 = vector.bitcast %601 : vector<1xi8> to vector<1xf8E8M0FNU> + %626 = vector.bitcast %602 : vector<16xi8> to vector<32xf4E2M1FN> + %627 = vector.bitcast %603 : vector<16xi8> to vector<32xf4E2M1FN> + %628 = vector.bitcast %604 : vector<16xi8> to vector<32xf4E2M1FN> + %629 = vector.bitcast %605 : vector<16xi8> to vector<32xf4E2M1FN> + %630 = vector.bitcast %606 : vector<16xi8> to vector<32xf4E2M1FN> + %631 = vector.bitcast %607 : vector<16xi8> to vector<32xf4E2M1FN> + %632 = vector.bitcast %608 : vector<16xi8> to vector<32xf4E2M1FN> + %633 = vector.bitcast %609 : vector<16xi8> to vector<32xf4E2M1FN> + %634 = vector.bitcast %610 : vector<1xi8> to vector<1xf8E8M0FNU> + %635 = vector.bitcast %611 : vector<1xi8> to vector<1xf8E8M0FNU> + %636 = vector.bitcast %612 : vector<1xi8> to vector<1xf8E8M0FNU> + %637 = vector.bitcast %613 : vector<1xi8> to vector<1xf8E8M0FNU> + %638 = vector.bitcast %614 : vector<1xi8> to vector<1xf8E8M0FNU> + %639 = vector.bitcast %615 : vector<1xi8> to vector<1xf8E8M0FNU> + %640 = vector.bitcast %616 : vector<1xi8> to vector<1xf8E8M0FNU> + %641 = vector.bitcast %617 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.s.setprio 1 + %642 = vector.extract %622[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %643 = vector.extract %634[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %644 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%643[0] * %626) + %arg6 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %645 = vector.extract %635[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %646 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%645[0] * %627) + %arg7 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %647 = vector.extract %636[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %648 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%647[0] * %628) + %arg8 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %649 = vector.extract %637[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %650 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%649[0] * %629) + %arg9 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %651 = vector.extract %638[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %652 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%651[0] * %630) + %arg10 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %653 = vector.extract %639[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %654 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%653[0] * %631) + %arg11 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %655 = vector.extract %640[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %656 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%655[0] * %632) + %arg12 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %657 = vector.extract %641[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %658 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%657[0] * %633) + %arg13 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %659 = vector.extract %623[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %660 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%643[0] * %626) + %arg14 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %661 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%645[0] * %627) + %arg15 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %662 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%647[0] * %628) + %arg16 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %663 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%649[0] * %629) + %arg17 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %664 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%651[0] * %630) + %arg18 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %665 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%653[0] * %631) + %arg19 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %666 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%655[0] * %632) + %arg20 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %667 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%657[0] * %633) + %arg21 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %668 = vector.extract %624[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %669 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%643[0] * %626) + %arg22 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %670 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%645[0] * %627) + %arg23 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %671 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%647[0] * %628) + %arg24 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %672 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%649[0] * %629) + %arg25 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %673 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%651[0] * %630) + %arg26 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %674 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%653[0] * %631) + %arg27 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %675 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%655[0] * %632) + %arg28 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %676 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%657[0] * %633) + %arg29 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %677 = vector.extract %625[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %678 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%643[0] * %626) + %arg30 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %679 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%645[0] * %627) + %arg31 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %680 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%647[0] * %628) + %arg32 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %681 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%649[0] * %629) + %arg33 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %682 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%651[0] * %630) + %arg34 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %683 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%653[0] * %631) + %arg35 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %684 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%655[0] * %632) + %arg36 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %685 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%657[0] * %633) + %arg37 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + %686 = vector.load %arg38[%62, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %687 = vector.load %arg38[%66, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %688 = vector.load %arg38[%67, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %689 = vector.load %arg38[%68, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %690 = vector.load %arg40[%62, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %691 = vector.load %arg40[%66, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %692 = vector.load %arg40[%67, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %693 = vector.load %arg40[%68, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %694 = vector.load %arg42[%70, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %695 = vector.load %arg42[%71, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %696 = vector.load %arg42[%72, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %697 = vector.load %arg42[%73, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %698 = vector.load %arg42[%74, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %699 = vector.load %arg42[%75, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %700 = vector.load %arg42[%76, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %701 = vector.load %arg42[%77, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %702 = vector.load %arg44[%70, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %703 = vector.load %arg44[%71, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %704 = vector.load %arg44[%72, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %705 = vector.load %arg44[%73, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %706 = vector.load %arg44[%74, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %707 = vector.load %arg44[%75, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %708 = vector.load %arg44[%76, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %709 = vector.load %arg44[%77, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %710 = vector.bitcast %686 : vector<16xi8> to vector<32xf4E2M1FN> + %711 = vector.bitcast %687 : vector<16xi8> to vector<32xf4E2M1FN> + %712 = vector.bitcast %688 : vector<16xi8> to vector<32xf4E2M1FN> + %713 = vector.bitcast %689 : vector<16xi8> to vector<32xf4E2M1FN> + %714 = vector.bitcast %690 : vector<1xi8> to vector<1xf8E8M0FNU> + %715 = vector.bitcast %691 : vector<1xi8> to vector<1xf8E8M0FNU> + %716 = vector.bitcast %692 : vector<1xi8> to vector<1xf8E8M0FNU> + %717 = vector.bitcast %693 : vector<1xi8> to vector<1xf8E8M0FNU> + %718 = vector.bitcast %694 : vector<16xi8> to vector<32xf4E2M1FN> + %719 = vector.bitcast %695 : vector<16xi8> to vector<32xf4E2M1FN> + %720 = vector.bitcast %696 : vector<16xi8> to vector<32xf4E2M1FN> + %721 = vector.bitcast %697 : vector<16xi8> to vector<32xf4E2M1FN> + %722 = vector.bitcast %698 : vector<16xi8> to vector<32xf4E2M1FN> + %723 = vector.bitcast %699 : vector<16xi8> to vector<32xf4E2M1FN> + %724 = vector.bitcast %700 : vector<16xi8> to vector<32xf4E2M1FN> + %725 = vector.bitcast %701 : vector<16xi8> to vector<32xf4E2M1FN> + %726 = vector.bitcast %702 : vector<1xi8> to vector<1xf8E8M0FNU> + %727 = vector.bitcast %703 : vector<1xi8> to vector<1xf8E8M0FNU> + %728 = vector.bitcast %704 : vector<1xi8> to vector<1xf8E8M0FNU> + %729 = vector.bitcast %705 : vector<1xi8> to vector<1xf8E8M0FNU> + %730 = vector.bitcast %706 : vector<1xi8> to vector<1xf8E8M0FNU> + %731 = vector.bitcast %707 : vector<1xi8> to vector<1xf8E8M0FNU> + %732 = vector.bitcast %708 : vector<1xi8> to vector<1xf8E8M0FNU> + %733 = vector.bitcast %709 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.s.setprio 1 + %734 = vector.extract %714[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %735 = vector.extract %726[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %736 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%735[0] * %718) + %644 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %737 = vector.extract %727[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %738 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%737[0] * %719) + %646 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %739 = vector.extract %728[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %740 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%739[0] * %720) + %648 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %741 = vector.extract %729[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %742 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%741[0] * %721) + %650 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %743 = vector.extract %730[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %744 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%743[0] * %722) + %652 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %745 = vector.extract %731[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %746 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%745[0] * %723) + %654 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %747 = vector.extract %732[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %748 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%747[0] * %724) + %656 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %749 = vector.extract %733[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %750 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%749[0] * %725) + %658 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %751 = vector.extract %715[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %752 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%735[0] * %718) + %660 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %753 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%737[0] * %719) + %661 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %754 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%739[0] * %720) + %662 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %755 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%741[0] * %721) + %663 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %756 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%743[0] * %722) + %664 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %757 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%745[0] * %723) + %665 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %758 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%747[0] * %724) + %666 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %759 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%749[0] * %725) + %667 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %760 = vector.extract %716[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %761 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%735[0] * %718) + %669 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %762 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%737[0] * %719) + %670 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %763 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%739[0] * %720) + %671 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %764 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%741[0] * %721) + %672 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %765 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%743[0] * %722) + %673 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %766 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%745[0] * %723) + %674 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %767 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%747[0] * %724) + %675 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %768 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%749[0] * %725) + %676 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %769 = vector.extract %717[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %770 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%735[0] * %718) + %678 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %771 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%737[0] * %719) + %679 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %772 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%739[0] * %720) + %680 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %773 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%741[0] * %721) + %681 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %774 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%743[0] * %722) + %682 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %775 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%745[0] * %723) + %683 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %776 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%747[0] * %724) + %684 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %777 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%749[0] * %725) + %685 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + rocdl.sched.barrier 0 + scf.yield %736, %738, %740, %742, %744, %746, %748, %750, %752, %753, %754, %755, %756, %757, %758, %759, %761, %762, %763, %764, %765, %766, %767, %768, %770, %771, %772, %773, %774, %775, %776, %777, %arg39, %arg38, %arg41, %arg40, %arg43, %arg42, %arg45, %arg44 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space> + } + amdgpu.lds_barrier + scf.if %61 { + rocdl.s.barrier + } + %84 = affine.apply #map23()[%thread_id_x, %thread_id_y] + %85 = affine.apply #map22()[%thread_id_x] + %86 = vector.load %83#38[%84, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %87 = arith.xori %33, %c1 : index + %88 = affine.apply #map32()[%thread_id_x, %87] + %89 = vector.load %83#38[%84, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %90 = affine.apply #map24()[%thread_id_x, %thread_id_y] + %91 = vector.load %83#38[%90, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %92 = vector.load %83#38[%90, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %93 = affine.apply #map25()[%thread_id_x, %thread_id_y] + %94 = vector.load %83#38[%93, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %95 = vector.load %83#38[%93, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %96 = affine.apply #map26()[%thread_id_x, %thread_id_y] + %97 = vector.load %83#38[%96, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %98 = vector.load %83#38[%96, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %99 = affine.apply #map27()[%thread_id_x, %thread_id_y] + %100 = vector.load %83#38[%99, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %101 = vector.load %83#38[%99, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %102 = affine.apply #map28()[%thread_id_x, %thread_id_y] + %103 = vector.load %83#38[%102, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %104 = vector.load %83#38[%102, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %105 = affine.apply #map29()[%thread_id_x, %thread_id_y] + %106 = vector.load %83#38[%105, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %107 = vector.load %83#38[%105, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %108 = affine.apply #map30()[%thread_id_x, %thread_id_y] + %109 = vector.load %83#38[%108, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %110 = vector.load %83#38[%108, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %111 = affine.apply #map18()[%thread_id_x] + %112 = arith.xori %111, %7 : index + %113 = affine.apply #map3()[%112] + %114 = vector.load %83#36[%84, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %115 = affine.apply #map31()[%thread_id_x] + %116 = arith.xori %115, %7 : index + %117 = affine.apply #map3()[%116] + %118 = vector.load %83#36[%84, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %119 = vector.load %83#36[%90, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %120 = vector.load %83#36[%90, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %121 = vector.load %83#36[%93, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %122 = vector.load %83#36[%93, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %123 = vector.load %83#36[%96, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %124 = vector.load %83#36[%96, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %125 = vector.load %83#36[%99, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %126 = vector.load %83#36[%99, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %127 = vector.load %83#36[%102, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %128 = vector.load %83#36[%102, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %129 = vector.load %83#36[%105, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %130 = vector.load %83#36[%105, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %131 = vector.load %83#36[%108, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %132 = vector.load %83#36[%108, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %133 = affine.apply #map17()[%thread_id_x] + %134 = vector.load %83#34[%133, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %135 = vector.load %83#34[%133, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %136 = affine.apply #map19()[%thread_id_x] + %137 = vector.load %83#34[%136, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %138 = vector.load %83#34[%136, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %139 = affine.apply #map20()[%thread_id_x] + %140 = vector.load %83#34[%139, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %141 = vector.load %83#34[%139, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %142 = affine.apply #map21()[%thread_id_x] + %143 = vector.load %83#34[%142, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %144 = vector.load %83#34[%142, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %145 = vector.load %83#32[%133, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %146 = vector.load %83#32[%133, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %147 = vector.load %83#32[%136, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %148 = vector.load %83#32[%136, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %149 = vector.load %83#32[%139, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %150 = vector.load %83#32[%139, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %151 = vector.load %83#32[%142, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %152 = vector.load %83#32[%142, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %153 = vector.bitcast %145 : vector<16xi8> to vector<32xf4E2M1FN> + %154 = vector.bitcast %146 : vector<16xi8> to vector<32xf4E2M1FN> + %155 = vector.bitcast %147 : vector<16xi8> to vector<32xf4E2M1FN> + %156 = vector.bitcast %148 : vector<16xi8> to vector<32xf4E2M1FN> + %157 = vector.bitcast %149 : vector<16xi8> to vector<32xf4E2M1FN> + %158 = vector.bitcast %150 : vector<16xi8> to vector<32xf4E2M1FN> + %159 = vector.bitcast %151 : vector<16xi8> to vector<32xf4E2M1FN> + %160 = vector.bitcast %152 : vector<16xi8> to vector<32xf4E2M1FN> + %161 = vector.bitcast %134 : vector<1xi8> to vector<1xf8E8M0FNU> + %162 = vector.bitcast %135 : vector<1xi8> to vector<1xf8E8M0FNU> + %163 = vector.bitcast %137 : vector<1xi8> to vector<1xf8E8M0FNU> + %164 = vector.bitcast %138 : vector<1xi8> to vector<1xf8E8M0FNU> + %165 = vector.bitcast %140 : vector<1xi8> to vector<1xf8E8M0FNU> + %166 = vector.bitcast %141 : vector<1xi8> to vector<1xf8E8M0FNU> + %167 = vector.bitcast %143 : vector<1xi8> to vector<1xf8E8M0FNU> + %168 = vector.bitcast %144 : vector<1xi8> to vector<1xf8E8M0FNU> + %169 = vector.bitcast %114 : vector<16xi8> to vector<32xf4E2M1FN> + %170 = vector.bitcast %118 : vector<16xi8> to vector<32xf4E2M1FN> + %171 = vector.bitcast %119 : vector<16xi8> to vector<32xf4E2M1FN> + %172 = vector.bitcast %120 : vector<16xi8> to vector<32xf4E2M1FN> + %173 = vector.bitcast %121 : vector<16xi8> to vector<32xf4E2M1FN> + %174 = vector.bitcast %122 : vector<16xi8> to vector<32xf4E2M1FN> + %175 = vector.bitcast %123 : vector<16xi8> to vector<32xf4E2M1FN> + %176 = vector.bitcast %124 : vector<16xi8> to vector<32xf4E2M1FN> + %177 = vector.bitcast %125 : vector<16xi8> to vector<32xf4E2M1FN> + %178 = vector.bitcast %126 : vector<16xi8> to vector<32xf4E2M1FN> + %179 = vector.bitcast %127 : vector<16xi8> to vector<32xf4E2M1FN> + %180 = vector.bitcast %128 : vector<16xi8> to vector<32xf4E2M1FN> + %181 = vector.bitcast %129 : vector<16xi8> to vector<32xf4E2M1FN> + %182 = vector.bitcast %130 : vector<16xi8> to vector<32xf4E2M1FN> + %183 = vector.bitcast %131 : vector<16xi8> to vector<32xf4E2M1FN> + %184 = vector.bitcast %132 : vector<16xi8> to vector<32xf4E2M1FN> + %185 = vector.bitcast %86 : vector<1xi8> to vector<1xf8E8M0FNU> + %186 = vector.bitcast %89 : vector<1xi8> to vector<1xf8E8M0FNU> + %187 = vector.bitcast %91 : vector<1xi8> to vector<1xf8E8M0FNU> + %188 = vector.bitcast %92 : vector<1xi8> to vector<1xf8E8M0FNU> + %189 = vector.bitcast %94 : vector<1xi8> to vector<1xf8E8M0FNU> + %190 = vector.bitcast %95 : vector<1xi8> to vector<1xf8E8M0FNU> + %191 = vector.bitcast %97 : vector<1xi8> to vector<1xf8E8M0FNU> + %192 = vector.bitcast %98 : vector<1xi8> to vector<1xf8E8M0FNU> + %193 = vector.bitcast %100 : vector<1xi8> to vector<1xf8E8M0FNU> + %194 = vector.bitcast %101 : vector<1xi8> to vector<1xf8E8M0FNU> + %195 = vector.bitcast %103 : vector<1xi8> to vector<1xf8E8M0FNU> + %196 = vector.bitcast %104 : vector<1xi8> to vector<1xf8E8M0FNU> + %197 = vector.bitcast %106 : vector<1xi8> to vector<1xf8E8M0FNU> + %198 = vector.bitcast %107 : vector<1xi8> to vector<1xf8E8M0FNU> + %199 = vector.bitcast %109 : vector<1xi8> to vector<1xf8E8M0FNU> + %200 = vector.bitcast %110 : vector<1xi8> to vector<1xf8E8M0FNU> + %201 = vector.extract %161[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %202 = vector.extract %185[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %203 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%202[0] * %169) + %83#0 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %204 = vector.extract %162[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %205 = vector.extract %186[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %206 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%205[0] * %170) + %203 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %207 = vector.extract %187[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %208 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%207[0] * %171) + %83#1 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %209 = vector.extract %188[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %210 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%209[0] * %172) + %208 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %211 = vector.extract %189[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %212 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%211[0] * %173) + %83#2 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %213 = vector.extract %190[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %214 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%213[0] * %174) + %212 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %215 = vector.extract %191[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %216 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%215[0] * %175) + %83#3 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %217 = vector.extract %192[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %218 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%217[0] * %176) + %216 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %219 = vector.extract %193[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %220 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%219[0] * %177) + %83#4 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %221 = vector.extract %194[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %222 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%221[0] * %178) + %220 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %223 = vector.extract %195[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %224 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%223[0] * %179) + %83#5 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %225 = vector.extract %196[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %226 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%225[0] * %180) + %224 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %227 = vector.extract %197[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %228 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%227[0] * %181) + %83#6 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %229 = vector.extract %198[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %230 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%229[0] * %182) + %228 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %231 = vector.extract %199[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %232 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%231[0] * %183) + %83#7 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %233 = vector.extract %200[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %234 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%233[0] * %184) + %232 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %235 = vector.extract %163[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %236 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%202[0] * %169) + %83#8 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %237 = vector.extract %164[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %238 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%205[0] * %170) + %236 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %239 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%207[0] * %171) + %83#9 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %240 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%209[0] * %172) + %239 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %241 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%211[0] * %173) + %83#10 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %242 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%213[0] * %174) + %241 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %243 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%215[0] * %175) + %83#11 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %244 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%217[0] * %176) + %243 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %245 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%219[0] * %177) + %83#12 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %246 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%221[0] * %178) + %245 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %247 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%223[0] * %179) + %83#13 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %248 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%225[0] * %180) + %247 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %249 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%227[0] * %181) + %83#14 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %250 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%229[0] * %182) + %249 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %251 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%231[0] * %183) + %83#15 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %252 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%233[0] * %184) + %251 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %253 = vector.extract %165[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %254 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%202[0] * %169) + %83#16 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %255 = vector.extract %166[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %256 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%205[0] * %170) + %254 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %257 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%207[0] * %171) + %83#17 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %258 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%209[0] * %172) + %257 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %259 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%211[0] * %173) + %83#18 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %260 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%213[0] * %174) + %259 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %261 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%215[0] * %175) + %83#19 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %262 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%217[0] * %176) + %261 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %263 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%219[0] * %177) + %83#20 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %264 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%221[0] * %178) + %263 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %265 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%223[0] * %179) + %83#21 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %266 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%225[0] * %180) + %265 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %267 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%227[0] * %181) + %83#22 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %268 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%229[0] * %182) + %267 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %269 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%231[0] * %183) + %83#23 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %270 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%233[0] * %184) + %269 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %271 = vector.extract %167[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %272 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%202[0] * %169) + %83#24 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %273 = vector.extract %168[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %274 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%205[0] * %170) + %272 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %275 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%207[0] * %171) + %83#25 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %276 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%209[0] * %172) + %275 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %277 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%211[0] * %173) + %83#26 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %278 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%213[0] * %174) + %277 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %279 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%215[0] * %175) + %83#27 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %280 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%217[0] * %176) + %279 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %281 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%219[0] * %177) + %83#28 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %282 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%221[0] * %178) + %281 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %283 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%223[0] * %179) + %83#29 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %284 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%225[0] * %180) + %283 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %285 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%227[0] * %181) + %83#30 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %286 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%229[0] * %182) + %285 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %287 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%231[0] * %183) + %83#31 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %288 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%233[0] * %184) + %287 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %289 = vector.extract_strided_slice %206 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %290 = affine.apply #map35()[%block_id_x] + %291 = affine.apply #map35()[%block_id_y] + %292 = affine.apply #map36()[%thread_id_x] + %293 = arith.muli %290, %c16384 overflow : index + %294 = arith.muli %292, %c16384 overflow : index + %295 = arith.addi %293, %291 overflow : index + %296 = arith.addi %294, %84 overflow : index + %reinterpret_cast_13 = memref.reinterpret_cast %4 to offset: [%295], sizes: [536870910], strides: [1] : memref to memref<536870910xf32, strided<[1], offset: ?>> + %cast_14 = memref.cast %reinterpret_cast_13 : memref<536870910xf32, strided<[1], offset: ?>> to memref> + %297 = amdgpu.fat_raw_buffer_cast %cast_14 validBytes(%c2147483643_i64) resetOffset : memref> to memref> + vector.store %289, %297[%296] : memref>, vector<1xf32> + %298 = vector.extract_strided_slice %206 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %299 = affine.apply #map37()[%thread_id_x] + %300 = arith.muli %299, %c16384 overflow : index + %301 = arith.addi %300, %84 overflow : index + vector.store %298, %297[%301] : memref>, vector<1xf32> + %302 = vector.extract_strided_slice %206 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %303 = affine.apply #map38()[%thread_id_x] + %304 = arith.muli %303, %c16384 overflow : index + %305 = arith.addi %304, %84 overflow : index + vector.store %302, %297[%305] : memref>, vector<1xf32> + %306 = vector.extract_strided_slice %206 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %307 = affine.apply #map39()[%thread_id_x] + %308 = arith.muli %307, %c16384 overflow : index + %309 = arith.addi %308, %84 overflow : index + vector.store %306, %297[%309] : memref>, vector<1xf32> + %310 = vector.extract_strided_slice %210 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %311 = arith.addi %294, %90 overflow : index + vector.store %310, %297[%311] : memref>, vector<1xf32> + %312 = vector.extract_strided_slice %210 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %313 = arith.addi %300, %90 overflow : index + vector.store %312, %297[%313] : memref>, vector<1xf32> + %314 = vector.extract_strided_slice %210 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %315 = arith.addi %304, %90 overflow : index + vector.store %314, %297[%315] : memref>, vector<1xf32> + %316 = vector.extract_strided_slice %210 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %317 = arith.addi %308, %90 overflow : index + vector.store %316, %297[%317] : memref>, vector<1xf32> + %318 = vector.extract_strided_slice %214 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %319 = arith.addi %294, %93 overflow : index + vector.store %318, %297[%319] : memref>, vector<1xf32> + %320 = vector.extract_strided_slice %214 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %321 = arith.addi %300, %93 overflow : index + vector.store %320, %297[%321] : memref>, vector<1xf32> + %322 = vector.extract_strided_slice %214 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %323 = arith.addi %304, %93 overflow : index + vector.store %322, %297[%323] : memref>, vector<1xf32> + %324 = vector.extract_strided_slice %214 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %325 = arith.addi %308, %93 overflow : index + vector.store %324, %297[%325] : memref>, vector<1xf32> + %326 = vector.extract_strided_slice %218 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %327 = arith.addi %294, %96 overflow : index + vector.store %326, %297[%327] : memref>, vector<1xf32> + %328 = vector.extract_strided_slice %218 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %329 = arith.addi %300, %96 overflow : index + vector.store %328, %297[%329] : memref>, vector<1xf32> + %330 = vector.extract_strided_slice %218 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %331 = arith.addi %304, %96 overflow : index + vector.store %330, %297[%331] : memref>, vector<1xf32> + %332 = vector.extract_strided_slice %218 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %333 = arith.addi %308, %96 overflow : index + vector.store %332, %297[%333] : memref>, vector<1xf32> + %334 = vector.extract_strided_slice %222 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %335 = arith.addi %294, %99 overflow : index + vector.store %334, %297[%335] : memref>, vector<1xf32> + %336 = vector.extract_strided_slice %222 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %337 = arith.addi %300, %99 overflow : index + vector.store %336, %297[%337] : memref>, vector<1xf32> + %338 = vector.extract_strided_slice %222 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %339 = arith.addi %304, %99 overflow : index + vector.store %338, %297[%339] : memref>, vector<1xf32> + %340 = vector.extract_strided_slice %222 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %341 = arith.addi %308, %99 overflow : index + vector.store %340, %297[%341] : memref>, vector<1xf32> + %342 = vector.extract_strided_slice %226 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %343 = arith.addi %294, %102 overflow : index + vector.store %342, %297[%343] : memref>, vector<1xf32> + %344 = vector.extract_strided_slice %226 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %345 = arith.addi %300, %102 overflow : index + vector.store %344, %297[%345] : memref>, vector<1xf32> + %346 = vector.extract_strided_slice %226 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %347 = arith.addi %304, %102 overflow : index + vector.store %346, %297[%347] : memref>, vector<1xf32> + %348 = vector.extract_strided_slice %226 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %349 = arith.addi %308, %102 overflow : index + vector.store %348, %297[%349] : memref>, vector<1xf32> + %350 = vector.extract_strided_slice %230 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %351 = arith.addi %294, %105 overflow : index + vector.store %350, %297[%351] : memref>, vector<1xf32> + %352 = vector.extract_strided_slice %230 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %353 = arith.addi %300, %105 overflow : index + vector.store %352, %297[%353] : memref>, vector<1xf32> + %354 = vector.extract_strided_slice %230 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %355 = arith.addi %304, %105 overflow : index + vector.store %354, %297[%355] : memref>, vector<1xf32> + %356 = vector.extract_strided_slice %230 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %357 = arith.addi %308, %105 overflow : index + vector.store %356, %297[%357] : memref>, vector<1xf32> + %358 = vector.extract_strided_slice %234 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %359 = arith.addi %294, %108 overflow : index + vector.store %358, %297[%359] : memref>, vector<1xf32> + %360 = vector.extract_strided_slice %234 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %361 = arith.addi %300, %108 overflow : index + vector.store %360, %297[%361] : memref>, vector<1xf32> + %362 = vector.extract_strided_slice %234 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %363 = arith.addi %304, %108 overflow : index + vector.store %362, %297[%363] : memref>, vector<1xf32> + %364 = vector.extract_strided_slice %234 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %365 = arith.addi %308, %108 overflow : index + vector.store %364, %297[%365] : memref>, vector<1xf32> + %366 = vector.extract_strided_slice %238 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %367 = affine.apply #map40()[%thread_id_x] + %368 = arith.muli %367, %c16384 overflow : index + %369 = arith.addi %368, %84 overflow : index + vector.store %366, %297[%369] : memref>, vector<1xf32> + %370 = vector.extract_strided_slice %238 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %371 = affine.apply #map41()[%thread_id_x] + %372 = arith.muli %371, %c16384 overflow : index + %373 = arith.addi %372, %84 overflow : index + vector.store %370, %297[%373] : memref>, vector<1xf32> + %374 = vector.extract_strided_slice %238 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %375 = affine.apply #map42()[%thread_id_x] + %376 = arith.muli %375, %c16384 overflow : index + %377 = arith.addi %376, %84 overflow : index + vector.store %374, %297[%377] : memref>, vector<1xf32> + %378 = vector.extract_strided_slice %238 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %379 = affine.apply #map43()[%thread_id_x] + %380 = arith.muli %379, %c16384 overflow : index + %381 = arith.addi %380, %84 overflow : index + vector.store %378, %297[%381] : memref>, vector<1xf32> + %382 = vector.extract_strided_slice %240 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %383 = arith.addi %368, %90 overflow : index + vector.store %382, %297[%383] : memref>, vector<1xf32> + %384 = vector.extract_strided_slice %240 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %385 = arith.addi %372, %90 overflow : index + vector.store %384, %297[%385] : memref>, vector<1xf32> + %386 = vector.extract_strided_slice %240 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %387 = arith.addi %376, %90 overflow : index + vector.store %386, %297[%387] : memref>, vector<1xf32> + %388 = vector.extract_strided_slice %240 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %389 = arith.addi %380, %90 overflow : index + vector.store %388, %297[%389] : memref>, vector<1xf32> + %390 = vector.extract_strided_slice %242 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %391 = arith.addi %368, %93 overflow : index + vector.store %390, %297[%391] : memref>, vector<1xf32> + %392 = vector.extract_strided_slice %242 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %393 = arith.addi %372, %93 overflow : index + vector.store %392, %297[%393] : memref>, vector<1xf32> + %394 = vector.extract_strided_slice %242 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %395 = arith.addi %376, %93 overflow : index + vector.store %394, %297[%395] : memref>, vector<1xf32> + %396 = vector.extract_strided_slice %242 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %397 = arith.addi %380, %93 overflow : index + vector.store %396, %297[%397] : memref>, vector<1xf32> + %398 = vector.extract_strided_slice %244 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %399 = arith.addi %368, %96 overflow : index + vector.store %398, %297[%399] : memref>, vector<1xf32> + %400 = vector.extract_strided_slice %244 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %401 = arith.addi %372, %96 overflow : index + vector.store %400, %297[%401] : memref>, vector<1xf32> + %402 = vector.extract_strided_slice %244 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %403 = arith.addi %376, %96 overflow : index + vector.store %402, %297[%403] : memref>, vector<1xf32> + %404 = vector.extract_strided_slice %244 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %405 = arith.addi %380, %96 overflow : index + vector.store %404, %297[%405] : memref>, vector<1xf32> + %406 = vector.extract_strided_slice %246 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %407 = arith.addi %368, %99 overflow : index + vector.store %406, %297[%407] : memref>, vector<1xf32> + %408 = vector.extract_strided_slice %246 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %409 = arith.addi %372, %99 overflow : index + vector.store %408, %297[%409] : memref>, vector<1xf32> + %410 = vector.extract_strided_slice %246 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %411 = arith.addi %376, %99 overflow : index + vector.store %410, %297[%411] : memref>, vector<1xf32> + %412 = vector.extract_strided_slice %246 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %413 = arith.addi %380, %99 overflow : index + vector.store %412, %297[%413] : memref>, vector<1xf32> + %414 = vector.extract_strided_slice %248 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %415 = arith.addi %368, %102 overflow : index + vector.store %414, %297[%415] : memref>, vector<1xf32> + %416 = vector.extract_strided_slice %248 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %417 = arith.addi %372, %102 overflow : index + vector.store %416, %297[%417] : memref>, vector<1xf32> + %418 = vector.extract_strided_slice %248 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %419 = arith.addi %376, %102 overflow : index + vector.store %418, %297[%419] : memref>, vector<1xf32> + %420 = vector.extract_strided_slice %248 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %421 = arith.addi %380, %102 overflow : index + vector.store %420, %297[%421] : memref>, vector<1xf32> + %422 = vector.extract_strided_slice %250 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %423 = arith.addi %368, %105 overflow : index + vector.store %422, %297[%423] : memref>, vector<1xf32> + %424 = vector.extract_strided_slice %250 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %425 = arith.addi %372, %105 overflow : index + vector.store %424, %297[%425] : memref>, vector<1xf32> + %426 = vector.extract_strided_slice %250 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %427 = arith.addi %376, %105 overflow : index + vector.store %426, %297[%427] : memref>, vector<1xf32> + %428 = vector.extract_strided_slice %250 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %429 = arith.addi %380, %105 overflow : index + vector.store %428, %297[%429] : memref>, vector<1xf32> + %430 = vector.extract_strided_slice %252 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %431 = arith.addi %368, %108 overflow : index + vector.store %430, %297[%431] : memref>, vector<1xf32> + %432 = vector.extract_strided_slice %252 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %433 = arith.addi %372, %108 overflow : index + vector.store %432, %297[%433] : memref>, vector<1xf32> + %434 = vector.extract_strided_slice %252 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %435 = arith.addi %376, %108 overflow : index + vector.store %434, %297[%435] : memref>, vector<1xf32> + %436 = vector.extract_strided_slice %252 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %437 = arith.addi %380, %108 overflow : index + vector.store %436, %297[%437] : memref>, vector<1xf32> + %438 = vector.extract_strided_slice %256 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %439 = affine.apply #map44()[%thread_id_x] + %440 = arith.muli %439, %c16384 overflow : index + %441 = arith.addi %440, %84 overflow : index + vector.store %438, %297[%441] : memref>, vector<1xf32> + %442 = vector.extract_strided_slice %256 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %443 = affine.apply #map45()[%thread_id_x] + %444 = arith.muli %443, %c16384 overflow : index + %445 = arith.addi %444, %84 overflow : index + vector.store %442, %297[%445] : memref>, vector<1xf32> + %446 = vector.extract_strided_slice %256 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %447 = affine.apply #map46()[%thread_id_x] + %448 = arith.muli %447, %c16384 overflow : index + %449 = arith.addi %448, %84 overflow : index + vector.store %446, %297[%449] : memref>, vector<1xf32> + %450 = vector.extract_strided_slice %256 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %451 = affine.apply #map47()[%thread_id_x] + %452 = arith.muli %451, %c16384 overflow : index + %453 = arith.addi %452, %84 overflow : index + vector.store %450, %297[%453] : memref>, vector<1xf32> + %454 = vector.extract_strided_slice %258 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %455 = arith.addi %440, %90 overflow : index + vector.store %454, %297[%455] : memref>, vector<1xf32> + %456 = vector.extract_strided_slice %258 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %457 = arith.addi %444, %90 overflow : index + vector.store %456, %297[%457] : memref>, vector<1xf32> + %458 = vector.extract_strided_slice %258 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %459 = arith.addi %448, %90 overflow : index + vector.store %458, %297[%459] : memref>, vector<1xf32> + %460 = vector.extract_strided_slice %258 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %461 = arith.addi %452, %90 overflow : index + vector.store %460, %297[%461] : memref>, vector<1xf32> + %462 = vector.extract_strided_slice %260 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %463 = arith.addi %440, %93 overflow : index + vector.store %462, %297[%463] : memref>, vector<1xf32> + %464 = vector.extract_strided_slice %260 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %465 = arith.addi %444, %93 overflow : index + vector.store %464, %297[%465] : memref>, vector<1xf32> + %466 = vector.extract_strided_slice %260 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %467 = arith.addi %448, %93 overflow : index + vector.store %466, %297[%467] : memref>, vector<1xf32> + %468 = vector.extract_strided_slice %260 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %469 = arith.addi %452, %93 overflow : index + vector.store %468, %297[%469] : memref>, vector<1xf32> + %470 = vector.extract_strided_slice %262 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %471 = arith.addi %440, %96 overflow : index + vector.store %470, %297[%471] : memref>, vector<1xf32> + %472 = vector.extract_strided_slice %262 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %473 = arith.addi %444, %96 overflow : index + vector.store %472, %297[%473] : memref>, vector<1xf32> + %474 = vector.extract_strided_slice %262 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %475 = arith.addi %448, %96 overflow : index + vector.store %474, %297[%475] : memref>, vector<1xf32> + %476 = vector.extract_strided_slice %262 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %477 = arith.addi %452, %96 overflow : index + vector.store %476, %297[%477] : memref>, vector<1xf32> + %478 = vector.extract_strided_slice %264 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %479 = arith.addi %440, %99 overflow : index + vector.store %478, %297[%479] : memref>, vector<1xf32> + %480 = vector.extract_strided_slice %264 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %481 = arith.addi %444, %99 overflow : index + vector.store %480, %297[%481] : memref>, vector<1xf32> + %482 = vector.extract_strided_slice %264 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %483 = arith.addi %448, %99 overflow : index + vector.store %482, %297[%483] : memref>, vector<1xf32> + %484 = vector.extract_strided_slice %264 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %485 = arith.addi %452, %99 overflow : index + vector.store %484, %297[%485] : memref>, vector<1xf32> + %486 = vector.extract_strided_slice %266 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %487 = arith.addi %440, %102 overflow : index + vector.store %486, %297[%487] : memref>, vector<1xf32> + %488 = vector.extract_strided_slice %266 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %489 = arith.addi %444, %102 overflow : index + vector.store %488, %297[%489] : memref>, vector<1xf32> + %490 = vector.extract_strided_slice %266 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %491 = arith.addi %448, %102 overflow : index + vector.store %490, %297[%491] : memref>, vector<1xf32> + %492 = vector.extract_strided_slice %266 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %493 = arith.addi %452, %102 overflow : index + vector.store %492, %297[%493] : memref>, vector<1xf32> + %494 = vector.extract_strided_slice %268 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %495 = arith.addi %440, %105 overflow : index + vector.store %494, %297[%495] : memref>, vector<1xf32> + %496 = vector.extract_strided_slice %268 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %497 = arith.addi %444, %105 overflow : index + vector.store %496, %297[%497] : memref>, vector<1xf32> + %498 = vector.extract_strided_slice %268 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %499 = arith.addi %448, %105 overflow : index + vector.store %498, %297[%499] : memref>, vector<1xf32> + %500 = vector.extract_strided_slice %268 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %501 = arith.addi %452, %105 overflow : index + vector.store %500, %297[%501] : memref>, vector<1xf32> + %502 = vector.extract_strided_slice %270 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %503 = arith.addi %440, %108 overflow : index + vector.store %502, %297[%503] : memref>, vector<1xf32> + %504 = vector.extract_strided_slice %270 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %505 = arith.addi %444, %108 overflow : index + vector.store %504, %297[%505] : memref>, vector<1xf32> + %506 = vector.extract_strided_slice %270 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %507 = arith.addi %448, %108 overflow : index + vector.store %506, %297[%507] : memref>, vector<1xf32> + %508 = vector.extract_strided_slice %270 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %509 = arith.addi %452, %108 overflow : index + vector.store %508, %297[%509] : memref>, vector<1xf32> + %510 = vector.extract_strided_slice %274 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %511 = affine.apply #map48()[%thread_id_x] + %512 = arith.muli %511, %c16384 overflow : index + %513 = arith.addi %512, %84 overflow : index + vector.store %510, %297[%513] : memref>, vector<1xf32> + %514 = vector.extract_strided_slice %274 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %515 = affine.apply #map49()[%thread_id_x] + %516 = arith.muli %515, %c16384 overflow : index + %517 = arith.addi %516, %84 overflow : index + vector.store %514, %297[%517] : memref>, vector<1xf32> + %518 = vector.extract_strided_slice %274 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %519 = affine.apply #map50()[%thread_id_x] + %520 = arith.muli %519, %c16384 overflow : index + %521 = arith.addi %520, %84 overflow : index + vector.store %518, %297[%521] : memref>, vector<1xf32> + %522 = vector.extract_strided_slice %274 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %523 = affine.apply #map51()[%thread_id_x] + %524 = arith.muli %523, %c16384 overflow : index + %525 = arith.addi %524, %84 overflow : index + vector.store %522, %297[%525] : memref>, vector<1xf32> + %526 = vector.extract_strided_slice %276 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %527 = arith.addi %512, %90 overflow : index + vector.store %526, %297[%527] : memref>, vector<1xf32> + %528 = vector.extract_strided_slice %276 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %529 = arith.addi %516, %90 overflow : index + vector.store %528, %297[%529] : memref>, vector<1xf32> + %530 = vector.extract_strided_slice %276 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %531 = arith.addi %520, %90 overflow : index + vector.store %530, %297[%531] : memref>, vector<1xf32> + %532 = vector.extract_strided_slice %276 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %533 = arith.addi %524, %90 overflow : index + vector.store %532, %297[%533] : memref>, vector<1xf32> + %534 = vector.extract_strided_slice %278 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %535 = arith.addi %512, %93 overflow : index + vector.store %534, %297[%535] : memref>, vector<1xf32> + %536 = vector.extract_strided_slice %278 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %537 = arith.addi %516, %93 overflow : index + vector.store %536, %297[%537] : memref>, vector<1xf32> + %538 = vector.extract_strided_slice %278 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %539 = arith.addi %520, %93 overflow : index + vector.store %538, %297[%539] : memref>, vector<1xf32> + %540 = vector.extract_strided_slice %278 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %541 = arith.addi %524, %93 overflow : index + vector.store %540, %297[%541] : memref>, vector<1xf32> + %542 = vector.extract_strided_slice %280 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %543 = arith.addi %512, %96 overflow : index + vector.store %542, %297[%543] : memref>, vector<1xf32> + %544 = vector.extract_strided_slice %280 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %545 = arith.addi %516, %96 overflow : index + vector.store %544, %297[%545] : memref>, vector<1xf32> + %546 = vector.extract_strided_slice %280 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %547 = arith.addi %520, %96 overflow : index + vector.store %546, %297[%547] : memref>, vector<1xf32> + %548 = vector.extract_strided_slice %280 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %549 = arith.addi %524, %96 overflow : index + vector.store %548, %297[%549] : memref>, vector<1xf32> + %550 = vector.extract_strided_slice %282 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %551 = arith.addi %512, %99 overflow : index + vector.store %550, %297[%551] : memref>, vector<1xf32> + %552 = vector.extract_strided_slice %282 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %553 = arith.addi %516, %99 overflow : index + vector.store %552, %297[%553] : memref>, vector<1xf32> + %554 = vector.extract_strided_slice %282 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %555 = arith.addi %520, %99 overflow : index + vector.store %554, %297[%555] : memref>, vector<1xf32> + %556 = vector.extract_strided_slice %282 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %557 = arith.addi %524, %99 overflow : index + vector.store %556, %297[%557] : memref>, vector<1xf32> + %558 = vector.extract_strided_slice %284 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %559 = arith.addi %512, %102 overflow : index + vector.store %558, %297[%559] : memref>, vector<1xf32> + %560 = vector.extract_strided_slice %284 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %561 = arith.addi %516, %102 overflow : index + vector.store %560, %297[%561] : memref>, vector<1xf32> + %562 = vector.extract_strided_slice %284 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %563 = arith.addi %520, %102 overflow : index + vector.store %562, %297[%563] : memref>, vector<1xf32> + %564 = vector.extract_strided_slice %284 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %565 = arith.addi %524, %102 overflow : index + vector.store %564, %297[%565] : memref>, vector<1xf32> + %566 = vector.extract_strided_slice %286 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %567 = arith.addi %512, %105 overflow : index + vector.store %566, %297[%567] : memref>, vector<1xf32> + %568 = vector.extract_strided_slice %286 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %569 = arith.addi %516, %105 overflow : index + vector.store %568, %297[%569] : memref>, vector<1xf32> + %570 = vector.extract_strided_slice %286 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %571 = arith.addi %520, %105 overflow : index + vector.store %570, %297[%571] : memref>, vector<1xf32> + %572 = vector.extract_strided_slice %286 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %573 = arith.addi %524, %105 overflow : index + vector.store %572, %297[%573] : memref>, vector<1xf32> + %574 = vector.extract_strided_slice %288 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %575 = arith.addi %512, %108 overflow : index + vector.store %574, %297[%575] : memref>, vector<1xf32> + %576 = vector.extract_strided_slice %288 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %577 = arith.addi %516, %108 overflow : index + vector.store %576, %297[%577] : memref>, vector<1xf32> + %578 = vector.extract_strided_slice %288 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %579 = arith.addi %520, %108 overflow : index + vector.store %578, %297[%579] : memref>, vector<1xf32> + %580 = vector.extract_strided_slice %288 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %581 = arith.addi %524, %108 overflow : index + vector.store %580, %297[%581] : memref>, vector<1xf32> + return + } + } + } + func.func @isolated_benchmark$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view, %arg5: !hal.fence, %arg6: !hal.fence) -> !hal.buffer_view { + %0 = hal.tensor.import wait(%arg5) => %arg0 : !hal.buffer_view -> tensor<16384x8192xi8> + %1 = hal.tensor.import wait(%arg5) => %arg1 : !hal.buffer_view -> tensor<16384x512xi8> + %2 = hal.tensor.import wait(%arg5) => %arg2 : !hal.buffer_view -> tensor<16384x8192xi8> + %3 = hal.tensor.import wait(%arg5) => %arg3 : !hal.buffer_view -> tensor<16384x512xi8> + %4 = hal.tensor.import wait(%arg5) => %arg4 : !hal.buffer_view -> tensor<16384x16384xf32> + %5 = flow.dispatch @gemm::@gemm(%0, %1, %2, %3, %4) : (tensor<16384x8192xi8>, tensor<16384x512xi8>, tensor<16384x8192xi8>, tensor<16384x512xi8>, tensor<16384x16384xf32>) -> %4 + %6 = hal.tensor.barrier join(%5 : tensor<16384x16384xf32>) => %arg6 : !hal.fence + %7 = hal.tensor.export %6 : tensor<16384x16384xf32> -> !hal.buffer_view + return %7 : !hal.buffer_view + } + } + """ + mlir2 = """ + #map = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8) floordiv 256) * 256)> + #map1 = affine_map<()[s0] -> ((s0 floordiv 8) mod 8)> + #map2 = affine_map<()[s0] -> (s0 mod 8)> + #map3 = affine_map<()[s0] -> (s0 * 16)> + #map4 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64) floordiv 32) * 256)> + #map5 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 64) floordiv 256) * 256 + 64)> + #map6 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 8) floordiv 32) * 256 + 64)> + #map7 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 128) floordiv 256) * 256 + 128)> + #map8 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 16) floordiv 32) * 256 + 128)> + #map9 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 192) floordiv 256) * 256 + 192)> + #map10 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 24) floordiv 32) * 256 + 192)> + #map11 = affine_map<()[s0, s1, s2] -> (s1 * 128 + s2 * 256 + s0 floordiv 2 - ((s1 * 128 + s0 floordiv 2) floordiv 256) * 256)> + #map12 = affine_map<()[s0] -> ((s0 floordiv 2) mod 2)> + #map13 = affine_map<()[s0] -> (s0 mod 2)> + #map14 = affine_map<()[s0] -> (s0 * 4)> + #map15 = affine_map<()[s0, s1] -> (s1 * 128 + (s0 floordiv 64) * 32 - ((s1 * 4 + s0 floordiv 64) floordiv 8) * 256)> + #map16 = affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 64)> + #map17 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64)> + #map18 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> + #map19 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 16)> + #map20 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 32)> + #map21 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 48)> + #map22 = affine_map<()[s0] -> (s0 * 4 + (s0 mod 64) floordiv 16 - (s0 floordiv 2) * 8)> + #map23 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16)> + #map24 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 16)> + #map25 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 32)> + #map26 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 48)> + #map27 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 64)> + #map28 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 80)> + #map29 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 96)> + #map30 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 112)> + #map31 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16 + 4)> + #map32 = affine_map<()[s0, s1] -> (s1 * 4 + (s0 mod 64) floordiv 16)> + #map33 = affine_map<()[s0, s1] -> (s0 * 128 + s1 * 16 + 128)> + #map34 = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 4 + 8)> + #map35 = affine_map<()[s0] -> (s0 * 256)> + #map36 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4)> + #map37 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 1)> + #map38 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 2)> + #map39 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 3)> + #map40 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 16)> + #map41 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 17)> + #map42 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 18)> + #map43 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 19)> + #map44 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 32)> + #map45 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 33)> + #map46 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 34)> + #map47 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 35)> + #map48 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 48)> + #map49 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 49)> + #map50 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 50)> + #map51 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 51)> + #translation = #iree_codegen.translation_info + module attributes {transform.with_named_sequence} { + stream.executable private @gemm { + stream.executable.export public @gemm workgroups() -> (index, index, index) { + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + stream.return %c64, %c64, %c1 : index, index, index + } + builtin.module { + func.func @gemm(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding, %arg3: !stream.binding, %arg4: !stream.binding) attributes {translation_info = #translation} { + %c4_i32 = arith.constant 4 : i32 + %c512_i14 = arith.constant 512 : i14 + %c-8192_i14 = arith.constant -8192 : i14 + %c2147483643_i64 = arith.constant 2147483643 : i64 + %c16384 = arith.constant 16384 : index + %c63 = arith.constant 63 : index + %c512 = arith.constant 512 : index + %c2147483646_i64 = arith.constant 2147483646 : i64 + %c8192 = arith.constant 8192 : index + %c1 = arith.constant 1 : index + %cst = arith.constant dense<0.000000e+00> : vector<4xf32> + %c0 = arith.constant 0 : index + %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> memref + %1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> memref + %2 = stream.binding.subspan %arg2[%c0] : !stream.binding -> memref + %3 = stream.binding.subspan %arg3[%c0] : !stream.binding -> memref + %4 = stream.binding.subspan %arg4[%c0] : !stream.binding -> memref + %block_id_x = gpu.block_id x upper_bound 64 + %block_id_y = gpu.block_id y upper_bound 64 + %thread_id_x = gpu.thread_id x upper_bound 256 + %thread_id_y = gpu.thread_id y upper_bound 2 + %alloc = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_0 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_1 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_2 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_3 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_4 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_5 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_6 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %5 = affine.apply #map()[%thread_id_x, %thread_id_y, %block_id_x] + %6 = affine.apply #map1()[%thread_id_x] + %7 = affine.apply #map2()[%thread_id_x] + %8 = arith.xori %7, %6 : index + %9 = affine.apply #map3()[%8] + %10 = affine.apply #map4()[%thread_id_x, %thread_id_y] + %11 = gpu.subgroup_broadcast %10, first_active_lane : index + %12 = gpu.subgroup_broadcast %c0, first_active_lane : index + %13 = arith.muli %5, %c8192 overflow : index + %14 = arith.addi %13, %9 overflow : index + %reinterpret_cast = memref.reinterpret_cast %0 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast = memref.cast %reinterpret_cast : memref<2147483646xi8, strided<[1]>> to memref> + %15 = amdgpu.fat_raw_buffer_cast %cast validBytes(%c2147483646_i64) cacheSwizzleStride(%c-8192_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %15[%14], %alloc_6[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %16 = affine.apply #map5()[%thread_id_x, %thread_id_y, %block_id_x] + %17 = affine.apply #map6()[%thread_id_x, %thread_id_y] + %18 = gpu.subgroup_broadcast %17, first_active_lane : index + %19 = arith.muli %16, %c8192 overflow : index + %20 = arith.addi %19, %9 overflow : index + amdgpu.gather_to_lds %15[%20], %alloc_6[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %21 = affine.apply #map7()[%thread_id_x, %thread_id_y, %block_id_x] + %22 = affine.apply #map8()[%thread_id_x, %thread_id_y] + %23 = gpu.subgroup_broadcast %22, first_active_lane : index + %24 = arith.muli %21, %c8192 overflow : index + %25 = arith.addi %24, %9 overflow : index + amdgpu.gather_to_lds %15[%25], %alloc_6[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %26 = affine.apply #map9()[%thread_id_x, %thread_id_y, %block_id_x] + %27 = affine.apply #map10()[%thread_id_x, %thread_id_y] + %28 = gpu.subgroup_broadcast %27, first_active_lane : index + %29 = arith.muli %26, %c8192 overflow : index + %30 = arith.addi %29, %9 overflow : index + amdgpu.gather_to_lds %15[%30], %alloc_6[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %31 = affine.apply #map11()[%thread_id_x, %thread_id_y, %block_id_x] + %32 = affine.apply #map12()[%thread_id_x] + %33 = affine.apply #map13()[%thread_id_x] + %34 = arith.xori %33, %32 : index + %35 = affine.apply #map14()[%34] + %36 = affine.apply #map15()[%thread_id_x, %thread_id_y] + %37 = gpu.subgroup_broadcast %36, first_active_lane : index + %38 = arith.muli %31, %c512 overflow : index + %39 = arith.addi %38, %35 overflow : index + %reinterpret_cast_7 = memref.reinterpret_cast %1 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_8 = memref.cast %reinterpret_cast_7 : memref<2147483646xi8, strided<[1]>> to memref> + %40 = amdgpu.fat_raw_buffer_cast %cast_8 validBytes(%c2147483646_i64) cacheSwizzleStride(%c512_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %40[%39], %alloc_4[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + %41 = affine.apply #map()[%thread_id_x, %thread_id_y, %block_id_y] + %42 = arith.muli %41, %c8192 overflow : index + %43 = arith.addi %42, %9 overflow : index + %reinterpret_cast_9 = memref.reinterpret_cast %2 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_10 = memref.cast %reinterpret_cast_9 : memref<2147483646xi8, strided<[1]>> to memref> + %44 = amdgpu.fat_raw_buffer_cast %cast_10 validBytes(%c2147483646_i64) cacheSwizzleStride(%c-8192_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %44[%43], %alloc_2[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %45 = affine.apply #map5()[%thread_id_x, %thread_id_y, %block_id_y] + %46 = arith.muli %45, %c8192 overflow : index + %47 = arith.addi %46, %9 overflow : index + amdgpu.gather_to_lds %44[%47], %alloc_2[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %48 = affine.apply #map7()[%thread_id_x, %thread_id_y, %block_id_y] + %49 = arith.muli %48, %c8192 overflow : index + %50 = arith.addi %49, %9 overflow : index + amdgpu.gather_to_lds %44[%50], %alloc_2[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %51 = affine.apply #map9()[%thread_id_x, %thread_id_y, %block_id_y] + %52 = arith.muli %51, %c8192 overflow : index + %53 = arith.addi %52, %9 overflow : index + amdgpu.gather_to_lds %44[%53], %alloc_2[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %54 = affine.apply #map11()[%thread_id_x, %thread_id_y, %block_id_y] + %55 = arith.muli %54, %c512 overflow : index + %56 = arith.addi %55, %35 overflow : index + %reinterpret_cast_11 = memref.reinterpret_cast %3 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_12 = memref.cast %reinterpret_cast_11 : memref<2147483646xi8, strided<[1]>> to memref> + %57 = amdgpu.fat_raw_buffer_cast %cast_12 validBytes(%c2147483646_i64) cacheSwizzleStride(%c512_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %57[%56], %alloc_0[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + rocdl.s.barrier + %58 = affine.apply #map16()[%thread_id_x, %thread_id_y] + %59 = arith.index_cast %58 : index to i32 + %60 = arith.cmpi sge, %59, %c4_i32 : i32 + %61 = arith.cmpi slt, %59, %c4_i32 : i32 + scf.if %60 { + rocdl.s.barrier + } + %62 = affine.apply #map17()[%thread_id_x] + %63 = affine.apply #map18()[%thread_id_x] + %64 = arith.xori %63, %7 : index + %65 = affine.apply #map3()[%64] + %66 = affine.apply #map19()[%thread_id_x] + %67 = affine.apply #map20()[%thread_id_x] + %68 = affine.apply #map21()[%thread_id_x] + %69 = affine.apply #map22()[%thread_id_x] + %70 = affine.apply #map23()[%thread_id_x, %thread_id_y] + %71 = affine.apply #map24()[%thread_id_x, %thread_id_y] + %72 = affine.apply #map25()[%thread_id_x, %thread_id_y] + %73 = affine.apply #map26()[%thread_id_x, %thread_id_y] + %74 = affine.apply #map27()[%thread_id_x, %thread_id_y] + %75 = affine.apply #map28()[%thread_id_x, %thread_id_y] + %76 = affine.apply #map29()[%thread_id_x, %thread_id_y] + %77 = affine.apply #map30()[%thread_id_x, %thread_id_y] + %78 = affine.apply #map31()[%thread_id_x] + %79 = arith.xori %78, %7 : index + %80 = affine.apply #map3()[%79] + %81 = arith.xori %33, %c1 : index + %82 = affine.apply #map32()[%thread_id_x, %81] + %83:40 = scf.for %arg5 = %c0 to %c63 step %c1 iter_args(%arg6 = %cst, %arg7 = %cst, %arg8 = %cst, %arg9 = %cst, %arg10 = %cst, %arg11 = %cst, %arg12 = %cst, %arg13 = %cst, %arg14 = %cst, %arg15 = %cst, %arg16 = %cst, %arg17 = %cst, %arg18 = %cst, %arg19 = %cst, %arg20 = %cst, %arg21 = %cst, %arg22 = %cst, %arg23 = %cst, %arg24 = %cst, %arg25 = %cst, %arg26 = %cst, %arg27 = %cst, %arg28 = %cst, %arg29 = %cst, %arg30 = %cst, %arg31 = %cst, %arg32 = %cst, %arg33 = %cst, %arg34 = %cst, %arg35 = %cst, %arg36 = %cst, %arg37 = %cst, %arg38 = %alloc_6, %arg39 = %alloc_5, %arg40 = %alloc_4, %arg41 = %alloc_3, %arg42 = %alloc_2, %arg43 = %alloc_1, %arg44 = %alloc_0, %arg45 = %alloc) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>) { + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.s.barrier + //amdgpu.lds_barrier + %582 = affine.apply #map33()[%arg5, %8] + %583 = arith.addi %13, %582 overflow : index + amdgpu.gather_to_lds %15[%583], %arg39[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %584 = arith.addi %19, %582 overflow : index + amdgpu.gather_to_lds %15[%584], %arg39[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %585 = arith.addi %24, %582 overflow : index + amdgpu.gather_to_lds %15[%585], %arg39[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %586 = arith.addi %29, %582 overflow : index + amdgpu.gather_to_lds %15[%586], %arg39[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %587 = affine.apply #map34()[%arg5, %34] + %588 = arith.addi %38, %587 overflow : index + amdgpu.gather_to_lds %40[%588], %arg41[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + %589 = arith.addi %42, %582 overflow : index + amdgpu.gather_to_lds %44[%589], %arg43[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %590 = arith.addi %46, %582 overflow : index + amdgpu.gather_to_lds %44[%590], %arg43[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %591 = arith.addi %49, %582 overflow : index + amdgpu.gather_to_lds %44[%591], %arg43[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %592 = arith.addi %52, %582 overflow : index + amdgpu.gather_to_lds %44[%592], %arg43[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %593 = arith.addi %55, %587 overflow : index + amdgpu.gather_to_lds %57[%593], %arg45[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + rocdl.sched.barrier 0 + amdgpu.memory_counter_wait load(10) + //rocdl.s.waitcnt 16368 + //amdgpu.lds_barrier + %594 = vector.load %arg38[%62, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %595 = vector.load %arg38[%66, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %596 = vector.load %arg38[%67, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %597 = vector.load %arg38[%68, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %598 = vector.load %arg40[%62, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %599 = vector.load %arg40[%66, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %600 = vector.load %arg40[%67, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %601 = vector.load %arg40[%68, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %602 = vector.load %arg42[%70, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %603 = vector.load %arg42[%71, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %604 = vector.load %arg42[%72, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %605 = vector.load %arg42[%73, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %606 = vector.load %arg42[%74, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %607 = vector.load %arg42[%75, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %608 = vector.load %arg42[%76, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %609 = vector.load %arg42[%77, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %610 = vector.load %arg44[%70, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %611 = vector.load %arg44[%71, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %612 = vector.load %arg44[%72, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %613 = vector.load %arg44[%73, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %614 = vector.load %arg44[%74, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %615 = vector.load %arg44[%75, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %616 = vector.load %arg44[%76, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %617 = vector.load %arg44[%77, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %618 = vector.bitcast %594 : vector<16xi8> to vector<32xf4E2M1FN> + %619 = vector.bitcast %595 : vector<16xi8> to vector<32xf4E2M1FN> + %620 = vector.bitcast %596 : vector<16xi8> to vector<32xf4E2M1FN> + %621 = vector.bitcast %597 : vector<16xi8> to vector<32xf4E2M1FN> + %622 = vector.bitcast %598 : vector<1xi8> to vector<1xf8E8M0FNU> + %623 = vector.bitcast %599 : vector<1xi8> to vector<1xf8E8M0FNU> + %624 = vector.bitcast %600 : vector<1xi8> to vector<1xf8E8M0FNU> + %625 = vector.bitcast %601 : vector<1xi8> to vector<1xf8E8M0FNU> + %626 = vector.bitcast %602 : vector<16xi8> to vector<32xf4E2M1FN> + %627 = vector.bitcast %603 : vector<16xi8> to vector<32xf4E2M1FN> + %628 = vector.bitcast %604 : vector<16xi8> to vector<32xf4E2M1FN> + %629 = vector.bitcast %605 : vector<16xi8> to vector<32xf4E2M1FN> + %630 = vector.bitcast %606 : vector<16xi8> to vector<32xf4E2M1FN> + %631 = vector.bitcast %607 : vector<16xi8> to vector<32xf4E2M1FN> + %632 = vector.bitcast %608 : vector<16xi8> to vector<32xf4E2M1FN> + %633 = vector.bitcast %609 : vector<16xi8> to vector<32xf4E2M1FN> + %634 = vector.bitcast %610 : vector<1xi8> to vector<1xf8E8M0FNU> + %635 = vector.bitcast %611 : vector<1xi8> to vector<1xf8E8M0FNU> + %636 = vector.bitcast %612 : vector<1xi8> to vector<1xf8E8M0FNU> + %637 = vector.bitcast %613 : vector<1xi8> to vector<1xf8E8M0FNU> + %638 = vector.bitcast %614 : vector<1xi8> to vector<1xf8E8M0FNU> + %639 = vector.bitcast %615 : vector<1xi8> to vector<1xf8E8M0FNU> + %640 = vector.bitcast %616 : vector<1xi8> to vector<1xf8E8M0FNU> + %641 = vector.bitcast %617 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.s.setprio 1 + %642 = vector.extract %622[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %643 = vector.extract %634[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %644 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%643[0] * %626) + %arg6 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %645 = vector.extract %635[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %646 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%645[0] * %627) + %arg7 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %647 = vector.extract %636[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %648 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%647[0] * %628) + %arg8 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %649 = vector.extract %637[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %650 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%649[0] * %629) + %arg9 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %651 = vector.extract %638[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %652 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%651[0] * %630) + %arg10 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %653 = vector.extract %639[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %654 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%653[0] * %631) + %arg11 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %655 = vector.extract %640[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %656 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%655[0] * %632) + %arg12 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %657 = vector.extract %641[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %658 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%657[0] * %633) + %arg13 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %659 = vector.extract %623[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %660 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%643[0] * %626) + %arg14 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %661 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%645[0] * %627) + %arg15 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %662 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%647[0] * %628) + %arg16 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %663 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%649[0] * %629) + %arg17 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %664 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%651[0] * %630) + %arg18 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %665 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%653[0] * %631) + %arg19 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %666 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%655[0] * %632) + %arg20 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %667 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%657[0] * %633) + %arg21 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %668 = vector.extract %624[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %669 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%643[0] * %626) + %arg22 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %670 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%645[0] * %627) + %arg23 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %671 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%647[0] * %628) + %arg24 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %672 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%649[0] * %629) + %arg25 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %673 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%651[0] * %630) + %arg26 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %674 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%653[0] * %631) + %arg27 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %675 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%655[0] * %632) + %arg28 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %676 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%657[0] * %633) + %arg29 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %677 = vector.extract %625[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %678 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%643[0] * %626) + %arg30 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %679 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%645[0] * %627) + %arg31 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %680 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%647[0] * %628) + %arg32 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %681 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%649[0] * %629) + %arg33 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %682 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%651[0] * %630) + %arg34 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %683 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%653[0] * %631) + %arg35 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %684 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%655[0] * %632) + %arg36 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %685 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%657[0] * %633) + %arg37 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + %686 = vector.load %arg38[%62, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %687 = vector.load %arg38[%66, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %688 = vector.load %arg38[%67, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %689 = vector.load %arg38[%68, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %690 = vector.load %arg40[%62, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %691 = vector.load %arg40[%66, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %692 = vector.load %arg40[%67, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %693 = vector.load %arg40[%68, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %694 = vector.load %arg42[%70, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %695 = vector.load %arg42[%71, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %696 = vector.load %arg42[%72, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %697 = vector.load %arg42[%73, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %698 = vector.load %arg42[%74, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %699 = vector.load %arg42[%75, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %700 = vector.load %arg42[%76, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %701 = vector.load %arg42[%77, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %702 = vector.load %arg44[%70, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %703 = vector.load %arg44[%71, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %704 = vector.load %arg44[%72, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %705 = vector.load %arg44[%73, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %706 = vector.load %arg44[%74, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %707 = vector.load %arg44[%75, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %708 = vector.load %arg44[%76, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %709 = vector.load %arg44[%77, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %710 = vector.bitcast %686 : vector<16xi8> to vector<32xf4E2M1FN> + %711 = vector.bitcast %687 : vector<16xi8> to vector<32xf4E2M1FN> + %712 = vector.bitcast %688 : vector<16xi8> to vector<32xf4E2M1FN> + %713 = vector.bitcast %689 : vector<16xi8> to vector<32xf4E2M1FN> + %714 = vector.bitcast %690 : vector<1xi8> to vector<1xf8E8M0FNU> + %715 = vector.bitcast %691 : vector<1xi8> to vector<1xf8E8M0FNU> + %716 = vector.bitcast %692 : vector<1xi8> to vector<1xf8E8M0FNU> + %717 = vector.bitcast %693 : vector<1xi8> to vector<1xf8E8M0FNU> + %718 = vector.bitcast %694 : vector<16xi8> to vector<32xf4E2M1FN> + %719 = vector.bitcast %695 : vector<16xi8> to vector<32xf4E2M1FN> + %720 = vector.bitcast %696 : vector<16xi8> to vector<32xf4E2M1FN> + %721 = vector.bitcast %697 : vector<16xi8> to vector<32xf4E2M1FN> + %722 = vector.bitcast %698 : vector<16xi8> to vector<32xf4E2M1FN> + %723 = vector.bitcast %699 : vector<16xi8> to vector<32xf4E2M1FN> + %724 = vector.bitcast %700 : vector<16xi8> to vector<32xf4E2M1FN> + %725 = vector.bitcast %701 : vector<16xi8> to vector<32xf4E2M1FN> + %726 = vector.bitcast %702 : vector<1xi8> to vector<1xf8E8M0FNU> + %727 = vector.bitcast %703 : vector<1xi8> to vector<1xf8E8M0FNU> + %728 = vector.bitcast %704 : vector<1xi8> to vector<1xf8E8M0FNU> + %729 = vector.bitcast %705 : vector<1xi8> to vector<1xf8E8M0FNU> + %730 = vector.bitcast %706 : vector<1xi8> to vector<1xf8E8M0FNU> + %731 = vector.bitcast %707 : vector<1xi8> to vector<1xf8E8M0FNU> + %732 = vector.bitcast %708 : vector<1xi8> to vector<1xf8E8M0FNU> + %733 = vector.bitcast %709 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.s.setprio 1 + %734 = vector.extract %714[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %735 = vector.extract %726[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %736 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%735[0] * %718) + %644 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %737 = vector.extract %727[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %738 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%737[0] * %719) + %646 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %739 = vector.extract %728[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %740 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%739[0] * %720) + %648 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %741 = vector.extract %729[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %742 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%741[0] * %721) + %650 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %743 = vector.extract %730[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %744 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%743[0] * %722) + %652 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %745 = vector.extract %731[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %746 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%745[0] * %723) + %654 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %747 = vector.extract %732[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %748 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%747[0] * %724) + %656 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %749 = vector.extract %733[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %750 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%749[0] * %725) + %658 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %751 = vector.extract %715[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %752 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%735[0] * %718) + %660 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %753 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%737[0] * %719) + %661 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %754 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%739[0] * %720) + %662 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %755 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%741[0] * %721) + %663 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %756 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%743[0] * %722) + %664 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %757 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%745[0] * %723) + %665 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %758 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%747[0] * %724) + %666 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %759 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%749[0] * %725) + %667 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %760 = vector.extract %716[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %761 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%735[0] * %718) + %669 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %762 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%737[0] * %719) + %670 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %763 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%739[0] * %720) + %671 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %764 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%741[0] * %721) + %672 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %765 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%743[0] * %722) + %673 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %766 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%745[0] * %723) + %674 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %767 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%747[0] * %724) + %675 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %768 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%749[0] * %725) + %676 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %769 = vector.extract %717[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %770 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%735[0] * %718) + %678 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %771 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%737[0] * %719) + %679 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %772 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%739[0] * %720) + %680 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %773 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%741[0] * %721) + %681 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %774 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%743[0] * %722) + %682 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %775 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%745[0] * %723) + %683 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %776 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%747[0] * %724) + %684 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %777 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%749[0] * %725) + %685 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + rocdl.sched.barrier 0 + scf.yield %736, %738, %740, %742, %744, %746, %748, %750, %752, %753, %754, %755, %756, %757, %758, %759, %761, %762, %763, %764, %765, %766, %767, %768, %770, %771, %772, %773, %774, %775, %776, %777, %arg39, %arg38, %arg41, %arg40, %arg43, %arg42, %arg45, %arg44 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space> + } + scf.if %61 { + rocdl.s.barrier + } + amdgpu.lds_barrier + %84 = affine.apply #map23()[%thread_id_x, %thread_id_y] + %85 = affine.apply #map22()[%thread_id_x] + %86 = vector.load %83#38[%84, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %87 = arith.xori %33, %c1 : index + %88 = affine.apply #map32()[%thread_id_x, %87] + %89 = vector.load %83#38[%84, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %90 = affine.apply #map24()[%thread_id_x, %thread_id_y] + %91 = vector.load %83#38[%90, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %92 = vector.load %83#38[%90, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %93 = affine.apply #map25()[%thread_id_x, %thread_id_y] + %94 = vector.load %83#38[%93, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %95 = vector.load %83#38[%93, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %96 = affine.apply #map26()[%thread_id_x, %thread_id_y] + %97 = vector.load %83#38[%96, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %98 = vector.load %83#38[%96, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %99 = affine.apply #map27()[%thread_id_x, %thread_id_y] + %100 = vector.load %83#38[%99, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %101 = vector.load %83#38[%99, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %102 = affine.apply #map28()[%thread_id_x, %thread_id_y] + %103 = vector.load %83#38[%102, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %104 = vector.load %83#38[%102, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %105 = affine.apply #map29()[%thread_id_x, %thread_id_y] + %106 = vector.load %83#38[%105, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %107 = vector.load %83#38[%105, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %108 = affine.apply #map30()[%thread_id_x, %thread_id_y] + %109 = vector.load %83#38[%108, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %110 = vector.load %83#38[%108, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %111 = affine.apply #map18()[%thread_id_x] + %112 = arith.xori %111, %7 : index + %113 = affine.apply #map3()[%112] + %114 = vector.load %83#36[%84, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %115 = affine.apply #map31()[%thread_id_x] + %116 = arith.xori %115, %7 : index + %117 = affine.apply #map3()[%116] + %118 = vector.load %83#36[%84, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %119 = vector.load %83#36[%90, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %120 = vector.load %83#36[%90, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %121 = vector.load %83#36[%93, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %122 = vector.load %83#36[%93, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %123 = vector.load %83#36[%96, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %124 = vector.load %83#36[%96, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %125 = vector.load %83#36[%99, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %126 = vector.load %83#36[%99, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %127 = vector.load %83#36[%102, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %128 = vector.load %83#36[%102, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %129 = vector.load %83#36[%105, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %130 = vector.load %83#36[%105, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %131 = vector.load %83#36[%108, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %132 = vector.load %83#36[%108, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %133 = affine.apply #map17()[%thread_id_x] + %134 = vector.load %83#34[%133, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %135 = vector.load %83#34[%133, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %136 = affine.apply #map19()[%thread_id_x] + %137 = vector.load %83#34[%136, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %138 = vector.load %83#34[%136, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %139 = affine.apply #map20()[%thread_id_x] + %140 = vector.load %83#34[%139, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %141 = vector.load %83#34[%139, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %142 = affine.apply #map21()[%thread_id_x] + %143 = vector.load %83#34[%142, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %144 = vector.load %83#34[%142, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %145 = vector.load %83#32[%133, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %146 = vector.load %83#32[%133, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %147 = vector.load %83#32[%136, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %148 = vector.load %83#32[%136, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %149 = vector.load %83#32[%139, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %150 = vector.load %83#32[%139, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %151 = vector.load %83#32[%142, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %152 = vector.load %83#32[%142, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %153 = vector.bitcast %145 : vector<16xi8> to vector<32xf4E2M1FN> + %154 = vector.bitcast %146 : vector<16xi8> to vector<32xf4E2M1FN> + %155 = vector.bitcast %147 : vector<16xi8> to vector<32xf4E2M1FN> + %156 = vector.bitcast %148 : vector<16xi8> to vector<32xf4E2M1FN> + %157 = vector.bitcast %149 : vector<16xi8> to vector<32xf4E2M1FN> + %158 = vector.bitcast %150 : vector<16xi8> to vector<32xf4E2M1FN> + %159 = vector.bitcast %151 : vector<16xi8> to vector<32xf4E2M1FN> + %160 = vector.bitcast %152 : vector<16xi8> to vector<32xf4E2M1FN> + %161 = vector.bitcast %134 : vector<1xi8> to vector<1xf8E8M0FNU> + %162 = vector.bitcast %135 : vector<1xi8> to vector<1xf8E8M0FNU> + %163 = vector.bitcast %137 : vector<1xi8> to vector<1xf8E8M0FNU> + %164 = vector.bitcast %138 : vector<1xi8> to vector<1xf8E8M0FNU> + %165 = vector.bitcast %140 : vector<1xi8> to vector<1xf8E8M0FNU> + %166 = vector.bitcast %141 : vector<1xi8> to vector<1xf8E8M0FNU> + %167 = vector.bitcast %143 : vector<1xi8> to vector<1xf8E8M0FNU> + %168 = vector.bitcast %144 : vector<1xi8> to vector<1xf8E8M0FNU> + %169 = vector.bitcast %114 : vector<16xi8> to vector<32xf4E2M1FN> + %170 = vector.bitcast %118 : vector<16xi8> to vector<32xf4E2M1FN> + %171 = vector.bitcast %119 : vector<16xi8> to vector<32xf4E2M1FN> + %172 = vector.bitcast %120 : vector<16xi8> to vector<32xf4E2M1FN> + %173 = vector.bitcast %121 : vector<16xi8> to vector<32xf4E2M1FN> + %174 = vector.bitcast %122 : vector<16xi8> to vector<32xf4E2M1FN> + %175 = vector.bitcast %123 : vector<16xi8> to vector<32xf4E2M1FN> + %176 = vector.bitcast %124 : vector<16xi8> to vector<32xf4E2M1FN> + %177 = vector.bitcast %125 : vector<16xi8> to vector<32xf4E2M1FN> + %178 = vector.bitcast %126 : vector<16xi8> to vector<32xf4E2M1FN> + %179 = vector.bitcast %127 : vector<16xi8> to vector<32xf4E2M1FN> + %180 = vector.bitcast %128 : vector<16xi8> to vector<32xf4E2M1FN> + %181 = vector.bitcast %129 : vector<16xi8> to vector<32xf4E2M1FN> + %182 = vector.bitcast %130 : vector<16xi8> to vector<32xf4E2M1FN> + %183 = vector.bitcast %131 : vector<16xi8> to vector<32xf4E2M1FN> + %184 = vector.bitcast %132 : vector<16xi8> to vector<32xf4E2M1FN> + %185 = vector.bitcast %86 : vector<1xi8> to vector<1xf8E8M0FNU> + %186 = vector.bitcast %89 : vector<1xi8> to vector<1xf8E8M0FNU> + %187 = vector.bitcast %91 : vector<1xi8> to vector<1xf8E8M0FNU> + %188 = vector.bitcast %92 : vector<1xi8> to vector<1xf8E8M0FNU> + %189 = vector.bitcast %94 : vector<1xi8> to vector<1xf8E8M0FNU> + %190 = vector.bitcast %95 : vector<1xi8> to vector<1xf8E8M0FNU> + %191 = vector.bitcast %97 : vector<1xi8> to vector<1xf8E8M0FNU> + %192 = vector.bitcast %98 : vector<1xi8> to vector<1xf8E8M0FNU> + %193 = vector.bitcast %100 : vector<1xi8> to vector<1xf8E8M0FNU> + %194 = vector.bitcast %101 : vector<1xi8> to vector<1xf8E8M0FNU> + %195 = vector.bitcast %103 : vector<1xi8> to vector<1xf8E8M0FNU> + %196 = vector.bitcast %104 : vector<1xi8> to vector<1xf8E8M0FNU> + %197 = vector.bitcast %106 : vector<1xi8> to vector<1xf8E8M0FNU> + %198 = vector.bitcast %107 : vector<1xi8> to vector<1xf8E8M0FNU> + %199 = vector.bitcast %109 : vector<1xi8> to vector<1xf8E8M0FNU> + %200 = vector.bitcast %110 : vector<1xi8> to vector<1xf8E8M0FNU> + %201 = vector.extract %161[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %202 = vector.extract %185[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %203 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%202[0] * %169) + %83#0 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %204 = vector.extract %162[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %205 = vector.extract %186[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %206 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%205[0] * %170) + %203 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %207 = vector.extract %187[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %208 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%207[0] * %171) + %83#1 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %209 = vector.extract %188[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %210 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%209[0] * %172) + %208 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %211 = vector.extract %189[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %212 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%211[0] * %173) + %83#2 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %213 = vector.extract %190[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %214 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%213[0] * %174) + %212 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %215 = vector.extract %191[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %216 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%215[0] * %175) + %83#3 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %217 = vector.extract %192[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %218 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%217[0] * %176) + %216 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %219 = vector.extract %193[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %220 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%219[0] * %177) + %83#4 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %221 = vector.extract %194[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %222 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%221[0] * %178) + %220 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %223 = vector.extract %195[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %224 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%223[0] * %179) + %83#5 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %225 = vector.extract %196[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %226 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%225[0] * %180) + %224 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %227 = vector.extract %197[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %228 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%227[0] * %181) + %83#6 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %229 = vector.extract %198[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %230 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%229[0] * %182) + %228 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %231 = vector.extract %199[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %232 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%231[0] * %183) + %83#7 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %233 = vector.extract %200[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %234 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%233[0] * %184) + %232 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %235 = vector.extract %163[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %236 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%202[0] * %169) + %83#8 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %237 = vector.extract %164[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %238 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%205[0] * %170) + %236 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %239 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%207[0] * %171) + %83#9 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %240 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%209[0] * %172) + %239 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %241 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%211[0] * %173) + %83#10 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %242 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%213[0] * %174) + %241 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %243 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%215[0] * %175) + %83#11 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %244 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%217[0] * %176) + %243 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %245 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%219[0] * %177) + %83#12 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %246 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%221[0] * %178) + %245 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %247 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%223[0] * %179) + %83#13 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %248 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%225[0] * %180) + %247 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %249 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%227[0] * %181) + %83#14 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %250 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%229[0] * %182) + %249 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %251 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%231[0] * %183) + %83#15 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %252 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%233[0] * %184) + %251 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %253 = vector.extract %165[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %254 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%202[0] * %169) + %83#16 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %255 = vector.extract %166[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %256 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%205[0] * %170) + %254 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %257 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%207[0] * %171) + %83#17 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %258 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%209[0] * %172) + %257 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %259 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%211[0] * %173) + %83#18 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %260 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%213[0] * %174) + %259 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %261 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%215[0] * %175) + %83#19 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %262 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%217[0] * %176) + %261 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %263 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%219[0] * %177) + %83#20 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %264 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%221[0] * %178) + %263 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %265 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%223[0] * %179) + %83#21 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %266 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%225[0] * %180) + %265 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %267 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%227[0] * %181) + %83#22 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %268 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%229[0] * %182) + %267 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %269 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%231[0] * %183) + %83#23 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %270 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%233[0] * %184) + %269 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %271 = vector.extract %167[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %272 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%202[0] * %169) + %83#24 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %273 = vector.extract %168[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %274 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%205[0] * %170) + %272 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %275 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%207[0] * %171) + %83#25 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %276 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%209[0] * %172) + %275 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %277 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%211[0] * %173) + %83#26 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %278 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%213[0] * %174) + %277 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %279 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%215[0] * %175) + %83#27 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %280 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%217[0] * %176) + %279 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %281 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%219[0] * %177) + %83#28 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %282 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%221[0] * %178) + %281 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %283 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%223[0] * %179) + %83#29 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %284 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%225[0] * %180) + %283 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %285 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%227[0] * %181) + %83#30 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %286 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%229[0] * %182) + %285 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %287 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%231[0] * %183) + %83#31 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %288 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%233[0] * %184) + %287 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %289 = vector.extract_strided_slice %206 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %290 = affine.apply #map35()[%block_id_x] + %291 = affine.apply #map35()[%block_id_y] + %292 = affine.apply #map36()[%thread_id_x] + %293 = arith.muli %290, %c16384 overflow : index + %294 = arith.muli %292, %c16384 overflow : index + %295 = arith.addi %293, %291 overflow : index + %296 = arith.addi %294, %84 overflow : index + %reinterpret_cast_13 = memref.reinterpret_cast %4 to offset: [%295], sizes: [536870910], strides: [1] : memref to memref<536870910xf32, strided<[1], offset: ?>> + %cast_14 = memref.cast %reinterpret_cast_13 : memref<536870910xf32, strided<[1], offset: ?>> to memref> + %297 = amdgpu.fat_raw_buffer_cast %cast_14 validBytes(%c2147483643_i64) resetOffset : memref> to memref> + vector.store %289, %297[%296] : memref>, vector<1xf32> + %298 = vector.extract_strided_slice %206 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %299 = affine.apply #map37()[%thread_id_x] + %300 = arith.muli %299, %c16384 overflow : index + %301 = arith.addi %300, %84 overflow : index + vector.store %298, %297[%301] : memref>, vector<1xf32> + %302 = vector.extract_strided_slice %206 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %303 = affine.apply #map38()[%thread_id_x] + %304 = arith.muli %303, %c16384 overflow : index + %305 = arith.addi %304, %84 overflow : index + vector.store %302, %297[%305] : memref>, vector<1xf32> + %306 = vector.extract_strided_slice %206 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %307 = affine.apply #map39()[%thread_id_x] + %308 = arith.muli %307, %c16384 overflow : index + %309 = arith.addi %308, %84 overflow : index + vector.store %306, %297[%309] : memref>, vector<1xf32> + %310 = vector.extract_strided_slice %210 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %311 = arith.addi %294, %90 overflow : index + vector.store %310, %297[%311] : memref>, vector<1xf32> + %312 = vector.extract_strided_slice %210 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %313 = arith.addi %300, %90 overflow : index + vector.store %312, %297[%313] : memref>, vector<1xf32> + %314 = vector.extract_strided_slice %210 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %315 = arith.addi %304, %90 overflow : index + vector.store %314, %297[%315] : memref>, vector<1xf32> + %316 = vector.extract_strided_slice %210 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %317 = arith.addi %308, %90 overflow : index + vector.store %316, %297[%317] : memref>, vector<1xf32> + %318 = vector.extract_strided_slice %214 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %319 = arith.addi %294, %93 overflow : index + vector.store %318, %297[%319] : memref>, vector<1xf32> + %320 = vector.extract_strided_slice %214 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %321 = arith.addi %300, %93 overflow : index + vector.store %320, %297[%321] : memref>, vector<1xf32> + %322 = vector.extract_strided_slice %214 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %323 = arith.addi %304, %93 overflow : index + vector.store %322, %297[%323] : memref>, vector<1xf32> + %324 = vector.extract_strided_slice %214 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %325 = arith.addi %308, %93 overflow : index + vector.store %324, %297[%325] : memref>, vector<1xf32> + %326 = vector.extract_strided_slice %218 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %327 = arith.addi %294, %96 overflow : index + vector.store %326, %297[%327] : memref>, vector<1xf32> + %328 = vector.extract_strided_slice %218 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %329 = arith.addi %300, %96 overflow : index + vector.store %328, %297[%329] : memref>, vector<1xf32> + %330 = vector.extract_strided_slice %218 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %331 = arith.addi %304, %96 overflow : index + vector.store %330, %297[%331] : memref>, vector<1xf32> + %332 = vector.extract_strided_slice %218 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %333 = arith.addi %308, %96 overflow : index + vector.store %332, %297[%333] : memref>, vector<1xf32> + %334 = vector.extract_strided_slice %222 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %335 = arith.addi %294, %99 overflow : index + vector.store %334, %297[%335] : memref>, vector<1xf32> + %336 = vector.extract_strided_slice %222 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %337 = arith.addi %300, %99 overflow : index + vector.store %336, %297[%337] : memref>, vector<1xf32> + %338 = vector.extract_strided_slice %222 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %339 = arith.addi %304, %99 overflow : index + vector.store %338, %297[%339] : memref>, vector<1xf32> + %340 = vector.extract_strided_slice %222 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %341 = arith.addi %308, %99 overflow : index + vector.store %340, %297[%341] : memref>, vector<1xf32> + %342 = vector.extract_strided_slice %226 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %343 = arith.addi %294, %102 overflow : index + vector.store %342, %297[%343] : memref>, vector<1xf32> + %344 = vector.extract_strided_slice %226 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %345 = arith.addi %300, %102 overflow : index + vector.store %344, %297[%345] : memref>, vector<1xf32> + %346 = vector.extract_strided_slice %226 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %347 = arith.addi %304, %102 overflow : index + vector.store %346, %297[%347] : memref>, vector<1xf32> + %348 = vector.extract_strided_slice %226 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %349 = arith.addi %308, %102 overflow : index + vector.store %348, %297[%349] : memref>, vector<1xf32> + %350 = vector.extract_strided_slice %230 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %351 = arith.addi %294, %105 overflow : index + vector.store %350, %297[%351] : memref>, vector<1xf32> + %352 = vector.extract_strided_slice %230 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %353 = arith.addi %300, %105 overflow : index + vector.store %352, %297[%353] : memref>, vector<1xf32> + %354 = vector.extract_strided_slice %230 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %355 = arith.addi %304, %105 overflow : index + vector.store %354, %297[%355] : memref>, vector<1xf32> + %356 = vector.extract_strided_slice %230 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %357 = arith.addi %308, %105 overflow : index + vector.store %356, %297[%357] : memref>, vector<1xf32> + %358 = vector.extract_strided_slice %234 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %359 = arith.addi %294, %108 overflow : index + vector.store %358, %297[%359] : memref>, vector<1xf32> + %360 = vector.extract_strided_slice %234 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %361 = arith.addi %300, %108 overflow : index + vector.store %360, %297[%361] : memref>, vector<1xf32> + %362 = vector.extract_strided_slice %234 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %363 = arith.addi %304, %108 overflow : index + vector.store %362, %297[%363] : memref>, vector<1xf32> + %364 = vector.extract_strided_slice %234 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %365 = arith.addi %308, %108 overflow : index + vector.store %364, %297[%365] : memref>, vector<1xf32> + %366 = vector.extract_strided_slice %238 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %367 = affine.apply #map40()[%thread_id_x] + %368 = arith.muli %367, %c16384 overflow : index + %369 = arith.addi %368, %84 overflow : index + vector.store %366, %297[%369] : memref>, vector<1xf32> + %370 = vector.extract_strided_slice %238 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %371 = affine.apply #map41()[%thread_id_x] + %372 = arith.muli %371, %c16384 overflow : index + %373 = arith.addi %372, %84 overflow : index + vector.store %370, %297[%373] : memref>, vector<1xf32> + %374 = vector.extract_strided_slice %238 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %375 = affine.apply #map42()[%thread_id_x] + %376 = arith.muli %375, %c16384 overflow : index + %377 = arith.addi %376, %84 overflow : index + vector.store %374, %297[%377] : memref>, vector<1xf32> + %378 = vector.extract_strided_slice %238 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %379 = affine.apply #map43()[%thread_id_x] + %380 = arith.muli %379, %c16384 overflow : index + %381 = arith.addi %380, %84 overflow : index + vector.store %378, %297[%381] : memref>, vector<1xf32> + %382 = vector.extract_strided_slice %240 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %383 = arith.addi %368, %90 overflow : index + vector.store %382, %297[%383] : memref>, vector<1xf32> + %384 = vector.extract_strided_slice %240 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %385 = arith.addi %372, %90 overflow : index + vector.store %384, %297[%385] : memref>, vector<1xf32> + %386 = vector.extract_strided_slice %240 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %387 = arith.addi %376, %90 overflow : index + vector.store %386, %297[%387] : memref>, vector<1xf32> + %388 = vector.extract_strided_slice %240 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %389 = arith.addi %380, %90 overflow : index + vector.store %388, %297[%389] : memref>, vector<1xf32> + %390 = vector.extract_strided_slice %242 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %391 = arith.addi %368, %93 overflow : index + vector.store %390, %297[%391] : memref>, vector<1xf32> + %392 = vector.extract_strided_slice %242 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %393 = arith.addi %372, %93 overflow : index + vector.store %392, %297[%393] : memref>, vector<1xf32> + %394 = vector.extract_strided_slice %242 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %395 = arith.addi %376, %93 overflow : index + vector.store %394, %297[%395] : memref>, vector<1xf32> + %396 = vector.extract_strided_slice %242 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %397 = arith.addi %380, %93 overflow : index + vector.store %396, %297[%397] : memref>, vector<1xf32> + %398 = vector.extract_strided_slice %244 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %399 = arith.addi %368, %96 overflow : index + vector.store %398, %297[%399] : memref>, vector<1xf32> + %400 = vector.extract_strided_slice %244 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %401 = arith.addi %372, %96 overflow : index + vector.store %400, %297[%401] : memref>, vector<1xf32> + %402 = vector.extract_strided_slice %244 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %403 = arith.addi %376, %96 overflow : index + vector.store %402, %297[%403] : memref>, vector<1xf32> + %404 = vector.extract_strided_slice %244 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %405 = arith.addi %380, %96 overflow : index + vector.store %404, %297[%405] : memref>, vector<1xf32> + %406 = vector.extract_strided_slice %246 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %407 = arith.addi %368, %99 overflow : index + vector.store %406, %297[%407] : memref>, vector<1xf32> + %408 = vector.extract_strided_slice %246 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %409 = arith.addi %372, %99 overflow : index + vector.store %408, %297[%409] : memref>, vector<1xf32> + %410 = vector.extract_strided_slice %246 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %411 = arith.addi %376, %99 overflow : index + vector.store %410, %297[%411] : memref>, vector<1xf32> + %412 = vector.extract_strided_slice %246 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %413 = arith.addi %380, %99 overflow : index + vector.store %412, %297[%413] : memref>, vector<1xf32> + %414 = vector.extract_strided_slice %248 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %415 = arith.addi %368, %102 overflow : index + vector.store %414, %297[%415] : memref>, vector<1xf32> + %416 = vector.extract_strided_slice %248 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %417 = arith.addi %372, %102 overflow : index + vector.store %416, %297[%417] : memref>, vector<1xf32> + %418 = vector.extract_strided_slice %248 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %419 = arith.addi %376, %102 overflow : index + vector.store %418, %297[%419] : memref>, vector<1xf32> + %420 = vector.extract_strided_slice %248 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %421 = arith.addi %380, %102 overflow : index + vector.store %420, %297[%421] : memref>, vector<1xf32> + %422 = vector.extract_strided_slice %250 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %423 = arith.addi %368, %105 overflow : index + vector.store %422, %297[%423] : memref>, vector<1xf32> + %424 = vector.extract_strided_slice %250 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %425 = arith.addi %372, %105 overflow : index + vector.store %424, %297[%425] : memref>, vector<1xf32> + %426 = vector.extract_strided_slice %250 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %427 = arith.addi %376, %105 overflow : index + vector.store %426, %297[%427] : memref>, vector<1xf32> + %428 = vector.extract_strided_slice %250 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %429 = arith.addi %380, %105 overflow : index + vector.store %428, %297[%429] : memref>, vector<1xf32> + %430 = vector.extract_strided_slice %252 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %431 = arith.addi %368, %108 overflow : index + vector.store %430, %297[%431] : memref>, vector<1xf32> + %432 = vector.extract_strided_slice %252 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %433 = arith.addi %372, %108 overflow : index + vector.store %432, %297[%433] : memref>, vector<1xf32> + %434 = vector.extract_strided_slice %252 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %435 = arith.addi %376, %108 overflow : index + vector.store %434, %297[%435] : memref>, vector<1xf32> + %436 = vector.extract_strided_slice %252 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %437 = arith.addi %380, %108 overflow : index + vector.store %436, %297[%437] : memref>, vector<1xf32> + %438 = vector.extract_strided_slice %256 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %439 = affine.apply #map44()[%thread_id_x] + %440 = arith.muli %439, %c16384 overflow : index + %441 = arith.addi %440, %84 overflow : index + vector.store %438, %297[%441] : memref>, vector<1xf32> + %442 = vector.extract_strided_slice %256 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %443 = affine.apply #map45()[%thread_id_x] + %444 = arith.muli %443, %c16384 overflow : index + %445 = arith.addi %444, %84 overflow : index + vector.store %442, %297[%445] : memref>, vector<1xf32> + %446 = vector.extract_strided_slice %256 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %447 = affine.apply #map46()[%thread_id_x] + %448 = arith.muli %447, %c16384 overflow : index + %449 = arith.addi %448, %84 overflow : index + vector.store %446, %297[%449] : memref>, vector<1xf32> + %450 = vector.extract_strided_slice %256 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %451 = affine.apply #map47()[%thread_id_x] + %452 = arith.muli %451, %c16384 overflow : index + %453 = arith.addi %452, %84 overflow : index + vector.store %450, %297[%453] : memref>, vector<1xf32> + %454 = vector.extract_strided_slice %258 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %455 = arith.addi %440, %90 overflow : index + vector.store %454, %297[%455] : memref>, vector<1xf32> + %456 = vector.extract_strided_slice %258 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %457 = arith.addi %444, %90 overflow : index + vector.store %456, %297[%457] : memref>, vector<1xf32> + %458 = vector.extract_strided_slice %258 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %459 = arith.addi %448, %90 overflow : index + vector.store %458, %297[%459] : memref>, vector<1xf32> + %460 = vector.extract_strided_slice %258 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %461 = arith.addi %452, %90 overflow : index + vector.store %460, %297[%461] : memref>, vector<1xf32> + %462 = vector.extract_strided_slice %260 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %463 = arith.addi %440, %93 overflow : index + vector.store %462, %297[%463] : memref>, vector<1xf32> + %464 = vector.extract_strided_slice %260 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %465 = arith.addi %444, %93 overflow : index + vector.store %464, %297[%465] : memref>, vector<1xf32> + %466 = vector.extract_strided_slice %260 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %467 = arith.addi %448, %93 overflow : index + vector.store %466, %297[%467] : memref>, vector<1xf32> + %468 = vector.extract_strided_slice %260 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %469 = arith.addi %452, %93 overflow : index + vector.store %468, %297[%469] : memref>, vector<1xf32> + %470 = vector.extract_strided_slice %262 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %471 = arith.addi %440, %96 overflow : index + vector.store %470, %297[%471] : memref>, vector<1xf32> + %472 = vector.extract_strided_slice %262 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %473 = arith.addi %444, %96 overflow : index + vector.store %472, %297[%473] : memref>, vector<1xf32> + %474 = vector.extract_strided_slice %262 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %475 = arith.addi %448, %96 overflow : index + vector.store %474, %297[%475] : memref>, vector<1xf32> + %476 = vector.extract_strided_slice %262 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %477 = arith.addi %452, %96 overflow : index + vector.store %476, %297[%477] : memref>, vector<1xf32> + %478 = vector.extract_strided_slice %264 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %479 = arith.addi %440, %99 overflow : index + vector.store %478, %297[%479] : memref>, vector<1xf32> + %480 = vector.extract_strided_slice %264 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %481 = arith.addi %444, %99 overflow : index + vector.store %480, %297[%481] : memref>, vector<1xf32> + %482 = vector.extract_strided_slice %264 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %483 = arith.addi %448, %99 overflow : index + vector.store %482, %297[%483] : memref>, vector<1xf32> + %484 = vector.extract_strided_slice %264 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %485 = arith.addi %452, %99 overflow : index + vector.store %484, %297[%485] : memref>, vector<1xf32> + %486 = vector.extract_strided_slice %266 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %487 = arith.addi %440, %102 overflow : index + vector.store %486, %297[%487] : memref>, vector<1xf32> + %488 = vector.extract_strided_slice %266 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %489 = arith.addi %444, %102 overflow : index + vector.store %488, %297[%489] : memref>, vector<1xf32> + %490 = vector.extract_strided_slice %266 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %491 = arith.addi %448, %102 overflow : index + vector.store %490, %297[%491] : memref>, vector<1xf32> + %492 = vector.extract_strided_slice %266 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %493 = arith.addi %452, %102 overflow : index + vector.store %492, %297[%493] : memref>, vector<1xf32> + %494 = vector.extract_strided_slice %268 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %495 = arith.addi %440, %105 overflow : index + vector.store %494, %297[%495] : memref>, vector<1xf32> + %496 = vector.extract_strided_slice %268 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %497 = arith.addi %444, %105 overflow : index + vector.store %496, %297[%497] : memref>, vector<1xf32> + %498 = vector.extract_strided_slice %268 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %499 = arith.addi %448, %105 overflow : index + vector.store %498, %297[%499] : memref>, vector<1xf32> + %500 = vector.extract_strided_slice %268 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %501 = arith.addi %452, %105 overflow : index + vector.store %500, %297[%501] : memref>, vector<1xf32> + %502 = vector.extract_strided_slice %270 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %503 = arith.addi %440, %108 overflow : index + vector.store %502, %297[%503] : memref>, vector<1xf32> + %504 = vector.extract_strided_slice %270 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %505 = arith.addi %444, %108 overflow : index + vector.store %504, %297[%505] : memref>, vector<1xf32> + %506 = vector.extract_strided_slice %270 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %507 = arith.addi %448, %108 overflow : index + vector.store %506, %297[%507] : memref>, vector<1xf32> + %508 = vector.extract_strided_slice %270 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %509 = arith.addi %452, %108 overflow : index + vector.store %508, %297[%509] : memref>, vector<1xf32> + %510 = vector.extract_strided_slice %274 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %511 = affine.apply #map48()[%thread_id_x] + %512 = arith.muli %511, %c16384 overflow : index + %513 = arith.addi %512, %84 overflow : index + vector.store %510, %297[%513] : memref>, vector<1xf32> + %514 = vector.extract_strided_slice %274 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %515 = affine.apply #map49()[%thread_id_x] + %516 = arith.muli %515, %c16384 overflow : index + %517 = arith.addi %516, %84 overflow : index + vector.store %514, %297[%517] : memref>, vector<1xf32> + %518 = vector.extract_strided_slice %274 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %519 = affine.apply #map50()[%thread_id_x] + %520 = arith.muli %519, %c16384 overflow : index + %521 = arith.addi %520, %84 overflow : index + vector.store %518, %297[%521] : memref>, vector<1xf32> + %522 = vector.extract_strided_slice %274 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %523 = affine.apply #map51()[%thread_id_x] + %524 = arith.muli %523, %c16384 overflow : index + %525 = arith.addi %524, %84 overflow : index + vector.store %522, %297[%525] : memref>, vector<1xf32> + %526 = vector.extract_strided_slice %276 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %527 = arith.addi %512, %90 overflow : index + vector.store %526, %297[%527] : memref>, vector<1xf32> + %528 = vector.extract_strided_slice %276 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %529 = arith.addi %516, %90 overflow : index + vector.store %528, %297[%529] : memref>, vector<1xf32> + %530 = vector.extract_strided_slice %276 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %531 = arith.addi %520, %90 overflow : index + vector.store %530, %297[%531] : memref>, vector<1xf32> + %532 = vector.extract_strided_slice %276 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %533 = arith.addi %524, %90 overflow : index + vector.store %532, %297[%533] : memref>, vector<1xf32> + %534 = vector.extract_strided_slice %278 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %535 = arith.addi %512, %93 overflow : index + vector.store %534, %297[%535] : memref>, vector<1xf32> + %536 = vector.extract_strided_slice %278 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %537 = arith.addi %516, %93 overflow : index + vector.store %536, %297[%537] : memref>, vector<1xf32> + %538 = vector.extract_strided_slice %278 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %539 = arith.addi %520, %93 overflow : index + vector.store %538, %297[%539] : memref>, vector<1xf32> + %540 = vector.extract_strided_slice %278 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %541 = arith.addi %524, %93 overflow : index + vector.store %540, %297[%541] : memref>, vector<1xf32> + %542 = vector.extract_strided_slice %280 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %543 = arith.addi %512, %96 overflow : index + vector.store %542, %297[%543] : memref>, vector<1xf32> + %544 = vector.extract_strided_slice %280 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %545 = arith.addi %516, %96 overflow : index + vector.store %544, %297[%545] : memref>, vector<1xf32> + %546 = vector.extract_strided_slice %280 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %547 = arith.addi %520, %96 overflow : index + vector.store %546, %297[%547] : memref>, vector<1xf32> + %548 = vector.extract_strided_slice %280 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %549 = arith.addi %524, %96 overflow : index + vector.store %548, %297[%549] : memref>, vector<1xf32> + %550 = vector.extract_strided_slice %282 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %551 = arith.addi %512, %99 overflow : index + vector.store %550, %297[%551] : memref>, vector<1xf32> + %552 = vector.extract_strided_slice %282 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %553 = arith.addi %516, %99 overflow : index + vector.store %552, %297[%553] : memref>, vector<1xf32> + %554 = vector.extract_strided_slice %282 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %555 = arith.addi %520, %99 overflow : index + vector.store %554, %297[%555] : memref>, vector<1xf32> + %556 = vector.extract_strided_slice %282 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %557 = arith.addi %524, %99 overflow : index + vector.store %556, %297[%557] : memref>, vector<1xf32> + %558 = vector.extract_strided_slice %284 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %559 = arith.addi %512, %102 overflow : index + vector.store %558, %297[%559] : memref>, vector<1xf32> + %560 = vector.extract_strided_slice %284 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %561 = arith.addi %516, %102 overflow : index + vector.store %560, %297[%561] : memref>, vector<1xf32> + %562 = vector.extract_strided_slice %284 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %563 = arith.addi %520, %102 overflow : index + vector.store %562, %297[%563] : memref>, vector<1xf32> + %564 = vector.extract_strided_slice %284 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %565 = arith.addi %524, %102 overflow : index + vector.store %564, %297[%565] : memref>, vector<1xf32> + %566 = vector.extract_strided_slice %286 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %567 = arith.addi %512, %105 overflow : index + vector.store %566, %297[%567] : memref>, vector<1xf32> + %568 = vector.extract_strided_slice %286 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %569 = arith.addi %516, %105 overflow : index + vector.store %568, %297[%569] : memref>, vector<1xf32> + %570 = vector.extract_strided_slice %286 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %571 = arith.addi %520, %105 overflow : index + vector.store %570, %297[%571] : memref>, vector<1xf32> + %572 = vector.extract_strided_slice %286 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %573 = arith.addi %524, %105 overflow : index + vector.store %572, %297[%573] : memref>, vector<1xf32> + %574 = vector.extract_strided_slice %288 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %575 = arith.addi %512, %108 overflow : index + vector.store %574, %297[%575] : memref>, vector<1xf32> + %576 = vector.extract_strided_slice %288 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %577 = arith.addi %516, %108 overflow : index + vector.store %576, %297[%577] : memref>, vector<1xf32> + %578 = vector.extract_strided_slice %288 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %579 = arith.addi %520, %108 overflow : index + vector.store %578, %297[%579] : memref>, vector<1xf32> + %580 = vector.extract_strided_slice %288 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %581 = arith.addi %524, %108 overflow : index + vector.store %580, %297[%581] : memref>, vector<1xf32> + return + } + } + } + func.func @isolated_benchmark$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view, %arg5: !hal.fence, %arg6: !hal.fence) -> !hal.buffer_view { + %0 = hal.tensor.import wait(%arg5) => %arg0 : !hal.buffer_view -> tensor<16384x8192xi8> + %1 = hal.tensor.import wait(%arg5) => %arg1 : !hal.buffer_view -> tensor<16384x512xi8> + %2 = hal.tensor.import wait(%arg5) => %arg2 : !hal.buffer_view -> tensor<16384x8192xi8> + %3 = hal.tensor.import wait(%arg5) => %arg3 : !hal.buffer_view -> tensor<16384x512xi8> + %4 = hal.tensor.import wait(%arg5) => %arg4 : !hal.buffer_view -> tensor<16384x16384xf32> + %5 = flow.dispatch @gemm::@gemm(%0, %1, %2, %3, %4) : (tensor<16384x8192xi8>, tensor<16384x512xi8>, tensor<16384x8192xi8>, tensor<16384x512xi8>, tensor<16384x16384xf32>) -> %4 + %6 = hal.tensor.barrier join(%5 : tensor<16384x16384xf32>) => %arg6 : !hal.fence + %7 = hal.tensor.export %6 : tensor<16384x16384xf32> -> !hal.buffer_view + return %7 : !hal.buffer_view + } + } + """ + ##cluster 0 loads most of it + mlir_different_mapping = """ + + #map = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8) floordiv 256) * 256)> + #map1 = affine_map<()[s0] -> ((s0 floordiv 8) mod 8)> + #map2 = affine_map<()[s0] -> (s0 mod 8)> + #map3 = affine_map<()[s0] -> (s0 * 16)> + #map4 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64) floordiv 32) * 256)> + #map5 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 64) floordiv 256) * 256 + 64)> + #map6 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 8) floordiv 32) * 256 + 64)> + #map7 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 128) floordiv 256) * 256 + 128)> + #map8 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 16) floordiv 32) * 256 + 128)> + #map9 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 192) floordiv 256) * 256 + 192)> + #map10 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 24) floordiv 32) * 256 + 192)> + #map11 = affine_map<()[s0, s1, s2] -> (s1 * 128 + s2 * 256 + s0 floordiv 2 - ((s1 * 128 + s0 floordiv 2) floordiv 256) * 256)> + #map12 = affine_map<()[s0] -> ((s0 floordiv 2) mod 2)> + #map13 = affine_map<()[s0] -> (s0 mod 2)> + #map14 = affine_map<()[s0] -> (s0 * 4)> + #map15 = affine_map<()[s0, s1] -> (s1 * 128 + (s0 floordiv 64) * 32 - ((s1 * 4 + s0 floordiv 64) floordiv 8) * 256)> + #map16 = affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 64)> + #map17 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64)> + #map18 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> + #map19 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 16)> + #map20 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 32)> + #map21 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 48)> + #map22 = affine_map<()[s0] -> (s0 * 4 + (s0 mod 64) floordiv 16 - (s0 floordiv 2) * 8)> + #map23 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16)> + #map24 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 16)> + #map25 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 32)> + #map26 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 48)> + #map27 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 64)> + #map28 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 80)> + #map29 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 96)> + #map30 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 112)> + #map31 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16 + 4)> + #map32 = affine_map<()[s0, s1] -> (s1 * 4 + (s0 mod 64) floordiv 16)> + #map33 = affine_map<()[s0, s1] -> (s0 * 128 + s1 * 16 + 128)> + #map34 = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 4 + 8)> + #map35 = affine_map<()[s0] -> (s0 * 256)> + #map36 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4)> + #map37 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 1)> + #map38 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 2)> + #map39 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 3)> + #map40 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 16)> + #map41 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 17)> + #map42 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 18)> + #map43 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 19)> + #map44 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 32)> + #map45 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 33)> + #map46 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 34)> + #map47 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 35)> + #map48 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 48)> + #map49 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 49)> + #map50 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 50)> + #map51 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 51)> + #translation = #iree_codegen.translation_info + module attributes {transform.with_named_sequence} { + stream.executable private @gemm { + stream.executable.export public @gemm workgroups() -> (index, index, index) { + %c16 = arith.constant 16 : index + %c224 = arith.constant 224 : index + %c1 = arith.constant 1 : index + stream.return %c16, %c224, %c1 : index, index, index + } + builtin.module { + func.func @gemm(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding, %arg3: !stream.binding, %arg4: !stream.binding) attributes {translation_info = #translation} { + %c4_i32 = arith.constant 4 : i32 + %c512_i14 = arith.constant 512 : i14 + %c-8192_i14 = arith.constant -8192 : i14 + %c2147483643_i64 = arith.constant 2147483643 : i64 + %c57344 = arith.constant 57344 : index + %c63 = arith.constant 63 : index + %c512 = arith.constant 512 : index + %c2147483646_i64 = arith.constant 2147483646 : i64 + %c8192 = arith.constant 8192 : index + %c1 = arith.constant 1 : index + %cst = arith.constant dense<0.000000e+00> : vector<4xf32> + %c0 = arith.constant 0 : index + %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> memref + %1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> memref + %2 = stream.binding.subspan %arg2[%c0] : !stream.binding -> memref + %3 = stream.binding.subspan %arg3[%c0] : !stream.binding -> memref + %4 = stream.binding.subspan %arg4[%c0] : !stream.binding -> memref + %block_id_x = gpu.block_id x upper_bound 16 + %block_id_y = gpu.block_id y upper_bound 224 + %thread_id_x = gpu.thread_id x upper_bound 256 + %thread_id_y = gpu.thread_id y upper_bound 2 + %alloc = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_0 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_1 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_2 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_3 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_4 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_5 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_6 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %c32_idx = arith.constant 32 : index + %c128_idx = arith.constant 128 : index + %c262144 = arith.constant 262144 : index + %c65536 = arith.constant 65536 : index + %is_cluster0 = arith.cmpi eq, %thread_id_y, %c0 : index + %5 = affine.apply #map()[%thread_id_x, %thread_id_y, %block_id_x] + %6 = affine.apply #map1()[%thread_id_x] + %7 = affine.apply #map2()[%thread_id_x] + %8 = arith.xori %7, %6 : index + %9 = affine.apply #map3()[%8] + %10 = affine.apply #map4()[%thread_id_x, %thread_id_y] + %11 = gpu.subgroup_broadcast %10, first_active_lane : index + %12 = gpu.subgroup_broadcast %c0, first_active_lane : index + %13 = arith.muli %5, %c8192 overflow : index + %14 = arith.addi %13, %9 overflow : index + %reinterpret_cast = memref.reinterpret_cast %0 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast = memref.cast %reinterpret_cast : memref<2147483646xi8, strided<[1]>> to memref> + %15 = amdgpu.fat_raw_buffer_cast %cast validBytes(%c2147483646_i64) cacheSwizzleStride(%c-8192_i14) resetOffset : memref> to memref> + // --- Address computations (all waves) --- + %16 = affine.apply #map5()[%thread_id_x, %thread_id_y, %block_id_x] + %17 = affine.apply #map6()[%thread_id_x, %thread_id_y] + %18 = gpu.subgroup_broadcast %17, first_active_lane : index + %19 = arith.muli %16, %c8192 overflow : index + %20 = arith.addi %19, %9 overflow : index + %21 = affine.apply #map7()[%thread_id_x, %thread_id_y, %block_id_x] + %22 = affine.apply #map8()[%thread_id_x, %thread_id_y] + %23 = gpu.subgroup_broadcast %22, first_active_lane : index + %24 = arith.muli %21, %c8192 overflow : index + %25 = arith.addi %24, %9 overflow : index + %26 = affine.apply #map9()[%thread_id_x, %thread_id_y, %block_id_x] + %27 = affine.apply #map10()[%thread_id_x, %thread_id_y] + %28 = gpu.subgroup_broadcast %27, first_active_lane : index + %29 = arith.muli %26, %c8192 overflow : index + %30 = arith.addi %29, %9 overflow : index + %31 = affine.apply #map11()[%thread_id_x, %thread_id_y, %block_id_x] + %32 = affine.apply #map12()[%thread_id_x] + %33 = affine.apply #map13()[%thread_id_x] + %34 = arith.xori %33, %32 : index + %35 = affine.apply #map14()[%34] + %36 = affine.apply #map15()[%thread_id_x, %thread_id_y] + %37 = gpu.subgroup_broadcast %36, first_active_lane : index + %38 = arith.muli %31, %c512 overflow : index + %39 = arith.addi %38, %35 overflow : index + %reinterpret_cast_7 = memref.reinterpret_cast %1 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_8 = memref.cast %reinterpret_cast_7 : memref<2147483646xi8, strided<[1]>> to memref> + %40 = amdgpu.fat_raw_buffer_cast %cast_8 validBytes(%c2147483646_i64) cacheSwizzleStride(%c512_i14) resetOffset : memref> to memref> + %41 = affine.apply #map()[%thread_id_x, %thread_id_y, %block_id_y] + %42 = arith.muli %41, %c8192 overflow : index + %43 = arith.addi %42, %9 overflow : index + %reinterpret_cast_9 = memref.reinterpret_cast %2 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_10 = memref.cast %reinterpret_cast_9 : memref<2147483646xi8, strided<[1]>> to memref> + %44 = amdgpu.fat_raw_buffer_cast %cast_10 validBytes(%c2147483646_i64) cacheSwizzleStride(%c-8192_i14) resetOffset : memref> to memref> + %45 = affine.apply #map5()[%thread_id_x, %thread_id_y, %block_id_y] + %46 = arith.muli %45, %c8192 overflow : index + %47 = arith.addi %46, %9 overflow : index + %48 = affine.apply #map7()[%thread_id_x, %thread_id_y, %block_id_y] + %49 = arith.muli %48, %c8192 overflow : index + %50 = arith.addi %49, %9 overflow : index + %51 = affine.apply #map9()[%thread_id_x, %thread_id_y, %block_id_y] + %52 = arith.muli %51, %c8192 overflow : index + %53 = arith.addi %52, %9 overflow : index + %54 = affine.apply #map11()[%thread_id_x, %thread_id_y, %block_id_y] + %55 = arith.muli %54, %c512 overflow : index + %56 = arith.addi %55, %35 overflow : index + %reinterpret_cast_11 = memref.reinterpret_cast %3 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_12 = memref.cast %reinterpret_cast_11 : memref<2147483646xi8, strided<[1]>> to memref> + %57 = amdgpu.fat_raw_buffer_cast %cast_12 validBytes(%c2147483646_i64) cacheSwizzleStride(%c512_i14) resetOffset : memref> to memref> + // --- Cluster 0 only: A data (8), A scale (2), B data (8) gathers --- + scf.if %is_cluster0 { + // A data: 4 original gathers (ty=0 addresses) + amdgpu.gather_to_lds %15[%14], %alloc_6[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%20], %alloc_6[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%25], %alloc_6[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%30], %alloc_6[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + // A data: 4 extra gathers (ty=1 addresses: global +262144, LDS row +32) + %ea_g0 = arith.addi %14, %c262144 overflow : index + %ea_l0 = arith.addi %11, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%ea_g0], %alloc_6[%ea_l0, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %ea_g1 = arith.addi %20, %c262144 overflow : index + %ea_l1 = arith.addi %18, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%ea_g1], %alloc_6[%ea_l1, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %ea_g2 = arith.addi %25, %c262144 overflow : index + %ea_l2 = arith.addi %23, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%ea_g2], %alloc_6[%ea_l2, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %ea_g3 = arith.addi %30, %c262144 overflow : index + %ea_l3 = arith.addi %28, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%ea_g3], %alloc_6[%ea_l3, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + // A scale: 1 original gather (ty=0) + amdgpu.gather_to_lds %40[%39], %alloc_4[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + // A scale: 1 extra gather (ty=1: global +65536, LDS row +128) + %eas_g0 = arith.addi %39, %c65536 overflow : index + %eas_l0 = arith.addi %37, %c128_idx overflow : index + amdgpu.gather_to_lds %40[%eas_g0], %alloc_4[%eas_l0, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + // B data: 4 original gathers (ty=0 addresses) + amdgpu.gather_to_lds %44[%43], %alloc_2[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%47], %alloc_2[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%50], %alloc_2[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%53], %alloc_2[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + // B data: 4 extra gathers (ty=1: global +262144, LDS row +32) + %eb_g0 = arith.addi %43, %c262144 overflow : index + amdgpu.gather_to_lds %44[%eb_g0], %alloc_2[%ea_l0, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %eb_g1 = arith.addi %47, %c262144 overflow : index + amdgpu.gather_to_lds %44[%eb_g1], %alloc_2[%ea_l1, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %eb_g2 = arith.addi %50, %c262144 overflow : index + amdgpu.gather_to_lds %44[%eb_g2], %alloc_2[%ea_l2, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %eb_g3 = arith.addi %53, %c262144 overflow : index + amdgpu.gather_to_lds %44[%eb_g3], %alloc_2[%ea_l3, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + } + // B scale: unchanged (both clusters, already cluster-aligned) + amdgpu.gather_to_lds %57[%56], %alloc_0[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + rocdl.s.barrier + %58 = affine.apply #map16()[%thread_id_x, %thread_id_y] + %59 = arith.index_cast %58 : index to i32 + %60 = arith.cmpi sge, %59, %c4_i32 : i32 + %61 = arith.cmpi slt, %59, %c4_i32 : i32 + scf.if %60 { + rocdl.s.barrier + } + %62 = affine.apply #map17()[%thread_id_x] + %63 = affine.apply #map18()[%thread_id_x] + %64 = arith.xori %63, %7 : index + %65 = affine.apply #map3()[%64] + %66 = affine.apply #map19()[%thread_id_x] + %67 = affine.apply #map20()[%thread_id_x] + %68 = affine.apply #map21()[%thread_id_x] + %69 = affine.apply #map22()[%thread_id_x] + %70 = affine.apply #map23()[%thread_id_x, %thread_id_y] + %71 = affine.apply #map24()[%thread_id_x, %thread_id_y] + %72 = affine.apply #map25()[%thread_id_x, %thread_id_y] + %73 = affine.apply #map26()[%thread_id_x, %thread_id_y] + %74 = affine.apply #map27()[%thread_id_x, %thread_id_y] + %75 = affine.apply #map28()[%thread_id_x, %thread_id_y] + %76 = affine.apply #map29()[%thread_id_x, %thread_id_y] + %77 = affine.apply #map30()[%thread_id_x, %thread_id_y] + %78 = affine.apply #map31()[%thread_id_x] + %79 = arith.xori %78, %7 : index + %80 = affine.apply #map3()[%79] + %81 = arith.xori %33, %c1 : index + %82 = affine.apply #map32()[%thread_id_x, %81] + %83:40 = scf.for %arg5 = %c0 to %c63 step %c1 iter_args(%arg6 = %cst, %arg7 = %cst, %arg8 = %cst, %arg9 = %cst, %arg10 = %cst, %arg11 = %cst, %arg12 = %cst, %arg13 = %cst, %arg14 = %cst, %arg15 = %cst, %arg16 = %cst, %arg17 = %cst, %arg18 = %cst, %arg19 = %cst, %arg20 = %cst, %arg21 = %cst, %arg22 = %cst, %arg23 = %cst, %arg24 = %cst, %arg25 = %cst, %arg26 = %cst, %arg27 = %cst, %arg28 = %cst, %arg29 = %cst, %arg30 = %cst, %arg31 = %cst, %arg32 = %cst, %arg33 = %cst, %arg34 = %cst, %arg35 = %cst, %arg36 = %cst, %arg37 = %cst, %arg38 = %alloc_6, %arg39 = %alloc_5, %arg40 = %alloc_4, %arg41 = %alloc_3, %arg42 = %alloc_2, %arg43 = %alloc_1, %arg44 = %alloc_0, %arg45 = %alloc) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>) { + rocdl.sched.barrier 0 + rocdl.s.barrier + // --- Address computations (all waves) --- + %582 = affine.apply #map33()[%arg5, %8] + %583 = arith.addi %13, %582 overflow : index + %584 = arith.addi %19, %582 overflow : index + %585 = arith.addi %24, %582 overflow : index + %586 = arith.addi %29, %582 overflow : index + %587 = affine.apply #map34()[%arg5, %34] + %588 = arith.addi %38, %587 overflow : index + %589 = arith.addi %42, %582 overflow : index + %590 = arith.addi %46, %582 overflow : index + %591 = arith.addi %49, %582 overflow : index + %592 = arith.addi %52, %582 overflow : index + %593 = arith.addi %55, %587 overflow : index + // --- Cluster 0 only: A data (8), A scale (2), B data (8) gathers --- + scf.if %is_cluster0 { + // A data: 4 original gathers (ty=0 addresses) + amdgpu.gather_to_lds %15[%583], %arg39[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%584], %arg39[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%585], %arg39[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%586], %arg39[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + // A data: 4 extra gathers (ty=1: global +262144, LDS row +32) + %lea_g0 = arith.addi %583, %c262144 overflow : index + %lea_l0 = arith.addi %11, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%lea_g0], %arg39[%lea_l0, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %lea_g1 = arith.addi %584, %c262144 overflow : index + %lea_l1 = arith.addi %18, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%lea_g1], %arg39[%lea_l1, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %lea_g2 = arith.addi %585, %c262144 overflow : index + %lea_l2 = arith.addi %23, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%lea_g2], %arg39[%lea_l2, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %lea_g3 = arith.addi %586, %c262144 overflow : index + %lea_l3 = arith.addi %28, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%lea_g3], %arg39[%lea_l3, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + // A scale: 1 original gather (ty=0) + amdgpu.gather_to_lds %40[%588], %arg41[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + // A scale: 1 extra gather (ty=1: global +65536, LDS row +128) + %leas_g0 = arith.addi %588, %c65536 overflow : index + %leas_l0 = arith.addi %37, %c128_idx overflow : index + amdgpu.gather_to_lds %40[%leas_g0], %arg41[%leas_l0, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + // B data: 4 original gathers (ty=0 addresses) + amdgpu.gather_to_lds %44[%589], %arg43[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%590], %arg43[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%591], %arg43[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%592], %arg43[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + // B data: 4 extra gathers (ty=1: global +262144, LDS row +32) + %leb_g0 = arith.addi %589, %c262144 overflow : index + amdgpu.gather_to_lds %44[%leb_g0], %arg43[%lea_l0, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %leb_g1 = arith.addi %590, %c262144 overflow : index + amdgpu.gather_to_lds %44[%leb_g1], %arg43[%lea_l1, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %leb_g2 = arith.addi %591, %c262144 overflow : index + amdgpu.gather_to_lds %44[%leb_g2], %arg43[%lea_l2, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %leb_g3 = arith.addi %592, %c262144 overflow : index + amdgpu.gather_to_lds %44[%leb_g3], %arg43[%lea_l3, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + } + // B scale: unchanged (both clusters) + amdgpu.gather_to_lds %57[%593], %arg45[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + rocdl.sched.barrier 0 + amdgpu.memory_counter_wait load(10) + %594 = vector.load %arg38[%62, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %595 = vector.load %arg38[%66, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %596 = vector.load %arg38[%67, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %597 = vector.load %arg38[%68, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %598 = vector.load %arg40[%62, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %599 = vector.load %arg40[%66, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %600 = vector.load %arg40[%67, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %601 = vector.load %arg40[%68, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %602 = vector.load %arg42[%70, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %603 = vector.load %arg42[%71, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %604 = vector.load %arg42[%72, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %605 = vector.load %arg42[%73, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %606 = vector.load %arg42[%74, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %607 = vector.load %arg42[%75, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %608 = vector.load %arg42[%76, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %609 = vector.load %arg42[%77, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %610 = vector.load %arg44[%70, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %611 = vector.load %arg44[%71, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %612 = vector.load %arg44[%72, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %613 = vector.load %arg44[%73, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %614 = vector.load %arg44[%74, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %615 = vector.load %arg44[%75, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %616 = vector.load %arg44[%76, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %617 = vector.load %arg44[%77, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %618 = vector.bitcast %594 : vector<16xi8> to vector<32xf4E2M1FN> + %619 = vector.bitcast %595 : vector<16xi8> to vector<32xf4E2M1FN> + %620 = vector.bitcast %596 : vector<16xi8> to vector<32xf4E2M1FN> + %621 = vector.bitcast %597 : vector<16xi8> to vector<32xf4E2M1FN> + %622 = vector.bitcast %598 : vector<1xi8> to vector<1xf8E8M0FNU> + %623 = vector.bitcast %599 : vector<1xi8> to vector<1xf8E8M0FNU> + %624 = vector.bitcast %600 : vector<1xi8> to vector<1xf8E8M0FNU> + %625 = vector.bitcast %601 : vector<1xi8> to vector<1xf8E8M0FNU> + %626 = vector.bitcast %602 : vector<16xi8> to vector<32xf4E2M1FN> + %627 = vector.bitcast %603 : vector<16xi8> to vector<32xf4E2M1FN> + %628 = vector.bitcast %604 : vector<16xi8> to vector<32xf4E2M1FN> + %629 = vector.bitcast %605 : vector<16xi8> to vector<32xf4E2M1FN> + %630 = vector.bitcast %606 : vector<16xi8> to vector<32xf4E2M1FN> + %631 = vector.bitcast %607 : vector<16xi8> to vector<32xf4E2M1FN> + %632 = vector.bitcast %608 : vector<16xi8> to vector<32xf4E2M1FN> + %633 = vector.bitcast %609 : vector<16xi8> to vector<32xf4E2M1FN> + %634 = vector.bitcast %610 : vector<1xi8> to vector<1xf8E8M0FNU> + %635 = vector.bitcast %611 : vector<1xi8> to vector<1xf8E8M0FNU> + %636 = vector.bitcast %612 : vector<1xi8> to vector<1xf8E8M0FNU> + %637 = vector.bitcast %613 : vector<1xi8> to vector<1xf8E8M0FNU> + %638 = vector.bitcast %614 : vector<1xi8> to vector<1xf8E8M0FNU> + %639 = vector.bitcast %615 : vector<1xi8> to vector<1xf8E8M0FNU> + %640 = vector.bitcast %616 : vector<1xi8> to vector<1xf8E8M0FNU> + %641 = vector.bitcast %617 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.s.setprio 1 + %642 = vector.extract %622[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %643 = vector.extract %634[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %644 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%643[0] * %626) + %arg6 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %645 = vector.extract %635[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %646 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%645[0] * %627) + %arg7 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %647 = vector.extract %636[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %648 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%647[0] * %628) + %arg8 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %649 = vector.extract %637[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %650 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%649[0] * %629) + %arg9 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %651 = vector.extract %638[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %652 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%651[0] * %630) + %arg10 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %653 = vector.extract %639[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %654 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%653[0] * %631) + %arg11 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %655 = vector.extract %640[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %656 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%655[0] * %632) + %arg12 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %657 = vector.extract %641[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %658 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%657[0] * %633) + %arg13 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %659 = vector.extract %623[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %660 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%643[0] * %626) + %arg14 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %661 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%645[0] * %627) + %arg15 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %662 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%647[0] * %628) + %arg16 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %663 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%649[0] * %629) + %arg17 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %664 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%651[0] * %630) + %arg18 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %665 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%653[0] * %631) + %arg19 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %666 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%655[0] * %632) + %arg20 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %667 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%657[0] * %633) + %arg21 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %668 = vector.extract %624[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %669 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%643[0] * %626) + %arg22 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %670 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%645[0] * %627) + %arg23 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %671 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%647[0] * %628) + %arg24 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %672 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%649[0] * %629) + %arg25 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %673 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%651[0] * %630) + %arg26 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %674 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%653[0] * %631) + %arg27 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %675 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%655[0] * %632) + %arg28 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %676 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%657[0] * %633) + %arg29 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %677 = vector.extract %625[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %678 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%643[0] * %626) + %arg30 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %679 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%645[0] * %627) + %arg31 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %680 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%647[0] * %628) + %arg32 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %681 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%649[0] * %629) + %arg33 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %682 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%651[0] * %630) + %arg34 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %683 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%653[0] * %631) + %arg35 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %684 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%655[0] * %632) + %arg36 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %685 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%657[0] * %633) + %arg37 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.sched.barrier 0 + %686 = vector.load %arg38[%62, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %687 = vector.load %arg38[%66, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %688 = vector.load %arg38[%67, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %689 = vector.load %arg38[%68, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %690 = vector.load %arg40[%62, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %691 = vector.load %arg40[%66, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %692 = vector.load %arg40[%67, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %693 = vector.load %arg40[%68, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %694 = vector.load %arg42[%70, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %695 = vector.load %arg42[%71, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %696 = vector.load %arg42[%72, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %697 = vector.load %arg42[%73, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %698 = vector.load %arg42[%74, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %699 = vector.load %arg42[%75, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %700 = vector.load %arg42[%76, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %701 = vector.load %arg42[%77, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %702 = vector.load %arg44[%70, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %703 = vector.load %arg44[%71, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %704 = vector.load %arg44[%72, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %705 = vector.load %arg44[%73, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %706 = vector.load %arg44[%74, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %707 = vector.load %arg44[%75, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %708 = vector.load %arg44[%76, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %709 = vector.load %arg44[%77, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %710 = vector.bitcast %686 : vector<16xi8> to vector<32xf4E2M1FN> + %711 = vector.bitcast %687 : vector<16xi8> to vector<32xf4E2M1FN> + %712 = vector.bitcast %688 : vector<16xi8> to vector<32xf4E2M1FN> + %713 = vector.bitcast %689 : vector<16xi8> to vector<32xf4E2M1FN> + %714 = vector.bitcast %690 : vector<1xi8> to vector<1xf8E8M0FNU> + %715 = vector.bitcast %691 : vector<1xi8> to vector<1xf8E8M0FNU> + %716 = vector.bitcast %692 : vector<1xi8> to vector<1xf8E8M0FNU> + %717 = vector.bitcast %693 : vector<1xi8> to vector<1xf8E8M0FNU> + %718 = vector.bitcast %694 : vector<16xi8> to vector<32xf4E2M1FN> + %719 = vector.bitcast %695 : vector<16xi8> to vector<32xf4E2M1FN> + %720 = vector.bitcast %696 : vector<16xi8> to vector<32xf4E2M1FN> + %721 = vector.bitcast %697 : vector<16xi8> to vector<32xf4E2M1FN> + %722 = vector.bitcast %698 : vector<16xi8> to vector<32xf4E2M1FN> + %723 = vector.bitcast %699 : vector<16xi8> to vector<32xf4E2M1FN> + %724 = vector.bitcast %700 : vector<16xi8> to vector<32xf4E2M1FN> + %725 = vector.bitcast %701 : vector<16xi8> to vector<32xf4E2M1FN> + %726 = vector.bitcast %702 : vector<1xi8> to vector<1xf8E8M0FNU> + %727 = vector.bitcast %703 : vector<1xi8> to vector<1xf8E8M0FNU> + %728 = vector.bitcast %704 : vector<1xi8> to vector<1xf8E8M0FNU> + %729 = vector.bitcast %705 : vector<1xi8> to vector<1xf8E8M0FNU> + %730 = vector.bitcast %706 : vector<1xi8> to vector<1xf8E8M0FNU> + %731 = vector.bitcast %707 : vector<1xi8> to vector<1xf8E8M0FNU> + %732 = vector.bitcast %708 : vector<1xi8> to vector<1xf8E8M0FNU> + %733 = vector.bitcast %709 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.s.setprio 1 + %734 = vector.extract %714[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %735 = vector.extract %726[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %736 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%735[0] * %718) + %644 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %737 = vector.extract %727[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %738 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%737[0] * %719) + %646 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %739 = vector.extract %728[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %740 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%739[0] * %720) + %648 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %741 = vector.extract %729[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %742 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%741[0] * %721) + %650 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %743 = vector.extract %730[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %744 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%743[0] * %722) + %652 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %745 = vector.extract %731[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %746 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%745[0] * %723) + %654 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %747 = vector.extract %732[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %748 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%747[0] * %724) + %656 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %749 = vector.extract %733[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %750 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%749[0] * %725) + %658 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %751 = vector.extract %715[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %752 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%735[0] * %718) + %660 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %753 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%737[0] * %719) + %661 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %754 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%739[0] * %720) + %662 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %755 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%741[0] * %721) + %663 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %756 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%743[0] * %722) + %664 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %757 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%745[0] * %723) + %665 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %758 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%747[0] * %724) + %666 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %759 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%749[0] * %725) + %667 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %760 = vector.extract %716[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %761 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%735[0] * %718) + %669 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %762 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%737[0] * %719) + %670 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %763 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%739[0] * %720) + %671 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %764 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%741[0] * %721) + %672 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %765 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%743[0] * %722) + %673 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %766 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%745[0] * %723) + %674 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %767 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%747[0] * %724) + %675 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %768 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%749[0] * %725) + %676 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %769 = vector.extract %717[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %770 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%735[0] * %718) + %678 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %771 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%737[0] * %719) + %679 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %772 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%739[0] * %720) + %680 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %773 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%741[0] * %721) + %681 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %774 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%743[0] * %722) + %682 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %775 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%745[0] * %723) + %683 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %776 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%747[0] * %724) + %684 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %777 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%749[0] * %725) + %685 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + rocdl.sched.barrier 0 + scf.yield %736, %738, %740, %742, %744, %746, %748, %750, %752, %753, %754, %755, %756, %757, %758, %759, %761, %762, %763, %764, %765, %766, %767, %768, %770, %771, %772, %773, %774, %775, %776, %777, %arg39, %arg38, %arg41, %arg40, %arg43, %arg42, %arg45, %arg44 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space> + } + scf.if %61 { + rocdl.s.barrier + } + amdgpu.lds_barrier + %84 = affine.apply #map23()[%thread_id_x, %thread_id_y] + %85 = affine.apply #map22()[%thread_id_x] + %86 = vector.load %83#38[%84, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %87 = arith.xori %33, %c1 : index + %88 = affine.apply #map32()[%thread_id_x, %87] + %89 = vector.load %83#38[%84, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %90 = affine.apply #map24()[%thread_id_x, %thread_id_y] + %91 = vector.load %83#38[%90, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %92 = vector.load %83#38[%90, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %93 = affine.apply #map25()[%thread_id_x, %thread_id_y] + %94 = vector.load %83#38[%93, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %95 = vector.load %83#38[%93, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %96 = affine.apply #map26()[%thread_id_x, %thread_id_y] + %97 = vector.load %83#38[%96, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %98 = vector.load %83#38[%96, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %99 = affine.apply #map27()[%thread_id_x, %thread_id_y] + %100 = vector.load %83#38[%99, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %101 = vector.load %83#38[%99, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %102 = affine.apply #map28()[%thread_id_x, %thread_id_y] + %103 = vector.load %83#38[%102, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %104 = vector.load %83#38[%102, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %105 = affine.apply #map29()[%thread_id_x, %thread_id_y] + %106 = vector.load %83#38[%105, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %107 = vector.load %83#38[%105, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %108 = affine.apply #map30()[%thread_id_x, %thread_id_y] + %109 = vector.load %83#38[%108, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %110 = vector.load %83#38[%108, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %111 = affine.apply #map18()[%thread_id_x] + %112 = arith.xori %111, %7 : index + %113 = affine.apply #map3()[%112] + %114 = vector.load %83#36[%84, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %115 = affine.apply #map31()[%thread_id_x] + %116 = arith.xori %115, %7 : index + %117 = affine.apply #map3()[%116] + %118 = vector.load %83#36[%84, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %119 = vector.load %83#36[%90, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %120 = vector.load %83#36[%90, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %121 = vector.load %83#36[%93, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %122 = vector.load %83#36[%93, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %123 = vector.load %83#36[%96, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %124 = vector.load %83#36[%96, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %125 = vector.load %83#36[%99, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %126 = vector.load %83#36[%99, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %127 = vector.load %83#36[%102, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %128 = vector.load %83#36[%102, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %129 = vector.load %83#36[%105, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %130 = vector.load %83#36[%105, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %131 = vector.load %83#36[%108, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %132 = vector.load %83#36[%108, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %133 = affine.apply #map17()[%thread_id_x] + %134 = vector.load %83#34[%133, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %135 = vector.load %83#34[%133, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %136 = affine.apply #map19()[%thread_id_x] + %137 = vector.load %83#34[%136, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %138 = vector.load %83#34[%136, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %139 = affine.apply #map20()[%thread_id_x] + %140 = vector.load %83#34[%139, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %141 = vector.load %83#34[%139, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %142 = affine.apply #map21()[%thread_id_x] + %143 = vector.load %83#34[%142, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %144 = vector.load %83#34[%142, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %145 = vector.load %83#32[%133, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %146 = vector.load %83#32[%133, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %147 = vector.load %83#32[%136, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %148 = vector.load %83#32[%136, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %149 = vector.load %83#32[%139, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %150 = vector.load %83#32[%139, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %151 = vector.load %83#32[%142, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %152 = vector.load %83#32[%142, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %153 = vector.bitcast %145 : vector<16xi8> to vector<32xf4E2M1FN> + %154 = vector.bitcast %146 : vector<16xi8> to vector<32xf4E2M1FN> + %155 = vector.bitcast %147 : vector<16xi8> to vector<32xf4E2M1FN> + %156 = vector.bitcast %148 : vector<16xi8> to vector<32xf4E2M1FN> + %157 = vector.bitcast %149 : vector<16xi8> to vector<32xf4E2M1FN> + %158 = vector.bitcast %150 : vector<16xi8> to vector<32xf4E2M1FN> + %159 = vector.bitcast %151 : vector<16xi8> to vector<32xf4E2M1FN> + %160 = vector.bitcast %152 : vector<16xi8> to vector<32xf4E2M1FN> + %161 = vector.bitcast %134 : vector<1xi8> to vector<1xf8E8M0FNU> + %162 = vector.bitcast %135 : vector<1xi8> to vector<1xf8E8M0FNU> + %163 = vector.bitcast %137 : vector<1xi8> to vector<1xf8E8M0FNU> + %164 = vector.bitcast %138 : vector<1xi8> to vector<1xf8E8M0FNU> + %165 = vector.bitcast %140 : vector<1xi8> to vector<1xf8E8M0FNU> + %166 = vector.bitcast %141 : vector<1xi8> to vector<1xf8E8M0FNU> + %167 = vector.bitcast %143 : vector<1xi8> to vector<1xf8E8M0FNU> + %168 = vector.bitcast %144 : vector<1xi8> to vector<1xf8E8M0FNU> + %169 = vector.bitcast %114 : vector<16xi8> to vector<32xf4E2M1FN> + %170 = vector.bitcast %118 : vector<16xi8> to vector<32xf4E2M1FN> + %171 = vector.bitcast %119 : vector<16xi8> to vector<32xf4E2M1FN> + %172 = vector.bitcast %120 : vector<16xi8> to vector<32xf4E2M1FN> + %173 = vector.bitcast %121 : vector<16xi8> to vector<32xf4E2M1FN> + %174 = vector.bitcast %122 : vector<16xi8> to vector<32xf4E2M1FN> + %175 = vector.bitcast %123 : vector<16xi8> to vector<32xf4E2M1FN> + %176 = vector.bitcast %124 : vector<16xi8> to vector<32xf4E2M1FN> + %177 = vector.bitcast %125 : vector<16xi8> to vector<32xf4E2M1FN> + %178 = vector.bitcast %126 : vector<16xi8> to vector<32xf4E2M1FN> + %179 = vector.bitcast %127 : vector<16xi8> to vector<32xf4E2M1FN> + %180 = vector.bitcast %128 : vector<16xi8> to vector<32xf4E2M1FN> + %181 = vector.bitcast %129 : vector<16xi8> to vector<32xf4E2M1FN> + %182 = vector.bitcast %130 : vector<16xi8> to vector<32xf4E2M1FN> + %183 = vector.bitcast %131 : vector<16xi8> to vector<32xf4E2M1FN> + %184 = vector.bitcast %132 : vector<16xi8> to vector<32xf4E2M1FN> + %185 = vector.bitcast %86 : vector<1xi8> to vector<1xf8E8M0FNU> + %186 = vector.bitcast %89 : vector<1xi8> to vector<1xf8E8M0FNU> + %187 = vector.bitcast %91 : vector<1xi8> to vector<1xf8E8M0FNU> + %188 = vector.bitcast %92 : vector<1xi8> to vector<1xf8E8M0FNU> + %189 = vector.bitcast %94 : vector<1xi8> to vector<1xf8E8M0FNU> + %190 = vector.bitcast %95 : vector<1xi8> to vector<1xf8E8M0FNU> + %191 = vector.bitcast %97 : vector<1xi8> to vector<1xf8E8M0FNU> + %192 = vector.bitcast %98 : vector<1xi8> to vector<1xf8E8M0FNU> + %193 = vector.bitcast %100 : vector<1xi8> to vector<1xf8E8M0FNU> + %194 = vector.bitcast %101 : vector<1xi8> to vector<1xf8E8M0FNU> + %195 = vector.bitcast %103 : vector<1xi8> to vector<1xf8E8M0FNU> + %196 = vector.bitcast %104 : vector<1xi8> to vector<1xf8E8M0FNU> + %197 = vector.bitcast %106 : vector<1xi8> to vector<1xf8E8M0FNU> + %198 = vector.bitcast %107 : vector<1xi8> to vector<1xf8E8M0FNU> + %199 = vector.bitcast %109 : vector<1xi8> to vector<1xf8E8M0FNU> + %200 = vector.bitcast %110 : vector<1xi8> to vector<1xf8E8M0FNU> + %201 = vector.extract %161[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %202 = vector.extract %185[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %203 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%202[0] * %169) + %83#0 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %204 = vector.extract %162[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %205 = vector.extract %186[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %206 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%205[0] * %170) + %203 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %207 = vector.extract %187[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %208 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%207[0] * %171) + %83#1 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %209 = vector.extract %188[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %210 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%209[0] * %172) + %208 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %211 = vector.extract %189[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %212 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%211[0] * %173) + %83#2 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %213 = vector.extract %190[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %214 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%213[0] * %174) + %212 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %215 = vector.extract %191[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %216 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%215[0] * %175) + %83#3 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %217 = vector.extract %192[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %218 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%217[0] * %176) + %216 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %219 = vector.extract %193[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %220 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%219[0] * %177) + %83#4 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %221 = vector.extract %194[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %222 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%221[0] * %178) + %220 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %223 = vector.extract %195[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %224 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%223[0] * %179) + %83#5 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %225 = vector.extract %196[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %226 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%225[0] * %180) + %224 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %227 = vector.extract %197[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %228 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%227[0] * %181) + %83#6 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %229 = vector.extract %198[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %230 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%229[0] * %182) + %228 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %231 = vector.extract %199[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %232 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%231[0] * %183) + %83#7 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %233 = vector.extract %200[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %234 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%233[0] * %184) + %232 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %235 = vector.extract %163[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %236 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%202[0] * %169) + %83#8 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %237 = vector.extract %164[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %238 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%205[0] * %170) + %236 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %239 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%207[0] * %171) + %83#9 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %240 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%209[0] * %172) + %239 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %241 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%211[0] * %173) + %83#10 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %242 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%213[0] * %174) + %241 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %243 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%215[0] * %175) + %83#11 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %244 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%217[0] * %176) + %243 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %245 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%219[0] * %177) + %83#12 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %246 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%221[0] * %178) + %245 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %247 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%223[0] * %179) + %83#13 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %248 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%225[0] * %180) + %247 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %249 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%227[0] * %181) + %83#14 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %250 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%229[0] * %182) + %249 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %251 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%231[0] * %183) + %83#15 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %252 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%233[0] * %184) + %251 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %253 = vector.extract %165[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %254 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%202[0] * %169) + %83#16 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %255 = vector.extract %166[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %256 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%205[0] * %170) + %254 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %257 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%207[0] * %171) + %83#17 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %258 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%209[0] * %172) + %257 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %259 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%211[0] * %173) + %83#18 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %260 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%213[0] * %174) + %259 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %261 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%215[0] * %175) + %83#19 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %262 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%217[0] * %176) + %261 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %263 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%219[0] * %177) + %83#20 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %264 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%221[0] * %178) + %263 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %265 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%223[0] * %179) + %83#21 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %266 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%225[0] * %180) + %265 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %267 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%227[0] * %181) + %83#22 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %268 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%229[0] * %182) + %267 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %269 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%231[0] * %183) + %83#23 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %270 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%233[0] * %184) + %269 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %271 = vector.extract %167[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %272 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%202[0] * %169) + %83#24 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %273 = vector.extract %168[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %274 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%205[0] * %170) + %272 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %275 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%207[0] * %171) + %83#25 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %276 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%209[0] * %172) + %275 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %277 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%211[0] * %173) + %83#26 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %278 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%213[0] * %174) + %277 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %279 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%215[0] * %175) + %83#27 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %280 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%217[0] * %176) + %279 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %281 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%219[0] * %177) + %83#28 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %282 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%221[0] * %178) + %281 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %283 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%223[0] * %179) + %83#29 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %284 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%225[0] * %180) + %283 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %285 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%227[0] * %181) + %83#30 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %286 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%229[0] * %182) + %285 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %287 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%231[0] * %183) + %83#31 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %288 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%233[0] * %184) + %287 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %289 = vector.extract_strided_slice %206 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %290 = affine.apply #map35()[%block_id_x] + %291 = affine.apply #map35()[%block_id_y] + %292 = affine.apply #map36()[%thread_id_x] + %293 = arith.muli %290, %c57344 overflow : index + %294 = arith.muli %292, %c57344 overflow : index + %295 = arith.addi %293, %291 overflow : index + %296 = arith.addi %294, %84 overflow : index + %reinterpret_cast_13 = memref.reinterpret_cast %4 to offset: [%295], sizes: [536870910], strides: [1] : memref to memref<536870910xf32, strided<[1], offset: ?>> + %cast_14 = memref.cast %reinterpret_cast_13 : memref<536870910xf32, strided<[1], offset: ?>> to memref> + %297 = amdgpu.fat_raw_buffer_cast %cast_14 validBytes(%c2147483643_i64) resetOffset : memref> to memref> + vector.store %289, %297[%296] : memref>, vector<1xf32> + %298 = vector.extract_strided_slice %206 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %299 = affine.apply #map37()[%thread_id_x] + %300 = arith.muli %299, %c57344 overflow : index + %301 = arith.addi %300, %84 overflow : index + vector.store %298, %297[%301] : memref>, vector<1xf32> + %302 = vector.extract_strided_slice %206 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %303 = affine.apply #map38()[%thread_id_x] + %304 = arith.muli %303, %c57344 overflow : index + %305 = arith.addi %304, %84 overflow : index + vector.store %302, %297[%305] : memref>, vector<1xf32> + %306 = vector.extract_strided_slice %206 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %307 = affine.apply #map39()[%thread_id_x] + %308 = arith.muli %307, %c57344 overflow : index + %309 = arith.addi %308, %84 overflow : index + vector.store %306, %297[%309] : memref>, vector<1xf32> + %310 = vector.extract_strided_slice %210 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %311 = arith.addi %294, %90 overflow : index + vector.store %310, %297[%311] : memref>, vector<1xf32> + %312 = vector.extract_strided_slice %210 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %313 = arith.addi %300, %90 overflow : index + vector.store %312, %297[%313] : memref>, vector<1xf32> + %314 = vector.extract_strided_slice %210 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %315 = arith.addi %304, %90 overflow : index + vector.store %314, %297[%315] : memref>, vector<1xf32> + %316 = vector.extract_strided_slice %210 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %317 = arith.addi %308, %90 overflow : index + vector.store %316, %297[%317] : memref>, vector<1xf32> + %318 = vector.extract_strided_slice %214 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %319 = arith.addi %294, %93 overflow : index + vector.store %318, %297[%319] : memref>, vector<1xf32> + %320 = vector.extract_strided_slice %214 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %321 = arith.addi %300, %93 overflow : index + vector.store %320, %297[%321] : memref>, vector<1xf32> + %322 = vector.extract_strided_slice %214 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %323 = arith.addi %304, %93 overflow : index + vector.store %322, %297[%323] : memref>, vector<1xf32> + %324 = vector.extract_strided_slice %214 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %325 = arith.addi %308, %93 overflow : index + vector.store %324, %297[%325] : memref>, vector<1xf32> + %326 = vector.extract_strided_slice %218 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %327 = arith.addi %294, %96 overflow : index + vector.store %326, %297[%327] : memref>, vector<1xf32> + %328 = vector.extract_strided_slice %218 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %329 = arith.addi %300, %96 overflow : index + vector.store %328, %297[%329] : memref>, vector<1xf32> + %330 = vector.extract_strided_slice %218 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %331 = arith.addi %304, %96 overflow : index + vector.store %330, %297[%331] : memref>, vector<1xf32> + %332 = vector.extract_strided_slice %218 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %333 = arith.addi %308, %96 overflow : index + vector.store %332, %297[%333] : memref>, vector<1xf32> + %334 = vector.extract_strided_slice %222 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %335 = arith.addi %294, %99 overflow : index + vector.store %334, %297[%335] : memref>, vector<1xf32> + %336 = vector.extract_strided_slice %222 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %337 = arith.addi %300, %99 overflow : index + vector.store %336, %297[%337] : memref>, vector<1xf32> + %338 = vector.extract_strided_slice %222 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %339 = arith.addi %304, %99 overflow : index + vector.store %338, %297[%339] : memref>, vector<1xf32> + %340 = vector.extract_strided_slice %222 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %341 = arith.addi %308, %99 overflow : index + vector.store %340, %297[%341] : memref>, vector<1xf32> + %342 = vector.extract_strided_slice %226 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %343 = arith.addi %294, %102 overflow : index + vector.store %342, %297[%343] : memref>, vector<1xf32> + %344 = vector.extract_strided_slice %226 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %345 = arith.addi %300, %102 overflow : index + vector.store %344, %297[%345] : memref>, vector<1xf32> + %346 = vector.extract_strided_slice %226 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %347 = arith.addi %304, %102 overflow : index + vector.store %346, %297[%347] : memref>, vector<1xf32> + %348 = vector.extract_strided_slice %226 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %349 = arith.addi %308, %102 overflow : index + vector.store %348, %297[%349] : memref>, vector<1xf32> + %350 = vector.extract_strided_slice %230 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %351 = arith.addi %294, %105 overflow : index + vector.store %350, %297[%351] : memref>, vector<1xf32> + %352 = vector.extract_strided_slice %230 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %353 = arith.addi %300, %105 overflow : index + vector.store %352, %297[%353] : memref>, vector<1xf32> + %354 = vector.extract_strided_slice %230 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %355 = arith.addi %304, %105 overflow : index + vector.store %354, %297[%355] : memref>, vector<1xf32> + %356 = vector.extract_strided_slice %230 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %357 = arith.addi %308, %105 overflow : index + vector.store %356, %297[%357] : memref>, vector<1xf32> + %358 = vector.extract_strided_slice %234 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %359 = arith.addi %294, %108 overflow : index + vector.store %358, %297[%359] : memref>, vector<1xf32> + %360 = vector.extract_strided_slice %234 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %361 = arith.addi %300, %108 overflow : index + vector.store %360, %297[%361] : memref>, vector<1xf32> + %362 = vector.extract_strided_slice %234 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %363 = arith.addi %304, %108 overflow : index + vector.store %362, %297[%363] : memref>, vector<1xf32> + %364 = vector.extract_strided_slice %234 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %365 = arith.addi %308, %108 overflow : index + vector.store %364, %297[%365] : memref>, vector<1xf32> + %366 = vector.extract_strided_slice %238 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %367 = affine.apply #map40()[%thread_id_x] + %368 = arith.muli %367, %c57344 overflow : index + %369 = arith.addi %368, %84 overflow : index + vector.store %366, %297[%369] : memref>, vector<1xf32> + %370 = vector.extract_strided_slice %238 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %371 = affine.apply #map41()[%thread_id_x] + %372 = arith.muli %371, %c57344 overflow : index + %373 = arith.addi %372, %84 overflow : index + vector.store %370, %297[%373] : memref>, vector<1xf32> + %374 = vector.extract_strided_slice %238 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %375 = affine.apply #map42()[%thread_id_x] + %376 = arith.muli %375, %c57344 overflow : index + %377 = arith.addi %376, %84 overflow : index + vector.store %374, %297[%377] : memref>, vector<1xf32> + %378 = vector.extract_strided_slice %238 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %379 = affine.apply #map43()[%thread_id_x] + %380 = arith.muli %379, %c57344 overflow : index + %381 = arith.addi %380, %84 overflow : index + vector.store %378, %297[%381] : memref>, vector<1xf32> + %382 = vector.extract_strided_slice %240 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %383 = arith.addi %368, %90 overflow : index + vector.store %382, %297[%383] : memref>, vector<1xf32> + %384 = vector.extract_strided_slice %240 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %385 = arith.addi %372, %90 overflow : index + vector.store %384, %297[%385] : memref>, vector<1xf32> + %386 = vector.extract_strided_slice %240 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %387 = arith.addi %376, %90 overflow : index + vector.store %386, %297[%387] : memref>, vector<1xf32> + %388 = vector.extract_strided_slice %240 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %389 = arith.addi %380, %90 overflow : index + vector.store %388, %297[%389] : memref>, vector<1xf32> + %390 = vector.extract_strided_slice %242 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %391 = arith.addi %368, %93 overflow : index + vector.store %390, %297[%391] : memref>, vector<1xf32> + %392 = vector.extract_strided_slice %242 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %393 = arith.addi %372, %93 overflow : index + vector.store %392, %297[%393] : memref>, vector<1xf32> + %394 = vector.extract_strided_slice %242 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %395 = arith.addi %376, %93 overflow : index + vector.store %394, %297[%395] : memref>, vector<1xf32> + %396 = vector.extract_strided_slice %242 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %397 = arith.addi %380, %93 overflow : index + vector.store %396, %297[%397] : memref>, vector<1xf32> + %398 = vector.extract_strided_slice %244 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %399 = arith.addi %368, %96 overflow : index + vector.store %398, %297[%399] : memref>, vector<1xf32> + %400 = vector.extract_strided_slice %244 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %401 = arith.addi %372, %96 overflow : index + vector.store %400, %297[%401] : memref>, vector<1xf32> + %402 = vector.extract_strided_slice %244 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %403 = arith.addi %376, %96 overflow : index + vector.store %402, %297[%403] : memref>, vector<1xf32> + %404 = vector.extract_strided_slice %244 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %405 = arith.addi %380, %96 overflow : index + vector.store %404, %297[%405] : memref>, vector<1xf32> + %406 = vector.extract_strided_slice %246 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %407 = arith.addi %368, %99 overflow : index + vector.store %406, %297[%407] : memref>, vector<1xf32> + %408 = vector.extract_strided_slice %246 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %409 = arith.addi %372, %99 overflow : index + vector.store %408, %297[%409] : memref>, vector<1xf32> + %410 = vector.extract_strided_slice %246 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %411 = arith.addi %376, %99 overflow : index + vector.store %410, %297[%411] : memref>, vector<1xf32> + %412 = vector.extract_strided_slice %246 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %413 = arith.addi %380, %99 overflow : index + vector.store %412, %297[%413] : memref>, vector<1xf32> + %414 = vector.extract_strided_slice %248 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %415 = arith.addi %368, %102 overflow : index + vector.store %414, %297[%415] : memref>, vector<1xf32> + %416 = vector.extract_strided_slice %248 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %417 = arith.addi %372, %102 overflow : index + vector.store %416, %297[%417] : memref>, vector<1xf32> + %418 = vector.extract_strided_slice %248 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %419 = arith.addi %376, %102 overflow : index + vector.store %418, %297[%419] : memref>, vector<1xf32> + %420 = vector.extract_strided_slice %248 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %421 = arith.addi %380, %102 overflow : index + vector.store %420, %297[%421] : memref>, vector<1xf32> + %422 = vector.extract_strided_slice %250 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %423 = arith.addi %368, %105 overflow : index + vector.store %422, %297[%423] : memref>, vector<1xf32> + %424 = vector.extract_strided_slice %250 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %425 = arith.addi %372, %105 overflow : index + vector.store %424, %297[%425] : memref>, vector<1xf32> + %426 = vector.extract_strided_slice %250 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %427 = arith.addi %376, %105 overflow : index + vector.store %426, %297[%427] : memref>, vector<1xf32> + %428 = vector.extract_strided_slice %250 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %429 = arith.addi %380, %105 overflow : index + vector.store %428, %297[%429] : memref>, vector<1xf32> + %430 = vector.extract_strided_slice %252 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %431 = arith.addi %368, %108 overflow : index + vector.store %430, %297[%431] : memref>, vector<1xf32> + %432 = vector.extract_strided_slice %252 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %433 = arith.addi %372, %108 overflow : index + vector.store %432, %297[%433] : memref>, vector<1xf32> + %434 = vector.extract_strided_slice %252 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %435 = arith.addi %376, %108 overflow : index + vector.store %434, %297[%435] : memref>, vector<1xf32> + %436 = vector.extract_strided_slice %252 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %437 = arith.addi %380, %108 overflow : index + vector.store %436, %297[%437] : memref>, vector<1xf32> + %438 = vector.extract_strided_slice %256 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %439 = affine.apply #map44()[%thread_id_x] + %440 = arith.muli %439, %c57344 overflow : index + %441 = arith.addi %440, %84 overflow : index + vector.store %438, %297[%441] : memref>, vector<1xf32> + %442 = vector.extract_strided_slice %256 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %443 = affine.apply #map45()[%thread_id_x] + %444 = arith.muli %443, %c57344 overflow : index + %445 = arith.addi %444, %84 overflow : index + vector.store %442, %297[%445] : memref>, vector<1xf32> + %446 = vector.extract_strided_slice %256 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %447 = affine.apply #map46()[%thread_id_x] + %448 = arith.muli %447, %c57344 overflow : index + %449 = arith.addi %448, %84 overflow : index + vector.store %446, %297[%449] : memref>, vector<1xf32> + %450 = vector.extract_strided_slice %256 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %451 = affine.apply #map47()[%thread_id_x] + %452 = arith.muli %451, %c57344 overflow : index + %453 = arith.addi %452, %84 overflow : index + vector.store %450, %297[%453] : memref>, vector<1xf32> + %454 = vector.extract_strided_slice %258 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %455 = arith.addi %440, %90 overflow : index + vector.store %454, %297[%455] : memref>, vector<1xf32> + %456 = vector.extract_strided_slice %258 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %457 = arith.addi %444, %90 overflow : index + vector.store %456, %297[%457] : memref>, vector<1xf32> + %458 = vector.extract_strided_slice %258 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %459 = arith.addi %448, %90 overflow : index + vector.store %458, %297[%459] : memref>, vector<1xf32> + %460 = vector.extract_strided_slice %258 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %461 = arith.addi %452, %90 overflow : index + vector.store %460, %297[%461] : memref>, vector<1xf32> + %462 = vector.extract_strided_slice %260 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %463 = arith.addi %440, %93 overflow : index + vector.store %462, %297[%463] : memref>, vector<1xf32> + %464 = vector.extract_strided_slice %260 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %465 = arith.addi %444, %93 overflow : index + vector.store %464, %297[%465] : memref>, vector<1xf32> + %466 = vector.extract_strided_slice %260 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %467 = arith.addi %448, %93 overflow : index + vector.store %466, %297[%467] : memref>, vector<1xf32> + %468 = vector.extract_strided_slice %260 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %469 = arith.addi %452, %93 overflow : index + vector.store %468, %297[%469] : memref>, vector<1xf32> + %470 = vector.extract_strided_slice %262 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %471 = arith.addi %440, %96 overflow : index + vector.store %470, %297[%471] : memref>, vector<1xf32> + %472 = vector.extract_strided_slice %262 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %473 = arith.addi %444, %96 overflow : index + vector.store %472, %297[%473] : memref>, vector<1xf32> + %474 = vector.extract_strided_slice %262 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %475 = arith.addi %448, %96 overflow : index + vector.store %474, %297[%475] : memref>, vector<1xf32> + %476 = vector.extract_strided_slice %262 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %477 = arith.addi %452, %96 overflow : index + vector.store %476, %297[%477] : memref>, vector<1xf32> + %478 = vector.extract_strided_slice %264 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %479 = arith.addi %440, %99 overflow : index + vector.store %478, %297[%479] : memref>, vector<1xf32> + %480 = vector.extract_strided_slice %264 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %481 = arith.addi %444, %99 overflow : index + vector.store %480, %297[%481] : memref>, vector<1xf32> + %482 = vector.extract_strided_slice %264 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %483 = arith.addi %448, %99 overflow : index + vector.store %482, %297[%483] : memref>, vector<1xf32> + %484 = vector.extract_strided_slice %264 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %485 = arith.addi %452, %99 overflow : index + vector.store %484, %297[%485] : memref>, vector<1xf32> + %486 = vector.extract_strided_slice %266 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %487 = arith.addi %440, %102 overflow : index + vector.store %486, %297[%487] : memref>, vector<1xf32> + %488 = vector.extract_strided_slice %266 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %489 = arith.addi %444, %102 overflow : index + vector.store %488, %297[%489] : memref>, vector<1xf32> + %490 = vector.extract_strided_slice %266 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %491 = arith.addi %448, %102 overflow : index + vector.store %490, %297[%491] : memref>, vector<1xf32> + %492 = vector.extract_strided_slice %266 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %493 = arith.addi %452, %102 overflow : index + vector.store %492, %297[%493] : memref>, vector<1xf32> + %494 = vector.extract_strided_slice %268 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %495 = arith.addi %440, %105 overflow : index + vector.store %494, %297[%495] : memref>, vector<1xf32> + %496 = vector.extract_strided_slice %268 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %497 = arith.addi %444, %105 overflow : index + vector.store %496, %297[%497] : memref>, vector<1xf32> + %498 = vector.extract_strided_slice %268 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %499 = arith.addi %448, %105 overflow : index + vector.store %498, %297[%499] : memref>, vector<1xf32> + %500 = vector.extract_strided_slice %268 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %501 = arith.addi %452, %105 overflow : index + vector.store %500, %297[%501] : memref>, vector<1xf32> + %502 = vector.extract_strided_slice %270 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %503 = arith.addi %440, %108 overflow : index + vector.store %502, %297[%503] : memref>, vector<1xf32> + %504 = vector.extract_strided_slice %270 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %505 = arith.addi %444, %108 overflow : index + vector.store %504, %297[%505] : memref>, vector<1xf32> + %506 = vector.extract_strided_slice %270 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %507 = arith.addi %448, %108 overflow : index + vector.store %506, %297[%507] : memref>, vector<1xf32> + %508 = vector.extract_strided_slice %270 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %509 = arith.addi %452, %108 overflow : index + vector.store %508, %297[%509] : memref>, vector<1xf32> + %510 = vector.extract_strided_slice %274 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %511 = affine.apply #map48()[%thread_id_x] + %512 = arith.muli %511, %c57344 overflow : index + %513 = arith.addi %512, %84 overflow : index + vector.store %510, %297[%513] : memref>, vector<1xf32> + %514 = vector.extract_strided_slice %274 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %515 = affine.apply #map49()[%thread_id_x] + %516 = arith.muli %515, %c57344 overflow : index + %517 = arith.addi %516, %84 overflow : index + vector.store %514, %297[%517] : memref>, vector<1xf32> + %518 = vector.extract_strided_slice %274 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %519 = affine.apply #map50()[%thread_id_x] + %520 = arith.muli %519, %c57344 overflow : index + %521 = arith.addi %520, %84 overflow : index + vector.store %518, %297[%521] : memref>, vector<1xf32> + %522 = vector.extract_strided_slice %274 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %523 = affine.apply #map51()[%thread_id_x] + %524 = arith.muli %523, %c57344 overflow : index + %525 = arith.addi %524, %84 overflow : index + vector.store %522, %297[%525] : memref>, vector<1xf32> + %526 = vector.extract_strided_slice %276 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %527 = arith.addi %512, %90 overflow : index + vector.store %526, %297[%527] : memref>, vector<1xf32> + %528 = vector.extract_strided_slice %276 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %529 = arith.addi %516, %90 overflow : index + vector.store %528, %297[%529] : memref>, vector<1xf32> + %530 = vector.extract_strided_slice %276 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %531 = arith.addi %520, %90 overflow : index + vector.store %530, %297[%531] : memref>, vector<1xf32> + %532 = vector.extract_strided_slice %276 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %533 = arith.addi %524, %90 overflow : index + vector.store %532, %297[%533] : memref>, vector<1xf32> + %534 = vector.extract_strided_slice %278 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %535 = arith.addi %512, %93 overflow : index + vector.store %534, %297[%535] : memref>, vector<1xf32> + %536 = vector.extract_strided_slice %278 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %537 = arith.addi %516, %93 overflow : index + vector.store %536, %297[%537] : memref>, vector<1xf32> + %538 = vector.extract_strided_slice %278 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %539 = arith.addi %520, %93 overflow : index + vector.store %538, %297[%539] : memref>, vector<1xf32> + %540 = vector.extract_strided_slice %278 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %541 = arith.addi %524, %93 overflow : index + vector.store %540, %297[%541] : memref>, vector<1xf32> + %542 = vector.extract_strided_slice %280 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %543 = arith.addi %512, %96 overflow : index + vector.store %542, %297[%543] : memref>, vector<1xf32> + %544 = vector.extract_strided_slice %280 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %545 = arith.addi %516, %96 overflow : index + vector.store %544, %297[%545] : memref>, vector<1xf32> + %546 = vector.extract_strided_slice %280 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %547 = arith.addi %520, %96 overflow : index + vector.store %546, %297[%547] : memref>, vector<1xf32> + %548 = vector.extract_strided_slice %280 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %549 = arith.addi %524, %96 overflow : index + vector.store %548, %297[%549] : memref>, vector<1xf32> + %550 = vector.extract_strided_slice %282 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %551 = arith.addi %512, %99 overflow : index + vector.store %550, %297[%551] : memref>, vector<1xf32> + %552 = vector.extract_strided_slice %282 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %553 = arith.addi %516, %99 overflow : index + vector.store %552, %297[%553] : memref>, vector<1xf32> + %554 = vector.extract_strided_slice %282 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %555 = arith.addi %520, %99 overflow : index + vector.store %554, %297[%555] : memref>, vector<1xf32> + %556 = vector.extract_strided_slice %282 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %557 = arith.addi %524, %99 overflow : index + vector.store %556, %297[%557] : memref>, vector<1xf32> + %558 = vector.extract_strided_slice %284 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %559 = arith.addi %512, %102 overflow : index + vector.store %558, %297[%559] : memref>, vector<1xf32> + %560 = vector.extract_strided_slice %284 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %561 = arith.addi %516, %102 overflow : index + vector.store %560, %297[%561] : memref>, vector<1xf32> + %562 = vector.extract_strided_slice %284 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %563 = arith.addi %520, %102 overflow : index + vector.store %562, %297[%563] : memref>, vector<1xf32> + %564 = vector.extract_strided_slice %284 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %565 = arith.addi %524, %102 overflow : index + vector.store %564, %297[%565] : memref>, vector<1xf32> + %566 = vector.extract_strided_slice %286 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %567 = arith.addi %512, %105 overflow : index + vector.store %566, %297[%567] : memref>, vector<1xf32> + %568 = vector.extract_strided_slice %286 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %569 = arith.addi %516, %105 overflow : index + vector.store %568, %297[%569] : memref>, vector<1xf32> + %570 = vector.extract_strided_slice %286 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %571 = arith.addi %520, %105 overflow : index + vector.store %570, %297[%571] : memref>, vector<1xf32> + %572 = vector.extract_strided_slice %286 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %573 = arith.addi %524, %105 overflow : index + vector.store %572, %297[%573] : memref>, vector<1xf32> + %574 = vector.extract_strided_slice %288 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %575 = arith.addi %512, %108 overflow : index + vector.store %574, %297[%575] : memref>, vector<1xf32> + %576 = vector.extract_strided_slice %288 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %577 = arith.addi %516, %108 overflow : index + vector.store %576, %297[%577] : memref>, vector<1xf32> + %578 = vector.extract_strided_slice %288 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %579 = arith.addi %520, %108 overflow : index + vector.store %578, %297[%579] : memref>, vector<1xf32> + %580 = vector.extract_strided_slice %288 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %581 = arith.addi %524, %108 overflow : index + vector.store %580, %297[%581] : memref>, vector<1xf32> + return + } + } + } + func.func @isolated_benchmark$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view, %arg5: !hal.fence, %arg6: !hal.fence) -> !hal.buffer_view { + %0 = hal.tensor.import wait(%arg5) => %arg0 : !hal.buffer_view -> tensor<4096x8192xi8> + %1 = hal.tensor.import wait(%arg5) => %arg1 : !hal.buffer_view -> tensor<4096x512xi8> + %2 = hal.tensor.import wait(%arg5) => %arg2 : !hal.buffer_view -> tensor<57344x8192xi8> + %3 = hal.tensor.import wait(%arg5) => %arg3 : !hal.buffer_view -> tensor<57344x512xi8> + %4 = hal.tensor.import wait(%arg5) => %arg4 : !hal.buffer_view -> tensor<4096x57344xf32> + %5 = flow.dispatch @gemm::@gemm(%0, %1, %2, %3, %4) : (tensor<4096x8192xi8>, tensor<4096x512xi8>, tensor<57344x8192xi8>, tensor<57344x512xi8>, tensor<4096x57344xf32>) -> %4 + %6 = hal.tensor.barrier join(%5 : tensor<4096x57344xf32>) => %arg6 : !hal.fence + %7 = hal.tensor.export %6 : tensor<4096x57344xf32> -> !hal.buffer_view + return %7 : !hal.buffer_view + } + } + """ + # split into depdendent and independent loads + mlir_pingpong_mixed = """ + #map = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8) floordiv 256) * 256)> + #map1 = affine_map<()[s0] -> ((s0 floordiv 8) mod 8)> + #map2 = affine_map<()[s0] -> (s0 mod 8)> + #map3 = affine_map<()[s0] -> (s0 * 16)> + #map4 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64) floordiv 32) * 256)> + #map5 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 64) floordiv 256) * 256 + 64)> + #map6 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 8) floordiv 32) * 256 + 64)> + #map7 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 128) floordiv 256) * 256 + 128)> + #map8 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 16) floordiv 32) * 256 + 128)> + #map9 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 192) floordiv 256) * 256 + 192)> + #map10 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 24) floordiv 32) * 256 + 192)> + #map11 = affine_map<()[s0, s1, s2] -> (s1 * 128 + s2 * 256 + s0 floordiv 2 - ((s1 * 128 + s0 floordiv 2) floordiv 256) * 256)> + #map12 = affine_map<()[s0] -> ((s0 floordiv 2) mod 2)> + #map13 = affine_map<()[s0] -> (s0 mod 2)> + #map14 = affine_map<()[s0] -> (s0 * 4)> + #map15 = affine_map<()[s0, s1] -> (s1 * 128 + (s0 floordiv 64) * 32 - ((s1 * 4 + s0 floordiv 64) floordiv 8) * 256)> + #map16 = affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 64)> + #map17 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64)> + #map18 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> + #map19 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 16)> + #map20 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 32)> + #map21 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 48)> + #map22 = affine_map<()[s0] -> (s0 * 4 + (s0 mod 64) floordiv 16 - (s0 floordiv 2) * 8)> + #map23 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16)> + #map24 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 16)> + #map25 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 32)> + #map26 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 48)> + #map27 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 64)> + #map28 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 80)> + #map29 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 96)> + #map30 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 112)> + #map31 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16 + 4)> + #map32 = affine_map<()[s0, s1] -> (s1 * 4 + (s0 mod 64) floordiv 16)> + #map33 = affine_map<()[s0, s1] -> (s0 * 128 + s1 * 16 + 128)> + #map34 = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 4 + 8)> + #map35 = affine_map<()[s0] -> (s0 * 256)> + #map36 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4)> + #map37 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 1)> + #map38 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 2)> + #map39 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 3)> + #map40 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 16)> + #map41 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 17)> + #map42 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 18)> + #map43 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 19)> + #map44 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 32)> + #map45 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 33)> + #map46 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 34)> + #map47 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 35)> + #map48 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 48)> + #map49 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 49)> + #map50 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 50)> + #map51 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 51)> + #translation = #iree_codegen.translation_info + module attributes {transform.with_named_sequence} { + stream.executable private @gemm { + stream.executable.export public @gemm workgroups() -> (index, index, index) { + %c16 = arith.constant 16 : index + %c224 = arith.constant 224 : index + %c1 = arith.constant 1 : index + stream.return %c16, %c224, %c1 : index, index, index + } + builtin.module { + func.func @gemm(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding, %arg3: !stream.binding, %arg4: !stream.binding) attributes {translation_info = #translation} { + %c4_i32 = arith.constant 4 : i32 + %c512_i14 = arith.constant 512 : i14 + %c-8192_i14 = arith.constant -8192 : i14 + %c2147483643_i64 = arith.constant 2147483643 : i64 + %c57344 = arith.constant 57344 : index + %c63 = arith.constant 63 : index + %c512 = arith.constant 512 : index + %c2147483646_i64 = arith.constant 2147483646 : i64 + %c8192 = arith.constant 8192 : index + %c1 = arith.constant 1 : index + %cst = arith.constant dense<0.000000e+00> : vector<4xf32> + %c0 = arith.constant 0 : index + %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> memref + %1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> memref + %2 = stream.binding.subspan %arg2[%c0] : !stream.binding -> memref + %3 = stream.binding.subspan %arg3[%c0] : !stream.binding -> memref + %4 = stream.binding.subspan %arg4[%c0] : !stream.binding -> memref + %block_id_x = gpu.block_id x upper_bound 16 + %block_id_y = gpu.block_id y upper_bound 224 + %thread_id_x = gpu.thread_id x upper_bound 256 + %thread_id_y = gpu.thread_id y upper_bound 2 + %alloc = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_0 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_1 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_2 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_3 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_4 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_5 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_6 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %5 = affine.apply #map()[%thread_id_x, %thread_id_y, %block_id_x] + %6 = affine.apply #map1()[%thread_id_x] + %7 = affine.apply #map2()[%thread_id_x] + %8 = arith.xori %7, %6 : index + %9 = affine.apply #map3()[%8] + %10 = affine.apply #map4()[%thread_id_x, %thread_id_y] + %11 = gpu.subgroup_broadcast %10, first_active_lane : index + %12 = gpu.subgroup_broadcast %c0, first_active_lane : index + %13 = arith.muli %5, %c8192 overflow : index + %14 = arith.addi %13, %9 overflow : index + %reinterpret_cast = memref.reinterpret_cast %0 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast = memref.cast %reinterpret_cast : memref<2147483646xi8, strided<[1]>> to memref> + %15 = amdgpu.fat_raw_buffer_cast %cast validBytes(%c2147483646_i64) cacheSwizzleStride(%c-8192_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %15[%14], %alloc_6[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %16 = affine.apply #map5()[%thread_id_x, %thread_id_y, %block_id_x] + %17 = affine.apply #map6()[%thread_id_x, %thread_id_y] + %18 = gpu.subgroup_broadcast %17, first_active_lane : index + %19 = arith.muli %16, %c8192 overflow : index + %20 = arith.addi %19, %9 overflow : index + amdgpu.gather_to_lds %15[%20], %alloc_6[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %21 = affine.apply #map7()[%thread_id_x, %thread_id_y, %block_id_x] + %22 = affine.apply #map8()[%thread_id_x, %thread_id_y] + %23 = gpu.subgroup_broadcast %22, first_active_lane : index + %24 = arith.muli %21, %c8192 overflow : index + %25 = arith.addi %24, %9 overflow : index + amdgpu.gather_to_lds %15[%25], %alloc_6[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %26 = affine.apply #map9()[%thread_id_x, %thread_id_y, %block_id_x] + %27 = affine.apply #map10()[%thread_id_x, %thread_id_y] + %28 = gpu.subgroup_broadcast %27, first_active_lane : index + %29 = arith.muli %26, %c8192 overflow : index + %30 = arith.addi %29, %9 overflow : index + amdgpu.gather_to_lds %15[%30], %alloc_6[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %31 = affine.apply #map11()[%thread_id_x, %thread_id_y, %block_id_x] + %32 = affine.apply #map12()[%thread_id_x] + %33 = affine.apply #map13()[%thread_id_x] + %34 = arith.xori %33, %32 : index + %35 = affine.apply #map14()[%34] + %36 = affine.apply #map15()[%thread_id_x, %thread_id_y] + %37 = gpu.subgroup_broadcast %36, first_active_lane : index + %38 = arith.muli %31, %c512 overflow : index + %39 = arith.addi %38, %35 overflow : index + %reinterpret_cast_7 = memref.reinterpret_cast %1 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_8 = memref.cast %reinterpret_cast_7 : memref<2147483646xi8, strided<[1]>> to memref> + %40 = amdgpu.fat_raw_buffer_cast %cast_8 validBytes(%c2147483646_i64) cacheSwizzleStride(%c512_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %40[%39], %alloc_4[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + %41 = affine.apply #map()[%thread_id_x, %thread_id_y, %block_id_y] + %42 = arith.muli %41, %c8192 overflow : index + %43 = arith.addi %42, %9 overflow : index + %reinterpret_cast_9 = memref.reinterpret_cast %2 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_10 = memref.cast %reinterpret_cast_9 : memref<2147483646xi8, strided<[1]>> to memref> + %44 = amdgpu.fat_raw_buffer_cast %cast_10 validBytes(%c2147483646_i64) cacheSwizzleStride(%c-8192_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %44[%43], %alloc_2[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %45 = affine.apply #map5()[%thread_id_x, %thread_id_y, %block_id_y] + %46 = arith.muli %45, %c8192 overflow : index + %47 = arith.addi %46, %9 overflow : index + amdgpu.gather_to_lds %44[%47], %alloc_2[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %48 = affine.apply #map7()[%thread_id_x, %thread_id_y, %block_id_y] + %49 = arith.muli %48, %c8192 overflow : index + %50 = arith.addi %49, %9 overflow : index + amdgpu.gather_to_lds %44[%50], %alloc_2[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %51 = affine.apply #map9()[%thread_id_x, %thread_id_y, %block_id_y] + %52 = arith.muli %51, %c8192 overflow : index + %53 = arith.addi %52, %9 overflow : index + amdgpu.gather_to_lds %44[%53], %alloc_2[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %54 = affine.apply #map11()[%thread_id_x, %thread_id_y, %block_id_y] + %55 = arith.muli %54, %c512 overflow : index + %56 = arith.addi %55, %35 overflow : index + %reinterpret_cast_11 = memref.reinterpret_cast %3 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_12 = memref.cast %reinterpret_cast_11 : memref<2147483646xi8, strided<[1]>> to memref> + %57 = amdgpu.fat_raw_buffer_cast %cast_12 validBytes(%c2147483646_i64) cacheSwizzleStride(%c512_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %57[%56], %alloc_0[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + rocdl.s.barrier + %58 = affine.apply #map16()[%thread_id_x, %thread_id_y] + %59 = arith.index_cast %58 : index to i32 + %60 = arith.cmpi sge, %59, %c4_i32 : i32 + %61 = arith.cmpi slt, %59, %c4_i32 : i32 + scf.if %60 { + rocdl.s.barrier + } + %62 = affine.apply #map17()[%thread_id_x] + %63 = affine.apply #map18()[%thread_id_x] + %64 = arith.xori %63, %7 : index + %65 = affine.apply #map3()[%64] + %66 = affine.apply #map19()[%thread_id_x] + %67 = affine.apply #map20()[%thread_id_x] + %68 = affine.apply #map21()[%thread_id_x] + %69 = affine.apply #map22()[%thread_id_x] + %70 = affine.apply #map23()[%thread_id_x, %thread_id_y] + %71 = affine.apply #map24()[%thread_id_x, %thread_id_y] + %72 = affine.apply #map25()[%thread_id_x, %thread_id_y] + %73 = affine.apply #map26()[%thread_id_x, %thread_id_y] + %74 = affine.apply #map27()[%thread_id_x, %thread_id_y] + %75 = affine.apply #map28()[%thread_id_x, %thread_id_y] + %76 = affine.apply #map29()[%thread_id_x, %thread_id_y] + %77 = affine.apply #map30()[%thread_id_x, %thread_id_y] + %78 = affine.apply #map31()[%thread_id_x] + %79 = arith.xori %78, %7 : index + %80 = affine.apply #map3()[%79] + %81 = arith.xori %33, %c1 : index + %82 = affine.apply #map32()[%thread_id_x, %81] + %83:40 = scf.for %arg5 = %c0 to %c63 step %c1 iter_args(%arg6 = %cst, %arg7 = %cst, %arg8 = %cst, %arg9 = %cst, %arg10 = %cst, %arg11 = %cst, %arg12 = %cst, %arg13 = %cst, %arg14 = %cst, %arg15 = %cst, %arg16 = %cst, %arg17 = %cst, %arg18 = %cst, %arg19 = %cst, %arg20 = %cst, %arg21 = %cst, %arg22 = %cst, %arg23 = %cst, %arg24 = %cst, %arg25 = %cst, %arg26 = %cst, %arg27 = %cst, %arg28 = %cst, %arg29 = %cst, %arg30 = %cst, %arg31 = %cst, %arg32 = %cst, %arg33 = %cst, %arg34 = %cst, %arg35 = %cst, %arg36 = %cst, %arg37 = %cst, %arg38 = %alloc_6, %arg39 = %alloc_5, %arg40 = %alloc_4, %arg41 = %alloc_3, %arg42 = %alloc_2, %arg43 = %alloc_1, %arg44 = %alloc_0, %arg45 = %alloc) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>) { + rocdl.sched.barrier 0 + rocdl.s.barrier + %582 = affine.apply #map33()[%arg5, %8] + %583 = arith.addi %13, %582 overflow : index + amdgpu.gather_to_lds %15[%583], %arg39[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %584 = arith.addi %19, %582 overflow : index + amdgpu.gather_to_lds %15[%584], %arg39[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %585 = arith.addi %24, %582 overflow : index + amdgpu.gather_to_lds %15[%585], %arg39[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %586 = arith.addi %29, %582 overflow : index + amdgpu.gather_to_lds %15[%586], %arg39[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %587 = affine.apply #map34()[%arg5, %34] + %588 = arith.addi %38, %587 overflow : index + amdgpu.gather_to_lds %40[%588], %arg41[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + %589 = arith.addi %42, %582 overflow : index + amdgpu.gather_to_lds %44[%589], %arg43[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %590 = arith.addi %46, %582 overflow : index + amdgpu.gather_to_lds %44[%590], %arg43[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %591 = arith.addi %49, %582 overflow : index + amdgpu.gather_to_lds %44[%591], %arg43[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %592 = arith.addi %52, %582 overflow : index + amdgpu.gather_to_lds %44[%592], %arg43[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %593 = arith.addi %55, %587 overflow : index + amdgpu.gather_to_lds %57[%593], %arg45[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + rocdl.sched.barrier 0 + amdgpu.memory_counter_wait load(10) + // --- SAFE vector.loads: A(M0,M1), Ascale(M0,M1), B(N0,N1,N4,N5), Bscale(N0,N1,N4,N5) --- + %594 = vector.load %arg38[%62, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %595 = vector.load %arg38[%66, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %598 = vector.load %arg40[%62, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %599 = vector.load %arg40[%66, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %602 = vector.load %arg42[%70, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %603 = vector.load %arg42[%71, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %606 = vector.load %arg42[%74, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %607 = vector.load %arg42[%75, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %610 = vector.load %arg44[%70, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %611 = vector.load %arg44[%71, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %614 = vector.load %arg44[%74, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %615 = vector.load %arg44[%75, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + // --- SAFE bitcasts --- + %618 = vector.bitcast %594 : vector<16xi8> to vector<32xf4E2M1FN> + %619 = vector.bitcast %595 : vector<16xi8> to vector<32xf4E2M1FN> + %622 = vector.bitcast %598 : vector<1xi8> to vector<1xf8E8M0FNU> + %623 = vector.bitcast %599 : vector<1xi8> to vector<1xf8E8M0FNU> + %626 = vector.bitcast %602 : vector<16xi8> to vector<32xf4E2M1FN> + %627 = vector.bitcast %603 : vector<16xi8> to vector<32xf4E2M1FN> + %630 = vector.bitcast %606 : vector<16xi8> to vector<32xf4E2M1FN> + %631 = vector.bitcast %607 : vector<16xi8> to vector<32xf4E2M1FN> + %634 = vector.bitcast %610 : vector<1xi8> to vector<1xf8E8M0FNU> + %635 = vector.bitcast %611 : vector<1xi8> to vector<1xf8E8M0FNU> + %638 = vector.bitcast %614 : vector<1xi8> to vector<1xf8E8M0FNU> + %639 = vector.bitcast %615 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.s.setprio 1 + // --- SAFE MFMAs: M0,M1 x N0,N1,N4,N5 (cluster 0 data only) --- + %642 = vector.extract %622[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %643 = vector.extract %634[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %644 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%643[0] * %626) + %arg6 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %645 = vector.extract %635[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %646 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%645[0] * %627) + %arg7 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %651 = vector.extract %638[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %652 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%651[0] * %630) + %arg10 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %653 = vector.extract %639[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %654 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%653[0] * %631) + %arg11 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %659 = vector.extract %623[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %660 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%643[0] * %626) + %arg14 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %661 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%645[0] * %627) + %arg15 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %664 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%651[0] * %630) + %arg18 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %665 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%653[0] * %631) + %arg19 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + // --- DEPENDENT vector.loads: A(M2,M3), Ascale(M2,M3), B(N2,N3,N6,N7), Bscale(N2,N3,N6,N7) --- + %596 = vector.load %arg38[%67, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %597 = vector.load %arg38[%68, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %600 = vector.load %arg40[%67, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %601 = vector.load %arg40[%68, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %604 = vector.load %arg42[%72, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %605 = vector.load %arg42[%73, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %608 = vector.load %arg42[%76, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %609 = vector.load %arg42[%77, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %612 = vector.load %arg44[%72, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %613 = vector.load %arg44[%73, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %616 = vector.load %arg44[%76, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %617 = vector.load %arg44[%77, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + // --- DEPENDENT bitcasts --- + %620 = vector.bitcast %596 : vector<16xi8> to vector<32xf4E2M1FN> + %621 = vector.bitcast %597 : vector<16xi8> to vector<32xf4E2M1FN> + %624 = vector.bitcast %600 : vector<1xi8> to vector<1xf8E8M0FNU> + %625 = vector.bitcast %601 : vector<1xi8> to vector<1xf8E8M0FNU> + %628 = vector.bitcast %604 : vector<16xi8> to vector<32xf4E2M1FN> + %629 = vector.bitcast %605 : vector<16xi8> to vector<32xf4E2M1FN> + %632 = vector.bitcast %608 : vector<16xi8> to vector<32xf4E2M1FN> + %633 = vector.bitcast %609 : vector<16xi8> to vector<32xf4E2M1FN> + %636 = vector.bitcast %612 : vector<1xi8> to vector<1xf8E8M0FNU> + %637 = vector.bitcast %613 : vector<1xi8> to vector<1xf8E8M0FNU> + %640 = vector.bitcast %616 : vector<1xi8> to vector<1xf8E8M0FNU> + %641 = vector.bitcast %617 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.s.setprio 1 + // --- DEPENDENT MFMAs: M0,M1 x N2,N3,N6,N7 (cluster 1 B data) --- + %647 = vector.extract %636[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %648 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%647[0] * %628) + %arg8 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %649 = vector.extract %637[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %650 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%649[0] * %629) + %arg9 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %655 = vector.extract %640[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %656 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%655[0] * %632) + %arg12 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %657 = vector.extract %641[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %658 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%657[0] * %633) + %arg13 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %662 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%647[0] * %628) + %arg16 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %663 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%649[0] * %629) + %arg17 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %666 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%655[0] * %632) + %arg20 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %667 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%657[0] * %633) + %arg21 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + // --- DEPENDENT MFMAs: M2 x all N (cluster 1 A data) --- + %668 = vector.extract %624[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %669 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%643[0] * %626) + %arg22 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %670 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%645[0] * %627) + %arg23 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %671 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%647[0] * %628) + %arg24 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %672 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%649[0] * %629) + %arg25 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %673 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%651[0] * %630) + %arg26 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %674 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%653[0] * %631) + %arg27 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %675 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%655[0] * %632) + %arg28 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %676 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%657[0] * %633) + %arg29 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + // --- DEPENDENT MFMAs: M3 x all N (cluster 1 A data) --- + %677 = vector.extract %625[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %678 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%643[0] * %626) + %arg30 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %679 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%645[0] * %627) + %arg31 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %680 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%647[0] * %628) + %arg32 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %681 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%649[0] * %629) + %arg33 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %682 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%651[0] * %630) + %arg34 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %683 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%653[0] * %631) + %arg35 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %684 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%655[0] * %632) + %arg36 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %685 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%657[0] * %633) + %arg37 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.sched.barrier 0 + // --- PHASE 2 SAFE vector.loads: A(M0,M1), Ascale(M0,M1), B(N0,N1,N4,N5), Bscale(N0,N1,N4,N5) --- + %686 = vector.load %arg38[%62, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %687 = vector.load %arg38[%66, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %690 = vector.load %arg40[%62, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %691 = vector.load %arg40[%66, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %694 = vector.load %arg42[%70, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %695 = vector.load %arg42[%71, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %698 = vector.load %arg42[%74, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %699 = vector.load %arg42[%75, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %702 = vector.load %arg44[%70, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %703 = vector.load %arg44[%71, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %706 = vector.load %arg44[%74, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %707 = vector.load %arg44[%75, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + // --- PHASE 2 SAFE bitcasts --- + %710 = vector.bitcast %686 : vector<16xi8> to vector<32xf4E2M1FN> + %711 = vector.bitcast %687 : vector<16xi8> to vector<32xf4E2M1FN> + %714 = vector.bitcast %690 : vector<1xi8> to vector<1xf8E8M0FNU> + %715 = vector.bitcast %691 : vector<1xi8> to vector<1xf8E8M0FNU> + %718 = vector.bitcast %694 : vector<16xi8> to vector<32xf4E2M1FN> + %719 = vector.bitcast %695 : vector<16xi8> to vector<32xf4E2M1FN> + %722 = vector.bitcast %698 : vector<16xi8> to vector<32xf4E2M1FN> + %723 = vector.bitcast %699 : vector<16xi8> to vector<32xf4E2M1FN> + %726 = vector.bitcast %702 : vector<1xi8> to vector<1xf8E8M0FNU> + %727 = vector.bitcast %703 : vector<1xi8> to vector<1xf8E8M0FNU> + %730 = vector.bitcast %706 : vector<1xi8> to vector<1xf8E8M0FNU> + %731 = vector.bitcast %707 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.s.setprio 1 + // --- PHASE 2 SAFE MFMAs: M0,M1 x N0,N1,N4,N5 (cluster 0 data only) --- + %734 = vector.extract %714[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %735 = vector.extract %726[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %736 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%735[0] * %718) + %644 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %737 = vector.extract %727[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %738 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%737[0] * %719) + %646 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %743 = vector.extract %730[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %744 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%743[0] * %722) + %652 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %745 = vector.extract %731[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %746 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%745[0] * %723) + %654 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %751 = vector.extract %715[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %752 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%735[0] * %718) + %660 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %753 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%737[0] * %719) + %661 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %756 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%743[0] * %722) + %664 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %757 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%745[0] * %723) + %665 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + // --- PHASE 2 DEPENDENT vector.loads: A(M2,M3), Ascale(M2,M3), B(N2,N3,N6,N7), Bscale(N2,N3,N6,N7) --- + %688 = vector.load %arg38[%67, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %689 = vector.load %arg38[%68, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %692 = vector.load %arg40[%67, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %693 = vector.load %arg40[%68, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %696 = vector.load %arg42[%72, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %697 = vector.load %arg42[%73, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %700 = vector.load %arg42[%76, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %701 = vector.load %arg42[%77, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %704 = vector.load %arg44[%72, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %705 = vector.load %arg44[%73, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %708 = vector.load %arg44[%76, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %709 = vector.load %arg44[%77, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + // --- PHASE 2 DEPENDENT bitcasts --- + %712 = vector.bitcast %688 : vector<16xi8> to vector<32xf4E2M1FN> + %713 = vector.bitcast %689 : vector<16xi8> to vector<32xf4E2M1FN> + %716 = vector.bitcast %692 : vector<1xi8> to vector<1xf8E8M0FNU> + %717 = vector.bitcast %693 : vector<1xi8> to vector<1xf8E8M0FNU> + %720 = vector.bitcast %696 : vector<16xi8> to vector<32xf4E2M1FN> + %721 = vector.bitcast %697 : vector<16xi8> to vector<32xf4E2M1FN> + %724 = vector.bitcast %700 : vector<16xi8> to vector<32xf4E2M1FN> + %725 = vector.bitcast %701 : vector<16xi8> to vector<32xf4E2M1FN> + %728 = vector.bitcast %704 : vector<1xi8> to vector<1xf8E8M0FNU> + %729 = vector.bitcast %705 : vector<1xi8> to vector<1xf8E8M0FNU> + %732 = vector.bitcast %708 : vector<1xi8> to vector<1xf8E8M0FNU> + %733 = vector.bitcast %709 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.s.setprio 1 + // --- PHASE 2 DEPENDENT MFMAs: M0,M1 x N2,N3,N6,N7 (cluster 1 B data) --- + %739 = vector.extract %728[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %740 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%739[0] * %720) + %648 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %741 = vector.extract %729[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %742 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%741[0] * %721) + %650 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %747 = vector.extract %732[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %748 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%747[0] * %724) + %656 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %749 = vector.extract %733[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %750 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%749[0] * %725) + %658 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %754 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%739[0] * %720) + %662 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %755 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%741[0] * %721) + %663 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %758 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%747[0] * %724) + %666 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %759 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%749[0] * %725) + %667 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + // --- PHASE 2 DEPENDENT MFMAs: M2 x all N (cluster 1 A data) --- + %760 = vector.extract %716[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %761 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%735[0] * %718) + %669 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %762 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%737[0] * %719) + %670 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %763 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%739[0] * %720) + %671 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %764 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%741[0] * %721) + %672 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %765 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%743[0] * %722) + %673 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %766 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%745[0] * %723) + %674 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %767 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%747[0] * %724) + %675 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %768 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%749[0] * %725) + %676 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + // --- PHASE 2 DEPENDENT MFMAs: M3 x all N (cluster 1 A data) --- + %769 = vector.extract %717[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %770 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%735[0] * %718) + %678 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %771 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%737[0] * %719) + %679 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %772 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%739[0] * %720) + %680 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %773 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%741[0] * %721) + %681 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %774 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%743[0] * %722) + %682 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %775 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%745[0] * %723) + %683 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %776 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%747[0] * %724) + %684 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %777 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%749[0] * %725) + %685 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + rocdl.sched.barrier 0 + scf.yield %736, %738, %740, %742, %744, %746, %748, %750, %752, %753, %754, %755, %756, %757, %758, %759, %761, %762, %763, %764, %765, %766, %767, %768, %770, %771, %772, %773, %774, %775, %776, %777, %arg39, %arg38, %arg41, %arg40, %arg43, %arg42, %arg45, %arg44 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space> + } + scf.if %61 { + rocdl.s.barrier + } + amdgpu.lds_barrier + %84 = affine.apply #map23()[%thread_id_x, %thread_id_y] + %85 = affine.apply #map22()[%thread_id_x] + %86 = vector.load %83#38[%84, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %87 = arith.xori %33, %c1 : index + %88 = affine.apply #map32()[%thread_id_x, %87] + %89 = vector.load %83#38[%84, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %90 = affine.apply #map24()[%thread_id_x, %thread_id_y] + %91 = vector.load %83#38[%90, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %92 = vector.load %83#38[%90, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %93 = affine.apply #map25()[%thread_id_x, %thread_id_y] + %94 = vector.load %83#38[%93, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %95 = vector.load %83#38[%93, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %96 = affine.apply #map26()[%thread_id_x, %thread_id_y] + %97 = vector.load %83#38[%96, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %98 = vector.load %83#38[%96, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %99 = affine.apply #map27()[%thread_id_x, %thread_id_y] + %100 = vector.load %83#38[%99, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %101 = vector.load %83#38[%99, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %102 = affine.apply #map28()[%thread_id_x, %thread_id_y] + %103 = vector.load %83#38[%102, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %104 = vector.load %83#38[%102, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %105 = affine.apply #map29()[%thread_id_x, %thread_id_y] + %106 = vector.load %83#38[%105, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %107 = vector.load %83#38[%105, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %108 = affine.apply #map30()[%thread_id_x, %thread_id_y] + %109 = vector.load %83#38[%108, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %110 = vector.load %83#38[%108, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %111 = affine.apply #map18()[%thread_id_x] + %112 = arith.xori %111, %7 : index + %113 = affine.apply #map3()[%112] + %114 = vector.load %83#36[%84, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %115 = affine.apply #map31()[%thread_id_x] + %116 = arith.xori %115, %7 : index + %117 = affine.apply #map3()[%116] + %118 = vector.load %83#36[%84, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %119 = vector.load %83#36[%90, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %120 = vector.load %83#36[%90, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %121 = vector.load %83#36[%93, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %122 = vector.load %83#36[%93, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %123 = vector.load %83#36[%96, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %124 = vector.load %83#36[%96, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %125 = vector.load %83#36[%99, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %126 = vector.load %83#36[%99, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %127 = vector.load %83#36[%102, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %128 = vector.load %83#36[%102, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %129 = vector.load %83#36[%105, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %130 = vector.load %83#36[%105, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %131 = vector.load %83#36[%108, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %132 = vector.load %83#36[%108, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %133 = affine.apply #map17()[%thread_id_x] + %134 = vector.load %83#34[%133, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %135 = vector.load %83#34[%133, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %136 = affine.apply #map19()[%thread_id_x] + %137 = vector.load %83#34[%136, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %138 = vector.load %83#34[%136, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %139 = affine.apply #map20()[%thread_id_x] + %140 = vector.load %83#34[%139, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %141 = vector.load %83#34[%139, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %142 = affine.apply #map21()[%thread_id_x] + %143 = vector.load %83#34[%142, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %144 = vector.load %83#34[%142, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %145 = vector.load %83#32[%133, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %146 = vector.load %83#32[%133, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %147 = vector.load %83#32[%136, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %148 = vector.load %83#32[%136, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %149 = vector.load %83#32[%139, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %150 = vector.load %83#32[%139, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %151 = vector.load %83#32[%142, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %152 = vector.load %83#32[%142, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %153 = vector.bitcast %145 : vector<16xi8> to vector<32xf4E2M1FN> + %154 = vector.bitcast %146 : vector<16xi8> to vector<32xf4E2M1FN> + %155 = vector.bitcast %147 : vector<16xi8> to vector<32xf4E2M1FN> + %156 = vector.bitcast %148 : vector<16xi8> to vector<32xf4E2M1FN> + %157 = vector.bitcast %149 : vector<16xi8> to vector<32xf4E2M1FN> + %158 = vector.bitcast %150 : vector<16xi8> to vector<32xf4E2M1FN> + %159 = vector.bitcast %151 : vector<16xi8> to vector<32xf4E2M1FN> + %160 = vector.bitcast %152 : vector<16xi8> to vector<32xf4E2M1FN> + %161 = vector.bitcast %134 : vector<1xi8> to vector<1xf8E8M0FNU> + %162 = vector.bitcast %135 : vector<1xi8> to vector<1xf8E8M0FNU> + %163 = vector.bitcast %137 : vector<1xi8> to vector<1xf8E8M0FNU> + %164 = vector.bitcast %138 : vector<1xi8> to vector<1xf8E8M0FNU> + %165 = vector.bitcast %140 : vector<1xi8> to vector<1xf8E8M0FNU> + %166 = vector.bitcast %141 : vector<1xi8> to vector<1xf8E8M0FNU> + %167 = vector.bitcast %143 : vector<1xi8> to vector<1xf8E8M0FNU> + %168 = vector.bitcast %144 : vector<1xi8> to vector<1xf8E8M0FNU> + %169 = vector.bitcast %114 : vector<16xi8> to vector<32xf4E2M1FN> + %170 = vector.bitcast %118 : vector<16xi8> to vector<32xf4E2M1FN> + %171 = vector.bitcast %119 : vector<16xi8> to vector<32xf4E2M1FN> + %172 = vector.bitcast %120 : vector<16xi8> to vector<32xf4E2M1FN> + %173 = vector.bitcast %121 : vector<16xi8> to vector<32xf4E2M1FN> + %174 = vector.bitcast %122 : vector<16xi8> to vector<32xf4E2M1FN> + %175 = vector.bitcast %123 : vector<16xi8> to vector<32xf4E2M1FN> + %176 = vector.bitcast %124 : vector<16xi8> to vector<32xf4E2M1FN> + %177 = vector.bitcast %125 : vector<16xi8> to vector<32xf4E2M1FN> + %178 = vector.bitcast %126 : vector<16xi8> to vector<32xf4E2M1FN> + %179 = vector.bitcast %127 : vector<16xi8> to vector<32xf4E2M1FN> + %180 = vector.bitcast %128 : vector<16xi8> to vector<32xf4E2M1FN> + %181 = vector.bitcast %129 : vector<16xi8> to vector<32xf4E2M1FN> + %182 = vector.bitcast %130 : vector<16xi8> to vector<32xf4E2M1FN> + %183 = vector.bitcast %131 : vector<16xi8> to vector<32xf4E2M1FN> + %184 = vector.bitcast %132 : vector<16xi8> to vector<32xf4E2M1FN> + %185 = vector.bitcast %86 : vector<1xi8> to vector<1xf8E8M0FNU> + %186 = vector.bitcast %89 : vector<1xi8> to vector<1xf8E8M0FNU> + %187 = vector.bitcast %91 : vector<1xi8> to vector<1xf8E8M0FNU> + %188 = vector.bitcast %92 : vector<1xi8> to vector<1xf8E8M0FNU> + %189 = vector.bitcast %94 : vector<1xi8> to vector<1xf8E8M0FNU> + %190 = vector.bitcast %95 : vector<1xi8> to vector<1xf8E8M0FNU> + %191 = vector.bitcast %97 : vector<1xi8> to vector<1xf8E8M0FNU> + %192 = vector.bitcast %98 : vector<1xi8> to vector<1xf8E8M0FNU> + %193 = vector.bitcast %100 : vector<1xi8> to vector<1xf8E8M0FNU> + %194 = vector.bitcast %101 : vector<1xi8> to vector<1xf8E8M0FNU> + %195 = vector.bitcast %103 : vector<1xi8> to vector<1xf8E8M0FNU> + %196 = vector.bitcast %104 : vector<1xi8> to vector<1xf8E8M0FNU> + %197 = vector.bitcast %106 : vector<1xi8> to vector<1xf8E8M0FNU> + %198 = vector.bitcast %107 : vector<1xi8> to vector<1xf8E8M0FNU> + %199 = vector.bitcast %109 : vector<1xi8> to vector<1xf8E8M0FNU> + %200 = vector.bitcast %110 : vector<1xi8> to vector<1xf8E8M0FNU> + %201 = vector.extract %161[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %202 = vector.extract %185[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %203 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%202[0] * %169) + %83#0 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %204 = vector.extract %162[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %205 = vector.extract %186[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %206 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%205[0] * %170) + %203 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %207 = vector.extract %187[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %208 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%207[0] * %171) + %83#1 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %209 = vector.extract %188[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %210 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%209[0] * %172) + %208 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %211 = vector.extract %189[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %212 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%211[0] * %173) + %83#2 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %213 = vector.extract %190[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %214 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%213[0] * %174) + %212 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %215 = vector.extract %191[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %216 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%215[0] * %175) + %83#3 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %217 = vector.extract %192[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %218 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%217[0] * %176) + %216 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %219 = vector.extract %193[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %220 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%219[0] * %177) + %83#4 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %221 = vector.extract %194[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %222 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%221[0] * %178) + %220 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %223 = vector.extract %195[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %224 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%223[0] * %179) + %83#5 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %225 = vector.extract %196[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %226 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%225[0] * %180) + %224 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %227 = vector.extract %197[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %228 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%227[0] * %181) + %83#6 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %229 = vector.extract %198[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %230 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%229[0] * %182) + %228 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %231 = vector.extract %199[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %232 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%231[0] * %183) + %83#7 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %233 = vector.extract %200[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %234 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%233[0] * %184) + %232 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %235 = vector.extract %163[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %236 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%202[0] * %169) + %83#8 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %237 = vector.extract %164[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %238 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%205[0] * %170) + %236 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %239 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%207[0] * %171) + %83#9 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %240 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%209[0] * %172) + %239 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %241 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%211[0] * %173) + %83#10 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %242 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%213[0] * %174) + %241 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %243 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%215[0] * %175) + %83#11 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %244 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%217[0] * %176) + %243 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %245 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%219[0] * %177) + %83#12 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %246 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%221[0] * %178) + %245 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %247 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%223[0] * %179) + %83#13 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %248 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%225[0] * %180) + %247 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %249 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%227[0] * %181) + %83#14 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %250 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%229[0] * %182) + %249 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %251 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%231[0] * %183) + %83#15 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %252 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%233[0] * %184) + %251 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %253 = vector.extract %165[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %254 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%202[0] * %169) + %83#16 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %255 = vector.extract %166[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %256 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%205[0] * %170) + %254 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %257 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%207[0] * %171) + %83#17 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %258 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%209[0] * %172) + %257 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %259 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%211[0] * %173) + %83#18 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %260 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%213[0] * %174) + %259 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %261 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%215[0] * %175) + %83#19 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %262 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%217[0] * %176) + %261 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %263 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%219[0] * %177) + %83#20 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %264 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%221[0] * %178) + %263 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %265 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%223[0] * %179) + %83#21 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %266 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%225[0] * %180) + %265 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %267 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%227[0] * %181) + %83#22 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %268 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%229[0] * %182) + %267 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %269 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%231[0] * %183) + %83#23 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %270 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%233[0] * %184) + %269 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %271 = vector.extract %167[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %272 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%202[0] * %169) + %83#24 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %273 = vector.extract %168[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %274 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%205[0] * %170) + %272 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %275 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%207[0] * %171) + %83#25 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %276 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%209[0] * %172) + %275 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %277 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%211[0] * %173) + %83#26 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %278 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%213[0] * %174) + %277 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %279 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%215[0] * %175) + %83#27 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %280 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%217[0] * %176) + %279 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %281 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%219[0] * %177) + %83#28 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %282 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%221[0] * %178) + %281 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %283 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%223[0] * %179) + %83#29 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %284 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%225[0] * %180) + %283 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %285 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%227[0] * %181) + %83#30 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %286 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%229[0] * %182) + %285 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %287 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%231[0] * %183) + %83#31 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %288 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%233[0] * %184) + %287 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %289 = vector.extract_strided_slice %206 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %290 = affine.apply #map35()[%block_id_x] + %291 = affine.apply #map35()[%block_id_y] + %292 = affine.apply #map36()[%thread_id_x] + %293 = arith.muli %290, %c57344 overflow : index + %294 = arith.muli %292, %c57344 overflow : index + %295 = arith.addi %293, %291 overflow : index + %296 = arith.addi %294, %84 overflow : index + %reinterpret_cast_13 = memref.reinterpret_cast %4 to offset: [%295], sizes: [536870910], strides: [1] : memref to memref<536870910xf32, strided<[1], offset: ?>> + %cast_14 = memref.cast %reinterpret_cast_13 : memref<536870910xf32, strided<[1], offset: ?>> to memref> + %297 = amdgpu.fat_raw_buffer_cast %cast_14 validBytes(%c2147483643_i64) resetOffset : memref> to memref> + vector.store %289, %297[%296] : memref>, vector<1xf32> + %298 = vector.extract_strided_slice %206 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %299 = affine.apply #map37()[%thread_id_x] + %300 = arith.muli %299, %c57344 overflow : index + %301 = arith.addi %300, %84 overflow : index + vector.store %298, %297[%301] : memref>, vector<1xf32> + %302 = vector.extract_strided_slice %206 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %303 = affine.apply #map38()[%thread_id_x] + %304 = arith.muli %303, %c57344 overflow : index + %305 = arith.addi %304, %84 overflow : index + vector.store %302, %297[%305] : memref>, vector<1xf32> + %306 = vector.extract_strided_slice %206 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %307 = affine.apply #map39()[%thread_id_x] + %308 = arith.muli %307, %c57344 overflow : index + %309 = arith.addi %308, %84 overflow : index + vector.store %306, %297[%309] : memref>, vector<1xf32> + %310 = vector.extract_strided_slice %210 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %311 = arith.addi %294, %90 overflow : index + vector.store %310, %297[%311] : memref>, vector<1xf32> + %312 = vector.extract_strided_slice %210 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %313 = arith.addi %300, %90 overflow : index + vector.store %312, %297[%313] : memref>, vector<1xf32> + %314 = vector.extract_strided_slice %210 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %315 = arith.addi %304, %90 overflow : index + vector.store %314, %297[%315] : memref>, vector<1xf32> + %316 = vector.extract_strided_slice %210 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %317 = arith.addi %308, %90 overflow : index + vector.store %316, %297[%317] : memref>, vector<1xf32> + %318 = vector.extract_strided_slice %214 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %319 = arith.addi %294, %93 overflow : index + vector.store %318, %297[%319] : memref>, vector<1xf32> + %320 = vector.extract_strided_slice %214 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %321 = arith.addi %300, %93 overflow : index + vector.store %320, %297[%321] : memref>, vector<1xf32> + %322 = vector.extract_strided_slice %214 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %323 = arith.addi %304, %93 overflow : index + vector.store %322, %297[%323] : memref>, vector<1xf32> + %324 = vector.extract_strided_slice %214 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %325 = arith.addi %308, %93 overflow : index + vector.store %324, %297[%325] : memref>, vector<1xf32> + %326 = vector.extract_strided_slice %218 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %327 = arith.addi %294, %96 overflow : index + vector.store %326, %297[%327] : memref>, vector<1xf32> + %328 = vector.extract_strided_slice %218 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %329 = arith.addi %300, %96 overflow : index + vector.store %328, %297[%329] : memref>, vector<1xf32> + %330 = vector.extract_strided_slice %218 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %331 = arith.addi %304, %96 overflow : index + vector.store %330, %297[%331] : memref>, vector<1xf32> + %332 = vector.extract_strided_slice %218 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %333 = arith.addi %308, %96 overflow : index + vector.store %332, %297[%333] : memref>, vector<1xf32> + %334 = vector.extract_strided_slice %222 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %335 = arith.addi %294, %99 overflow : index + vector.store %334, %297[%335] : memref>, vector<1xf32> + %336 = vector.extract_strided_slice %222 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %337 = arith.addi %300, %99 overflow : index + vector.store %336, %297[%337] : memref>, vector<1xf32> + %338 = vector.extract_strided_slice %222 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %339 = arith.addi %304, %99 overflow : index + vector.store %338, %297[%339] : memref>, vector<1xf32> + %340 = vector.extract_strided_slice %222 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %341 = arith.addi %308, %99 overflow : index + vector.store %340, %297[%341] : memref>, vector<1xf32> + %342 = vector.extract_strided_slice %226 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %343 = arith.addi %294, %102 overflow : index + vector.store %342, %297[%343] : memref>, vector<1xf32> + %344 = vector.extract_strided_slice %226 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %345 = arith.addi %300, %102 overflow : index + vector.store %344, %297[%345] : memref>, vector<1xf32> + %346 = vector.extract_strided_slice %226 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %347 = arith.addi %304, %102 overflow : index + vector.store %346, %297[%347] : memref>, vector<1xf32> + %348 = vector.extract_strided_slice %226 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %349 = arith.addi %308, %102 overflow : index + vector.store %348, %297[%349] : memref>, vector<1xf32> + %350 = vector.extract_strided_slice %230 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %351 = arith.addi %294, %105 overflow : index + vector.store %350, %297[%351] : memref>, vector<1xf32> + %352 = vector.extract_strided_slice %230 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %353 = arith.addi %300, %105 overflow : index + vector.store %352, %297[%353] : memref>, vector<1xf32> + %354 = vector.extract_strided_slice %230 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %355 = arith.addi %304, %105 overflow : index + vector.store %354, %297[%355] : memref>, vector<1xf32> + %356 = vector.extract_strided_slice %230 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %357 = arith.addi %308, %105 overflow : index + vector.store %356, %297[%357] : memref>, vector<1xf32> + %358 = vector.extract_strided_slice %234 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %359 = arith.addi %294, %108 overflow : index + vector.store %358, %297[%359] : memref>, vector<1xf32> + %360 = vector.extract_strided_slice %234 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %361 = arith.addi %300, %108 overflow : index + vector.store %360, %297[%361] : memref>, vector<1xf32> + %362 = vector.extract_strided_slice %234 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %363 = arith.addi %304, %108 overflow : index + vector.store %362, %297[%363] : memref>, vector<1xf32> + %364 = vector.extract_strided_slice %234 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %365 = arith.addi %308, %108 overflow : index + vector.store %364, %297[%365] : memref>, vector<1xf32> + %366 = vector.extract_strided_slice %238 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %367 = affine.apply #map40()[%thread_id_x] + %368 = arith.muli %367, %c57344 overflow : index + %369 = arith.addi %368, %84 overflow : index + vector.store %366, %297[%369] : memref>, vector<1xf32> + %370 = vector.extract_strided_slice %238 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %371 = affine.apply #map41()[%thread_id_x] + %372 = arith.muli %371, %c57344 overflow : index + %373 = arith.addi %372, %84 overflow : index + vector.store %370, %297[%373] : memref>, vector<1xf32> + %374 = vector.extract_strided_slice %238 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %375 = affine.apply #map42()[%thread_id_x] + %376 = arith.muli %375, %c57344 overflow : index + %377 = arith.addi %376, %84 overflow : index + vector.store %374, %297[%377] : memref>, vector<1xf32> + %378 = vector.extract_strided_slice %238 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %379 = affine.apply #map43()[%thread_id_x] + %380 = arith.muli %379, %c57344 overflow : index + %381 = arith.addi %380, %84 overflow : index + vector.store %378, %297[%381] : memref>, vector<1xf32> + %382 = vector.extract_strided_slice %240 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %383 = arith.addi %368, %90 overflow : index + vector.store %382, %297[%383] : memref>, vector<1xf32> + %384 = vector.extract_strided_slice %240 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %385 = arith.addi %372, %90 overflow : index + vector.store %384, %297[%385] : memref>, vector<1xf32> + %386 = vector.extract_strided_slice %240 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %387 = arith.addi %376, %90 overflow : index + vector.store %386, %297[%387] : memref>, vector<1xf32> + %388 = vector.extract_strided_slice %240 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %389 = arith.addi %380, %90 overflow : index + vector.store %388, %297[%389] : memref>, vector<1xf32> + %390 = vector.extract_strided_slice %242 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %391 = arith.addi %368, %93 overflow : index + vector.store %390, %297[%391] : memref>, vector<1xf32> + %392 = vector.extract_strided_slice %242 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %393 = arith.addi %372, %93 overflow : index + vector.store %392, %297[%393] : memref>, vector<1xf32> + %394 = vector.extract_strided_slice %242 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %395 = arith.addi %376, %93 overflow : index + vector.store %394, %297[%395] : memref>, vector<1xf32> + %396 = vector.extract_strided_slice %242 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %397 = arith.addi %380, %93 overflow : index + vector.store %396, %297[%397] : memref>, vector<1xf32> + %398 = vector.extract_strided_slice %244 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %399 = arith.addi %368, %96 overflow : index + vector.store %398, %297[%399] : memref>, vector<1xf32> + %400 = vector.extract_strided_slice %244 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %401 = arith.addi %372, %96 overflow : index + vector.store %400, %297[%401] : memref>, vector<1xf32> + %402 = vector.extract_strided_slice %244 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %403 = arith.addi %376, %96 overflow : index + vector.store %402, %297[%403] : memref>, vector<1xf32> + %404 = vector.extract_strided_slice %244 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %405 = arith.addi %380, %96 overflow : index + vector.store %404, %297[%405] : memref>, vector<1xf32> + %406 = vector.extract_strided_slice %246 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %407 = arith.addi %368, %99 overflow : index + vector.store %406, %297[%407] : memref>, vector<1xf32> + %408 = vector.extract_strided_slice %246 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %409 = arith.addi %372, %99 overflow : index + vector.store %408, %297[%409] : memref>, vector<1xf32> + %410 = vector.extract_strided_slice %246 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %411 = arith.addi %376, %99 overflow : index + vector.store %410, %297[%411] : memref>, vector<1xf32> + %412 = vector.extract_strided_slice %246 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %413 = arith.addi %380, %99 overflow : index + vector.store %412, %297[%413] : memref>, vector<1xf32> + %414 = vector.extract_strided_slice %248 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %415 = arith.addi %368, %102 overflow : index + vector.store %414, %297[%415] : memref>, vector<1xf32> + %416 = vector.extract_strided_slice %248 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %417 = arith.addi %372, %102 overflow : index + vector.store %416, %297[%417] : memref>, vector<1xf32> + %418 = vector.extract_strided_slice %248 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %419 = arith.addi %376, %102 overflow : index + vector.store %418, %297[%419] : memref>, vector<1xf32> + %420 = vector.extract_strided_slice %248 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %421 = arith.addi %380, %102 overflow : index + vector.store %420, %297[%421] : memref>, vector<1xf32> + %422 = vector.extract_strided_slice %250 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %423 = arith.addi %368, %105 overflow : index + vector.store %422, %297[%423] : memref>, vector<1xf32> + %424 = vector.extract_strided_slice %250 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %425 = arith.addi %372, %105 overflow : index + vector.store %424, %297[%425] : memref>, vector<1xf32> + %426 = vector.extract_strided_slice %250 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %427 = arith.addi %376, %105 overflow : index + vector.store %426, %297[%427] : memref>, vector<1xf32> + %428 = vector.extract_strided_slice %250 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %429 = arith.addi %380, %105 overflow : index + vector.store %428, %297[%429] : memref>, vector<1xf32> + %430 = vector.extract_strided_slice %252 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %431 = arith.addi %368, %108 overflow : index + vector.store %430, %297[%431] : memref>, vector<1xf32> + %432 = vector.extract_strided_slice %252 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %433 = arith.addi %372, %108 overflow : index + vector.store %432, %297[%433] : memref>, vector<1xf32> + %434 = vector.extract_strided_slice %252 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %435 = arith.addi %376, %108 overflow : index + vector.store %434, %297[%435] : memref>, vector<1xf32> + %436 = vector.extract_strided_slice %252 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %437 = arith.addi %380, %108 overflow : index + vector.store %436, %297[%437] : memref>, vector<1xf32> + %438 = vector.extract_strided_slice %256 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %439 = affine.apply #map44()[%thread_id_x] + %440 = arith.muli %439, %c57344 overflow : index + %441 = arith.addi %440, %84 overflow : index + vector.store %438, %297[%441] : memref>, vector<1xf32> + %442 = vector.extract_strided_slice %256 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %443 = affine.apply #map45()[%thread_id_x] + %444 = arith.muli %443, %c57344 overflow : index + %445 = arith.addi %444, %84 overflow : index + vector.store %442, %297[%445] : memref>, vector<1xf32> + %446 = vector.extract_strided_slice %256 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %447 = affine.apply #map46()[%thread_id_x] + %448 = arith.muli %447, %c57344 overflow : index + %449 = arith.addi %448, %84 overflow : index + vector.store %446, %297[%449] : memref>, vector<1xf32> + %450 = vector.extract_strided_slice %256 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %451 = affine.apply #map47()[%thread_id_x] + %452 = arith.muli %451, %c57344 overflow : index + %453 = arith.addi %452, %84 overflow : index + vector.store %450, %297[%453] : memref>, vector<1xf32> + %454 = vector.extract_strided_slice %258 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %455 = arith.addi %440, %90 overflow : index + vector.store %454, %297[%455] : memref>, vector<1xf32> + %456 = vector.extract_strided_slice %258 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %457 = arith.addi %444, %90 overflow : index + vector.store %456, %297[%457] : memref>, vector<1xf32> + %458 = vector.extract_strided_slice %258 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %459 = arith.addi %448, %90 overflow : index + vector.store %458, %297[%459] : memref>, vector<1xf32> + %460 = vector.extract_strided_slice %258 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %461 = arith.addi %452, %90 overflow : index + vector.store %460, %297[%461] : memref>, vector<1xf32> + %462 = vector.extract_strided_slice %260 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %463 = arith.addi %440, %93 overflow : index + vector.store %462, %297[%463] : memref>, vector<1xf32> + %464 = vector.extract_strided_slice %260 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %465 = arith.addi %444, %93 overflow : index + vector.store %464, %297[%465] : memref>, vector<1xf32> + %466 = vector.extract_strided_slice %260 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %467 = arith.addi %448, %93 overflow : index + vector.store %466, %297[%467] : memref>, vector<1xf32> + %468 = vector.extract_strided_slice %260 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %469 = arith.addi %452, %93 overflow : index + vector.store %468, %297[%469] : memref>, vector<1xf32> + %470 = vector.extract_strided_slice %262 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %471 = arith.addi %440, %96 overflow : index + vector.store %470, %297[%471] : memref>, vector<1xf32> + %472 = vector.extract_strided_slice %262 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %473 = arith.addi %444, %96 overflow : index + vector.store %472, %297[%473] : memref>, vector<1xf32> + %474 = vector.extract_strided_slice %262 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %475 = arith.addi %448, %96 overflow : index + vector.store %474, %297[%475] : memref>, vector<1xf32> + %476 = vector.extract_strided_slice %262 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %477 = arith.addi %452, %96 overflow : index + vector.store %476, %297[%477] : memref>, vector<1xf32> + %478 = vector.extract_strided_slice %264 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %479 = arith.addi %440, %99 overflow : index + vector.store %478, %297[%479] : memref>, vector<1xf32> + %480 = vector.extract_strided_slice %264 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %481 = arith.addi %444, %99 overflow : index + vector.store %480, %297[%481] : memref>, vector<1xf32> + %482 = vector.extract_strided_slice %264 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %483 = arith.addi %448, %99 overflow : index + vector.store %482, %297[%483] : memref>, vector<1xf32> + %484 = vector.extract_strided_slice %264 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %485 = arith.addi %452, %99 overflow : index + vector.store %484, %297[%485] : memref>, vector<1xf32> + %486 = vector.extract_strided_slice %266 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %487 = arith.addi %440, %102 overflow : index + vector.store %486, %297[%487] : memref>, vector<1xf32> + %488 = vector.extract_strided_slice %266 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %489 = arith.addi %444, %102 overflow : index + vector.store %488, %297[%489] : memref>, vector<1xf32> + %490 = vector.extract_strided_slice %266 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %491 = arith.addi %448, %102 overflow : index + vector.store %490, %297[%491] : memref>, vector<1xf32> + %492 = vector.extract_strided_slice %266 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %493 = arith.addi %452, %102 overflow : index + vector.store %492, %297[%493] : memref>, vector<1xf32> + %494 = vector.extract_strided_slice %268 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %495 = arith.addi %440, %105 overflow : index + vector.store %494, %297[%495] : memref>, vector<1xf32> + %496 = vector.extract_strided_slice %268 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %497 = arith.addi %444, %105 overflow : index + vector.store %496, %297[%497] : memref>, vector<1xf32> + %498 = vector.extract_strided_slice %268 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %499 = arith.addi %448, %105 overflow : index + vector.store %498, %297[%499] : memref>, vector<1xf32> + %500 = vector.extract_strided_slice %268 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %501 = arith.addi %452, %105 overflow : index + vector.store %500, %297[%501] : memref>, vector<1xf32> + %502 = vector.extract_strided_slice %270 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %503 = arith.addi %440, %108 overflow : index + vector.store %502, %297[%503] : memref>, vector<1xf32> + %504 = vector.extract_strided_slice %270 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %505 = arith.addi %444, %108 overflow : index + vector.store %504, %297[%505] : memref>, vector<1xf32> + %506 = vector.extract_strided_slice %270 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %507 = arith.addi %448, %108 overflow : index + vector.store %506, %297[%507] : memref>, vector<1xf32> + %508 = vector.extract_strided_slice %270 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %509 = arith.addi %452, %108 overflow : index + vector.store %508, %297[%509] : memref>, vector<1xf32> + %510 = vector.extract_strided_slice %274 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %511 = affine.apply #map48()[%thread_id_x] + %512 = arith.muli %511, %c57344 overflow : index + %513 = arith.addi %512, %84 overflow : index + vector.store %510, %297[%513] : memref>, vector<1xf32> + %514 = vector.extract_strided_slice %274 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %515 = affine.apply #map49()[%thread_id_x] + %516 = arith.muli %515, %c57344 overflow : index + %517 = arith.addi %516, %84 overflow : index + vector.store %514, %297[%517] : memref>, vector<1xf32> + %518 = vector.extract_strided_slice %274 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %519 = affine.apply #map50()[%thread_id_x] + %520 = arith.muli %519, %c57344 overflow : index + %521 = arith.addi %520, %84 overflow : index + vector.store %518, %297[%521] : memref>, vector<1xf32> + %522 = vector.extract_strided_slice %274 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %523 = affine.apply #map51()[%thread_id_x] + %524 = arith.muli %523, %c57344 overflow : index + %525 = arith.addi %524, %84 overflow : index + vector.store %522, %297[%525] : memref>, vector<1xf32> + %526 = vector.extract_strided_slice %276 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %527 = arith.addi %512, %90 overflow : index + vector.store %526, %297[%527] : memref>, vector<1xf32> + %528 = vector.extract_strided_slice %276 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %529 = arith.addi %516, %90 overflow : index + vector.store %528, %297[%529] : memref>, vector<1xf32> + %530 = vector.extract_strided_slice %276 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %531 = arith.addi %520, %90 overflow : index + vector.store %530, %297[%531] : memref>, vector<1xf32> + %532 = vector.extract_strided_slice %276 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %533 = arith.addi %524, %90 overflow : index + vector.store %532, %297[%533] : memref>, vector<1xf32> + %534 = vector.extract_strided_slice %278 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %535 = arith.addi %512, %93 overflow : index + vector.store %534, %297[%535] : memref>, vector<1xf32> + %536 = vector.extract_strided_slice %278 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %537 = arith.addi %516, %93 overflow : index + vector.store %536, %297[%537] : memref>, vector<1xf32> + %538 = vector.extract_strided_slice %278 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %539 = arith.addi %520, %93 overflow : index + vector.store %538, %297[%539] : memref>, vector<1xf32> + %540 = vector.extract_strided_slice %278 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %541 = arith.addi %524, %93 overflow : index + vector.store %540, %297[%541] : memref>, vector<1xf32> + %542 = vector.extract_strided_slice %280 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %543 = arith.addi %512, %96 overflow : index + vector.store %542, %297[%543] : memref>, vector<1xf32> + %544 = vector.extract_strided_slice %280 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %545 = arith.addi %516, %96 overflow : index + vector.store %544, %297[%545] : memref>, vector<1xf32> + %546 = vector.extract_strided_slice %280 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %547 = arith.addi %520, %96 overflow : index + vector.store %546, %297[%547] : memref>, vector<1xf32> + %548 = vector.extract_strided_slice %280 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %549 = arith.addi %524, %96 overflow : index + vector.store %548, %297[%549] : memref>, vector<1xf32> + %550 = vector.extract_strided_slice %282 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %551 = arith.addi %512, %99 overflow : index + vector.store %550, %297[%551] : memref>, vector<1xf32> + %552 = vector.extract_strided_slice %282 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %553 = arith.addi %516, %99 overflow : index + vector.store %552, %297[%553] : memref>, vector<1xf32> + %554 = vector.extract_strided_slice %282 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %555 = arith.addi %520, %99 overflow : index + vector.store %554, %297[%555] : memref>, vector<1xf32> + %556 = vector.extract_strided_slice %282 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %557 = arith.addi %524, %99 overflow : index + vector.store %556, %297[%557] : memref>, vector<1xf32> + %558 = vector.extract_strided_slice %284 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %559 = arith.addi %512, %102 overflow : index + vector.store %558, %297[%559] : memref>, vector<1xf32> + %560 = vector.extract_strided_slice %284 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %561 = arith.addi %516, %102 overflow : index + vector.store %560, %297[%561] : memref>, vector<1xf32> + %562 = vector.extract_strided_slice %284 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %563 = arith.addi %520, %102 overflow : index + vector.store %562, %297[%563] : memref>, vector<1xf32> + %564 = vector.extract_strided_slice %284 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %565 = arith.addi %524, %102 overflow : index + vector.store %564, %297[%565] : memref>, vector<1xf32> + %566 = vector.extract_strided_slice %286 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %567 = arith.addi %512, %105 overflow : index + vector.store %566, %297[%567] : memref>, vector<1xf32> + %568 = vector.extract_strided_slice %286 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %569 = arith.addi %516, %105 overflow : index + vector.store %568, %297[%569] : memref>, vector<1xf32> + %570 = vector.extract_strided_slice %286 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %571 = arith.addi %520, %105 overflow : index + vector.store %570, %297[%571] : memref>, vector<1xf32> + %572 = vector.extract_strided_slice %286 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %573 = arith.addi %524, %105 overflow : index + vector.store %572, %297[%573] : memref>, vector<1xf32> + %574 = vector.extract_strided_slice %288 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %575 = arith.addi %512, %108 overflow : index + vector.store %574, %297[%575] : memref>, vector<1xf32> + %576 = vector.extract_strided_slice %288 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %577 = arith.addi %516, %108 overflow : index + vector.store %576, %297[%577] : memref>, vector<1xf32> + %578 = vector.extract_strided_slice %288 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %579 = arith.addi %520, %108 overflow : index + vector.store %578, %297[%579] : memref>, vector<1xf32> + %580 = vector.extract_strided_slice %288 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %581 = arith.addi %524, %108 overflow : index + vector.store %580, %297[%581] : memref>, vector<1xf32> + return + } + } + } + func.func @isolated_benchmark$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view, %arg5: !hal.fence, %arg6: !hal.fence) -> !hal.buffer_view { + %0 = hal.tensor.import wait(%arg5) => %arg0 : !hal.buffer_view -> tensor<4096x8192xi8> + %1 = hal.tensor.import wait(%arg5) => %arg1 : !hal.buffer_view -> tensor<4096x512xi8> + %2 = hal.tensor.import wait(%arg5) => %arg2 : !hal.buffer_view -> tensor<57344x8192xi8> + %3 = hal.tensor.import wait(%arg5) => %arg3 : !hal.buffer_view -> tensor<57344x512xi8> + %4 = hal.tensor.import wait(%arg5) => %arg4 : !hal.buffer_view -> tensor<4096x57344xf32> + %5 = flow.dispatch @gemm::@gemm(%0, %1, %2, %3, %4) : (tensor<4096x8192xi8>, tensor<4096x512xi8>, tensor<57344x8192xi8>, tensor<57344x512xi8>, tensor<4096x57344xf32>) -> %4 + %6 = hal.tensor.barrier join(%5 : tensor<4096x57344xf32>) => %arg6 : !hal.fence + %7 = hal.tensor.export %6 : tensor<4096x57344xf32> -> !hal.buffer_view + return %7 : !hal.buffer_view + } + } + """ + + mlir_claude = """ + #map = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8) floordiv 256) * 256)> + #map1 = affine_map<()[s0] -> ((s0 floordiv 8) mod 8)> + #map2 = affine_map<()[s0] -> (s0 mod 8)> + #map3 = affine_map<()[s0] -> (s0 * 16)> + #map4 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64) floordiv 32) * 256)> + #map5 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 64) floordiv 256) * 256 + 64)> + #map6 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 8) floordiv 32) * 256 + 64)> + #map7 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 128) floordiv 256) * 256 + 128)> + #map8 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 16) floordiv 32) * 256 + 128)> + #map9 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 192) floordiv 256) * 256 + 192)> + #map10 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 24) floordiv 32) * 256 + 192)> + #map11 = affine_map<()[s0, s1, s2] -> (s1 * 128 + s2 * 256 + s0 floordiv 2 - ((s1 * 128 + s0 floordiv 2) floordiv 256) * 256)> + #map12 = affine_map<()[s0] -> ((s0 floordiv 2) mod 2)> + #map13 = affine_map<()[s0] -> (s0 mod 2)> + #map14 = affine_map<()[s0] -> (s0 * 4)> + #map15 = affine_map<()[s0, s1] -> (s1 * 128 + (s0 floordiv 64) * 32 - ((s1 * 4 + s0 floordiv 64) floordiv 8) * 256)> + #map16 = affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 64)> + #map17 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64)> + #map18 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> + #map19 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 16)> + #map20 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 32)> + #map21 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 48)> + #map22 = affine_map<()[s0] -> (s0 * 4 + (s0 mod 64) floordiv 16 - (s0 floordiv 2) * 8)> + #map23 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16)> + #map24 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 16)> + #map25 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 32)> + #map26 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 48)> + #map27 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 64)> + #map28 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 80)> + #map29 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 96)> + #map30 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 112)> + #map31 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16 + 4)> + #map32 = affine_map<()[s0, s1] -> (s1 * 4 + (s0 mod 64) floordiv 16)> + #map33 = affine_map<()[s0, s1] -> (s0 * 128 + s1 * 16 + 128)> + #map34 = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 4 + 8)> + #map35 = affine_map<()[s0] -> (s0 * 256)> + #map36 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4)> + #map37 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 1)> + #map38 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 2)> + #map39 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 3)> + #map40 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 16)> + #map41 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 17)> + #map42 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 18)> + #map43 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 19)> + #map44 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 32)> + #map45 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 33)> + #map46 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 34)> + #map47 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 35)> + #map48 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 48)> + #map49 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 49)> + #map50 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 50)> + #map51 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 51)> + #translation = #iree_codegen.translation_info + module attributes {transform.with_named_sequence} { + stream.executable private @gemm { + stream.executable.export public @gemm workgroups() -> (index, index, index) { + %c16 = arith.constant 16 : index + %c224 = arith.constant 224 : index + %c1 = arith.constant 1 : index + stream.return %c16, %c224, %c1 : index, index, index + } + builtin.module { + func.func @gemm(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding, %arg3: !stream.binding, %arg4: !stream.binding) attributes {translation_info = #translation} { + %c4_i32 = arith.constant 4 : i32 + %c512_i14 = arith.constant 512 : i14 + %c-8192_i14 = arith.constant -8192 : i14 + %c2147483643_i64 = arith.constant 2147483643 : i64 + %c57344 = arith.constant 57344 : index + %c63 = arith.constant 63 : index + %c512 = arith.constant 512 : index + %c2147483646_i64 = arith.constant 2147483646 : i64 + %c8192 = arith.constant 8192 : index + %c1 = arith.constant 1 : index + %cst = arith.constant dense<0.000000e+00> : vector<4xf32> + %c0 = arith.constant 0 : index + %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> memref + %1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> memref + %2 = stream.binding.subspan %arg2[%c0] : !stream.binding -> memref + %3 = stream.binding.subspan %arg3[%c0] : !stream.binding -> memref + %4 = stream.binding.subspan %arg4[%c0] : !stream.binding -> memref + %block_id_x = gpu.block_id x upper_bound 16 + %block_id_y = gpu.block_id y upper_bound 224 + %thread_id_x = gpu.thread_id x upper_bound 256 + %thread_id_y = gpu.thread_id y upper_bound 2 + %alloc = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_0 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_1 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_2 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_3 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_4 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_5 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_6 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %c32_idx = arith.constant 32 : index + %c128_idx = arith.constant 128 : index + %c262144 = arith.constant 262144 : index + %c65536 = arith.constant 65536 : index + %is_cluster0 = arith.cmpi eq, %thread_id_y, %c0 : index + %5 = affine.apply #map()[%thread_id_x, %thread_id_y, %block_id_x] + %6 = affine.apply #map1()[%thread_id_x] + %7 = affine.apply #map2()[%thread_id_x] + %8 = arith.xori %7, %6 : index + %9 = affine.apply #map3()[%8] + %10 = affine.apply #map4()[%thread_id_x, %thread_id_y] + %11 = gpu.subgroup_broadcast %10, first_active_lane : index + %12 = gpu.subgroup_broadcast %c0, first_active_lane : index + %13 = arith.muli %5, %c8192 overflow : index + %14 = arith.addi %13, %9 overflow : index + %reinterpret_cast = memref.reinterpret_cast %0 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast = memref.cast %reinterpret_cast : memref<2147483646xi8, strided<[1]>> to memref> + %15 = amdgpu.fat_raw_buffer_cast %cast validBytes(%c2147483646_i64) cacheSwizzleStride(%c-8192_i14) resetOffset : memref> to memref> + // --- Address computations (all waves) --- + %16 = affine.apply #map5()[%thread_id_x, %thread_id_y, %block_id_x] + %17 = affine.apply #map6()[%thread_id_x, %thread_id_y] + %18 = gpu.subgroup_broadcast %17, first_active_lane : index + %19 = arith.muli %16, %c8192 overflow : index + %20 = arith.addi %19, %9 overflow : index + %21 = affine.apply #map7()[%thread_id_x, %thread_id_y, %block_id_x] + %22 = affine.apply #map8()[%thread_id_x, %thread_id_y] + %23 = gpu.subgroup_broadcast %22, first_active_lane : index + %24 = arith.muli %21, %c8192 overflow : index + %25 = arith.addi %24, %9 overflow : index + %26 = affine.apply #map9()[%thread_id_x, %thread_id_y, %block_id_x] + %27 = affine.apply #map10()[%thread_id_x, %thread_id_y] + %28 = gpu.subgroup_broadcast %27, first_active_lane : index + %29 = arith.muli %26, %c8192 overflow : index + %30 = arith.addi %29, %9 overflow : index + %31 = affine.apply #map11()[%thread_id_x, %thread_id_y, %block_id_x] + %32 = affine.apply #map12()[%thread_id_x] + %33 = affine.apply #map13()[%thread_id_x] + %34 = arith.xori %33, %32 : index + %35 = affine.apply #map14()[%34] + %36 = affine.apply #map15()[%thread_id_x, %thread_id_y] + %37 = gpu.subgroup_broadcast %36, first_active_lane : index + %38 = arith.muli %31, %c512 overflow : index + %39 = arith.addi %38, %35 overflow : index + %reinterpret_cast_7 = memref.reinterpret_cast %1 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_8 = memref.cast %reinterpret_cast_7 : memref<2147483646xi8, strided<[1]>> to memref> + %40 = amdgpu.fat_raw_buffer_cast %cast_8 validBytes(%c2147483646_i64) cacheSwizzleStride(%c512_i14) resetOffset : memref> to memref> + %41 = affine.apply #map()[%thread_id_x, %thread_id_y, %block_id_y] + %42 = arith.muli %41, %c8192 overflow : index + %43 = arith.addi %42, %9 overflow : index + %reinterpret_cast_9 = memref.reinterpret_cast %2 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_10 = memref.cast %reinterpret_cast_9 : memref<2147483646xi8, strided<[1]>> to memref> + %44 = amdgpu.fat_raw_buffer_cast %cast_10 validBytes(%c2147483646_i64) cacheSwizzleStride(%c-8192_i14) resetOffset : memref> to memref> + %45 = affine.apply #map5()[%thread_id_x, %thread_id_y, %block_id_y] + %46 = arith.muli %45, %c8192 overflow : index + %47 = arith.addi %46, %9 overflow : index + %48 = affine.apply #map7()[%thread_id_x, %thread_id_y, %block_id_y] + %49 = arith.muli %48, %c8192 overflow : index + %50 = arith.addi %49, %9 overflow : index + %51 = affine.apply #map9()[%thread_id_x, %thread_id_y, %block_id_y] + %52 = arith.muli %51, %c8192 overflow : index + %53 = arith.addi %52, %9 overflow : index + %54 = affine.apply #map11()[%thread_id_x, %thread_id_y, %block_id_y] + %55 = arith.muli %54, %c512 overflow : index + %56 = arith.addi %55, %35 overflow : index + %reinterpret_cast_11 = memref.reinterpret_cast %3 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_12 = memref.cast %reinterpret_cast_11 : memref<2147483646xi8, strided<[1]>> to memref> + %57 = amdgpu.fat_raw_buffer_cast %cast_12 validBytes(%c2147483646_i64) cacheSwizzleStride(%c512_i14) resetOffset : memref> to memref> + // --- Cluster 0 only: A data (8), A scale (2), B data (8) gathers --- + scf.if %is_cluster0 { + // A data: 4 original gathers (ty=0 addresses) + amdgpu.gather_to_lds %15[%14], %alloc_6[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%20], %alloc_6[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%25], %alloc_6[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%30], %alloc_6[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + // A data: 4 extra gathers (ty=1 addresses: global +262144, LDS row +32) + %ea_g0 = arith.addi %14, %c262144 overflow : index + %ea_l0 = arith.addi %11, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%ea_g0], %alloc_6[%ea_l0, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %ea_g1 = arith.addi %20, %c262144 overflow : index + %ea_l1 = arith.addi %18, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%ea_g1], %alloc_6[%ea_l1, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %ea_g2 = arith.addi %25, %c262144 overflow : index + %ea_l2 = arith.addi %23, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%ea_g2], %alloc_6[%ea_l2, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %ea_g3 = arith.addi %30, %c262144 overflow : index + %ea_l3 = arith.addi %28, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%ea_g3], %alloc_6[%ea_l3, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + // A scale: 1 original gather (ty=0) + amdgpu.gather_to_lds %40[%39], %alloc_4[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + // A scale: 1 extra gather (ty=1: global +65536, LDS row +128) + %eas_g0 = arith.addi %39, %c65536 overflow : index + %eas_l0 = arith.addi %37, %c128_idx overflow : index + amdgpu.gather_to_lds %40[%eas_g0], %alloc_4[%eas_l0, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + // B data: 4 original gathers (ty=0 addresses) + amdgpu.gather_to_lds %44[%43], %alloc_2[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%47], %alloc_2[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%50], %alloc_2[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%53], %alloc_2[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + // B data: 4 extra gathers (ty=1: global +262144, LDS row +32) + %eb_g0 = arith.addi %43, %c262144 overflow : index + amdgpu.gather_to_lds %44[%eb_g0], %alloc_2[%ea_l0, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %eb_g1 = arith.addi %47, %c262144 overflow : index + amdgpu.gather_to_lds %44[%eb_g1], %alloc_2[%ea_l1, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %eb_g2 = arith.addi %50, %c262144 overflow : index + amdgpu.gather_to_lds %44[%eb_g2], %alloc_2[%ea_l2, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %eb_g3 = arith.addi %53, %c262144 overflow : index + amdgpu.gather_to_lds %44[%eb_g3], %alloc_2[%ea_l3, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + } + // B scale: unchanged (both clusters, already cluster-aligned) + amdgpu.gather_to_lds %57[%56], %alloc_0[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + rocdl.s.barrier + %58 = affine.apply #map16()[%thread_id_x, %thread_id_y] + %59 = arith.index_cast %58 : index to i32 + %60 = arith.cmpi sge, %59, %c4_i32 : i32 + %61 = arith.cmpi slt, %59, %c4_i32 : i32 + scf.if %60 { + rocdl.s.barrier + } + %62 = affine.apply #map17()[%thread_id_x] + %63 = affine.apply #map18()[%thread_id_x] + %64 = arith.xori %63, %7 : index + %65 = affine.apply #map3()[%64] + %66 = affine.apply #map19()[%thread_id_x] + %67 = affine.apply #map20()[%thread_id_x] + %68 = affine.apply #map21()[%thread_id_x] + %69 = affine.apply #map22()[%thread_id_x] + %70 = affine.apply #map23()[%thread_id_x, %thread_id_y] + %71 = affine.apply #map24()[%thread_id_x, %thread_id_y] + %72 = affine.apply #map25()[%thread_id_x, %thread_id_y] + %73 = affine.apply #map26()[%thread_id_x, %thread_id_y] + %74 = affine.apply #map27()[%thread_id_x, %thread_id_y] + %75 = affine.apply #map28()[%thread_id_x, %thread_id_y] + %76 = affine.apply #map29()[%thread_id_x, %thread_id_y] + %77 = affine.apply #map30()[%thread_id_x, %thread_id_y] + %78 = affine.apply #map31()[%thread_id_x] + %79 = arith.xori %78, %7 : index + %80 = affine.apply #map3()[%79] + %81 = arith.xori %33, %c1 : index + %82 = affine.apply #map32()[%thread_id_x, %81] + %83:40 = scf.for %arg5 = %c0 to %c63 step %c1 iter_args(%arg6 = %cst, %arg7 = %cst, %arg8 = %cst, %arg9 = %cst, %arg10 = %cst, %arg11 = %cst, %arg12 = %cst, %arg13 = %cst, %arg14 = %cst, %arg15 = %cst, %arg16 = %cst, %arg17 = %cst, %arg18 = %cst, %arg19 = %cst, %arg20 = %cst, %arg21 = %cst, %arg22 = %cst, %arg23 = %cst, %arg24 = %cst, %arg25 = %cst, %arg26 = %cst, %arg27 = %cst, %arg28 = %cst, %arg29 = %cst, %arg30 = %cst, %arg31 = %cst, %arg32 = %cst, %arg33 = %cst, %arg34 = %cst, %arg35 = %cst, %arg36 = %cst, %arg37 = %cst, %arg38 = %alloc_6, %arg39 = %alloc_5, %arg40 = %alloc_4, %arg41 = %alloc_3, %arg42 = %alloc_2, %arg43 = %alloc_1, %arg44 = %alloc_0, %arg45 = %alloc) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>) { + rocdl.sched.barrier 0 + rocdl.s.barrier + // --- Address computations (all waves) --- + %582 = affine.apply #map33()[%arg5, %8] + %583 = arith.addi %13, %582 overflow : index + %584 = arith.addi %19, %582 overflow : index + %585 = arith.addi %24, %582 overflow : index + %586 = arith.addi %29, %582 overflow : index + %587 = affine.apply #map34()[%arg5, %34] + %588 = arith.addi %38, %587 overflow : index + %589 = arith.addi %42, %582 overflow : index + %590 = arith.addi %46, %582 overflow : index + %591 = arith.addi %49, %582 overflow : index + %592 = arith.addi %52, %582 overflow : index + %593 = arith.addi %55, %587 overflow : index + // --- Cluster 0 only: A data (8), A scale (2), B data (8) gathers --- + scf.if %is_cluster0 { + // A data: 4 original gathers (ty=0 addresses) + amdgpu.gather_to_lds %15[%583], %arg39[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%584], %arg39[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%585], %arg39[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%586], %arg39[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + // A data: 4 extra gathers (ty=1: global +262144, LDS row +32) + %lea_g0 = arith.addi %583, %c262144 overflow : index + %lea_l0 = arith.addi %11, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%lea_g0], %arg39[%lea_l0, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %lea_g1 = arith.addi %584, %c262144 overflow : index + %lea_l1 = arith.addi %18, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%lea_g1], %arg39[%lea_l1, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %lea_g2 = arith.addi %585, %c262144 overflow : index + %lea_l2 = arith.addi %23, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%lea_g2], %arg39[%lea_l2, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %lea_g3 = arith.addi %586, %c262144 overflow : index + %lea_l3 = arith.addi %28, %c32_idx overflow : index + amdgpu.gather_to_lds %15[%lea_g3], %arg39[%lea_l3, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + // A scale: 1 original gather (ty=0) + amdgpu.gather_to_lds %40[%588], %arg41[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + // A scale: 1 extra gather (ty=1: global +65536, LDS row +128) + %leas_g0 = arith.addi %588, %c65536 overflow : index + %leas_l0 = arith.addi %37, %c128_idx overflow : index + amdgpu.gather_to_lds %40[%leas_g0], %arg41[%leas_l0, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + // B data: 4 original gathers (ty=0 addresses) + amdgpu.gather_to_lds %44[%589], %arg43[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%590], %arg43[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%591], %arg43[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%592], %arg43[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + // B data: 4 extra gathers (ty=1: global +262144, LDS row +32) + %leb_g0 = arith.addi %589, %c262144 overflow : index + amdgpu.gather_to_lds %44[%leb_g0], %arg43[%lea_l0, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %leb_g1 = arith.addi %590, %c262144 overflow : index + amdgpu.gather_to_lds %44[%leb_g1], %arg43[%lea_l1, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %leb_g2 = arith.addi %591, %c262144 overflow : index + amdgpu.gather_to_lds %44[%leb_g2], %arg43[%lea_l2, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %leb_g3 = arith.addi %592, %c262144 overflow : index + amdgpu.gather_to_lds %44[%leb_g3], %arg43[%lea_l3, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + } + // B scale: unchanged (both clusters) + amdgpu.gather_to_lds %57[%593], %arg45[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + rocdl.sched.barrier 0 + amdgpu.memory_counter_wait load(10) + %594 = vector.load %arg38[%62, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %595 = vector.load %arg38[%66, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %596 = vector.load %arg38[%67, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %597 = vector.load %arg38[%68, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %598 = vector.load %arg40[%62, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %599 = vector.load %arg40[%66, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %600 = vector.load %arg40[%67, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %601 = vector.load %arg40[%68, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %602 = vector.load %arg42[%70, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %603 = vector.load %arg42[%71, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %604 = vector.load %arg42[%72, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %605 = vector.load %arg42[%73, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %606 = vector.load %arg42[%74, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %607 = vector.load %arg42[%75, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %608 = vector.load %arg42[%76, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %609 = vector.load %arg42[%77, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %610 = vector.load %arg44[%70, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %611 = vector.load %arg44[%71, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %612 = vector.load %arg44[%72, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %613 = vector.load %arg44[%73, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %614 = vector.load %arg44[%74, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %615 = vector.load %arg44[%75, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %616 = vector.load %arg44[%76, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %617 = vector.load %arg44[%77, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %618 = vector.bitcast %594 : vector<16xi8> to vector<32xf4E2M1FN> + %619 = vector.bitcast %595 : vector<16xi8> to vector<32xf4E2M1FN> + %620 = vector.bitcast %596 : vector<16xi8> to vector<32xf4E2M1FN> + %621 = vector.bitcast %597 : vector<16xi8> to vector<32xf4E2M1FN> + %622 = vector.bitcast %598 : vector<1xi8> to vector<1xf8E8M0FNU> + %623 = vector.bitcast %599 : vector<1xi8> to vector<1xf8E8M0FNU> + %624 = vector.bitcast %600 : vector<1xi8> to vector<1xf8E8M0FNU> + %625 = vector.bitcast %601 : vector<1xi8> to vector<1xf8E8M0FNU> + %626 = vector.bitcast %602 : vector<16xi8> to vector<32xf4E2M1FN> + %627 = vector.bitcast %603 : vector<16xi8> to vector<32xf4E2M1FN> + %628 = vector.bitcast %604 : vector<16xi8> to vector<32xf4E2M1FN> + %629 = vector.bitcast %605 : vector<16xi8> to vector<32xf4E2M1FN> + %630 = vector.bitcast %606 : vector<16xi8> to vector<32xf4E2M1FN> + %631 = vector.bitcast %607 : vector<16xi8> to vector<32xf4E2M1FN> + %632 = vector.bitcast %608 : vector<16xi8> to vector<32xf4E2M1FN> + %633 = vector.bitcast %609 : vector<16xi8> to vector<32xf4E2M1FN> + %634 = vector.bitcast %610 : vector<1xi8> to vector<1xf8E8M0FNU> + %635 = vector.bitcast %611 : vector<1xi8> to vector<1xf8E8M0FNU> + %636 = vector.bitcast %612 : vector<1xi8> to vector<1xf8E8M0FNU> + %637 = vector.bitcast %613 : vector<1xi8> to vector<1xf8E8M0FNU> + %638 = vector.bitcast %614 : vector<1xi8> to vector<1xf8E8M0FNU> + %639 = vector.bitcast %615 : vector<1xi8> to vector<1xf8E8M0FNU> + %640 = vector.bitcast %616 : vector<1xi8> to vector<1xf8E8M0FNU> + %641 = vector.bitcast %617 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.s.setprio 1 + %642 = vector.extract %622[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %643 = vector.extract %634[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %644 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%643[0] * %626) + %arg6 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %645 = vector.extract %635[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %646 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%645[0] * %627) + %arg7 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %647 = vector.extract %636[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %648 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%647[0] * %628) + %arg8 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %649 = vector.extract %637[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %650 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%649[0] * %629) + %arg9 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %651 = vector.extract %638[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %652 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%651[0] * %630) + %arg10 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %653 = vector.extract %639[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %654 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%653[0] * %631) + %arg11 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %655 = vector.extract %640[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %656 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%655[0] * %632) + %arg12 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %657 = vector.extract %641[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %658 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%657[0] * %633) + %arg13 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %659 = vector.extract %623[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %660 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%643[0] * %626) + %arg14 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %661 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%645[0] * %627) + %arg15 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %662 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%647[0] * %628) + %arg16 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %663 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%649[0] * %629) + %arg17 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %664 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%651[0] * %630) + %arg18 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %665 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%653[0] * %631) + %arg19 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %666 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%655[0] * %632) + %arg20 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %667 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%657[0] * %633) + %arg21 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %668 = vector.extract %624[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %669 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%643[0] * %626) + %arg22 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %670 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%645[0] * %627) + %arg23 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %671 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%647[0] * %628) + %arg24 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %672 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%649[0] * %629) + %arg25 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %673 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%651[0] * %630) + %arg26 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %674 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%653[0] * %631) + %arg27 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %675 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%655[0] * %632) + %arg28 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %676 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%657[0] * %633) + %arg29 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %677 = vector.extract %625[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %678 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%643[0] * %626) + %arg30 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %679 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%645[0] * %627) + %arg31 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %680 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%647[0] * %628) + %arg32 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %681 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%649[0] * %629) + %arg33 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %682 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%651[0] * %630) + %arg34 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %683 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%653[0] * %631) + %arg35 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %684 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%655[0] * %632) + %arg36 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %685 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%657[0] * %633) + %arg37 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.sched.barrier 0 + %686 = vector.load %arg38[%62, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %687 = vector.load %arg38[%66, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %688 = vector.load %arg38[%67, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %689 = vector.load %arg38[%68, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %690 = vector.load %arg40[%62, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %691 = vector.load %arg40[%66, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %692 = vector.load %arg40[%67, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %693 = vector.load %arg40[%68, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %694 = vector.load %arg42[%70, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %695 = vector.load %arg42[%71, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %696 = vector.load %arg42[%72, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %697 = vector.load %arg42[%73, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %698 = vector.load %arg42[%74, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %699 = vector.load %arg42[%75, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %700 = vector.load %arg42[%76, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %701 = vector.load %arg42[%77, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %702 = vector.load %arg44[%70, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %703 = vector.load %arg44[%71, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %704 = vector.load %arg44[%72, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %705 = vector.load %arg44[%73, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %706 = vector.load %arg44[%74, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %707 = vector.load %arg44[%75, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %708 = vector.load %arg44[%76, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %709 = vector.load %arg44[%77, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %710 = vector.bitcast %686 : vector<16xi8> to vector<32xf4E2M1FN> + %711 = vector.bitcast %687 : vector<16xi8> to vector<32xf4E2M1FN> + %712 = vector.bitcast %688 : vector<16xi8> to vector<32xf4E2M1FN> + %713 = vector.bitcast %689 : vector<16xi8> to vector<32xf4E2M1FN> + %714 = vector.bitcast %690 : vector<1xi8> to vector<1xf8E8M0FNU> + %715 = vector.bitcast %691 : vector<1xi8> to vector<1xf8E8M0FNU> + %716 = vector.bitcast %692 : vector<1xi8> to vector<1xf8E8M0FNU> + %717 = vector.bitcast %693 : vector<1xi8> to vector<1xf8E8M0FNU> + %718 = vector.bitcast %694 : vector<16xi8> to vector<32xf4E2M1FN> + %719 = vector.bitcast %695 : vector<16xi8> to vector<32xf4E2M1FN> + %720 = vector.bitcast %696 : vector<16xi8> to vector<32xf4E2M1FN> + %721 = vector.bitcast %697 : vector<16xi8> to vector<32xf4E2M1FN> + %722 = vector.bitcast %698 : vector<16xi8> to vector<32xf4E2M1FN> + %723 = vector.bitcast %699 : vector<16xi8> to vector<32xf4E2M1FN> + %724 = vector.bitcast %700 : vector<16xi8> to vector<32xf4E2M1FN> + %725 = vector.bitcast %701 : vector<16xi8> to vector<32xf4E2M1FN> + %726 = vector.bitcast %702 : vector<1xi8> to vector<1xf8E8M0FNU> + %727 = vector.bitcast %703 : vector<1xi8> to vector<1xf8E8M0FNU> + %728 = vector.bitcast %704 : vector<1xi8> to vector<1xf8E8M0FNU> + %729 = vector.bitcast %705 : vector<1xi8> to vector<1xf8E8M0FNU> + %730 = vector.bitcast %706 : vector<1xi8> to vector<1xf8E8M0FNU> + %731 = vector.bitcast %707 : vector<1xi8> to vector<1xf8E8M0FNU> + %732 = vector.bitcast %708 : vector<1xi8> to vector<1xf8E8M0FNU> + %733 = vector.bitcast %709 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.s.setprio 1 + %734 = vector.extract %714[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %735 = vector.extract %726[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %736 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%735[0] * %718) + %644 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %737 = vector.extract %727[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %738 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%737[0] * %719) + %646 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %739 = vector.extract %728[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %740 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%739[0] * %720) + %648 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %741 = vector.extract %729[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %742 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%741[0] * %721) + %650 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %743 = vector.extract %730[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %744 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%743[0] * %722) + %652 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %745 = vector.extract %731[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %746 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%745[0] * %723) + %654 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %747 = vector.extract %732[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %748 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%747[0] * %724) + %656 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %749 = vector.extract %733[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %750 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%749[0] * %725) + %658 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %751 = vector.extract %715[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %752 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%735[0] * %718) + %660 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %753 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%737[0] * %719) + %661 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %754 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%739[0] * %720) + %662 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %755 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%741[0] * %721) + %663 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %756 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%743[0] * %722) + %664 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %757 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%745[0] * %723) + %665 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %758 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%747[0] * %724) + %666 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %759 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%749[0] * %725) + %667 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %760 = vector.extract %716[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %761 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%735[0] * %718) + %669 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %762 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%737[0] * %719) + %670 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %763 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%739[0] * %720) + %671 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %764 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%741[0] * %721) + %672 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %765 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%743[0] * %722) + %673 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %766 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%745[0] * %723) + %674 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %767 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%747[0] * %724) + %675 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %768 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%749[0] * %725) + %676 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %769 = vector.extract %717[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %770 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%735[0] * %718) + %678 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %771 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%737[0] * %719) + %679 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %772 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%739[0] * %720) + %680 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %773 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%741[0] * %721) + %681 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %774 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%743[0] * %722) + %682 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %775 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%745[0] * %723) + %683 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %776 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%747[0] * %724) + %684 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %777 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%749[0] * %725) + %685 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + rocdl.sched.barrier 0 + scf.yield %736, %738, %740, %742, %744, %746, %748, %750, %752, %753, %754, %755, %756, %757, %758, %759, %761, %762, %763, %764, %765, %766, %767, %768, %770, %771, %772, %773, %774, %775, %776, %777, %arg39, %arg38, %arg41, %arg40, %arg43, %arg42, %arg45, %arg44 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space> + } + scf.if %61 { + rocdl.s.barrier + } + amdgpu.lds_barrier + %84 = affine.apply #map23()[%thread_id_x, %thread_id_y] + %85 = affine.apply #map22()[%thread_id_x] + %86 = vector.load %83#38[%84, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %87 = arith.xori %33, %c1 : index + %88 = affine.apply #map32()[%thread_id_x, %87] + %89 = vector.load %83#38[%84, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %90 = affine.apply #map24()[%thread_id_x, %thread_id_y] + %91 = vector.load %83#38[%90, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %92 = vector.load %83#38[%90, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %93 = affine.apply #map25()[%thread_id_x, %thread_id_y] + %94 = vector.load %83#38[%93, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %95 = vector.load %83#38[%93, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %96 = affine.apply #map26()[%thread_id_x, %thread_id_y] + %97 = vector.load %83#38[%96, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %98 = vector.load %83#38[%96, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %99 = affine.apply #map27()[%thread_id_x, %thread_id_y] + %100 = vector.load %83#38[%99, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %101 = vector.load %83#38[%99, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %102 = affine.apply #map28()[%thread_id_x, %thread_id_y] + %103 = vector.load %83#38[%102, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %104 = vector.load %83#38[%102, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %105 = affine.apply #map29()[%thread_id_x, %thread_id_y] + %106 = vector.load %83#38[%105, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %107 = vector.load %83#38[%105, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %108 = affine.apply #map30()[%thread_id_x, %thread_id_y] + %109 = vector.load %83#38[%108, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %110 = vector.load %83#38[%108, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %111 = affine.apply #map18()[%thread_id_x] + %112 = arith.xori %111, %7 : index + %113 = affine.apply #map3()[%112] + %114 = vector.load %83#36[%84, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %115 = affine.apply #map31()[%thread_id_x] + %116 = arith.xori %115, %7 : index + %117 = affine.apply #map3()[%116] + %118 = vector.load %83#36[%84, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %119 = vector.load %83#36[%90, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %120 = vector.load %83#36[%90, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %121 = vector.load %83#36[%93, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %122 = vector.load %83#36[%93, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %123 = vector.load %83#36[%96, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %124 = vector.load %83#36[%96, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %125 = vector.load %83#36[%99, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %126 = vector.load %83#36[%99, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %127 = vector.load %83#36[%102, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %128 = vector.load %83#36[%102, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %129 = vector.load %83#36[%105, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %130 = vector.load %83#36[%105, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %131 = vector.load %83#36[%108, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %132 = vector.load %83#36[%108, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %133 = affine.apply #map17()[%thread_id_x] + %134 = vector.load %83#34[%133, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %135 = vector.load %83#34[%133, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %136 = affine.apply #map19()[%thread_id_x] + %137 = vector.load %83#34[%136, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %138 = vector.load %83#34[%136, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %139 = affine.apply #map20()[%thread_id_x] + %140 = vector.load %83#34[%139, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %141 = vector.load %83#34[%139, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %142 = affine.apply #map21()[%thread_id_x] + %143 = vector.load %83#34[%142, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %144 = vector.load %83#34[%142, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %145 = vector.load %83#32[%133, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %146 = vector.load %83#32[%133, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %147 = vector.load %83#32[%136, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %148 = vector.load %83#32[%136, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %149 = vector.load %83#32[%139, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %150 = vector.load %83#32[%139, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %151 = vector.load %83#32[%142, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %152 = vector.load %83#32[%142, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %153 = vector.bitcast %145 : vector<16xi8> to vector<32xf4E2M1FN> + %154 = vector.bitcast %146 : vector<16xi8> to vector<32xf4E2M1FN> + %155 = vector.bitcast %147 : vector<16xi8> to vector<32xf4E2M1FN> + %156 = vector.bitcast %148 : vector<16xi8> to vector<32xf4E2M1FN> + %157 = vector.bitcast %149 : vector<16xi8> to vector<32xf4E2M1FN> + %158 = vector.bitcast %150 : vector<16xi8> to vector<32xf4E2M1FN> + %159 = vector.bitcast %151 : vector<16xi8> to vector<32xf4E2M1FN> + %160 = vector.bitcast %152 : vector<16xi8> to vector<32xf4E2M1FN> + %161 = vector.bitcast %134 : vector<1xi8> to vector<1xf8E8M0FNU> + %162 = vector.bitcast %135 : vector<1xi8> to vector<1xf8E8M0FNU> + %163 = vector.bitcast %137 : vector<1xi8> to vector<1xf8E8M0FNU> + %164 = vector.bitcast %138 : vector<1xi8> to vector<1xf8E8M0FNU> + %165 = vector.bitcast %140 : vector<1xi8> to vector<1xf8E8M0FNU> + %166 = vector.bitcast %141 : vector<1xi8> to vector<1xf8E8M0FNU> + %167 = vector.bitcast %143 : vector<1xi8> to vector<1xf8E8M0FNU> + %168 = vector.bitcast %144 : vector<1xi8> to vector<1xf8E8M0FNU> + %169 = vector.bitcast %114 : vector<16xi8> to vector<32xf4E2M1FN> + %170 = vector.bitcast %118 : vector<16xi8> to vector<32xf4E2M1FN> + %171 = vector.bitcast %119 : vector<16xi8> to vector<32xf4E2M1FN> + %172 = vector.bitcast %120 : vector<16xi8> to vector<32xf4E2M1FN> + %173 = vector.bitcast %121 : vector<16xi8> to vector<32xf4E2M1FN> + %174 = vector.bitcast %122 : vector<16xi8> to vector<32xf4E2M1FN> + %175 = vector.bitcast %123 : vector<16xi8> to vector<32xf4E2M1FN> + %176 = vector.bitcast %124 : vector<16xi8> to vector<32xf4E2M1FN> + %177 = vector.bitcast %125 : vector<16xi8> to vector<32xf4E2M1FN> + %178 = vector.bitcast %126 : vector<16xi8> to vector<32xf4E2M1FN> + %179 = vector.bitcast %127 : vector<16xi8> to vector<32xf4E2M1FN> + %180 = vector.bitcast %128 : vector<16xi8> to vector<32xf4E2M1FN> + %181 = vector.bitcast %129 : vector<16xi8> to vector<32xf4E2M1FN> + %182 = vector.bitcast %130 : vector<16xi8> to vector<32xf4E2M1FN> + %183 = vector.bitcast %131 : vector<16xi8> to vector<32xf4E2M1FN> + %184 = vector.bitcast %132 : vector<16xi8> to vector<32xf4E2M1FN> + %185 = vector.bitcast %86 : vector<1xi8> to vector<1xf8E8M0FNU> + %186 = vector.bitcast %89 : vector<1xi8> to vector<1xf8E8M0FNU> + %187 = vector.bitcast %91 : vector<1xi8> to vector<1xf8E8M0FNU> + %188 = vector.bitcast %92 : vector<1xi8> to vector<1xf8E8M0FNU> + %189 = vector.bitcast %94 : vector<1xi8> to vector<1xf8E8M0FNU> + %190 = vector.bitcast %95 : vector<1xi8> to vector<1xf8E8M0FNU> + %191 = vector.bitcast %97 : vector<1xi8> to vector<1xf8E8M0FNU> + %192 = vector.bitcast %98 : vector<1xi8> to vector<1xf8E8M0FNU> + %193 = vector.bitcast %100 : vector<1xi8> to vector<1xf8E8M0FNU> + %194 = vector.bitcast %101 : vector<1xi8> to vector<1xf8E8M0FNU> + %195 = vector.bitcast %103 : vector<1xi8> to vector<1xf8E8M0FNU> + %196 = vector.bitcast %104 : vector<1xi8> to vector<1xf8E8M0FNU> + %197 = vector.bitcast %106 : vector<1xi8> to vector<1xf8E8M0FNU> + %198 = vector.bitcast %107 : vector<1xi8> to vector<1xf8E8M0FNU> + %199 = vector.bitcast %109 : vector<1xi8> to vector<1xf8E8M0FNU> + %200 = vector.bitcast %110 : vector<1xi8> to vector<1xf8E8M0FNU> + %201 = vector.extract %161[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %202 = vector.extract %185[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %203 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%202[0] * %169) + %83#0 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %204 = vector.extract %162[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %205 = vector.extract %186[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %206 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%205[0] * %170) + %203 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %207 = vector.extract %187[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %208 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%207[0] * %171) + %83#1 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %209 = vector.extract %188[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %210 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%209[0] * %172) + %208 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %211 = vector.extract %189[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %212 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%211[0] * %173) + %83#2 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %213 = vector.extract %190[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %214 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%213[0] * %174) + %212 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %215 = vector.extract %191[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %216 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%215[0] * %175) + %83#3 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %217 = vector.extract %192[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %218 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%217[0] * %176) + %216 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %219 = vector.extract %193[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %220 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%219[0] * %177) + %83#4 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %221 = vector.extract %194[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %222 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%221[0] * %178) + %220 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %223 = vector.extract %195[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %224 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%223[0] * %179) + %83#5 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %225 = vector.extract %196[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %226 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%225[0] * %180) + %224 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %227 = vector.extract %197[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %228 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%227[0] * %181) + %83#6 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %229 = vector.extract %198[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %230 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%229[0] * %182) + %228 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %231 = vector.extract %199[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %232 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%231[0] * %183) + %83#7 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %233 = vector.extract %200[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %234 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%233[0] * %184) + %232 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %235 = vector.extract %163[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %236 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%202[0] * %169) + %83#8 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %237 = vector.extract %164[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %238 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%205[0] * %170) + %236 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %239 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%207[0] * %171) + %83#9 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %240 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%209[0] * %172) + %239 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %241 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%211[0] * %173) + %83#10 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %242 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%213[0] * %174) + %241 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %243 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%215[0] * %175) + %83#11 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %244 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%217[0] * %176) + %243 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %245 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%219[0] * %177) + %83#12 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %246 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%221[0] * %178) + %245 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %247 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%223[0] * %179) + %83#13 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %248 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%225[0] * %180) + %247 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %249 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%227[0] * %181) + %83#14 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %250 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%229[0] * %182) + %249 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %251 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%231[0] * %183) + %83#15 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %252 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%233[0] * %184) + %251 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %253 = vector.extract %165[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %254 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%202[0] * %169) + %83#16 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %255 = vector.extract %166[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %256 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%205[0] * %170) + %254 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %257 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%207[0] * %171) + %83#17 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %258 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%209[0] * %172) + %257 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %259 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%211[0] * %173) + %83#18 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %260 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%213[0] * %174) + %259 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %261 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%215[0] * %175) + %83#19 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %262 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%217[0] * %176) + %261 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %263 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%219[0] * %177) + %83#20 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %264 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%221[0] * %178) + %263 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %265 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%223[0] * %179) + %83#21 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %266 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%225[0] * %180) + %265 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %267 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%227[0] * %181) + %83#22 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %268 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%229[0] * %182) + %267 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %269 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%231[0] * %183) + %83#23 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %270 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%233[0] * %184) + %269 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %271 = vector.extract %167[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %272 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%202[0] * %169) + %83#24 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %273 = vector.extract %168[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %274 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%205[0] * %170) + %272 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %275 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%207[0] * %171) + %83#25 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %276 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%209[0] * %172) + %275 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %277 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%211[0] * %173) + %83#26 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %278 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%213[0] * %174) + %277 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %279 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%215[0] * %175) + %83#27 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %280 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%217[0] * %176) + %279 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %281 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%219[0] * %177) + %83#28 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %282 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%221[0] * %178) + %281 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %283 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%223[0] * %179) + %83#29 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %284 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%225[0] * %180) + %283 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %285 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%227[0] * %181) + %83#30 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %286 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%229[0] * %182) + %285 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %287 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%231[0] * %183) + %83#31 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %288 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%233[0] * %184) + %287 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %289 = vector.extract_strided_slice %206 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %290 = affine.apply #map35()[%block_id_x] + %291 = affine.apply #map35()[%block_id_y] + %292 = affine.apply #map36()[%thread_id_x] + %293 = arith.muli %290, %c57344 overflow : index + %294 = arith.muli %292, %c57344 overflow : index + %295 = arith.addi %293, %291 overflow : index + %296 = arith.addi %294, %84 overflow : index + %reinterpret_cast_13 = memref.reinterpret_cast %4 to offset: [%295], sizes: [536870910], strides: [1] : memref to memref<536870910xf32, strided<[1], offset: ?>> + %cast_14 = memref.cast %reinterpret_cast_13 : memref<536870910xf32, strided<[1], offset: ?>> to memref> + %297 = amdgpu.fat_raw_buffer_cast %cast_14 validBytes(%c2147483643_i64) resetOffset : memref> to memref> + vector.store %289, %297[%296] : memref>, vector<1xf32> + %298 = vector.extract_strided_slice %206 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %299 = affine.apply #map37()[%thread_id_x] + %300 = arith.muli %299, %c57344 overflow : index + %301 = arith.addi %300, %84 overflow : index + vector.store %298, %297[%301] : memref>, vector<1xf32> + %302 = vector.extract_strided_slice %206 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %303 = affine.apply #map38()[%thread_id_x] + %304 = arith.muli %303, %c57344 overflow : index + %305 = arith.addi %304, %84 overflow : index + vector.store %302, %297[%305] : memref>, vector<1xf32> + %306 = vector.extract_strided_slice %206 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %307 = affine.apply #map39()[%thread_id_x] + %308 = arith.muli %307, %c57344 overflow : index + %309 = arith.addi %308, %84 overflow : index + vector.store %306, %297[%309] : memref>, vector<1xf32> + %310 = vector.extract_strided_slice %210 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %311 = arith.addi %294, %90 overflow : index + vector.store %310, %297[%311] : memref>, vector<1xf32> + %312 = vector.extract_strided_slice %210 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %313 = arith.addi %300, %90 overflow : index + vector.store %312, %297[%313] : memref>, vector<1xf32> + %314 = vector.extract_strided_slice %210 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %315 = arith.addi %304, %90 overflow : index + vector.store %314, %297[%315] : memref>, vector<1xf32> + %316 = vector.extract_strided_slice %210 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %317 = arith.addi %308, %90 overflow : index + vector.store %316, %297[%317] : memref>, vector<1xf32> + %318 = vector.extract_strided_slice %214 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %319 = arith.addi %294, %93 overflow : index + vector.store %318, %297[%319] : memref>, vector<1xf32> + %320 = vector.extract_strided_slice %214 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %321 = arith.addi %300, %93 overflow : index + vector.store %320, %297[%321] : memref>, vector<1xf32> + %322 = vector.extract_strided_slice %214 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %323 = arith.addi %304, %93 overflow : index + vector.store %322, %297[%323] : memref>, vector<1xf32> + %324 = vector.extract_strided_slice %214 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %325 = arith.addi %308, %93 overflow : index + vector.store %324, %297[%325] : memref>, vector<1xf32> + %326 = vector.extract_strided_slice %218 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %327 = arith.addi %294, %96 overflow : index + vector.store %326, %297[%327] : memref>, vector<1xf32> + %328 = vector.extract_strided_slice %218 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %329 = arith.addi %300, %96 overflow : index + vector.store %328, %297[%329] : memref>, vector<1xf32> + %330 = vector.extract_strided_slice %218 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %331 = arith.addi %304, %96 overflow : index + vector.store %330, %297[%331] : memref>, vector<1xf32> + %332 = vector.extract_strided_slice %218 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %333 = arith.addi %308, %96 overflow : index + vector.store %332, %297[%333] : memref>, vector<1xf32> + %334 = vector.extract_strided_slice %222 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %335 = arith.addi %294, %99 overflow : index + vector.store %334, %297[%335] : memref>, vector<1xf32> + %336 = vector.extract_strided_slice %222 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %337 = arith.addi %300, %99 overflow : index + vector.store %336, %297[%337] : memref>, vector<1xf32> + %338 = vector.extract_strided_slice %222 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %339 = arith.addi %304, %99 overflow : index + vector.store %338, %297[%339] : memref>, vector<1xf32> + %340 = vector.extract_strided_slice %222 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %341 = arith.addi %308, %99 overflow : index + vector.store %340, %297[%341] : memref>, vector<1xf32> + %342 = vector.extract_strided_slice %226 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %343 = arith.addi %294, %102 overflow : index + vector.store %342, %297[%343] : memref>, vector<1xf32> + %344 = vector.extract_strided_slice %226 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %345 = arith.addi %300, %102 overflow : index + vector.store %344, %297[%345] : memref>, vector<1xf32> + %346 = vector.extract_strided_slice %226 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %347 = arith.addi %304, %102 overflow : index + vector.store %346, %297[%347] : memref>, vector<1xf32> + %348 = vector.extract_strided_slice %226 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %349 = arith.addi %308, %102 overflow : index + vector.store %348, %297[%349] : memref>, vector<1xf32> + %350 = vector.extract_strided_slice %230 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %351 = arith.addi %294, %105 overflow : index + vector.store %350, %297[%351] : memref>, vector<1xf32> + %352 = vector.extract_strided_slice %230 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %353 = arith.addi %300, %105 overflow : index + vector.store %352, %297[%353] : memref>, vector<1xf32> + %354 = vector.extract_strided_slice %230 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %355 = arith.addi %304, %105 overflow : index + vector.store %354, %297[%355] : memref>, vector<1xf32> + %356 = vector.extract_strided_slice %230 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %357 = arith.addi %308, %105 overflow : index + vector.store %356, %297[%357] : memref>, vector<1xf32> + %358 = vector.extract_strided_slice %234 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %359 = arith.addi %294, %108 overflow : index + vector.store %358, %297[%359] : memref>, vector<1xf32> + %360 = vector.extract_strided_slice %234 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %361 = arith.addi %300, %108 overflow : index + vector.store %360, %297[%361] : memref>, vector<1xf32> + %362 = vector.extract_strided_slice %234 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %363 = arith.addi %304, %108 overflow : index + vector.store %362, %297[%363] : memref>, vector<1xf32> + %364 = vector.extract_strided_slice %234 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %365 = arith.addi %308, %108 overflow : index + vector.store %364, %297[%365] : memref>, vector<1xf32> + %366 = vector.extract_strided_slice %238 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %367 = affine.apply #map40()[%thread_id_x] + %368 = arith.muli %367, %c57344 overflow : index + %369 = arith.addi %368, %84 overflow : index + vector.store %366, %297[%369] : memref>, vector<1xf32> + %370 = vector.extract_strided_slice %238 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %371 = affine.apply #map41()[%thread_id_x] + %372 = arith.muli %371, %c57344 overflow : index + %373 = arith.addi %372, %84 overflow : index + vector.store %370, %297[%373] : memref>, vector<1xf32> + %374 = vector.extract_strided_slice %238 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %375 = affine.apply #map42()[%thread_id_x] + %376 = arith.muli %375, %c57344 overflow : index + %377 = arith.addi %376, %84 overflow : index + vector.store %374, %297[%377] : memref>, vector<1xf32> + %378 = vector.extract_strided_slice %238 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %379 = affine.apply #map43()[%thread_id_x] + %380 = arith.muli %379, %c57344 overflow : index + %381 = arith.addi %380, %84 overflow : index + vector.store %378, %297[%381] : memref>, vector<1xf32> + %382 = vector.extract_strided_slice %240 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %383 = arith.addi %368, %90 overflow : index + vector.store %382, %297[%383] : memref>, vector<1xf32> + %384 = vector.extract_strided_slice %240 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %385 = arith.addi %372, %90 overflow : index + vector.store %384, %297[%385] : memref>, vector<1xf32> + %386 = vector.extract_strided_slice %240 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %387 = arith.addi %376, %90 overflow : index + vector.store %386, %297[%387] : memref>, vector<1xf32> + %388 = vector.extract_strided_slice %240 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %389 = arith.addi %380, %90 overflow : index + vector.store %388, %297[%389] : memref>, vector<1xf32> + %390 = vector.extract_strided_slice %242 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %391 = arith.addi %368, %93 overflow : index + vector.store %390, %297[%391] : memref>, vector<1xf32> + %392 = vector.extract_strided_slice %242 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %393 = arith.addi %372, %93 overflow : index + vector.store %392, %297[%393] : memref>, vector<1xf32> + %394 = vector.extract_strided_slice %242 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %395 = arith.addi %376, %93 overflow : index + vector.store %394, %297[%395] : memref>, vector<1xf32> + %396 = vector.extract_strided_slice %242 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %397 = arith.addi %380, %93 overflow : index + vector.store %396, %297[%397] : memref>, vector<1xf32> + %398 = vector.extract_strided_slice %244 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %399 = arith.addi %368, %96 overflow : index + vector.store %398, %297[%399] : memref>, vector<1xf32> + %400 = vector.extract_strided_slice %244 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %401 = arith.addi %372, %96 overflow : index + vector.store %400, %297[%401] : memref>, vector<1xf32> + %402 = vector.extract_strided_slice %244 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %403 = arith.addi %376, %96 overflow : index + vector.store %402, %297[%403] : memref>, vector<1xf32> + %404 = vector.extract_strided_slice %244 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %405 = arith.addi %380, %96 overflow : index + vector.store %404, %297[%405] : memref>, vector<1xf32> + %406 = vector.extract_strided_slice %246 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %407 = arith.addi %368, %99 overflow : index + vector.store %406, %297[%407] : memref>, vector<1xf32> + %408 = vector.extract_strided_slice %246 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %409 = arith.addi %372, %99 overflow : index + vector.store %408, %297[%409] : memref>, vector<1xf32> + %410 = vector.extract_strided_slice %246 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %411 = arith.addi %376, %99 overflow : index + vector.store %410, %297[%411] : memref>, vector<1xf32> + %412 = vector.extract_strided_slice %246 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %413 = arith.addi %380, %99 overflow : index + vector.store %412, %297[%413] : memref>, vector<1xf32> + %414 = vector.extract_strided_slice %248 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %415 = arith.addi %368, %102 overflow : index + vector.store %414, %297[%415] : memref>, vector<1xf32> + %416 = vector.extract_strided_slice %248 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %417 = arith.addi %372, %102 overflow : index + vector.store %416, %297[%417] : memref>, vector<1xf32> + %418 = vector.extract_strided_slice %248 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %419 = arith.addi %376, %102 overflow : index + vector.store %418, %297[%419] : memref>, vector<1xf32> + %420 = vector.extract_strided_slice %248 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %421 = arith.addi %380, %102 overflow : index + vector.store %420, %297[%421] : memref>, vector<1xf32> + %422 = vector.extract_strided_slice %250 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %423 = arith.addi %368, %105 overflow : index + vector.store %422, %297[%423] : memref>, vector<1xf32> + %424 = vector.extract_strided_slice %250 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %425 = arith.addi %372, %105 overflow : index + vector.store %424, %297[%425] : memref>, vector<1xf32> + %426 = vector.extract_strided_slice %250 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %427 = arith.addi %376, %105 overflow : index + vector.store %426, %297[%427] : memref>, vector<1xf32> + %428 = vector.extract_strided_slice %250 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %429 = arith.addi %380, %105 overflow : index + vector.store %428, %297[%429] : memref>, vector<1xf32> + %430 = vector.extract_strided_slice %252 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %431 = arith.addi %368, %108 overflow : index + vector.store %430, %297[%431] : memref>, vector<1xf32> + %432 = vector.extract_strided_slice %252 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %433 = arith.addi %372, %108 overflow : index + vector.store %432, %297[%433] : memref>, vector<1xf32> + %434 = vector.extract_strided_slice %252 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %435 = arith.addi %376, %108 overflow : index + vector.store %434, %297[%435] : memref>, vector<1xf32> + %436 = vector.extract_strided_slice %252 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %437 = arith.addi %380, %108 overflow : index + vector.store %436, %297[%437] : memref>, vector<1xf32> + %438 = vector.extract_strided_slice %256 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %439 = affine.apply #map44()[%thread_id_x] + %440 = arith.muli %439, %c57344 overflow : index + %441 = arith.addi %440, %84 overflow : index + vector.store %438, %297[%441] : memref>, vector<1xf32> + %442 = vector.extract_strided_slice %256 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %443 = affine.apply #map45()[%thread_id_x] + %444 = arith.muli %443, %c57344 overflow : index + %445 = arith.addi %444, %84 overflow : index + vector.store %442, %297[%445] : memref>, vector<1xf32> + %446 = vector.extract_strided_slice %256 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %447 = affine.apply #map46()[%thread_id_x] + %448 = arith.muli %447, %c57344 overflow : index + %449 = arith.addi %448, %84 overflow : index + vector.store %446, %297[%449] : memref>, vector<1xf32> + %450 = vector.extract_strided_slice %256 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %451 = affine.apply #map47()[%thread_id_x] + %452 = arith.muli %451, %c57344 overflow : index + %453 = arith.addi %452, %84 overflow : index + vector.store %450, %297[%453] : memref>, vector<1xf32> + %454 = vector.extract_strided_slice %258 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %455 = arith.addi %440, %90 overflow : index + vector.store %454, %297[%455] : memref>, vector<1xf32> + %456 = vector.extract_strided_slice %258 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %457 = arith.addi %444, %90 overflow : index + vector.store %456, %297[%457] : memref>, vector<1xf32> + %458 = vector.extract_strided_slice %258 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %459 = arith.addi %448, %90 overflow : index + vector.store %458, %297[%459] : memref>, vector<1xf32> + %460 = vector.extract_strided_slice %258 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %461 = arith.addi %452, %90 overflow : index + vector.store %460, %297[%461] : memref>, vector<1xf32> + %462 = vector.extract_strided_slice %260 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %463 = arith.addi %440, %93 overflow : index + vector.store %462, %297[%463] : memref>, vector<1xf32> + %464 = vector.extract_strided_slice %260 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %465 = arith.addi %444, %93 overflow : index + vector.store %464, %297[%465] : memref>, vector<1xf32> + %466 = vector.extract_strided_slice %260 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %467 = arith.addi %448, %93 overflow : index + vector.store %466, %297[%467] : memref>, vector<1xf32> + %468 = vector.extract_strided_slice %260 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %469 = arith.addi %452, %93 overflow : index + vector.store %468, %297[%469] : memref>, vector<1xf32> + %470 = vector.extract_strided_slice %262 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %471 = arith.addi %440, %96 overflow : index + vector.store %470, %297[%471] : memref>, vector<1xf32> + %472 = vector.extract_strided_slice %262 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %473 = arith.addi %444, %96 overflow : index + vector.store %472, %297[%473] : memref>, vector<1xf32> + %474 = vector.extract_strided_slice %262 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %475 = arith.addi %448, %96 overflow : index + vector.store %474, %297[%475] : memref>, vector<1xf32> + %476 = vector.extract_strided_slice %262 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %477 = arith.addi %452, %96 overflow : index + vector.store %476, %297[%477] : memref>, vector<1xf32> + %478 = vector.extract_strided_slice %264 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %479 = arith.addi %440, %99 overflow : index + vector.store %478, %297[%479] : memref>, vector<1xf32> + %480 = vector.extract_strided_slice %264 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %481 = arith.addi %444, %99 overflow : index + vector.store %480, %297[%481] : memref>, vector<1xf32> + %482 = vector.extract_strided_slice %264 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %483 = arith.addi %448, %99 overflow : index + vector.store %482, %297[%483] : memref>, vector<1xf32> + %484 = vector.extract_strided_slice %264 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %485 = arith.addi %452, %99 overflow : index + vector.store %484, %297[%485] : memref>, vector<1xf32> + %486 = vector.extract_strided_slice %266 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %487 = arith.addi %440, %102 overflow : index + vector.store %486, %297[%487] : memref>, vector<1xf32> + %488 = vector.extract_strided_slice %266 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %489 = arith.addi %444, %102 overflow : index + vector.store %488, %297[%489] : memref>, vector<1xf32> + %490 = vector.extract_strided_slice %266 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %491 = arith.addi %448, %102 overflow : index + vector.store %490, %297[%491] : memref>, vector<1xf32> + %492 = vector.extract_strided_slice %266 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %493 = arith.addi %452, %102 overflow : index + vector.store %492, %297[%493] : memref>, vector<1xf32> + %494 = vector.extract_strided_slice %268 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %495 = arith.addi %440, %105 overflow : index + vector.store %494, %297[%495] : memref>, vector<1xf32> + %496 = vector.extract_strided_slice %268 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %497 = arith.addi %444, %105 overflow : index + vector.store %496, %297[%497] : memref>, vector<1xf32> + %498 = vector.extract_strided_slice %268 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %499 = arith.addi %448, %105 overflow : index + vector.store %498, %297[%499] : memref>, vector<1xf32> + %500 = vector.extract_strided_slice %268 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %501 = arith.addi %452, %105 overflow : index + vector.store %500, %297[%501] : memref>, vector<1xf32> + %502 = vector.extract_strided_slice %270 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %503 = arith.addi %440, %108 overflow : index + vector.store %502, %297[%503] : memref>, vector<1xf32> + %504 = vector.extract_strided_slice %270 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %505 = arith.addi %444, %108 overflow : index + vector.store %504, %297[%505] : memref>, vector<1xf32> + %506 = vector.extract_strided_slice %270 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %507 = arith.addi %448, %108 overflow : index + vector.store %506, %297[%507] : memref>, vector<1xf32> + %508 = vector.extract_strided_slice %270 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %509 = arith.addi %452, %108 overflow : index + vector.store %508, %297[%509] : memref>, vector<1xf32> + %510 = vector.extract_strided_slice %274 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %511 = affine.apply #map48()[%thread_id_x] + %512 = arith.muli %511, %c57344 overflow : index + %513 = arith.addi %512, %84 overflow : index + vector.store %510, %297[%513] : memref>, vector<1xf32> + %514 = vector.extract_strided_slice %274 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %515 = affine.apply #map49()[%thread_id_x] + %516 = arith.muli %515, %c57344 overflow : index + %517 = arith.addi %516, %84 overflow : index + vector.store %514, %297[%517] : memref>, vector<1xf32> + %518 = vector.extract_strided_slice %274 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %519 = affine.apply #map50()[%thread_id_x] + %520 = arith.muli %519, %c57344 overflow : index + %521 = arith.addi %520, %84 overflow : index + vector.store %518, %297[%521] : memref>, vector<1xf32> + %522 = vector.extract_strided_slice %274 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %523 = affine.apply #map51()[%thread_id_x] + %524 = arith.muli %523, %c57344 overflow : index + %525 = arith.addi %524, %84 overflow : index + vector.store %522, %297[%525] : memref>, vector<1xf32> + %526 = vector.extract_strided_slice %276 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %527 = arith.addi %512, %90 overflow : index + vector.store %526, %297[%527] : memref>, vector<1xf32> + %528 = vector.extract_strided_slice %276 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %529 = arith.addi %516, %90 overflow : index + vector.store %528, %297[%529] : memref>, vector<1xf32> + %530 = vector.extract_strided_slice %276 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %531 = arith.addi %520, %90 overflow : index + vector.store %530, %297[%531] : memref>, vector<1xf32> + %532 = vector.extract_strided_slice %276 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %533 = arith.addi %524, %90 overflow : index + vector.store %532, %297[%533] : memref>, vector<1xf32> + %534 = vector.extract_strided_slice %278 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %535 = arith.addi %512, %93 overflow : index + vector.store %534, %297[%535] : memref>, vector<1xf32> + %536 = vector.extract_strided_slice %278 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %537 = arith.addi %516, %93 overflow : index + vector.store %536, %297[%537] : memref>, vector<1xf32> + %538 = vector.extract_strided_slice %278 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %539 = arith.addi %520, %93 overflow : index + vector.store %538, %297[%539] : memref>, vector<1xf32> + %540 = vector.extract_strided_slice %278 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %541 = arith.addi %524, %93 overflow : index + vector.store %540, %297[%541] : memref>, vector<1xf32> + %542 = vector.extract_strided_slice %280 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %543 = arith.addi %512, %96 overflow : index + vector.store %542, %297[%543] : memref>, vector<1xf32> + %544 = vector.extract_strided_slice %280 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %545 = arith.addi %516, %96 overflow : index + vector.store %544, %297[%545] : memref>, vector<1xf32> + %546 = vector.extract_strided_slice %280 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %547 = arith.addi %520, %96 overflow : index + vector.store %546, %297[%547] : memref>, vector<1xf32> + %548 = vector.extract_strided_slice %280 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %549 = arith.addi %524, %96 overflow : index + vector.store %548, %297[%549] : memref>, vector<1xf32> + %550 = vector.extract_strided_slice %282 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %551 = arith.addi %512, %99 overflow : index + vector.store %550, %297[%551] : memref>, vector<1xf32> + %552 = vector.extract_strided_slice %282 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %553 = arith.addi %516, %99 overflow : index + vector.store %552, %297[%553] : memref>, vector<1xf32> + %554 = vector.extract_strided_slice %282 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %555 = arith.addi %520, %99 overflow : index + vector.store %554, %297[%555] : memref>, vector<1xf32> + %556 = vector.extract_strided_slice %282 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %557 = arith.addi %524, %99 overflow : index + vector.store %556, %297[%557] : memref>, vector<1xf32> + %558 = vector.extract_strided_slice %284 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %559 = arith.addi %512, %102 overflow : index + vector.store %558, %297[%559] : memref>, vector<1xf32> + %560 = vector.extract_strided_slice %284 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %561 = arith.addi %516, %102 overflow : index + vector.store %560, %297[%561] : memref>, vector<1xf32> + %562 = vector.extract_strided_slice %284 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %563 = arith.addi %520, %102 overflow : index + vector.store %562, %297[%563] : memref>, vector<1xf32> + %564 = vector.extract_strided_slice %284 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %565 = arith.addi %524, %102 overflow : index + vector.store %564, %297[%565] : memref>, vector<1xf32> + %566 = vector.extract_strided_slice %286 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %567 = arith.addi %512, %105 overflow : index + vector.store %566, %297[%567] : memref>, vector<1xf32> + %568 = vector.extract_strided_slice %286 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %569 = arith.addi %516, %105 overflow : index + vector.store %568, %297[%569] : memref>, vector<1xf32> + %570 = vector.extract_strided_slice %286 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %571 = arith.addi %520, %105 overflow : index + vector.store %570, %297[%571] : memref>, vector<1xf32> + %572 = vector.extract_strided_slice %286 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %573 = arith.addi %524, %105 overflow : index + vector.store %572, %297[%573] : memref>, vector<1xf32> + %574 = vector.extract_strided_slice %288 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %575 = arith.addi %512, %108 overflow : index + vector.store %574, %297[%575] : memref>, vector<1xf32> + %576 = vector.extract_strided_slice %288 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %577 = arith.addi %516, %108 overflow : index + vector.store %576, %297[%577] : memref>, vector<1xf32> + %578 = vector.extract_strided_slice %288 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %579 = arith.addi %520, %108 overflow : index + vector.store %578, %297[%579] : memref>, vector<1xf32> + %580 = vector.extract_strided_slice %288 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %581 = arith.addi %524, %108 overflow : index + vector.store %580, %297[%581] : memref>, vector<1xf32> + return + } + } + } + func.func @isolated_benchmark$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view, %arg5: !hal.fence, %arg6: !hal.fence) -> !hal.buffer_view { + %0 = hal.tensor.import wait(%arg5) => %arg0 : !hal.buffer_view -> tensor<4096x8192xi8> + %1 = hal.tensor.import wait(%arg5) => %arg1 : !hal.buffer_view -> tensor<4096x512xi8> + %2 = hal.tensor.import wait(%arg5) => %arg2 : !hal.buffer_view -> tensor<57344x8192xi8> + %3 = hal.tensor.import wait(%arg5) => %arg3 : !hal.buffer_view -> tensor<57344x512xi8> + %4 = hal.tensor.import wait(%arg5) => %arg4 : !hal.buffer_view -> tensor<4096x57344xf32> + %5 = flow.dispatch @gemm::@gemm(%0, %1, %2, %3, %4) : (tensor<4096x8192xi8>, tensor<4096x512xi8>, tensor<57344x8192xi8>, tensor<57344x512xi8>, tensor<4096x57344xf32>) -> %4 + %6 = hal.tensor.barrier join(%5 : tensor<4096x57344xf32>) => %arg6 : !hal.fence + %7 = hal.tensor.export %6 : tensor<4096x57344xf32> -> !hal.buffer_view + return %7 : !hal.buffer_view + } + } + """ + + mlir_claude_rescheduled = """ + #map = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8) floordiv 256) * 256)> + #map1 = affine_map<()[s0] -> ((s0 floordiv 8) mod 8)> + #map2 = affine_map<()[s0] -> (s0 mod 8)> + #map3 = affine_map<()[s0] -> (s0 * 16)> + #map4 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64) floordiv 32) * 256)> + #map5 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 64) floordiv 256) * 256 + 64)> + #map6 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 8) floordiv 32) * 256 + 64)> + #map7 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 128) floordiv 256) * 256 + 128)> + #map8 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 16) floordiv 32) * 256 + 128)> + #map9 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 192) floordiv 256) * 256 + 192)> + #map10 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 24) floordiv 32) * 256 + 192)> + #map11 = affine_map<()[s0, s1, s2] -> (s1 * 128 + s2 * 256 + s0 floordiv 2 - ((s1 * 128 + s0 floordiv 2) floordiv 256) * 256)> + #map12 = affine_map<()[s0] -> ((s0 floordiv 2) mod 2)> + #map13 = affine_map<()[s0] -> (s0 mod 2)> + #map14 = affine_map<()[s0] -> (s0 * 4)> + #map15 = affine_map<()[s0, s1] -> (s1 * 128 + (s0 floordiv 64) * 32 - ((s1 * 4 + s0 floordiv 64) floordiv 8) * 256)> + #map16 = affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 64)> + #map17 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64)> + #map18 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> + #map19 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 16)> + #map20 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 32)> + #map21 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 48)> + #map22 = affine_map<()[s0] -> (s0 * 4 + (s0 mod 64) floordiv 16 - (s0 floordiv 2) * 8)> + #map23 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16)> + #map24 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 16)> + #map25 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 32)> + #map26 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 48)> + #map27 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 64)> + #map28 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 80)> + #map29 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 96)> + #map30 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 112)> + #map31 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16 + 4)> + #map32 = affine_map<()[s0, s1] -> (s1 * 4 + (s0 mod 64) floordiv 16)> + #map33 = affine_map<()[s0, s1] -> (s0 * 128 + s1 * 16 + 128)> + #map34 = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 4 + 8)> + #map35 = affine_map<()[s0] -> (s0 * 256)> + #map36 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4)> + #map37 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 1)> + #map38 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 2)> + #map39 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 3)> + #map40 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 16)> + #map41 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 17)> + #map42 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 18)> + #map43 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 19)> + #map44 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 32)> + #map45 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 33)> + #map46 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 34)> + #map47 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 35)> + #map48 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 48)> + #map49 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 49)> + #map50 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 50)> + #map51 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 51)> + #translation = #iree_codegen.translation_info + module attributes {transform.with_named_sequence} { + stream.executable private @gemm { + stream.executable.export public @gemm workgroups() -> (index, index, index) { + %c16 = arith.constant 16 : index + %c224 = arith.constant 224 : index + %c1 = arith.constant 1 : index + stream.return %c16, %c224, %c1 : index, index, index + } + builtin.module { + func.func @gemm(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding, %arg3: !stream.binding, %arg4: !stream.binding) attributes {translation_info = #translation} { + %c4_i32 = arith.constant 4 : i32 + %c512_i14 = arith.constant 512 : i14 + %c-8192_i14 = arith.constant -8192 : i14 + %c2147483643_i64 = arith.constant 2147483643 : i64 + %c57344 = arith.constant 57344 : index + %c63 = arith.constant 63 : index + %c512 = arith.constant 512 : index + %c2147483646_i64 = arith.constant 2147483646 : i64 + %c8192 = arith.constant 8192 : index + %c1 = arith.constant 1 : index + %cst = arith.constant dense<0.000000e+00> : vector<4xf32> + %c0 = arith.constant 0 : index + %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> memref + %1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> memref + %2 = stream.binding.subspan %arg2[%c0] : !stream.binding -> memref + %3 = stream.binding.subspan %arg3[%c0] : !stream.binding -> memref + %4 = stream.binding.subspan %arg4[%c0] : !stream.binding -> memref + %block_id_x = gpu.block_id x upper_bound 16 + %block_id_y = gpu.block_id y upper_bound 224 + %thread_id_x = gpu.thread_id x upper_bound 256 + %thread_id_y = gpu.thread_id y upper_bound 2 + %alloc = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_0 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_1 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_2 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_3 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_4 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_5 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_6 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %5 = affine.apply #map()[%thread_id_x, %thread_id_y, %block_id_x] + %6 = affine.apply #map1()[%thread_id_x] + %7 = affine.apply #map2()[%thread_id_x] + %8 = arith.xori %7, %6 : index + %9 = affine.apply #map3()[%8] + %10 = affine.apply #map4()[%thread_id_x, %thread_id_y] + %11 = gpu.subgroup_broadcast %10, first_active_lane : index + %12 = gpu.subgroup_broadcast %c0, first_active_lane : index + %13 = arith.muli %5, %c8192 overflow : index + %14 = arith.addi %13, %9 overflow : index + %reinterpret_cast = memref.reinterpret_cast %0 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast = memref.cast %reinterpret_cast : memref<2147483646xi8, strided<[1]>> to memref> + %15 = amdgpu.fat_raw_buffer_cast %cast validBytes(%c2147483646_i64) cacheSwizzleStride(%c-8192_i14) resetOffset : memref> to memref> + %16 = affine.apply #map5()[%thread_id_x, %thread_id_y, %block_id_x] + %17 = affine.apply #map6()[%thread_id_x, %thread_id_y] + %18 = gpu.subgroup_broadcast %17, first_active_lane : index + %19 = arith.muli %16, %c8192 overflow : index + %20 = arith.addi %19, %9 overflow : index + %21 = affine.apply #map7()[%thread_id_x, %thread_id_y, %block_id_x] + %22 = affine.apply #map8()[%thread_id_x, %thread_id_y] + %23 = gpu.subgroup_broadcast %22, first_active_lane : index + %24 = arith.muli %21, %c8192 overflow : index + %25 = arith.addi %24, %9 overflow : index + %26 = affine.apply #map9()[%thread_id_x, %thread_id_y, %block_id_x] + %27 = affine.apply #map10()[%thread_id_x, %thread_id_y] + %28 = gpu.subgroup_broadcast %27, first_active_lane : index + %29 = arith.muli %26, %c8192 overflow : index + %30 = arith.addi %29, %9 overflow : index + %31 = affine.apply #map11()[%thread_id_x, %thread_id_y, %block_id_x] + %32 = affine.apply #map12()[%thread_id_x] + %33 = affine.apply #map13()[%thread_id_x] + %34 = arith.xori %33, %32 : index + %35 = affine.apply #map14()[%34] + %36 = affine.apply #map15()[%thread_id_x, %thread_id_y] + %37 = gpu.subgroup_broadcast %36, first_active_lane : index + %38 = arith.muli %31, %c512 overflow : index + %39 = arith.addi %38, %35 overflow : index + %reinterpret_cast_7 = memref.reinterpret_cast %1 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_8 = memref.cast %reinterpret_cast_7 : memref<2147483646xi8, strided<[1]>> to memref> + %40 = amdgpu.fat_raw_buffer_cast %cast_8 validBytes(%c2147483646_i64) cacheSwizzleStride(%c512_i14) resetOffset : memref> to memref> + %41 = affine.apply #map()[%thread_id_x, %thread_id_y, %block_id_y] + %42 = arith.muli %41, %c8192 overflow : index + %43 = arith.addi %42, %9 overflow : index + %reinterpret_cast_9 = memref.reinterpret_cast %2 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_10 = memref.cast %reinterpret_cast_9 : memref<2147483646xi8, strided<[1]>> to memref> + %44 = amdgpu.fat_raw_buffer_cast %cast_10 validBytes(%c2147483646_i64) cacheSwizzleStride(%c-8192_i14) resetOffset : memref> to memref> + %45 = affine.apply #map5()[%thread_id_x, %thread_id_y, %block_id_y] + %46 = arith.muli %45, %c8192 overflow : index + %47 = arith.addi %46, %9 overflow : index + %48 = affine.apply #map7()[%thread_id_x, %thread_id_y, %block_id_y] + %49 = arith.muli %48, %c8192 overflow : index + %50 = arith.addi %49, %9 overflow : index + %51 = affine.apply #map9()[%thread_id_x, %thread_id_y, %block_id_y] + %52 = arith.muli %51, %c8192 overflow : index + %53 = arith.addi %52, %9 overflow : index + %54 = affine.apply #map11()[%thread_id_x, %thread_id_y, %block_id_y] + %55 = arith.muli %54, %c512 overflow : index + %56 = arith.addi %55, %35 overflow : index + %reinterpret_cast_11 = memref.reinterpret_cast %3 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_12 = memref.cast %reinterpret_cast_11 : memref<2147483646xi8, strided<[1]>> to memref> + %57 = amdgpu.fat_raw_buffer_cast %cast_12 validBytes(%c2147483646_i64) cacheSwizzleStride(%c512_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %15[%14], %alloc_6[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%20], %alloc_6[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%25], %alloc_6[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%30], %alloc_6[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %40[%39], %alloc_4[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%43], %alloc_2[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%47], %alloc_2[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%50], %alloc_2[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%53], %alloc_2[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %57[%56], %alloc_0[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + rocdl.s.barrier + %58 = affine.apply #map16()[%thread_id_x, %thread_id_y] + %59 = arith.index_cast %58 : index to i32 + %60 = arith.cmpi sge, %59, %c4_i32 : i32 + %61 = arith.cmpi slt, %59, %c4_i32 : i32 + scf.if %60 { + rocdl.s.barrier + } + %62 = affine.apply #map17()[%thread_id_x] + %63 = affine.apply #map18()[%thread_id_x] + %64 = arith.xori %63, %7 : index + %65 = affine.apply #map3()[%64] + %66 = affine.apply #map19()[%thread_id_x] + %67 = affine.apply #map20()[%thread_id_x] + %68 = affine.apply #map21()[%thread_id_x] + %69 = affine.apply #map22()[%thread_id_x] + %70 = affine.apply #map23()[%thread_id_x, %thread_id_y] + %71 = affine.apply #map24()[%thread_id_x, %thread_id_y] + %72 = affine.apply #map25()[%thread_id_x, %thread_id_y] + %73 = affine.apply #map26()[%thread_id_x, %thread_id_y] + %74 = affine.apply #map27()[%thread_id_x, %thread_id_y] + %75 = affine.apply #map28()[%thread_id_x, %thread_id_y] + %76 = affine.apply #map29()[%thread_id_x, %thread_id_y] + %77 = affine.apply #map30()[%thread_id_x, %thread_id_y] + %78 = affine.apply #map31()[%thread_id_x] + %79 = arith.xori %78, %7 : index + %80 = affine.apply #map3()[%79] + %81 = arith.xori %33, %c1 : index + %82 = affine.apply #map32()[%thread_id_x, %81] + %83:40 = scf.for %arg5 = %c0 to %c63 step %c1 iter_args(%arg6 = %cst, %arg7 = %cst, %arg8 = %cst, %arg9 = %cst, %arg10 = %cst, %arg11 = %cst, %arg12 = %cst, %arg13 = %cst, %arg14 = %cst, %arg15 = %cst, %arg16 = %cst, %arg17 = %cst, %arg18 = %cst, %arg19 = %cst, %arg20 = %cst, %arg21 = %cst, %arg22 = %cst, %arg23 = %cst, %arg24 = %cst, %arg25 = %cst, %arg26 = %cst, %arg27 = %cst, %arg28 = %cst, %arg29 = %cst, %arg30 = %cst, %arg31 = %cst, %arg32 = %cst, %arg33 = %cst, %arg34 = %cst, %arg35 = %cst, %arg36 = %cst, %arg37 = %cst, %arg38 = %alloc_6, %arg39 = %alloc_5, %arg40 = %alloc_4, %arg41 = %alloc_3, %arg42 = %alloc_2, %arg43 = %alloc_1, %arg44 = %alloc_0, %arg45 = %alloc) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>) { + rocdl.sched.barrier 0 + rocdl.s.barrier + %582 = affine.apply #map33()[%arg5, %8] + %583 = arith.addi %13, %582 overflow : index + %584 = arith.addi %19, %582 overflow : index + %585 = arith.addi %24, %582 overflow : index + %586 = arith.addi %29, %582 overflow : index + %587 = affine.apply #map34()[%arg5, %34] + %588 = arith.addi %38, %587 overflow : index + %589 = arith.addi %42, %582 overflow : index + %590 = arith.addi %46, %582 overflow : index + %591 = arith.addi %49, %582 overflow : index + %592 = arith.addi %52, %582 overflow : index + %593 = arith.addi %55, %587 overflow : index + amdgpu.gather_to_lds %15[%583], %arg39[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%584], %arg39[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%585], %arg39[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %15[%586], %arg39[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %40[%588], %arg41[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%589], %arg43[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%590], %arg43[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%591], %arg43[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %44[%592], %arg43[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + amdgpu.gather_to_lds %57[%593], %arg45[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + rocdl.sched.barrier 0 + amdgpu.memory_counter_wait load(10) + %594 = vector.load %arg38[%62, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %595 = vector.load %arg38[%66, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %596 = vector.load %arg38[%67, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %597 = vector.load %arg38[%68, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %598 = vector.load %arg40[%62, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %599 = vector.load %arg40[%66, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %600 = vector.load %arg40[%67, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %601 = vector.load %arg40[%68, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %602 = vector.load %arg42[%70, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %603 = vector.load %arg42[%71, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %604 = vector.load %arg42[%72, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %605 = vector.load %arg42[%73, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %606 = vector.load %arg42[%74, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %607 = vector.load %arg42[%75, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %608 = vector.load %arg42[%76, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %609 = vector.load %arg42[%77, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %610 = vector.load %arg44[%70, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %611 = vector.load %arg44[%71, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %612 = vector.load %arg44[%72, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %613 = vector.load %arg44[%73, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %614 = vector.load %arg44[%74, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %615 = vector.load %arg44[%75, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %616 = vector.load %arg44[%76, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %617 = vector.load %arg44[%77, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %618 = vector.bitcast %594 : vector<16xi8> to vector<32xf4E2M1FN> + %619 = vector.bitcast %595 : vector<16xi8> to vector<32xf4E2M1FN> + %620 = vector.bitcast %596 : vector<16xi8> to vector<32xf4E2M1FN> + %621 = vector.bitcast %597 : vector<16xi8> to vector<32xf4E2M1FN> + %622 = vector.bitcast %598 : vector<1xi8> to vector<1xf8E8M0FNU> + %623 = vector.bitcast %599 : vector<1xi8> to vector<1xf8E8M0FNU> + %624 = vector.bitcast %600 : vector<1xi8> to vector<1xf8E8M0FNU> + %625 = vector.bitcast %601 : vector<1xi8> to vector<1xf8E8M0FNU> + %626 = vector.bitcast %602 : vector<16xi8> to vector<32xf4E2M1FN> + %627 = vector.bitcast %603 : vector<16xi8> to vector<32xf4E2M1FN> + %628 = vector.bitcast %604 : vector<16xi8> to vector<32xf4E2M1FN> + %629 = vector.bitcast %605 : vector<16xi8> to vector<32xf4E2M1FN> + %630 = vector.bitcast %606 : vector<16xi8> to vector<32xf4E2M1FN> + %631 = vector.bitcast %607 : vector<16xi8> to vector<32xf4E2M1FN> + %632 = vector.bitcast %608 : vector<16xi8> to vector<32xf4E2M1FN> + %633 = vector.bitcast %609 : vector<16xi8> to vector<32xf4E2M1FN> + %634 = vector.bitcast %610 : vector<1xi8> to vector<1xf8E8M0FNU> + %635 = vector.bitcast %611 : vector<1xi8> to vector<1xf8E8M0FNU> + %636 = vector.bitcast %612 : vector<1xi8> to vector<1xf8E8M0FNU> + %637 = vector.bitcast %613 : vector<1xi8> to vector<1xf8E8M0FNU> + %638 = vector.bitcast %614 : vector<1xi8> to vector<1xf8E8M0FNU> + %639 = vector.bitcast %615 : vector<1xi8> to vector<1xf8E8M0FNU> + %640 = vector.bitcast %616 : vector<1xi8> to vector<1xf8E8M0FNU> + %641 = vector.bitcast %617 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.s.setprio 1 + // --- SAFE MFMAs: M0,M1 x N0,N1,N4,N5 (cluster 0 data only) --- + %642 = vector.extract %622[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %643 = vector.extract %634[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %644 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%643[0] * %626) + %arg6 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %645 = vector.extract %635[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %646 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%645[0] * %627) + %arg7 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %651 = vector.extract %638[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %652 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%651[0] * %630) + %arg10 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %653 = vector.extract %639[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %654 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%653[0] * %631) + %arg11 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %659 = vector.extract %623[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %660 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%643[0] * %626) + %arg14 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %661 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%645[0] * %627) + %arg15 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %664 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%651[0] * %630) + %arg18 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %665 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%653[0] * %631) + %arg19 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + // --- DEPENDENT MFMAs: M0,M1 x N2,N3,N6,N7 (cluster 1 B data) --- + %647 = vector.extract %636[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %648 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%647[0] * %628) + %arg8 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %649 = vector.extract %637[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %650 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%649[0] * %629) + %arg9 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %655 = vector.extract %640[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %656 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%655[0] * %632) + %arg12 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %657 = vector.extract %641[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %658 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%657[0] * %633) + %arg13 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %662 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%647[0] * %628) + %arg16 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %663 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%649[0] * %629) + %arg17 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %666 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%655[0] * %632) + %arg20 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %667 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%657[0] * %633) + %arg21 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + // --- DEPENDENT MFMAs: M2 x all N (cluster 1 A data) --- + %668 = vector.extract %624[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %669 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%643[0] * %626) + %arg22 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %670 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%645[0] * %627) + %arg23 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %671 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%647[0] * %628) + %arg24 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %672 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%649[0] * %629) + %arg25 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %673 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%651[0] * %630) + %arg26 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %674 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%653[0] * %631) + %arg27 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %675 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%655[0] * %632) + %arg28 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %676 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%657[0] * %633) + %arg29 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + // --- DEPENDENT MFMAs: M3 x all N (cluster 1 A data) --- + %677 = vector.extract %625[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %678 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%643[0] * %626) + %arg30 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %679 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%645[0] * %627) + %arg31 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %680 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%647[0] * %628) + %arg32 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %681 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%649[0] * %629) + %arg33 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %682 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%651[0] * %630) + %arg34 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %683 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%653[0] * %631) + %arg35 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %684 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%655[0] * %632) + %arg36 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %685 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%657[0] * %633) + %arg37 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.sched.barrier 0 + %686 = vector.load %arg38[%62, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %687 = vector.load %arg38[%66, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %688 = vector.load %arg38[%67, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %689 = vector.load %arg38[%68, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %690 = vector.load %arg40[%62, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %691 = vector.load %arg40[%66, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %692 = vector.load %arg40[%67, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %693 = vector.load %arg40[%68, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %694 = vector.load %arg42[%70, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %695 = vector.load %arg42[%71, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %696 = vector.load %arg42[%72, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %697 = vector.load %arg42[%73, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %698 = vector.load %arg42[%74, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %699 = vector.load %arg42[%75, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %700 = vector.load %arg42[%76, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %701 = vector.load %arg42[%77, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %702 = vector.load %arg44[%70, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %703 = vector.load %arg44[%71, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %704 = vector.load %arg44[%72, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %705 = vector.load %arg44[%73, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %706 = vector.load %arg44[%74, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %707 = vector.load %arg44[%75, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %708 = vector.load %arg44[%76, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %709 = vector.load %arg44[%77, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %710 = vector.bitcast %686 : vector<16xi8> to vector<32xf4E2M1FN> + %711 = vector.bitcast %687 : vector<16xi8> to vector<32xf4E2M1FN> + %712 = vector.bitcast %688 : vector<16xi8> to vector<32xf4E2M1FN> + %713 = vector.bitcast %689 : vector<16xi8> to vector<32xf4E2M1FN> + %714 = vector.bitcast %690 : vector<1xi8> to vector<1xf8E8M0FNU> + %715 = vector.bitcast %691 : vector<1xi8> to vector<1xf8E8M0FNU> + %716 = vector.bitcast %692 : vector<1xi8> to vector<1xf8E8M0FNU> + %717 = vector.bitcast %693 : vector<1xi8> to vector<1xf8E8M0FNU> + %718 = vector.bitcast %694 : vector<16xi8> to vector<32xf4E2M1FN> + %719 = vector.bitcast %695 : vector<16xi8> to vector<32xf4E2M1FN> + %720 = vector.bitcast %696 : vector<16xi8> to vector<32xf4E2M1FN> + %721 = vector.bitcast %697 : vector<16xi8> to vector<32xf4E2M1FN> + %722 = vector.bitcast %698 : vector<16xi8> to vector<32xf4E2M1FN> + %723 = vector.bitcast %699 : vector<16xi8> to vector<32xf4E2M1FN> + %724 = vector.bitcast %700 : vector<16xi8> to vector<32xf4E2M1FN> + %725 = vector.bitcast %701 : vector<16xi8> to vector<32xf4E2M1FN> + %726 = vector.bitcast %702 : vector<1xi8> to vector<1xf8E8M0FNU> + %727 = vector.bitcast %703 : vector<1xi8> to vector<1xf8E8M0FNU> + %728 = vector.bitcast %704 : vector<1xi8> to vector<1xf8E8M0FNU> + %729 = vector.bitcast %705 : vector<1xi8> to vector<1xf8E8M0FNU> + %730 = vector.bitcast %706 : vector<1xi8> to vector<1xf8E8M0FNU> + %731 = vector.bitcast %707 : vector<1xi8> to vector<1xf8E8M0FNU> + %732 = vector.bitcast %708 : vector<1xi8> to vector<1xf8E8M0FNU> + %733 = vector.bitcast %709 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.s.setprio 1 + // --- SAFE MFMAs: M0,M1 x N0,N1,N4,N5 (cluster 0 data only) --- + %734 = vector.extract %714[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %735 = vector.extract %726[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %736 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%735[0] * %718) + %644 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %737 = vector.extract %727[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %738 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%737[0] * %719) + %646 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %743 = vector.extract %730[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %744 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%743[0] * %722) + %652 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %745 = vector.extract %731[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %746 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%745[0] * %723) + %654 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %751 = vector.extract %715[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %752 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%735[0] * %718) + %660 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %753 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%737[0] * %719) + %661 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %756 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%743[0] * %722) + %664 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %757 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%745[0] * %723) + %665 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + // --- DEPENDENT MFMAs: M0,M1 x N2,N3,N6,N7 (cluster 1 B data) --- + %739 = vector.extract %728[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %740 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%739[0] * %720) + %648 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %741 = vector.extract %729[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %742 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%741[0] * %721) + %650 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %747 = vector.extract %732[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %748 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%747[0] * %724) + %656 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %749 = vector.extract %733[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %750 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%749[0] * %725) + %658 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %754 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%739[0] * %720) + %662 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %755 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%741[0] * %721) + %663 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %758 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%747[0] * %724) + %666 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %759 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%749[0] * %725) + %667 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + // --- DEPENDENT MFMAs: M2 x all N (cluster 1 A data) --- + %760 = vector.extract %716[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %761 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%735[0] * %718) + %669 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %762 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%737[0] * %719) + %670 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %763 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%739[0] * %720) + %671 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %764 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%741[0] * %721) + %672 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %765 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%743[0] * %722) + %673 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %766 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%745[0] * %723) + %674 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %767 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%747[0] * %724) + %675 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %768 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%749[0] * %725) + %676 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + // --- DEPENDENT MFMAs: M3 x all N (cluster 1 A data) --- + %769 = vector.extract %717[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %770 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%735[0] * %718) + %678 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %771 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%737[0] * %719) + %679 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %772 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%739[0] * %720) + %680 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %773 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%741[0] * %721) + %681 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %774 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%743[0] * %722) + %682 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %775 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%745[0] * %723) + %683 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %776 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%747[0] * %724) + %684 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %777 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%749[0] * %725) + %685 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + rocdl.sched.barrier 0 + scf.yield %736, %738, %740, %742, %744, %746, %748, %750, %752, %753, %754, %755, %756, %757, %758, %759, %761, %762, %763, %764, %765, %766, %767, %768, %770, %771, %772, %773, %774, %775, %776, %777, %arg39, %arg38, %arg41, %arg40, %arg43, %arg42, %arg45, %arg44 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space> + } + scf.if %61 { + rocdl.s.barrier + } + amdgpu.lds_barrier + %84 = affine.apply #map23()[%thread_id_x, %thread_id_y] + %85 = affine.apply #map22()[%thread_id_x] + %86 = vector.load %83#38[%84, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %87 = arith.xori %33, %c1 : index + %88 = affine.apply #map32()[%thread_id_x, %87] + %89 = vector.load %83#38[%84, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %90 = affine.apply #map24()[%thread_id_x, %thread_id_y] + %91 = vector.load %83#38[%90, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %92 = vector.load %83#38[%90, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %93 = affine.apply #map25()[%thread_id_x, %thread_id_y] + %94 = vector.load %83#38[%93, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %95 = vector.load %83#38[%93, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %96 = affine.apply #map26()[%thread_id_x, %thread_id_y] + %97 = vector.load %83#38[%96, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %98 = vector.load %83#38[%96, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %99 = affine.apply #map27()[%thread_id_x, %thread_id_y] + %100 = vector.load %83#38[%99, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %101 = vector.load %83#38[%99, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %102 = affine.apply #map28()[%thread_id_x, %thread_id_y] + %103 = vector.load %83#38[%102, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %104 = vector.load %83#38[%102, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %105 = affine.apply #map29()[%thread_id_x, %thread_id_y] + %106 = vector.load %83#38[%105, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %107 = vector.load %83#38[%105, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %108 = affine.apply #map30()[%thread_id_x, %thread_id_y] + %109 = vector.load %83#38[%108, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %110 = vector.load %83#38[%108, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %111 = affine.apply #map18()[%thread_id_x] + %112 = arith.xori %111, %7 : index + %113 = affine.apply #map3()[%112] + %114 = vector.load %83#36[%84, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %115 = affine.apply #map31()[%thread_id_x] + %116 = arith.xori %115, %7 : index + %117 = affine.apply #map3()[%116] + %118 = vector.load %83#36[%84, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %119 = vector.load %83#36[%90, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %120 = vector.load %83#36[%90, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %121 = vector.load %83#36[%93, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %122 = vector.load %83#36[%93, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %123 = vector.load %83#36[%96, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %124 = vector.load %83#36[%96, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %125 = vector.load %83#36[%99, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %126 = vector.load %83#36[%99, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %127 = vector.load %83#36[%102, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %128 = vector.load %83#36[%102, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %129 = vector.load %83#36[%105, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %130 = vector.load %83#36[%105, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %131 = vector.load %83#36[%108, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %132 = vector.load %83#36[%108, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %133 = affine.apply #map17()[%thread_id_x] + %134 = vector.load %83#34[%133, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %135 = vector.load %83#34[%133, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %136 = affine.apply #map19()[%thread_id_x] + %137 = vector.load %83#34[%136, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %138 = vector.load %83#34[%136, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %139 = affine.apply #map20()[%thread_id_x] + %140 = vector.load %83#34[%139, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %141 = vector.load %83#34[%139, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %142 = affine.apply #map21()[%thread_id_x] + %143 = vector.load %83#34[%142, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %144 = vector.load %83#34[%142, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %145 = vector.load %83#32[%133, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %146 = vector.load %83#32[%133, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %147 = vector.load %83#32[%136, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %148 = vector.load %83#32[%136, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %149 = vector.load %83#32[%139, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %150 = vector.load %83#32[%139, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %151 = vector.load %83#32[%142, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %152 = vector.load %83#32[%142, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %153 = vector.bitcast %145 : vector<16xi8> to vector<32xf4E2M1FN> + %154 = vector.bitcast %146 : vector<16xi8> to vector<32xf4E2M1FN> + %155 = vector.bitcast %147 : vector<16xi8> to vector<32xf4E2M1FN> + %156 = vector.bitcast %148 : vector<16xi8> to vector<32xf4E2M1FN> + %157 = vector.bitcast %149 : vector<16xi8> to vector<32xf4E2M1FN> + %158 = vector.bitcast %150 : vector<16xi8> to vector<32xf4E2M1FN> + %159 = vector.bitcast %151 : vector<16xi8> to vector<32xf4E2M1FN> + %160 = vector.bitcast %152 : vector<16xi8> to vector<32xf4E2M1FN> + %161 = vector.bitcast %134 : vector<1xi8> to vector<1xf8E8M0FNU> + %162 = vector.bitcast %135 : vector<1xi8> to vector<1xf8E8M0FNU> + %163 = vector.bitcast %137 : vector<1xi8> to vector<1xf8E8M0FNU> + %164 = vector.bitcast %138 : vector<1xi8> to vector<1xf8E8M0FNU> + %165 = vector.bitcast %140 : vector<1xi8> to vector<1xf8E8M0FNU> + %166 = vector.bitcast %141 : vector<1xi8> to vector<1xf8E8M0FNU> + %167 = vector.bitcast %143 : vector<1xi8> to vector<1xf8E8M0FNU> + %168 = vector.bitcast %144 : vector<1xi8> to vector<1xf8E8M0FNU> + %169 = vector.bitcast %114 : vector<16xi8> to vector<32xf4E2M1FN> + %170 = vector.bitcast %118 : vector<16xi8> to vector<32xf4E2M1FN> + %171 = vector.bitcast %119 : vector<16xi8> to vector<32xf4E2M1FN> + %172 = vector.bitcast %120 : vector<16xi8> to vector<32xf4E2M1FN> + %173 = vector.bitcast %121 : vector<16xi8> to vector<32xf4E2M1FN> + %174 = vector.bitcast %122 : vector<16xi8> to vector<32xf4E2M1FN> + %175 = vector.bitcast %123 : vector<16xi8> to vector<32xf4E2M1FN> + %176 = vector.bitcast %124 : vector<16xi8> to vector<32xf4E2M1FN> + %177 = vector.bitcast %125 : vector<16xi8> to vector<32xf4E2M1FN> + %178 = vector.bitcast %126 : vector<16xi8> to vector<32xf4E2M1FN> + %179 = vector.bitcast %127 : vector<16xi8> to vector<32xf4E2M1FN> + %180 = vector.bitcast %128 : vector<16xi8> to vector<32xf4E2M1FN> + %181 = vector.bitcast %129 : vector<16xi8> to vector<32xf4E2M1FN> + %182 = vector.bitcast %130 : vector<16xi8> to vector<32xf4E2M1FN> + %183 = vector.bitcast %131 : vector<16xi8> to vector<32xf4E2M1FN> + %184 = vector.bitcast %132 : vector<16xi8> to vector<32xf4E2M1FN> + %185 = vector.bitcast %86 : vector<1xi8> to vector<1xf8E8M0FNU> + %186 = vector.bitcast %89 : vector<1xi8> to vector<1xf8E8M0FNU> + %187 = vector.bitcast %91 : vector<1xi8> to vector<1xf8E8M0FNU> + %188 = vector.bitcast %92 : vector<1xi8> to vector<1xf8E8M0FNU> + %189 = vector.bitcast %94 : vector<1xi8> to vector<1xf8E8M0FNU> + %190 = vector.bitcast %95 : vector<1xi8> to vector<1xf8E8M0FNU> + %191 = vector.bitcast %97 : vector<1xi8> to vector<1xf8E8M0FNU> + %192 = vector.bitcast %98 : vector<1xi8> to vector<1xf8E8M0FNU> + %193 = vector.bitcast %100 : vector<1xi8> to vector<1xf8E8M0FNU> + %194 = vector.bitcast %101 : vector<1xi8> to vector<1xf8E8M0FNU> + %195 = vector.bitcast %103 : vector<1xi8> to vector<1xf8E8M0FNU> + %196 = vector.bitcast %104 : vector<1xi8> to vector<1xf8E8M0FNU> + %197 = vector.bitcast %106 : vector<1xi8> to vector<1xf8E8M0FNU> + %198 = vector.bitcast %107 : vector<1xi8> to vector<1xf8E8M0FNU> + %199 = vector.bitcast %109 : vector<1xi8> to vector<1xf8E8M0FNU> + %200 = vector.bitcast %110 : vector<1xi8> to vector<1xf8E8M0FNU> + %201 = vector.extract %161[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %202 = vector.extract %185[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %203 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%202[0] * %169) + %83#0 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %204 = vector.extract %162[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %205 = vector.extract %186[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %206 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%205[0] * %170) + %203 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %207 = vector.extract %187[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %208 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%207[0] * %171) + %83#1 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %209 = vector.extract %188[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %210 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%209[0] * %172) + %208 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %211 = vector.extract %189[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %212 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%211[0] * %173) + %83#2 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %213 = vector.extract %190[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %214 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%213[0] * %174) + %212 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %215 = vector.extract %191[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %216 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%215[0] * %175) + %83#3 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %217 = vector.extract %192[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %218 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%217[0] * %176) + %216 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %219 = vector.extract %193[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %220 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%219[0] * %177) + %83#4 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %221 = vector.extract %194[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %222 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%221[0] * %178) + %220 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %223 = vector.extract %195[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %224 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%223[0] * %179) + %83#5 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %225 = vector.extract %196[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %226 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%225[0] * %180) + %224 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %227 = vector.extract %197[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %228 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%227[0] * %181) + %83#6 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %229 = vector.extract %198[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %230 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%229[0] * %182) + %228 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %231 = vector.extract %199[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %232 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%231[0] * %183) + %83#7 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %233 = vector.extract %200[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %234 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%233[0] * %184) + %232 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %235 = vector.extract %163[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %236 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%202[0] * %169) + %83#8 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %237 = vector.extract %164[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %238 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%205[0] * %170) + %236 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %239 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%207[0] * %171) + %83#9 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %240 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%209[0] * %172) + %239 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %241 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%211[0] * %173) + %83#10 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %242 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%213[0] * %174) + %241 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %243 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%215[0] * %175) + %83#11 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %244 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%217[0] * %176) + %243 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %245 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%219[0] * %177) + %83#12 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %246 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%221[0] * %178) + %245 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %247 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%223[0] * %179) + %83#13 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %248 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%225[0] * %180) + %247 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %249 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%227[0] * %181) + %83#14 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %250 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%229[0] * %182) + %249 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %251 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%231[0] * %183) + %83#15 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %252 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%233[0] * %184) + %251 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %253 = vector.extract %165[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %254 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%202[0] * %169) + %83#16 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %255 = vector.extract %166[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %256 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%205[0] * %170) + %254 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %257 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%207[0] * %171) + %83#17 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %258 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%209[0] * %172) + %257 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %259 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%211[0] * %173) + %83#18 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %260 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%213[0] * %174) + %259 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %261 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%215[0] * %175) + %83#19 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %262 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%217[0] * %176) + %261 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %263 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%219[0] * %177) + %83#20 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %264 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%221[0] * %178) + %263 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %265 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%223[0] * %179) + %83#21 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %266 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%225[0] * %180) + %265 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %267 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%227[0] * %181) + %83#22 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %268 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%229[0] * %182) + %267 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %269 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%231[0] * %183) + %83#23 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %270 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%233[0] * %184) + %269 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %271 = vector.extract %167[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %272 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%202[0] * %169) + %83#24 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %273 = vector.extract %168[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %274 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%205[0] * %170) + %272 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %275 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%207[0] * %171) + %83#25 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %276 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%209[0] * %172) + %275 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %277 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%211[0] * %173) + %83#26 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %278 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%213[0] * %174) + %277 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %279 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%215[0] * %175) + %83#27 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %280 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%217[0] * %176) + %279 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %281 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%219[0] * %177) + %83#28 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %282 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%221[0] * %178) + %281 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %283 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%223[0] * %179) + %83#29 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %284 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%225[0] * %180) + %283 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %285 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%227[0] * %181) + %83#30 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %286 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%229[0] * %182) + %285 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %287 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%231[0] * %183) + %83#31 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %288 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%233[0] * %184) + %287 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %289 = vector.extract_strided_slice %206 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %290 = affine.apply #map35()[%block_id_x] + %291 = affine.apply #map35()[%block_id_y] + %292 = affine.apply #map36()[%thread_id_x] + %293 = arith.muli %290, %c57344 overflow : index + %294 = arith.muli %292, %c57344 overflow : index + %295 = arith.addi %293, %291 overflow : index + %296 = arith.addi %294, %84 overflow : index + %reinterpret_cast_13 = memref.reinterpret_cast %4 to offset: [%295], sizes: [536870910], strides: [1] : memref to memref<536870910xf32, strided<[1], offset: ?>> + %cast_14 = memref.cast %reinterpret_cast_13 : memref<536870910xf32, strided<[1], offset: ?>> to memref> + %297 = amdgpu.fat_raw_buffer_cast %cast_14 validBytes(%c2147483643_i64) resetOffset : memref> to memref> + vector.store %289, %297[%296] : memref>, vector<1xf32> + %298 = vector.extract_strided_slice %206 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %299 = affine.apply #map37()[%thread_id_x] + %300 = arith.muli %299, %c57344 overflow : index + %301 = arith.addi %300, %84 overflow : index + vector.store %298, %297[%301] : memref>, vector<1xf32> + %302 = vector.extract_strided_slice %206 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %303 = affine.apply #map38()[%thread_id_x] + %304 = arith.muli %303, %c57344 overflow : index + %305 = arith.addi %304, %84 overflow : index + vector.store %302, %297[%305] : memref>, vector<1xf32> + %306 = vector.extract_strided_slice %206 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %307 = affine.apply #map39()[%thread_id_x] + %308 = arith.muli %307, %c57344 overflow : index + %309 = arith.addi %308, %84 overflow : index + vector.store %306, %297[%309] : memref>, vector<1xf32> + %310 = vector.extract_strided_slice %210 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %311 = arith.addi %294, %90 overflow : index + vector.store %310, %297[%311] : memref>, vector<1xf32> + %312 = vector.extract_strided_slice %210 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %313 = arith.addi %300, %90 overflow : index + vector.store %312, %297[%313] : memref>, vector<1xf32> + %314 = vector.extract_strided_slice %210 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %315 = arith.addi %304, %90 overflow : index + vector.store %314, %297[%315] : memref>, vector<1xf32> + %316 = vector.extract_strided_slice %210 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %317 = arith.addi %308, %90 overflow : index + vector.store %316, %297[%317] : memref>, vector<1xf32> + %318 = vector.extract_strided_slice %214 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %319 = arith.addi %294, %93 overflow : index + vector.store %318, %297[%319] : memref>, vector<1xf32> + %320 = vector.extract_strided_slice %214 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %321 = arith.addi %300, %93 overflow : index + vector.store %320, %297[%321] : memref>, vector<1xf32> + %322 = vector.extract_strided_slice %214 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %323 = arith.addi %304, %93 overflow : index + vector.store %322, %297[%323] : memref>, vector<1xf32> + %324 = vector.extract_strided_slice %214 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %325 = arith.addi %308, %93 overflow : index + vector.store %324, %297[%325] : memref>, vector<1xf32> + %326 = vector.extract_strided_slice %218 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %327 = arith.addi %294, %96 overflow : index + vector.store %326, %297[%327] : memref>, vector<1xf32> + %328 = vector.extract_strided_slice %218 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %329 = arith.addi %300, %96 overflow : index + vector.store %328, %297[%329] : memref>, vector<1xf32> + %330 = vector.extract_strided_slice %218 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %331 = arith.addi %304, %96 overflow : index + vector.store %330, %297[%331] : memref>, vector<1xf32> + %332 = vector.extract_strided_slice %218 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %333 = arith.addi %308, %96 overflow : index + vector.store %332, %297[%333] : memref>, vector<1xf32> + %334 = vector.extract_strided_slice %222 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %335 = arith.addi %294, %99 overflow : index + vector.store %334, %297[%335] : memref>, vector<1xf32> + %336 = vector.extract_strided_slice %222 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %337 = arith.addi %300, %99 overflow : index + vector.store %336, %297[%337] : memref>, vector<1xf32> + %338 = vector.extract_strided_slice %222 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %339 = arith.addi %304, %99 overflow : index + vector.store %338, %297[%339] : memref>, vector<1xf32> + %340 = vector.extract_strided_slice %222 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %341 = arith.addi %308, %99 overflow : index + vector.store %340, %297[%341] : memref>, vector<1xf32> + %342 = vector.extract_strided_slice %226 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %343 = arith.addi %294, %102 overflow : index + vector.store %342, %297[%343] : memref>, vector<1xf32> + %344 = vector.extract_strided_slice %226 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %345 = arith.addi %300, %102 overflow : index + vector.store %344, %297[%345] : memref>, vector<1xf32> + %346 = vector.extract_strided_slice %226 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %347 = arith.addi %304, %102 overflow : index + vector.store %346, %297[%347] : memref>, vector<1xf32> + %348 = vector.extract_strided_slice %226 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %349 = arith.addi %308, %102 overflow : index + vector.store %348, %297[%349] : memref>, vector<1xf32> + %350 = vector.extract_strided_slice %230 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %351 = arith.addi %294, %105 overflow : index + vector.store %350, %297[%351] : memref>, vector<1xf32> + %352 = vector.extract_strided_slice %230 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %353 = arith.addi %300, %105 overflow : index + vector.store %352, %297[%353] : memref>, vector<1xf32> + %354 = vector.extract_strided_slice %230 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %355 = arith.addi %304, %105 overflow : index + vector.store %354, %297[%355] : memref>, vector<1xf32> + %356 = vector.extract_strided_slice %230 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %357 = arith.addi %308, %105 overflow : index + vector.store %356, %297[%357] : memref>, vector<1xf32> + %358 = vector.extract_strided_slice %234 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %359 = arith.addi %294, %108 overflow : index + vector.store %358, %297[%359] : memref>, vector<1xf32> + %360 = vector.extract_strided_slice %234 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %361 = arith.addi %300, %108 overflow : index + vector.store %360, %297[%361] : memref>, vector<1xf32> + %362 = vector.extract_strided_slice %234 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %363 = arith.addi %304, %108 overflow : index + vector.store %362, %297[%363] : memref>, vector<1xf32> + %364 = vector.extract_strided_slice %234 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %365 = arith.addi %308, %108 overflow : index + vector.store %364, %297[%365] : memref>, vector<1xf32> + %366 = vector.extract_strided_slice %238 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %367 = affine.apply #map40()[%thread_id_x] + %368 = arith.muli %367, %c57344 overflow : index + %369 = arith.addi %368, %84 overflow : index + vector.store %366, %297[%369] : memref>, vector<1xf32> + %370 = vector.extract_strided_slice %238 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %371 = affine.apply #map41()[%thread_id_x] + %372 = arith.muli %371, %c57344 overflow : index + %373 = arith.addi %372, %84 overflow : index + vector.store %370, %297[%373] : memref>, vector<1xf32> + %374 = vector.extract_strided_slice %238 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %375 = affine.apply #map42()[%thread_id_x] + %376 = arith.muli %375, %c57344 overflow : index + %377 = arith.addi %376, %84 overflow : index + vector.store %374, %297[%377] : memref>, vector<1xf32> + %378 = vector.extract_strided_slice %238 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %379 = affine.apply #map43()[%thread_id_x] + %380 = arith.muli %379, %c57344 overflow : index + %381 = arith.addi %380, %84 overflow : index + vector.store %378, %297[%381] : memref>, vector<1xf32> + %382 = vector.extract_strided_slice %240 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %383 = arith.addi %368, %90 overflow : index + vector.store %382, %297[%383] : memref>, vector<1xf32> + %384 = vector.extract_strided_slice %240 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %385 = arith.addi %372, %90 overflow : index + vector.store %384, %297[%385] : memref>, vector<1xf32> + %386 = vector.extract_strided_slice %240 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %387 = arith.addi %376, %90 overflow : index + vector.store %386, %297[%387] : memref>, vector<1xf32> + %388 = vector.extract_strided_slice %240 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %389 = arith.addi %380, %90 overflow : index + vector.store %388, %297[%389] : memref>, vector<1xf32> + %390 = vector.extract_strided_slice %242 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %391 = arith.addi %368, %93 overflow : index + vector.store %390, %297[%391] : memref>, vector<1xf32> + %392 = vector.extract_strided_slice %242 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %393 = arith.addi %372, %93 overflow : index + vector.store %392, %297[%393] : memref>, vector<1xf32> + %394 = vector.extract_strided_slice %242 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %395 = arith.addi %376, %93 overflow : index + vector.store %394, %297[%395] : memref>, vector<1xf32> + %396 = vector.extract_strided_slice %242 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %397 = arith.addi %380, %93 overflow : index + vector.store %396, %297[%397] : memref>, vector<1xf32> + %398 = vector.extract_strided_slice %244 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %399 = arith.addi %368, %96 overflow : index + vector.store %398, %297[%399] : memref>, vector<1xf32> + %400 = vector.extract_strided_slice %244 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %401 = arith.addi %372, %96 overflow : index + vector.store %400, %297[%401] : memref>, vector<1xf32> + %402 = vector.extract_strided_slice %244 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %403 = arith.addi %376, %96 overflow : index + vector.store %402, %297[%403] : memref>, vector<1xf32> + %404 = vector.extract_strided_slice %244 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %405 = arith.addi %380, %96 overflow : index + vector.store %404, %297[%405] : memref>, vector<1xf32> + %406 = vector.extract_strided_slice %246 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %407 = arith.addi %368, %99 overflow : index + vector.store %406, %297[%407] : memref>, vector<1xf32> + %408 = vector.extract_strided_slice %246 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %409 = arith.addi %372, %99 overflow : index + vector.store %408, %297[%409] : memref>, vector<1xf32> + %410 = vector.extract_strided_slice %246 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %411 = arith.addi %376, %99 overflow : index + vector.store %410, %297[%411] : memref>, vector<1xf32> + %412 = vector.extract_strided_slice %246 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %413 = arith.addi %380, %99 overflow : index + vector.store %412, %297[%413] : memref>, vector<1xf32> + %414 = vector.extract_strided_slice %248 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %415 = arith.addi %368, %102 overflow : index + vector.store %414, %297[%415] : memref>, vector<1xf32> + %416 = vector.extract_strided_slice %248 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %417 = arith.addi %372, %102 overflow : index + vector.store %416, %297[%417] : memref>, vector<1xf32> + %418 = vector.extract_strided_slice %248 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %419 = arith.addi %376, %102 overflow : index + vector.store %418, %297[%419] : memref>, vector<1xf32> + %420 = vector.extract_strided_slice %248 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %421 = arith.addi %380, %102 overflow : index + vector.store %420, %297[%421] : memref>, vector<1xf32> + %422 = vector.extract_strided_slice %250 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %423 = arith.addi %368, %105 overflow : index + vector.store %422, %297[%423] : memref>, vector<1xf32> + %424 = vector.extract_strided_slice %250 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %425 = arith.addi %372, %105 overflow : index + vector.store %424, %297[%425] : memref>, vector<1xf32> + %426 = vector.extract_strided_slice %250 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %427 = arith.addi %376, %105 overflow : index + vector.store %426, %297[%427] : memref>, vector<1xf32> + %428 = vector.extract_strided_slice %250 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %429 = arith.addi %380, %105 overflow : index + vector.store %428, %297[%429] : memref>, vector<1xf32> + %430 = vector.extract_strided_slice %252 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %431 = arith.addi %368, %108 overflow : index + vector.store %430, %297[%431] : memref>, vector<1xf32> + %432 = vector.extract_strided_slice %252 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %433 = arith.addi %372, %108 overflow : index + vector.store %432, %297[%433] : memref>, vector<1xf32> + %434 = vector.extract_strided_slice %252 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %435 = arith.addi %376, %108 overflow : index + vector.store %434, %297[%435] : memref>, vector<1xf32> + %436 = vector.extract_strided_slice %252 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %437 = arith.addi %380, %108 overflow : index + vector.store %436, %297[%437] : memref>, vector<1xf32> + %438 = vector.extract_strided_slice %256 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %439 = affine.apply #map44()[%thread_id_x] + %440 = arith.muli %439, %c57344 overflow : index + %441 = arith.addi %440, %84 overflow : index + vector.store %438, %297[%441] : memref>, vector<1xf32> + %442 = vector.extract_strided_slice %256 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %443 = affine.apply #map45()[%thread_id_x] + %444 = arith.muli %443, %c57344 overflow : index + %445 = arith.addi %444, %84 overflow : index + vector.store %442, %297[%445] : memref>, vector<1xf32> + %446 = vector.extract_strided_slice %256 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %447 = affine.apply #map46()[%thread_id_x] + %448 = arith.muli %447, %c57344 overflow : index + %449 = arith.addi %448, %84 overflow : index + vector.store %446, %297[%449] : memref>, vector<1xf32> + %450 = vector.extract_strided_slice %256 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %451 = affine.apply #map47()[%thread_id_x] + %452 = arith.muli %451, %c57344 overflow : index + %453 = arith.addi %452, %84 overflow : index + vector.store %450, %297[%453] : memref>, vector<1xf32> + %454 = vector.extract_strided_slice %258 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %455 = arith.addi %440, %90 overflow : index + vector.store %454, %297[%455] : memref>, vector<1xf32> + %456 = vector.extract_strided_slice %258 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %457 = arith.addi %444, %90 overflow : index + vector.store %456, %297[%457] : memref>, vector<1xf32> + %458 = vector.extract_strided_slice %258 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %459 = arith.addi %448, %90 overflow : index + vector.store %458, %297[%459] : memref>, vector<1xf32> + %460 = vector.extract_strided_slice %258 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %461 = arith.addi %452, %90 overflow : index + vector.store %460, %297[%461] : memref>, vector<1xf32> + %462 = vector.extract_strided_slice %260 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %463 = arith.addi %440, %93 overflow : index + vector.store %462, %297[%463] : memref>, vector<1xf32> + %464 = vector.extract_strided_slice %260 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %465 = arith.addi %444, %93 overflow : index + vector.store %464, %297[%465] : memref>, vector<1xf32> + %466 = vector.extract_strided_slice %260 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %467 = arith.addi %448, %93 overflow : index + vector.store %466, %297[%467] : memref>, vector<1xf32> + %468 = vector.extract_strided_slice %260 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %469 = arith.addi %452, %93 overflow : index + vector.store %468, %297[%469] : memref>, vector<1xf32> + %470 = vector.extract_strided_slice %262 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %471 = arith.addi %440, %96 overflow : index + vector.store %470, %297[%471] : memref>, vector<1xf32> + %472 = vector.extract_strided_slice %262 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %473 = arith.addi %444, %96 overflow : index + vector.store %472, %297[%473] : memref>, vector<1xf32> + %474 = vector.extract_strided_slice %262 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %475 = arith.addi %448, %96 overflow : index + vector.store %474, %297[%475] : memref>, vector<1xf32> + %476 = vector.extract_strided_slice %262 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %477 = arith.addi %452, %96 overflow : index + vector.store %476, %297[%477] : memref>, vector<1xf32> + %478 = vector.extract_strided_slice %264 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %479 = arith.addi %440, %99 overflow : index + vector.store %478, %297[%479] : memref>, vector<1xf32> + %480 = vector.extract_strided_slice %264 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %481 = arith.addi %444, %99 overflow : index + vector.store %480, %297[%481] : memref>, vector<1xf32> + %482 = vector.extract_strided_slice %264 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %483 = arith.addi %448, %99 overflow : index + vector.store %482, %297[%483] : memref>, vector<1xf32> + %484 = vector.extract_strided_slice %264 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %485 = arith.addi %452, %99 overflow : index + vector.store %484, %297[%485] : memref>, vector<1xf32> + %486 = vector.extract_strided_slice %266 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %487 = arith.addi %440, %102 overflow : index + vector.store %486, %297[%487] : memref>, vector<1xf32> + %488 = vector.extract_strided_slice %266 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %489 = arith.addi %444, %102 overflow : index + vector.store %488, %297[%489] : memref>, vector<1xf32> + %490 = vector.extract_strided_slice %266 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %491 = arith.addi %448, %102 overflow : index + vector.store %490, %297[%491] : memref>, vector<1xf32> + %492 = vector.extract_strided_slice %266 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %493 = arith.addi %452, %102 overflow : index + vector.store %492, %297[%493] : memref>, vector<1xf32> + %494 = vector.extract_strided_slice %268 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %495 = arith.addi %440, %105 overflow : index + vector.store %494, %297[%495] : memref>, vector<1xf32> + %496 = vector.extract_strided_slice %268 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %497 = arith.addi %444, %105 overflow : index + vector.store %496, %297[%497] : memref>, vector<1xf32> + %498 = vector.extract_strided_slice %268 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %499 = arith.addi %448, %105 overflow : index + vector.store %498, %297[%499] : memref>, vector<1xf32> + %500 = vector.extract_strided_slice %268 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %501 = arith.addi %452, %105 overflow : index + vector.store %500, %297[%501] : memref>, vector<1xf32> + %502 = vector.extract_strided_slice %270 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %503 = arith.addi %440, %108 overflow : index + vector.store %502, %297[%503] : memref>, vector<1xf32> + %504 = vector.extract_strided_slice %270 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %505 = arith.addi %444, %108 overflow : index + vector.store %504, %297[%505] : memref>, vector<1xf32> + %506 = vector.extract_strided_slice %270 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %507 = arith.addi %448, %108 overflow : index + vector.store %506, %297[%507] : memref>, vector<1xf32> + %508 = vector.extract_strided_slice %270 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %509 = arith.addi %452, %108 overflow : index + vector.store %508, %297[%509] : memref>, vector<1xf32> + %510 = vector.extract_strided_slice %274 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %511 = affine.apply #map48()[%thread_id_x] + %512 = arith.muli %511, %c57344 overflow : index + %513 = arith.addi %512, %84 overflow : index + vector.store %510, %297[%513] : memref>, vector<1xf32> + %514 = vector.extract_strided_slice %274 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %515 = affine.apply #map49()[%thread_id_x] + %516 = arith.muli %515, %c57344 overflow : index + %517 = arith.addi %516, %84 overflow : index + vector.store %514, %297[%517] : memref>, vector<1xf32> + %518 = vector.extract_strided_slice %274 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %519 = affine.apply #map50()[%thread_id_x] + %520 = arith.muli %519, %c57344 overflow : index + %521 = arith.addi %520, %84 overflow : index + vector.store %518, %297[%521] : memref>, vector<1xf32> + %522 = vector.extract_strided_slice %274 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %523 = affine.apply #map51()[%thread_id_x] + %524 = arith.muli %523, %c57344 overflow : index + %525 = arith.addi %524, %84 overflow : index + vector.store %522, %297[%525] : memref>, vector<1xf32> + %526 = vector.extract_strided_slice %276 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %527 = arith.addi %512, %90 overflow : index + vector.store %526, %297[%527] : memref>, vector<1xf32> + %528 = vector.extract_strided_slice %276 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %529 = arith.addi %516, %90 overflow : index + vector.store %528, %297[%529] : memref>, vector<1xf32> + %530 = vector.extract_strided_slice %276 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %531 = arith.addi %520, %90 overflow : index + vector.store %530, %297[%531] : memref>, vector<1xf32> + %532 = vector.extract_strided_slice %276 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %533 = arith.addi %524, %90 overflow : index + vector.store %532, %297[%533] : memref>, vector<1xf32> + %534 = vector.extract_strided_slice %278 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %535 = arith.addi %512, %93 overflow : index + vector.store %534, %297[%535] : memref>, vector<1xf32> + %536 = vector.extract_strided_slice %278 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %537 = arith.addi %516, %93 overflow : index + vector.store %536, %297[%537] : memref>, vector<1xf32> + %538 = vector.extract_strided_slice %278 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %539 = arith.addi %520, %93 overflow : index + vector.store %538, %297[%539] : memref>, vector<1xf32> + %540 = vector.extract_strided_slice %278 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %541 = arith.addi %524, %93 overflow : index + vector.store %540, %297[%541] : memref>, vector<1xf32> + %542 = vector.extract_strided_slice %280 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %543 = arith.addi %512, %96 overflow : index + vector.store %542, %297[%543] : memref>, vector<1xf32> + %544 = vector.extract_strided_slice %280 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %545 = arith.addi %516, %96 overflow : index + vector.store %544, %297[%545] : memref>, vector<1xf32> + %546 = vector.extract_strided_slice %280 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %547 = arith.addi %520, %96 overflow : index + vector.store %546, %297[%547] : memref>, vector<1xf32> + %548 = vector.extract_strided_slice %280 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %549 = arith.addi %524, %96 overflow : index + vector.store %548, %297[%549] : memref>, vector<1xf32> + %550 = vector.extract_strided_slice %282 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %551 = arith.addi %512, %99 overflow : index + vector.store %550, %297[%551] : memref>, vector<1xf32> + %552 = vector.extract_strided_slice %282 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %553 = arith.addi %516, %99 overflow : index + vector.store %552, %297[%553] : memref>, vector<1xf32> + %554 = vector.extract_strided_slice %282 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %555 = arith.addi %520, %99 overflow : index + vector.store %554, %297[%555] : memref>, vector<1xf32> + %556 = vector.extract_strided_slice %282 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %557 = arith.addi %524, %99 overflow : index + vector.store %556, %297[%557] : memref>, vector<1xf32> + %558 = vector.extract_strided_slice %284 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %559 = arith.addi %512, %102 overflow : index + vector.store %558, %297[%559] : memref>, vector<1xf32> + %560 = vector.extract_strided_slice %284 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %561 = arith.addi %516, %102 overflow : index + vector.store %560, %297[%561] : memref>, vector<1xf32> + %562 = vector.extract_strided_slice %284 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %563 = arith.addi %520, %102 overflow : index + vector.store %562, %297[%563] : memref>, vector<1xf32> + %564 = vector.extract_strided_slice %284 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %565 = arith.addi %524, %102 overflow : index + vector.store %564, %297[%565] : memref>, vector<1xf32> + %566 = vector.extract_strided_slice %286 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %567 = arith.addi %512, %105 overflow : index + vector.store %566, %297[%567] : memref>, vector<1xf32> + %568 = vector.extract_strided_slice %286 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %569 = arith.addi %516, %105 overflow : index + vector.store %568, %297[%569] : memref>, vector<1xf32> + %570 = vector.extract_strided_slice %286 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %571 = arith.addi %520, %105 overflow : index + vector.store %570, %297[%571] : memref>, vector<1xf32> + %572 = vector.extract_strided_slice %286 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %573 = arith.addi %524, %105 overflow : index + vector.store %572, %297[%573] : memref>, vector<1xf32> + %574 = vector.extract_strided_slice %288 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %575 = arith.addi %512, %108 overflow : index + vector.store %574, %297[%575] : memref>, vector<1xf32> + %576 = vector.extract_strided_slice %288 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %577 = arith.addi %516, %108 overflow : index + vector.store %576, %297[%577] : memref>, vector<1xf32> + %578 = vector.extract_strided_slice %288 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %579 = arith.addi %520, %108 overflow : index + vector.store %578, %297[%579] : memref>, vector<1xf32> + %580 = vector.extract_strided_slice %288 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %581 = arith.addi %524, %108 overflow : index + vector.store %580, %297[%581] : memref>, vector<1xf32> + return + } + } + } + func.func @isolated_benchmark$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view, %arg5: !hal.fence, %arg6: !hal.fence) -> !hal.buffer_view { + %0 = hal.tensor.import wait(%arg5) => %arg0 : !hal.buffer_view -> tensor<4096x8192xi8> + %1 = hal.tensor.import wait(%arg5) => %arg1 : !hal.buffer_view -> tensor<4096x512xi8> + %2 = hal.tensor.import wait(%arg5) => %arg2 : !hal.buffer_view -> tensor<57344x8192xi8> + %3 = hal.tensor.import wait(%arg5) => %arg3 : !hal.buffer_view -> tensor<57344x512xi8> + %4 = hal.tensor.import wait(%arg5) => %arg4 : !hal.buffer_view -> tensor<4096x57344xf32> + %5 = flow.dispatch @gemm::@gemm(%0, %1, %2, %3, %4) : (tensor<4096x8192xi8>, tensor<4096x512xi8>, tensor<57344x8192xi8>, tensor<57344x512xi8>, tensor<4096x57344xf32>) -> %4 + %6 = hal.tensor.barrier join(%5 : tensor<4096x57344xf32>) => %arg6 : !hal.fence + %7 = hal.tensor.export %6 : tensor<4096x57344xf32> -> !hal.buffer_view + return %7 : !hal.buffer_view + } + } + """ + + mlir_claude_rescheduled2 = """ + + + #map = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8) floordiv 256) * 256)> + #map1 = affine_map<()[s0] -> ((s0 floordiv 8) mod 8)> + #map2 = affine_map<()[s0] -> (s0 mod 8)> + #map3 = affine_map<()[s0] -> (s0 * 16)> + #map4 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64) floordiv 32) * 256)> + #map5 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 64) floordiv 256) * 256 + 64)> + #map6 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 8) floordiv 32) * 256 + 64)> + #map7 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 128) floordiv 256) * 256 + 128)> + #map8 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 16) floordiv 32) * 256 + 128)> + #map9 = affine_map<()[s0, s1, s2] -> (s1 * 32 + s2 * 256 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 192) floordiv 256) * 256 + 192)> + #map10 = affine_map<()[s0, s1] -> (s1 * 32 + (s0 floordiv 64) * 8 - ((s1 * 4 + s0 floordiv 64 + 24) floordiv 32) * 256 + 192)> + #map11 = affine_map<()[s0, s1, s2] -> (s1 * 128 + s2 * 256 + s0 floordiv 2 - ((s1 * 128 + s0 floordiv 2) floordiv 256) * 256)> + #map12 = affine_map<()[s0] -> ((s0 floordiv 2) mod 2)> + #map13 = affine_map<()[s0] -> (s0 mod 2)> + #map14 = affine_map<()[s0] -> (s0 * 4)> + #map15 = affine_map<()[s0, s1] -> (s1 * 128 + (s0 floordiv 64) * 32 - ((s1 * 4 + s0 floordiv 64) floordiv 8) * 256)> + #map16 = affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 64)> + #map17 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64)> + #map18 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> + #map19 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 16)> + #map20 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 32)> + #map21 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 64 + 48)> + #map22 = affine_map<()[s0] -> (s0 * 4 + (s0 mod 64) floordiv 16 - (s0 floordiv 2) * 8)> + #map23 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16)> + #map24 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 16)> + #map25 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 32)> + #map26 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 48)> + #map27 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 64)> + #map28 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 80)> + #map29 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 96)> + #map30 = affine_map<()[s0, s1] -> (s0 + s1 * 128 - (s0 floordiv 16) * 16 + 112)> + #map31 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16 + 4)> + #map32 = affine_map<()[s0, s1] -> (s1 * 4 + (s0 mod 64) floordiv 16)> + #map33 = affine_map<()[s0, s1] -> (s0 * 128 + s1 * 16 + 128)> + #map34 = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 4 + 8)> + #map35 = affine_map<()[s0] -> (s0 * 256)> + #map36 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4)> + #map37 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 1)> + #map38 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 2)> + #map39 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 3)> + #map40 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 16)> + #map41 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 17)> + #map42 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 18)> + #map43 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 19)> + #map44 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 32)> + #map45 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 33)> + #map46 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 34)> + #map47 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 35)> + #map48 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 48)> + #map49 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 49)> + #map50 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 50)> + #map51 = affine_map<()[s0] -> ((s0 floordiv 64) * 64 + ((s0 mod 64) floordiv 16) * 4 + 51)> + #translation = #iree_codegen.translation_info + module attributes {transform.with_named_sequence} { + stream.executable private @gemm { + stream.executable.export public @gemm workgroups() -> (index, index, index) { + %c16 = arith.constant 16 : index + %c224 = arith.constant 224 : index + %c1 = arith.constant 1 : index + stream.return %c16, %c224, %c1 : index, index, index + } + builtin.module { + func.func @gemm(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding, %arg3: !stream.binding, %arg4: !stream.binding) attributes {translation_info = #translation} { + %c4_i32 = arith.constant 4 : i32 + %c512_i14 = arith.constant 512 : i14 + %c-8192_i14 = arith.constant -8192 : i14 + %c2147483643_i64 = arith.constant 2147483643 : i64 + %c57344 = arith.constant 57344 : index + %c63 = arith.constant 63 : index + %c512 = arith.constant 512 : index + %c2147483646_i64 = arith.constant 2147483646 : i64 + %c8192 = arith.constant 8192 : index + %c1 = arith.constant 1 : index + %cst = arith.constant dense<0.000000e+00> : vector<4xf32> + %c0 = arith.constant 0 : index + %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> memref + %1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> memref + %2 = stream.binding.subspan %arg2[%c0] : !stream.binding -> memref + %3 = stream.binding.subspan %arg3[%c0] : !stream.binding -> memref + %4 = stream.binding.subspan %arg4[%c0] : !stream.binding -> memref + %block_id_x = gpu.block_id x upper_bound 16 + %block_id_y = gpu.block_id y upper_bound 224 + %thread_id_x = gpu.thread_id x upper_bound 256 + %thread_id_y = gpu.thread_id y upper_bound 2 + %alloc = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_0 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_1 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_2 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_3 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_4 = memref.alloc() : memref<256x8xi8, #gpu.address_space> + %alloc_5 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %alloc_6 = memref.alloc() : memref<256x128xi8, #gpu.address_space> + %5 = affine.apply #map()[%thread_id_x, %thread_id_y, %block_id_x] + %6 = affine.apply #map1()[%thread_id_x] + %7 = affine.apply #map2()[%thread_id_x] + %8 = arith.xori %7, %6 : index + %9 = affine.apply #map3()[%8] + %10 = affine.apply #map4()[%thread_id_x, %thread_id_y] + %11 = gpu.subgroup_broadcast %10, first_active_lane : index + %12 = gpu.subgroup_broadcast %c0, first_active_lane : index + %13 = arith.muli %5, %c8192 overflow : index + %14 = arith.addi %13, %9 overflow : index + %reinterpret_cast = memref.reinterpret_cast %0 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast = memref.cast %reinterpret_cast : memref<2147483646xi8, strided<[1]>> to memref> + %15 = amdgpu.fat_raw_buffer_cast %cast validBytes(%c2147483646_i64) cacheSwizzleStride(%c-8192_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %15[%14], %alloc_6[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %16 = affine.apply #map5()[%thread_id_x, %thread_id_y, %block_id_x] + %17 = affine.apply #map6()[%thread_id_x, %thread_id_y] + %18 = gpu.subgroup_broadcast %17, first_active_lane : index + %19 = arith.muli %16, %c8192 overflow : index + %20 = arith.addi %19, %9 overflow : index + amdgpu.gather_to_lds %15[%20], %alloc_6[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %21 = affine.apply #map7()[%thread_id_x, %thread_id_y, %block_id_x] + %22 = affine.apply #map8()[%thread_id_x, %thread_id_y] + %23 = gpu.subgroup_broadcast %22, first_active_lane : index + %24 = arith.muli %21, %c8192 overflow : index + %25 = arith.addi %24, %9 overflow : index + amdgpu.gather_to_lds %15[%25], %alloc_6[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %26 = affine.apply #map9()[%thread_id_x, %thread_id_y, %block_id_x] + %27 = affine.apply #map10()[%thread_id_x, %thread_id_y] + %28 = gpu.subgroup_broadcast %27, first_active_lane : index + %29 = arith.muli %26, %c8192 overflow : index + %30 = arith.addi %29, %9 overflow : index + amdgpu.gather_to_lds %15[%30], %alloc_6[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %31 = affine.apply #map11()[%thread_id_x, %thread_id_y, %block_id_x] + %32 = affine.apply #map12()[%thread_id_x] + %33 = affine.apply #map13()[%thread_id_x] + %34 = arith.xori %33, %32 : index + %35 = affine.apply #map14()[%34] + %36 = affine.apply #map15()[%thread_id_x, %thread_id_y] + %37 = gpu.subgroup_broadcast %36, first_active_lane : index + %38 = arith.muli %31, %c512 overflow : index + %39 = arith.addi %38, %35 overflow : index + %reinterpret_cast_7 = memref.reinterpret_cast %1 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_8 = memref.cast %reinterpret_cast_7 : memref<2147483646xi8, strided<[1]>> to memref> + %40 = amdgpu.fat_raw_buffer_cast %cast_8 validBytes(%c2147483646_i64) cacheSwizzleStride(%c512_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %40[%39], %alloc_4[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + %41 = affine.apply #map()[%thread_id_x, %thread_id_y, %block_id_y] + %42 = arith.muli %41, %c8192 overflow : index + %43 = arith.addi %42, %9 overflow : index + %reinterpret_cast_9 = memref.reinterpret_cast %2 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_10 = memref.cast %reinterpret_cast_9 : memref<2147483646xi8, strided<[1]>> to memref> + %44 = amdgpu.fat_raw_buffer_cast %cast_10 validBytes(%c2147483646_i64) cacheSwizzleStride(%c-8192_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %44[%43], %alloc_2[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %45 = affine.apply #map5()[%thread_id_x, %thread_id_y, %block_id_y] + %46 = arith.muli %45, %c8192 overflow : index + %47 = arith.addi %46, %9 overflow : index + amdgpu.gather_to_lds %44[%47], %alloc_2[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %48 = affine.apply #map7()[%thread_id_x, %thread_id_y, %block_id_y] + %49 = arith.muli %48, %c8192 overflow : index + %50 = arith.addi %49, %9 overflow : index + amdgpu.gather_to_lds %44[%50], %alloc_2[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %51 = affine.apply #map9()[%thread_id_x, %thread_id_y, %block_id_y] + %52 = arith.muli %51, %c8192 overflow : index + %53 = arith.addi %52, %9 overflow : index + amdgpu.gather_to_lds %44[%53], %alloc_2[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %54 = affine.apply #map11()[%thread_id_x, %thread_id_y, %block_id_y] + %55 = arith.muli %54, %c512 overflow : index + %56 = arith.addi %55, %35 overflow : index + %reinterpret_cast_11 = memref.reinterpret_cast %3 to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1]>> + %cast_12 = memref.cast %reinterpret_cast_11 : memref<2147483646xi8, strided<[1]>> to memref> + %57 = amdgpu.fat_raw_buffer_cast %cast_12 validBytes(%c2147483646_i64) cacheSwizzleStride(%c512_i14) resetOffset : memref> to memref> + amdgpu.gather_to_lds %57[%56], %alloc_0[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + rocdl.s.barrier + %58 = affine.apply #map16()[%thread_id_x, %thread_id_y] + %59 = arith.index_cast %58 : index to i32 + %60 = arith.cmpi sge, %59, %c4_i32 : i32 + %61 = arith.cmpi slt, %59, %c4_i32 : i32 + scf.if %60 { + rocdl.s.barrier + } + %62 = affine.apply #map17()[%thread_id_x] + %63 = affine.apply #map18()[%thread_id_x] + %64 = arith.xori %63, %7 : index + %65 = affine.apply #map3()[%64] + %66 = affine.apply #map19()[%thread_id_x] + %67 = affine.apply #map20()[%thread_id_x] + %68 = affine.apply #map21()[%thread_id_x] + %69 = affine.apply #map22()[%thread_id_x] + %70 = affine.apply #map23()[%thread_id_x, %thread_id_y] + %71 = affine.apply #map24()[%thread_id_x, %thread_id_y] + %72 = affine.apply #map25()[%thread_id_x, %thread_id_y] + %73 = affine.apply #map26()[%thread_id_x, %thread_id_y] + %74 = affine.apply #map27()[%thread_id_x, %thread_id_y] + %75 = affine.apply #map28()[%thread_id_x, %thread_id_y] + %76 = affine.apply #map29()[%thread_id_x, %thread_id_y] + %77 = affine.apply #map30()[%thread_id_x, %thread_id_y] + %78 = affine.apply #map31()[%thread_id_x] + %79 = arith.xori %78, %7 : index + %80 = affine.apply #map3()[%79] + %81 = arith.xori %33, %c1 : index + %82 = affine.apply #map32()[%thread_id_x, %81] + %83:40 = scf.for %arg5 = %c0 to %c63 step %c1 iter_args(%arg6 = %cst, %arg7 = %cst, %arg8 = %cst, %arg9 = %cst, %arg10 = %cst, %arg11 = %cst, %arg12 = %cst, %arg13 = %cst, %arg14 = %cst, %arg15 = %cst, %arg16 = %cst, %arg17 = %cst, %arg18 = %cst, %arg19 = %cst, %arg20 = %cst, %arg21 = %cst, %arg22 = %cst, %arg23 = %cst, %arg24 = %cst, %arg25 = %cst, %arg26 = %cst, %arg27 = %cst, %arg28 = %cst, %arg29 = %cst, %arg30 = %cst, %arg31 = %cst, %arg32 = %cst, %arg33 = %cst, %arg34 = %cst, %arg35 = %cst, %arg36 = %cst, %arg37 = %cst, %arg38 = %alloc_6, %arg39 = %alloc_5, %arg40 = %alloc_4, %arg41 = %alloc_3, %arg42 = %alloc_2, %arg43 = %alloc_1, %arg44 = %alloc_0, %arg45 = %alloc) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>) { + rocdl.sched.barrier 0 + amdgpu.memory_counter_wait load(0) + rocdl.s.barrier + //rocdl.s.barrier + %582 = affine.apply #map33()[%arg5, %8] + %583 = arith.addi %13, %582 overflow : index + amdgpu.gather_to_lds %15[%583], %arg39[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %584 = arith.addi %19, %582 overflow : index + amdgpu.gather_to_lds %15[%584], %arg39[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %585 = arith.addi %24, %582 overflow : index + amdgpu.gather_to_lds %15[%585], %arg39[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %586 = arith.addi %29, %582 overflow : index + amdgpu.gather_to_lds %15[%586], %arg39[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %587 = affine.apply #map34()[%arg5, %34] + %588 = arith.addi %38, %587 overflow : index + amdgpu.gather_to_lds %40[%588], %arg41[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + %589 = arith.addi %42, %582 overflow : index + amdgpu.gather_to_lds %44[%589], %arg43[%11, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %590 = arith.addi %46, %582 overflow : index + amdgpu.gather_to_lds %44[%590], %arg43[%18, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %591 = arith.addi %49, %582 overflow : index + amdgpu.gather_to_lds %44[%591], %arg43[%23, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %592 = arith.addi %52, %582 overflow : index + amdgpu.gather_to_lds %44[%592], %arg43[%28, %12] : vector<16xi8>, memref>, memref<256x128xi8, #gpu.address_space> + %593 = arith.addi %55, %587 overflow : index + amdgpu.gather_to_lds %57[%593], %arg45[%37, %12] : vector<4xi8>, memref>, memref<256x8xi8, #gpu.address_space> + rocdl.sched.barrier 0 + //amdgpu.memory_counter_wait load(10) + // --- SAFE vector.loads: A(M0,M1), Ascale(M0,M1), B(N0,N1,N4,N5), Bscale(N0,N1,N4,N5) --- + %594 = vector.load %arg38[%62, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %595 = vector.load %arg38[%66, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %598 = vector.load %arg40[%62, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %599 = vector.load %arg40[%66, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %602 = vector.load %arg42[%70, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %603 = vector.load %arg42[%71, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %606 = vector.load %arg42[%74, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %607 = vector.load %arg42[%75, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %610 = vector.load %arg44[%70, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %611 = vector.load %arg44[%71, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %614 = vector.load %arg44[%74, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %615 = vector.load %arg44[%75, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + // --- SAFE bitcasts --- + %618 = vector.bitcast %594 : vector<16xi8> to vector<32xf4E2M1FN> + %619 = vector.bitcast %595 : vector<16xi8> to vector<32xf4E2M1FN> + %622 = vector.bitcast %598 : vector<1xi8> to vector<1xf8E8M0FNU> + %623 = vector.bitcast %599 : vector<1xi8> to vector<1xf8E8M0FNU> + %626 = vector.bitcast %602 : vector<16xi8> to vector<32xf4E2M1FN> + %627 = vector.bitcast %603 : vector<16xi8> to vector<32xf4E2M1FN> + %630 = vector.bitcast %606 : vector<16xi8> to vector<32xf4E2M1FN> + %631 = vector.bitcast %607 : vector<16xi8> to vector<32xf4E2M1FN> + %634 = vector.bitcast %610 : vector<1xi8> to vector<1xf8E8M0FNU> + %635 = vector.bitcast %611 : vector<1xi8> to vector<1xf8E8M0FNU> + %638 = vector.bitcast %614 : vector<1xi8> to vector<1xf8E8M0FNU> + %639 = vector.bitcast %615 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.s.setprio 1 + // --- SAFE MFMAs: M0,M1 x N0,N1,N4,N5 (cluster 0 data only) --- + %642 = vector.extract %622[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %643 = vector.extract %634[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %644 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%643[0] * %626) + %arg6 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %645 = vector.extract %635[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %646 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%645[0] * %627) + %arg7 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %651 = vector.extract %638[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %652 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%651[0] * %630) + %arg10 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %653 = vector.extract %639[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %654 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%653[0] * %631) + %arg11 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %659 = vector.extract %623[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %660 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%643[0] * %626) + %arg14 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %661 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%645[0] * %627) + %arg15 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %664 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%651[0] * %630) + %arg18 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %665 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%653[0] * %631) + %arg19 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + // --- DEPENDENT vector.loads: A(M2,M3), Ascale(M2,M3), B(N2,N3,N6,N7), Bscale(N2,N3,N6,N7) --- + %596 = vector.load %arg38[%67, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %597 = vector.load %arg38[%68, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %600 = vector.load %arg40[%67, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %601 = vector.load %arg40[%68, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %604 = vector.load %arg42[%72, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %605 = vector.load %arg42[%73, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %608 = vector.load %arg42[%76, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %609 = vector.load %arg42[%77, %65] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %612 = vector.load %arg44[%72, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %613 = vector.load %arg44[%73, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %616 = vector.load %arg44[%76, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %617 = vector.load %arg44[%77, %69] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + // --- DEPENDENT bitcasts --- + %620 = vector.bitcast %596 : vector<16xi8> to vector<32xf4E2M1FN> + %621 = vector.bitcast %597 : vector<16xi8> to vector<32xf4E2M1FN> + %624 = vector.bitcast %600 : vector<1xi8> to vector<1xf8E8M0FNU> + %625 = vector.bitcast %601 : vector<1xi8> to vector<1xf8E8M0FNU> + %628 = vector.bitcast %604 : vector<16xi8> to vector<32xf4E2M1FN> + %629 = vector.bitcast %605 : vector<16xi8> to vector<32xf4E2M1FN> + %632 = vector.bitcast %608 : vector<16xi8> to vector<32xf4E2M1FN> + %633 = vector.bitcast %609 : vector<16xi8> to vector<32xf4E2M1FN> + %636 = vector.bitcast %612 : vector<1xi8> to vector<1xf8E8M0FNU> + %637 = vector.bitcast %613 : vector<1xi8> to vector<1xf8E8M0FNU> + %640 = vector.bitcast %616 : vector<1xi8> to vector<1xf8E8M0FNU> + %641 = vector.bitcast %617 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.s.setprio 1 + // --- DEPENDENT MFMAs: M0,M1 x N2,N3,N6,N7 (cluster 1 B data) --- + %647 = vector.extract %636[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %648 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%647[0] * %628) + %arg8 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %649 = vector.extract %637[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %650 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%649[0] * %629) + %arg9 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %655 = vector.extract %640[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %656 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%655[0] * %632) + %arg12 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %657 = vector.extract %641[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %658 = amdgpu.scaled_mfma 16x16x128 (%642[0] * %618) * (%657[0] * %633) + %arg13 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %662 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%647[0] * %628) + %arg16 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %663 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%649[0] * %629) + %arg17 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %666 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%655[0] * %632) + %arg20 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %667 = amdgpu.scaled_mfma 16x16x128 (%659[0] * %619) * (%657[0] * %633) + %arg21 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + // --- DEPENDENT MFMAs: M2 x all N (cluster 1 A data) --- + %668 = vector.extract %624[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %669 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%643[0] * %626) + %arg22 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %670 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%645[0] * %627) + %arg23 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %671 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%647[0] * %628) + %arg24 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %672 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%649[0] * %629) + %arg25 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %673 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%651[0] * %630) + %arg26 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %674 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%653[0] * %631) + %arg27 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %675 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%655[0] * %632) + %arg28 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %676 = amdgpu.scaled_mfma 16x16x128 (%668[0] * %620) * (%657[0] * %633) + %arg29 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + // --- DEPENDENT MFMAs: M3 x all N (cluster 1 A data) --- + %677 = vector.extract %625[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %678 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%643[0] * %626) + %arg30 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %679 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%645[0] * %627) + %arg31 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %680 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%647[0] * %628) + %arg32 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %681 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%649[0] * %629) + %arg33 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %682 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%651[0] * %630) + %arg34 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %683 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%653[0] * %631) + %arg35 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %684 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%655[0] * %632) + %arg36 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %685 = amdgpu.scaled_mfma 16x16x128 (%677[0] * %621) * (%657[0] * %633) + %arg37 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.sched.barrier 0 + // --- PHASE 2 SAFE vector.loads: A(M0,M1), Ascale(M0,M1), B(N0,N1,N4,N5), Bscale(N0,N1,N4,N5) --- + %686 = vector.load %arg38[%62, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %687 = vector.load %arg38[%66, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %690 = vector.load %arg40[%62, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %691 = vector.load %arg40[%66, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %694 = vector.load %arg42[%70, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %695 = vector.load %arg42[%71, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %698 = vector.load %arg42[%74, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %699 = vector.load %arg42[%75, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %702 = vector.load %arg44[%70, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %703 = vector.load %arg44[%71, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %706 = vector.load %arg44[%74, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %707 = vector.load %arg44[%75, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + // --- PHASE 2 SAFE bitcasts --- + %710 = vector.bitcast %686 : vector<16xi8> to vector<32xf4E2M1FN> + %711 = vector.bitcast %687 : vector<16xi8> to vector<32xf4E2M1FN> + %714 = vector.bitcast %690 : vector<1xi8> to vector<1xf8E8M0FNU> + %715 = vector.bitcast %691 : vector<1xi8> to vector<1xf8E8M0FNU> + %718 = vector.bitcast %694 : vector<16xi8> to vector<32xf4E2M1FN> + %719 = vector.bitcast %695 : vector<16xi8> to vector<32xf4E2M1FN> + %722 = vector.bitcast %698 : vector<16xi8> to vector<32xf4E2M1FN> + %723 = vector.bitcast %699 : vector<16xi8> to vector<32xf4E2M1FN> + %726 = vector.bitcast %702 : vector<1xi8> to vector<1xf8E8M0FNU> + %727 = vector.bitcast %703 : vector<1xi8> to vector<1xf8E8M0FNU> + %730 = vector.bitcast %706 : vector<1xi8> to vector<1xf8E8M0FNU> + %731 = vector.bitcast %707 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.sched.barrier 0 + rocdl.s.barrier + rocdl.sched.barrier 0 + rocdl.s.setprio 1 + // --- PHASE 2 SAFE MFMAs: M0,M1 x N0,N1,N4,N5 (cluster 0 data only) --- + %734 = vector.extract %714[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %735 = vector.extract %726[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %736 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%735[0] * %718) + %644 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %737 = vector.extract %727[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %738 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%737[0] * %719) + %646 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %743 = vector.extract %730[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %744 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%743[0] * %722) + %652 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %745 = vector.extract %731[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %746 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%745[0] * %723) + %654 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %751 = vector.extract %715[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %752 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%735[0] * %718) + %660 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %753 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%737[0] * %719) + %661 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %756 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%743[0] * %722) + %664 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %757 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%745[0] * %723) + %665 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + // --- PHASE 2 DEPENDENT vector.loads: A(M2,M3), Ascale(M2,M3), B(N2,N3,N6,N7), Bscale(N2,N3,N6,N7) --- + %688 = vector.load %arg38[%67, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %689 = vector.load %arg38[%68, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %692 = vector.load %arg40[%67, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %693 = vector.load %arg40[%68, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %696 = vector.load %arg42[%72, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %697 = vector.load %arg42[%73, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %700 = vector.load %arg42[%76, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %701 = vector.load %arg42[%77, %80] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %704 = vector.load %arg44[%72, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %705 = vector.load %arg44[%73, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %708 = vector.load %arg44[%76, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %709 = vector.load %arg44[%77, %82] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + // --- PHASE 2 DEPENDENT bitcasts --- + %712 = vector.bitcast %688 : vector<16xi8> to vector<32xf4E2M1FN> + %713 = vector.bitcast %689 : vector<16xi8> to vector<32xf4E2M1FN> + %716 = vector.bitcast %692 : vector<1xi8> to vector<1xf8E8M0FNU> + %717 = vector.bitcast %693 : vector<1xi8> to vector<1xf8E8M0FNU> + %720 = vector.bitcast %696 : vector<16xi8> to vector<32xf4E2M1FN> + %721 = vector.bitcast %697 : vector<16xi8> to vector<32xf4E2M1FN> + %724 = vector.bitcast %700 : vector<16xi8> to vector<32xf4E2M1FN> + %725 = vector.bitcast %701 : vector<16xi8> to vector<32xf4E2M1FN> + %728 = vector.bitcast %704 : vector<1xi8> to vector<1xf8E8M0FNU> + %729 = vector.bitcast %705 : vector<1xi8> to vector<1xf8E8M0FNU> + %732 = vector.bitcast %708 : vector<1xi8> to vector<1xf8E8M0FNU> + %733 = vector.bitcast %709 : vector<1xi8> to vector<1xf8E8M0FNU> + rocdl.s.setprio 1 + // --- PHASE 2 DEPENDENT MFMAs: M0,M1 x N2,N3,N6,N7 (cluster 1 B data) --- + %739 = vector.extract %728[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %740 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%739[0] * %720) + %648 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %741 = vector.extract %729[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %742 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%741[0] * %721) + %650 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %747 = vector.extract %732[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %748 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%747[0] * %724) + %656 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %749 = vector.extract %733[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %750 = amdgpu.scaled_mfma 16x16x128 (%734[0] * %710) * (%749[0] * %725) + %658 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %754 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%739[0] * %720) + %662 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %755 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%741[0] * %721) + %663 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %758 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%747[0] * %724) + %666 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %759 = amdgpu.scaled_mfma 16x16x128 (%751[0] * %711) * (%749[0] * %725) + %667 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + // --- PHASE 2 DEPENDENT MFMAs: (cluster 1 A data) --- + %760 = vector.extract %716[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %761 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%735[0] * %718) + %669 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %762 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%737[0] * %719) + %670 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %763 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%739[0] * %720) + %671 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %764 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%741[0] * %721) + %672 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %765 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%743[0] * %722) + %673 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %766 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%745[0] * %723) + %674 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %767 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%747[0] * %724) + %675 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %768 = amdgpu.scaled_mfma 16x16x128 (%760[0] * %712) * (%749[0] * %725) + %676 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + // --- PHASE 2 DEPENDENT MFMAs: (cluster 1 A data) --- + %769 = vector.extract %717[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %770 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%735[0] * %718) + %678 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %771 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%737[0] * %719) + %679 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %772 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%739[0] * %720) + %680 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %773 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%741[0] * %721) + %681 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %774 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%743[0] * %722) + %682 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %775 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%745[0] * %723) + %683 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %776 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%747[0] * %724) + %684 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %777 = amdgpu.scaled_mfma 16x16x128 (%769[0] * %713) * (%749[0] * %725) + %685 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + rocdl.s.setprio 0 + rocdl.sched.barrier 0 + scf.yield %736, %738, %740, %742, %744, %746, %748, %750, %752, %753, %754, %755, %756, %757, %758, %759, %761, %762, %763, %764, %765, %766, %767, %768, %770, %771, %772, %773, %774, %775, %776, %777, %arg39, %arg38, %arg41, %arg40, %arg43, %arg42, %arg45, %arg44 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x128xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space>, memref<256x8xi8, #gpu.address_space> + } + scf.if %61 { + rocdl.s.barrier + } + amdgpu.lds_barrier + %84 = affine.apply #map23()[%thread_id_x, %thread_id_y] + %85 = affine.apply #map22()[%thread_id_x] + %86 = vector.load %83#38[%84, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %87 = arith.xori %33, %c1 : index + %88 = affine.apply #map32()[%thread_id_x, %87] + %89 = vector.load %83#38[%84, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %90 = affine.apply #map24()[%thread_id_x, %thread_id_y] + %91 = vector.load %83#38[%90, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %92 = vector.load %83#38[%90, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %93 = affine.apply #map25()[%thread_id_x, %thread_id_y] + %94 = vector.load %83#38[%93, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %95 = vector.load %83#38[%93, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %96 = affine.apply #map26()[%thread_id_x, %thread_id_y] + %97 = vector.load %83#38[%96, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %98 = vector.load %83#38[%96, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %99 = affine.apply #map27()[%thread_id_x, %thread_id_y] + %100 = vector.load %83#38[%99, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %101 = vector.load %83#38[%99, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %102 = affine.apply #map28()[%thread_id_x, %thread_id_y] + %103 = vector.load %83#38[%102, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %104 = vector.load %83#38[%102, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %105 = affine.apply #map29()[%thread_id_x, %thread_id_y] + %106 = vector.load %83#38[%105, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %107 = vector.load %83#38[%105, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %108 = affine.apply #map30()[%thread_id_x, %thread_id_y] + %109 = vector.load %83#38[%108, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %110 = vector.load %83#38[%108, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %111 = affine.apply #map18()[%thread_id_x] + %112 = arith.xori %111, %7 : index + %113 = affine.apply #map3()[%112] + %114 = vector.load %83#36[%84, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %115 = affine.apply #map31()[%thread_id_x] + %116 = arith.xori %115, %7 : index + %117 = affine.apply #map3()[%116] + %118 = vector.load %83#36[%84, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %119 = vector.load %83#36[%90, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %120 = vector.load %83#36[%90, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %121 = vector.load %83#36[%93, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %122 = vector.load %83#36[%93, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %123 = vector.load %83#36[%96, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %124 = vector.load %83#36[%96, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %125 = vector.load %83#36[%99, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %126 = vector.load %83#36[%99, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %127 = vector.load %83#36[%102, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %128 = vector.load %83#36[%102, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %129 = vector.load %83#36[%105, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %130 = vector.load %83#36[%105, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %131 = vector.load %83#36[%108, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %132 = vector.load %83#36[%108, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %133 = affine.apply #map17()[%thread_id_x] + %134 = vector.load %83#34[%133, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %135 = vector.load %83#34[%133, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %136 = affine.apply #map19()[%thread_id_x] + %137 = vector.load %83#34[%136, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %138 = vector.load %83#34[%136, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %139 = affine.apply #map20()[%thread_id_x] + %140 = vector.load %83#34[%139, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %141 = vector.load %83#34[%139, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %142 = affine.apply #map21()[%thread_id_x] + %143 = vector.load %83#34[%142, %85] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %144 = vector.load %83#34[%142, %88] : memref<256x8xi8, #gpu.address_space>, vector<1xi8> + %145 = vector.load %83#32[%133, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %146 = vector.load %83#32[%133, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %147 = vector.load %83#32[%136, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %148 = vector.load %83#32[%136, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %149 = vector.load %83#32[%139, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %150 = vector.load %83#32[%139, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %151 = vector.load %83#32[%142, %113] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %152 = vector.load %83#32[%142, %117] : memref<256x128xi8, #gpu.address_space>, vector<16xi8> + %153 = vector.bitcast %145 : vector<16xi8> to vector<32xf4E2M1FN> + %154 = vector.bitcast %146 : vector<16xi8> to vector<32xf4E2M1FN> + %155 = vector.bitcast %147 : vector<16xi8> to vector<32xf4E2M1FN> + %156 = vector.bitcast %148 : vector<16xi8> to vector<32xf4E2M1FN> + %157 = vector.bitcast %149 : vector<16xi8> to vector<32xf4E2M1FN> + %158 = vector.bitcast %150 : vector<16xi8> to vector<32xf4E2M1FN> + %159 = vector.bitcast %151 : vector<16xi8> to vector<32xf4E2M1FN> + %160 = vector.bitcast %152 : vector<16xi8> to vector<32xf4E2M1FN> + %161 = vector.bitcast %134 : vector<1xi8> to vector<1xf8E8M0FNU> + %162 = vector.bitcast %135 : vector<1xi8> to vector<1xf8E8M0FNU> + %163 = vector.bitcast %137 : vector<1xi8> to vector<1xf8E8M0FNU> + %164 = vector.bitcast %138 : vector<1xi8> to vector<1xf8E8M0FNU> + %165 = vector.bitcast %140 : vector<1xi8> to vector<1xf8E8M0FNU> + %166 = vector.bitcast %141 : vector<1xi8> to vector<1xf8E8M0FNU> + %167 = vector.bitcast %143 : vector<1xi8> to vector<1xf8E8M0FNU> + %168 = vector.bitcast %144 : vector<1xi8> to vector<1xf8E8M0FNU> + %169 = vector.bitcast %114 : vector<16xi8> to vector<32xf4E2M1FN> + %170 = vector.bitcast %118 : vector<16xi8> to vector<32xf4E2M1FN> + %171 = vector.bitcast %119 : vector<16xi8> to vector<32xf4E2M1FN> + %172 = vector.bitcast %120 : vector<16xi8> to vector<32xf4E2M1FN> + %173 = vector.bitcast %121 : vector<16xi8> to vector<32xf4E2M1FN> + %174 = vector.bitcast %122 : vector<16xi8> to vector<32xf4E2M1FN> + %175 = vector.bitcast %123 : vector<16xi8> to vector<32xf4E2M1FN> + %176 = vector.bitcast %124 : vector<16xi8> to vector<32xf4E2M1FN> + %177 = vector.bitcast %125 : vector<16xi8> to vector<32xf4E2M1FN> + %178 = vector.bitcast %126 : vector<16xi8> to vector<32xf4E2M1FN> + %179 = vector.bitcast %127 : vector<16xi8> to vector<32xf4E2M1FN> + %180 = vector.bitcast %128 : vector<16xi8> to vector<32xf4E2M1FN> + %181 = vector.bitcast %129 : vector<16xi8> to vector<32xf4E2M1FN> + %182 = vector.bitcast %130 : vector<16xi8> to vector<32xf4E2M1FN> + %183 = vector.bitcast %131 : vector<16xi8> to vector<32xf4E2M1FN> + %184 = vector.bitcast %132 : vector<16xi8> to vector<32xf4E2M1FN> + %185 = vector.bitcast %86 : vector<1xi8> to vector<1xf8E8M0FNU> + %186 = vector.bitcast %89 : vector<1xi8> to vector<1xf8E8M0FNU> + %187 = vector.bitcast %91 : vector<1xi8> to vector<1xf8E8M0FNU> + %188 = vector.bitcast %92 : vector<1xi8> to vector<1xf8E8M0FNU> + %189 = vector.bitcast %94 : vector<1xi8> to vector<1xf8E8M0FNU> + %190 = vector.bitcast %95 : vector<1xi8> to vector<1xf8E8M0FNU> + %191 = vector.bitcast %97 : vector<1xi8> to vector<1xf8E8M0FNU> + %192 = vector.bitcast %98 : vector<1xi8> to vector<1xf8E8M0FNU> + %193 = vector.bitcast %100 : vector<1xi8> to vector<1xf8E8M0FNU> + %194 = vector.bitcast %101 : vector<1xi8> to vector<1xf8E8M0FNU> + %195 = vector.bitcast %103 : vector<1xi8> to vector<1xf8E8M0FNU> + %196 = vector.bitcast %104 : vector<1xi8> to vector<1xf8E8M0FNU> + %197 = vector.bitcast %106 : vector<1xi8> to vector<1xf8E8M0FNU> + %198 = vector.bitcast %107 : vector<1xi8> to vector<1xf8E8M0FNU> + %199 = vector.bitcast %109 : vector<1xi8> to vector<1xf8E8M0FNU> + %200 = vector.bitcast %110 : vector<1xi8> to vector<1xf8E8M0FNU> + %201 = vector.extract %161[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %202 = vector.extract %185[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %203 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%202[0] * %169) + %83#0 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %204 = vector.extract %162[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %205 = vector.extract %186[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %206 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%205[0] * %170) + %203 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %207 = vector.extract %187[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %208 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%207[0] * %171) + %83#1 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %209 = vector.extract %188[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %210 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%209[0] * %172) + %208 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %211 = vector.extract %189[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %212 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%211[0] * %173) + %83#2 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %213 = vector.extract %190[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %214 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%213[0] * %174) + %212 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %215 = vector.extract %191[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %216 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%215[0] * %175) + %83#3 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %217 = vector.extract %192[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %218 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%217[0] * %176) + %216 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %219 = vector.extract %193[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %220 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%219[0] * %177) + %83#4 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %221 = vector.extract %194[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %222 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%221[0] * %178) + %220 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %223 = vector.extract %195[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %224 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%223[0] * %179) + %83#5 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %225 = vector.extract %196[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %226 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%225[0] * %180) + %224 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %227 = vector.extract %197[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %228 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%227[0] * %181) + %83#6 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %229 = vector.extract %198[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %230 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%229[0] * %182) + %228 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %231 = vector.extract %199[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %232 = amdgpu.scaled_mfma 16x16x128 (%201[0] * %153) * (%231[0] * %183) + %83#7 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %233 = vector.extract %200[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %234 = amdgpu.scaled_mfma 16x16x128 (%204[0] * %154) * (%233[0] * %184) + %232 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %235 = vector.extract %163[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %236 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%202[0] * %169) + %83#8 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %237 = vector.extract %164[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %238 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%205[0] * %170) + %236 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %239 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%207[0] * %171) + %83#9 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %240 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%209[0] * %172) + %239 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %241 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%211[0] * %173) + %83#10 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %242 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%213[0] * %174) + %241 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %243 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%215[0] * %175) + %83#11 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %244 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%217[0] * %176) + %243 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %245 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%219[0] * %177) + %83#12 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %246 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%221[0] * %178) + %245 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %247 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%223[0] * %179) + %83#13 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %248 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%225[0] * %180) + %247 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %249 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%227[0] * %181) + %83#14 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %250 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%229[0] * %182) + %249 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %251 = amdgpu.scaled_mfma 16x16x128 (%235[0] * %155) * (%231[0] * %183) + %83#15 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %252 = amdgpu.scaled_mfma 16x16x128 (%237[0] * %156) * (%233[0] * %184) + %251 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %253 = vector.extract %165[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %254 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%202[0] * %169) + %83#16 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %255 = vector.extract %166[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %256 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%205[0] * %170) + %254 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %257 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%207[0] * %171) + %83#17 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %258 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%209[0] * %172) + %257 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %259 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%211[0] * %173) + %83#18 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %260 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%213[0] * %174) + %259 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %261 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%215[0] * %175) + %83#19 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %262 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%217[0] * %176) + %261 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %263 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%219[0] * %177) + %83#20 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %264 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%221[0] * %178) + %263 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %265 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%223[0] * %179) + %83#21 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %266 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%225[0] * %180) + %265 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %267 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%227[0] * %181) + %83#22 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %268 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%229[0] * %182) + %267 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %269 = amdgpu.scaled_mfma 16x16x128 (%253[0] * %157) * (%231[0] * %183) + %83#23 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %270 = amdgpu.scaled_mfma 16x16x128 (%255[0] * %158) * (%233[0] * %184) + %269 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %271 = vector.extract %167[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %272 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%202[0] * %169) + %83#24 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %273 = vector.extract %168[0] : f8E8M0FNU from vector<1xf8E8M0FNU> + %274 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%205[0] * %170) + %272 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %275 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%207[0] * %171) + %83#25 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %276 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%209[0] * %172) + %275 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %277 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%211[0] * %173) + %83#26 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %278 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%213[0] * %174) + %277 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %279 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%215[0] * %175) + %83#27 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %280 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%217[0] * %176) + %279 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %281 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%219[0] * %177) + %83#28 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %282 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%221[0] * %178) + %281 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %283 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%223[0] * %179) + %83#29 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %284 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%225[0] * %180) + %283 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %285 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%227[0] * %181) + %83#30 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %286 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%229[0] * %182) + %285 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %287 = amdgpu.scaled_mfma 16x16x128 (%271[0] * %159) * (%231[0] * %183) + %83#31 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %288 = amdgpu.scaled_mfma 16x16x128 (%273[0] * %160) * (%233[0] * %184) + %287 : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> + %289 = vector.extract_strided_slice %206 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %290 = affine.apply #map35()[%block_id_x] + %291 = affine.apply #map35()[%block_id_y] + %292 = affine.apply #map36()[%thread_id_x] + %293 = arith.muli %290, %c57344 overflow : index + %294 = arith.muli %292, %c57344 overflow : index + %295 = arith.addi %293, %291 overflow : index + %296 = arith.addi %294, %84 overflow : index + %reinterpret_cast_13 = memref.reinterpret_cast %4 to offset: [%295], sizes: [536870910], strides: [1] : memref to memref<536870910xf32, strided<[1], offset: ?>> + %cast_14 = memref.cast %reinterpret_cast_13 : memref<536870910xf32, strided<[1], offset: ?>> to memref> + %297 = amdgpu.fat_raw_buffer_cast %cast_14 validBytes(%c2147483643_i64) resetOffset : memref> to memref> + vector.store %289, %297[%296] : memref>, vector<1xf32> + %298 = vector.extract_strided_slice %206 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %299 = affine.apply #map37()[%thread_id_x] + %300 = arith.muli %299, %c57344 overflow : index + %301 = arith.addi %300, %84 overflow : index + vector.store %298, %297[%301] : memref>, vector<1xf32> + %302 = vector.extract_strided_slice %206 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %303 = affine.apply #map38()[%thread_id_x] + %304 = arith.muli %303, %c57344 overflow : index + %305 = arith.addi %304, %84 overflow : index + vector.store %302, %297[%305] : memref>, vector<1xf32> + %306 = vector.extract_strided_slice %206 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %307 = affine.apply #map39()[%thread_id_x] + %308 = arith.muli %307, %c57344 overflow : index + %309 = arith.addi %308, %84 overflow : index + vector.store %306, %297[%309] : memref>, vector<1xf32> + %310 = vector.extract_strided_slice %210 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %311 = arith.addi %294, %90 overflow : index + vector.store %310, %297[%311] : memref>, vector<1xf32> + %312 = vector.extract_strided_slice %210 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %313 = arith.addi %300, %90 overflow : index + vector.store %312, %297[%313] : memref>, vector<1xf32> + %314 = vector.extract_strided_slice %210 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %315 = arith.addi %304, %90 overflow : index + vector.store %314, %297[%315] : memref>, vector<1xf32> + %316 = vector.extract_strided_slice %210 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %317 = arith.addi %308, %90 overflow : index + vector.store %316, %297[%317] : memref>, vector<1xf32> + %318 = vector.extract_strided_slice %214 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %319 = arith.addi %294, %93 overflow : index + vector.store %318, %297[%319] : memref>, vector<1xf32> + %320 = vector.extract_strided_slice %214 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %321 = arith.addi %300, %93 overflow : index + vector.store %320, %297[%321] : memref>, vector<1xf32> + %322 = vector.extract_strided_slice %214 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %323 = arith.addi %304, %93 overflow : index + vector.store %322, %297[%323] : memref>, vector<1xf32> + %324 = vector.extract_strided_slice %214 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %325 = arith.addi %308, %93 overflow : index + vector.store %324, %297[%325] : memref>, vector<1xf32> + %326 = vector.extract_strided_slice %218 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %327 = arith.addi %294, %96 overflow : index + vector.store %326, %297[%327] : memref>, vector<1xf32> + %328 = vector.extract_strided_slice %218 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %329 = arith.addi %300, %96 overflow : index + vector.store %328, %297[%329] : memref>, vector<1xf32> + %330 = vector.extract_strided_slice %218 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %331 = arith.addi %304, %96 overflow : index + vector.store %330, %297[%331] : memref>, vector<1xf32> + %332 = vector.extract_strided_slice %218 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %333 = arith.addi %308, %96 overflow : index + vector.store %332, %297[%333] : memref>, vector<1xf32> + %334 = vector.extract_strided_slice %222 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %335 = arith.addi %294, %99 overflow : index + vector.store %334, %297[%335] : memref>, vector<1xf32> + %336 = vector.extract_strided_slice %222 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %337 = arith.addi %300, %99 overflow : index + vector.store %336, %297[%337] : memref>, vector<1xf32> + %338 = vector.extract_strided_slice %222 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %339 = arith.addi %304, %99 overflow : index + vector.store %338, %297[%339] : memref>, vector<1xf32> + %340 = vector.extract_strided_slice %222 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %341 = arith.addi %308, %99 overflow : index + vector.store %340, %297[%341] : memref>, vector<1xf32> + %342 = vector.extract_strided_slice %226 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %343 = arith.addi %294, %102 overflow : index + vector.store %342, %297[%343] : memref>, vector<1xf32> + %344 = vector.extract_strided_slice %226 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %345 = arith.addi %300, %102 overflow : index + vector.store %344, %297[%345] : memref>, vector<1xf32> + %346 = vector.extract_strided_slice %226 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %347 = arith.addi %304, %102 overflow : index + vector.store %346, %297[%347] : memref>, vector<1xf32> + %348 = vector.extract_strided_slice %226 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %349 = arith.addi %308, %102 overflow : index + vector.store %348, %297[%349] : memref>, vector<1xf32> + %350 = vector.extract_strided_slice %230 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %351 = arith.addi %294, %105 overflow : index + vector.store %350, %297[%351] : memref>, vector<1xf32> + %352 = vector.extract_strided_slice %230 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %353 = arith.addi %300, %105 overflow : index + vector.store %352, %297[%353] : memref>, vector<1xf32> + %354 = vector.extract_strided_slice %230 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %355 = arith.addi %304, %105 overflow : index + vector.store %354, %297[%355] : memref>, vector<1xf32> + %356 = vector.extract_strided_slice %230 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %357 = arith.addi %308, %105 overflow : index + vector.store %356, %297[%357] : memref>, vector<1xf32> + %358 = vector.extract_strided_slice %234 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %359 = arith.addi %294, %108 overflow : index + vector.store %358, %297[%359] : memref>, vector<1xf32> + %360 = vector.extract_strided_slice %234 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %361 = arith.addi %300, %108 overflow : index + vector.store %360, %297[%361] : memref>, vector<1xf32> + %362 = vector.extract_strided_slice %234 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %363 = arith.addi %304, %108 overflow : index + vector.store %362, %297[%363] : memref>, vector<1xf32> + %364 = vector.extract_strided_slice %234 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %365 = arith.addi %308, %108 overflow : index + vector.store %364, %297[%365] : memref>, vector<1xf32> + %366 = vector.extract_strided_slice %238 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %367 = affine.apply #map40()[%thread_id_x] + %368 = arith.muli %367, %c57344 overflow : index + %369 = arith.addi %368, %84 overflow : index + vector.store %366, %297[%369] : memref>, vector<1xf32> + %370 = vector.extract_strided_slice %238 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %371 = affine.apply #map41()[%thread_id_x] + %372 = arith.muli %371, %c57344 overflow : index + %373 = arith.addi %372, %84 overflow : index + vector.store %370, %297[%373] : memref>, vector<1xf32> + %374 = vector.extract_strided_slice %238 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %375 = affine.apply #map42()[%thread_id_x] + %376 = arith.muli %375, %c57344 overflow : index + %377 = arith.addi %376, %84 overflow : index + vector.store %374, %297[%377] : memref>, vector<1xf32> + %378 = vector.extract_strided_slice %238 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %379 = affine.apply #map43()[%thread_id_x] + %380 = arith.muli %379, %c57344 overflow : index + %381 = arith.addi %380, %84 overflow : index + vector.store %378, %297[%381] : memref>, vector<1xf32> + %382 = vector.extract_strided_slice %240 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %383 = arith.addi %368, %90 overflow : index + vector.store %382, %297[%383] : memref>, vector<1xf32> + %384 = vector.extract_strided_slice %240 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %385 = arith.addi %372, %90 overflow : index + vector.store %384, %297[%385] : memref>, vector<1xf32> + %386 = vector.extract_strided_slice %240 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %387 = arith.addi %376, %90 overflow : index + vector.store %386, %297[%387] : memref>, vector<1xf32> + %388 = vector.extract_strided_slice %240 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %389 = arith.addi %380, %90 overflow : index + vector.store %388, %297[%389] : memref>, vector<1xf32> + %390 = vector.extract_strided_slice %242 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %391 = arith.addi %368, %93 overflow : index + vector.store %390, %297[%391] : memref>, vector<1xf32> + %392 = vector.extract_strided_slice %242 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %393 = arith.addi %372, %93 overflow : index + vector.store %392, %297[%393] : memref>, vector<1xf32> + %394 = vector.extract_strided_slice %242 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %395 = arith.addi %376, %93 overflow : index + vector.store %394, %297[%395] : memref>, vector<1xf32> + %396 = vector.extract_strided_slice %242 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %397 = arith.addi %380, %93 overflow : index + vector.store %396, %297[%397] : memref>, vector<1xf32> + %398 = vector.extract_strided_slice %244 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %399 = arith.addi %368, %96 overflow : index + vector.store %398, %297[%399] : memref>, vector<1xf32> + %400 = vector.extract_strided_slice %244 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %401 = arith.addi %372, %96 overflow : index + vector.store %400, %297[%401] : memref>, vector<1xf32> + %402 = vector.extract_strided_slice %244 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %403 = arith.addi %376, %96 overflow : index + vector.store %402, %297[%403] : memref>, vector<1xf32> + %404 = vector.extract_strided_slice %244 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %405 = arith.addi %380, %96 overflow : index + vector.store %404, %297[%405] : memref>, vector<1xf32> + %406 = vector.extract_strided_slice %246 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %407 = arith.addi %368, %99 overflow : index + vector.store %406, %297[%407] : memref>, vector<1xf32> + %408 = vector.extract_strided_slice %246 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %409 = arith.addi %372, %99 overflow : index + vector.store %408, %297[%409] : memref>, vector<1xf32> + %410 = vector.extract_strided_slice %246 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %411 = arith.addi %376, %99 overflow : index + vector.store %410, %297[%411] : memref>, vector<1xf32> + %412 = vector.extract_strided_slice %246 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %413 = arith.addi %380, %99 overflow : index + vector.store %412, %297[%413] : memref>, vector<1xf32> + %414 = vector.extract_strided_slice %248 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %415 = arith.addi %368, %102 overflow : index + vector.store %414, %297[%415] : memref>, vector<1xf32> + %416 = vector.extract_strided_slice %248 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %417 = arith.addi %372, %102 overflow : index + vector.store %416, %297[%417] : memref>, vector<1xf32> + %418 = vector.extract_strided_slice %248 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %419 = arith.addi %376, %102 overflow : index + vector.store %418, %297[%419] : memref>, vector<1xf32> + %420 = vector.extract_strided_slice %248 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %421 = arith.addi %380, %102 overflow : index + vector.store %420, %297[%421] : memref>, vector<1xf32> + %422 = vector.extract_strided_slice %250 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %423 = arith.addi %368, %105 overflow : index + vector.store %422, %297[%423] : memref>, vector<1xf32> + %424 = vector.extract_strided_slice %250 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %425 = arith.addi %372, %105 overflow : index + vector.store %424, %297[%425] : memref>, vector<1xf32> + %426 = vector.extract_strided_slice %250 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %427 = arith.addi %376, %105 overflow : index + vector.store %426, %297[%427] : memref>, vector<1xf32> + %428 = vector.extract_strided_slice %250 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %429 = arith.addi %380, %105 overflow : index + vector.store %428, %297[%429] : memref>, vector<1xf32> + %430 = vector.extract_strided_slice %252 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %431 = arith.addi %368, %108 overflow : index + vector.store %430, %297[%431] : memref>, vector<1xf32> + %432 = vector.extract_strided_slice %252 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %433 = arith.addi %372, %108 overflow : index + vector.store %432, %297[%433] : memref>, vector<1xf32> + %434 = vector.extract_strided_slice %252 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %435 = arith.addi %376, %108 overflow : index + vector.store %434, %297[%435] : memref>, vector<1xf32> + %436 = vector.extract_strided_slice %252 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %437 = arith.addi %380, %108 overflow : index + vector.store %436, %297[%437] : memref>, vector<1xf32> + %438 = vector.extract_strided_slice %256 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %439 = affine.apply #map44()[%thread_id_x] + %440 = arith.muli %439, %c57344 overflow : index + %441 = arith.addi %440, %84 overflow : index + vector.store %438, %297[%441] : memref>, vector<1xf32> + %442 = vector.extract_strided_slice %256 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %443 = affine.apply #map45()[%thread_id_x] + %444 = arith.muli %443, %c57344 overflow : index + %445 = arith.addi %444, %84 overflow : index + vector.store %442, %297[%445] : memref>, vector<1xf32> + %446 = vector.extract_strided_slice %256 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %447 = affine.apply #map46()[%thread_id_x] + %448 = arith.muli %447, %c57344 overflow : index + %449 = arith.addi %448, %84 overflow : index + vector.store %446, %297[%449] : memref>, vector<1xf32> + %450 = vector.extract_strided_slice %256 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %451 = affine.apply #map47()[%thread_id_x] + %452 = arith.muli %451, %c57344 overflow : index + %453 = arith.addi %452, %84 overflow : index + vector.store %450, %297[%453] : memref>, vector<1xf32> + %454 = vector.extract_strided_slice %258 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %455 = arith.addi %440, %90 overflow : index + vector.store %454, %297[%455] : memref>, vector<1xf32> + %456 = vector.extract_strided_slice %258 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %457 = arith.addi %444, %90 overflow : index + vector.store %456, %297[%457] : memref>, vector<1xf32> + %458 = vector.extract_strided_slice %258 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %459 = arith.addi %448, %90 overflow : index + vector.store %458, %297[%459] : memref>, vector<1xf32> + %460 = vector.extract_strided_slice %258 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %461 = arith.addi %452, %90 overflow : index + vector.store %460, %297[%461] : memref>, vector<1xf32> + %462 = vector.extract_strided_slice %260 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %463 = arith.addi %440, %93 overflow : index + vector.store %462, %297[%463] : memref>, vector<1xf32> + %464 = vector.extract_strided_slice %260 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %465 = arith.addi %444, %93 overflow : index + vector.store %464, %297[%465] : memref>, vector<1xf32> + %466 = vector.extract_strided_slice %260 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %467 = arith.addi %448, %93 overflow : index + vector.store %466, %297[%467] : memref>, vector<1xf32> + %468 = vector.extract_strided_slice %260 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %469 = arith.addi %452, %93 overflow : index + vector.store %468, %297[%469] : memref>, vector<1xf32> + %470 = vector.extract_strided_slice %262 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %471 = arith.addi %440, %96 overflow : index + vector.store %470, %297[%471] : memref>, vector<1xf32> + %472 = vector.extract_strided_slice %262 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %473 = arith.addi %444, %96 overflow : index + vector.store %472, %297[%473] : memref>, vector<1xf32> + %474 = vector.extract_strided_slice %262 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %475 = arith.addi %448, %96 overflow : index + vector.store %474, %297[%475] : memref>, vector<1xf32> + %476 = vector.extract_strided_slice %262 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %477 = arith.addi %452, %96 overflow : index + vector.store %476, %297[%477] : memref>, vector<1xf32> + %478 = vector.extract_strided_slice %264 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %479 = arith.addi %440, %99 overflow : index + vector.store %478, %297[%479] : memref>, vector<1xf32> + %480 = vector.extract_strided_slice %264 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %481 = arith.addi %444, %99 overflow : index + vector.store %480, %297[%481] : memref>, vector<1xf32> + %482 = vector.extract_strided_slice %264 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %483 = arith.addi %448, %99 overflow : index + vector.store %482, %297[%483] : memref>, vector<1xf32> + %484 = vector.extract_strided_slice %264 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %485 = arith.addi %452, %99 overflow : index + vector.store %484, %297[%485] : memref>, vector<1xf32> + %486 = vector.extract_strided_slice %266 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %487 = arith.addi %440, %102 overflow : index + vector.store %486, %297[%487] : memref>, vector<1xf32> + %488 = vector.extract_strided_slice %266 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %489 = arith.addi %444, %102 overflow : index + vector.store %488, %297[%489] : memref>, vector<1xf32> + %490 = vector.extract_strided_slice %266 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %491 = arith.addi %448, %102 overflow : index + vector.store %490, %297[%491] : memref>, vector<1xf32> + %492 = vector.extract_strided_slice %266 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %493 = arith.addi %452, %102 overflow : index + vector.store %492, %297[%493] : memref>, vector<1xf32> + %494 = vector.extract_strided_slice %268 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %495 = arith.addi %440, %105 overflow : index + vector.store %494, %297[%495] : memref>, vector<1xf32> + %496 = vector.extract_strided_slice %268 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %497 = arith.addi %444, %105 overflow : index + vector.store %496, %297[%497] : memref>, vector<1xf32> + %498 = vector.extract_strided_slice %268 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %499 = arith.addi %448, %105 overflow : index + vector.store %498, %297[%499] : memref>, vector<1xf32> + %500 = vector.extract_strided_slice %268 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %501 = arith.addi %452, %105 overflow : index + vector.store %500, %297[%501] : memref>, vector<1xf32> + %502 = vector.extract_strided_slice %270 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %503 = arith.addi %440, %108 overflow : index + vector.store %502, %297[%503] : memref>, vector<1xf32> + %504 = vector.extract_strided_slice %270 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %505 = arith.addi %444, %108 overflow : index + vector.store %504, %297[%505] : memref>, vector<1xf32> + %506 = vector.extract_strided_slice %270 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %507 = arith.addi %448, %108 overflow : index + vector.store %506, %297[%507] : memref>, vector<1xf32> + %508 = vector.extract_strided_slice %270 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %509 = arith.addi %452, %108 overflow : index + vector.store %508, %297[%509] : memref>, vector<1xf32> + %510 = vector.extract_strided_slice %274 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %511 = affine.apply #map48()[%thread_id_x] + %512 = arith.muli %511, %c57344 overflow : index + %513 = arith.addi %512, %84 overflow : index + vector.store %510, %297[%513] : memref>, vector<1xf32> + %514 = vector.extract_strided_slice %274 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %515 = affine.apply #map49()[%thread_id_x] + %516 = arith.muli %515, %c57344 overflow : index + %517 = arith.addi %516, %84 overflow : index + vector.store %514, %297[%517] : memref>, vector<1xf32> + %518 = vector.extract_strided_slice %274 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %519 = affine.apply #map50()[%thread_id_x] + %520 = arith.muli %519, %c57344 overflow : index + %521 = arith.addi %520, %84 overflow : index + vector.store %518, %297[%521] : memref>, vector<1xf32> + %522 = vector.extract_strided_slice %274 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %523 = affine.apply #map51()[%thread_id_x] + %524 = arith.muli %523, %c57344 overflow : index + %525 = arith.addi %524, %84 overflow : index + vector.store %522, %297[%525] : memref>, vector<1xf32> + %526 = vector.extract_strided_slice %276 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %527 = arith.addi %512, %90 overflow : index + vector.store %526, %297[%527] : memref>, vector<1xf32> + %528 = vector.extract_strided_slice %276 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %529 = arith.addi %516, %90 overflow : index + vector.store %528, %297[%529] : memref>, vector<1xf32> + %530 = vector.extract_strided_slice %276 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %531 = arith.addi %520, %90 overflow : index + vector.store %530, %297[%531] : memref>, vector<1xf32> + %532 = vector.extract_strided_slice %276 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %533 = arith.addi %524, %90 overflow : index + vector.store %532, %297[%533] : memref>, vector<1xf32> + %534 = vector.extract_strided_slice %278 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %535 = arith.addi %512, %93 overflow : index + vector.store %534, %297[%535] : memref>, vector<1xf32> + %536 = vector.extract_strided_slice %278 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %537 = arith.addi %516, %93 overflow : index + vector.store %536, %297[%537] : memref>, vector<1xf32> + %538 = vector.extract_strided_slice %278 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %539 = arith.addi %520, %93 overflow : index + vector.store %538, %297[%539] : memref>, vector<1xf32> + %540 = vector.extract_strided_slice %278 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %541 = arith.addi %524, %93 overflow : index + vector.store %540, %297[%541] : memref>, vector<1xf32> + %542 = vector.extract_strided_slice %280 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %543 = arith.addi %512, %96 overflow : index + vector.store %542, %297[%543] : memref>, vector<1xf32> + %544 = vector.extract_strided_slice %280 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %545 = arith.addi %516, %96 overflow : index + vector.store %544, %297[%545] : memref>, vector<1xf32> + %546 = vector.extract_strided_slice %280 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %547 = arith.addi %520, %96 overflow : index + vector.store %546, %297[%547] : memref>, vector<1xf32> + %548 = vector.extract_strided_slice %280 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %549 = arith.addi %524, %96 overflow : index + vector.store %548, %297[%549] : memref>, vector<1xf32> + %550 = vector.extract_strided_slice %282 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %551 = arith.addi %512, %99 overflow : index + vector.store %550, %297[%551] : memref>, vector<1xf32> + %552 = vector.extract_strided_slice %282 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %553 = arith.addi %516, %99 overflow : index + vector.store %552, %297[%553] : memref>, vector<1xf32> + %554 = vector.extract_strided_slice %282 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %555 = arith.addi %520, %99 overflow : index + vector.store %554, %297[%555] : memref>, vector<1xf32> + %556 = vector.extract_strided_slice %282 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %557 = arith.addi %524, %99 overflow : index + vector.store %556, %297[%557] : memref>, vector<1xf32> + %558 = vector.extract_strided_slice %284 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %559 = arith.addi %512, %102 overflow : index + vector.store %558, %297[%559] : memref>, vector<1xf32> + %560 = vector.extract_strided_slice %284 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %561 = arith.addi %516, %102 overflow : index + vector.store %560, %297[%561] : memref>, vector<1xf32> + %562 = vector.extract_strided_slice %284 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %563 = arith.addi %520, %102 overflow : index + vector.store %562, %297[%563] : memref>, vector<1xf32> + %564 = vector.extract_strided_slice %284 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %565 = arith.addi %524, %102 overflow : index + vector.store %564, %297[%565] : memref>, vector<1xf32> + %566 = vector.extract_strided_slice %286 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %567 = arith.addi %512, %105 overflow : index + vector.store %566, %297[%567] : memref>, vector<1xf32> + %568 = vector.extract_strided_slice %286 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %569 = arith.addi %516, %105 overflow : index + vector.store %568, %297[%569] : memref>, vector<1xf32> + %570 = vector.extract_strided_slice %286 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %571 = arith.addi %520, %105 overflow : index + vector.store %570, %297[%571] : memref>, vector<1xf32> + %572 = vector.extract_strided_slice %286 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %573 = arith.addi %524, %105 overflow : index + vector.store %572, %297[%573] : memref>, vector<1xf32> + %574 = vector.extract_strided_slice %288 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %575 = arith.addi %512, %108 overflow : index + vector.store %574, %297[%575] : memref>, vector<1xf32> + %576 = vector.extract_strided_slice %288 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %577 = arith.addi %516, %108 overflow : index + vector.store %576, %297[%577] : memref>, vector<1xf32> + %578 = vector.extract_strided_slice %288 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %579 = arith.addi %520, %108 overflow : index + vector.store %578, %297[%579] : memref>, vector<1xf32> + %580 = vector.extract_strided_slice %288 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> + %581 = arith.addi %524, %108 overflow : index + vector.store %580, %297[%581] : memref>, vector<1xf32> + return + } + } + } + func.func @isolated_benchmark$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view, %arg5: !hal.fence, %arg6: !hal.fence) -> !hal.buffer_view { + %0 = hal.tensor.import wait(%arg5) => %arg0 : !hal.buffer_view -> tensor<4096x8192xi8> + %1 = hal.tensor.import wait(%arg5) => %arg1 : !hal.buffer_view -> tensor<4096x512xi8> + %2 = hal.tensor.import wait(%arg5) => %arg2 : !hal.buffer_view -> tensor<57344x8192xi8> + %3 = hal.tensor.import wait(%arg5) => %arg3 : !hal.buffer_view -> tensor<57344x512xi8> + %4 = hal.tensor.import wait(%arg5) => %arg4 : !hal.buffer_view -> tensor<4096x57344xf32> + %5 = flow.dispatch @gemm::@gemm(%0, %1, %2, %3, %4) : (tensor<4096x8192xi8>, tensor<4096x512xi8>, tensor<57344x8192xi8>, tensor<57344x512xi8>, tensor<4096x57344xf32>) -> %4 + %6 = hal.tensor.barrier join(%5 : tensor<4096x57344xf32>) => %arg6 : !hal.fence + %7 = hal.tensor.export %6 : tensor<4096x57344xf32> -> !hal.buffer_view + return %7 : !hal.buffer_view + } + } + """ gemm, options = get_tagged_mxfp4_gemm(shape, block, num_waves=8) + schedule = get_mxfp4_dbuf_schedule(use_stagger=True) + options.use_buffer_ops = True + options.specialize = True + # options.override_mlir=mlir_claude_rescheduled2 + options.print_ir_after = "all" if is_debug else [] + # options.print_ir_after = "all" options = set_default_run_config(options) gemm = wave_compile(options, gemm, schedule) + print(gemm.asm) + _run_mxfp_gemm(gemm, shape) print("MXFP GEMM double-buffer 8-wave test passed!") +def test_dbuf_8wave_mxfp_gemm_shuffle( + is_debug=False, shape=(4096, 57344, 16384), block=(256, 256, 256) +): + """Double-buffered MXFP4 GEMM, 8 waves, with stagger.""" + + gemm, options = get_preshuffle_kernel(shape, block) + + schedule = get_mxfp4_dbuf_schedule_shuffle(use_stagger=True) + + options.use_buffer_ops = True + options.specialize = True + # options.override_mlir=mlir_claude_rescheduled2 + + options.print_ir_after = "all" if is_debug else [] + # options.print_ir_after = "all" + options = set_default_run_config(options) + gemm = wave_compile(options, gemm, schedule) + + print(gemm.asm) + + _run_mxfp_gemm(gemm, shape, shuffle_scales=True) + print("MXFP GEMM double-buffer 8-wave test passed!") + + +def test_triplebuf_8wave_mxfp_gemm( + is_debug=False, shape=(16384, 16384, 16384), block=(256, 256, 256) +): + """Double-buffered MXFP4 GEMM, 8 waves, with stagger.""" + gemm, options = get_tagged_mxfp4_gemm(shape, block, num_waves=8) + schedule = get_mxfp4_triplebuf_schedule(use_stagger=True) + + options.print_ir_after = "all" if is_debug else [] + options = set_default_run_config(options) + gemm = wave_compile(options, gemm, schedule) + + print(gemm.asm) + + _run_mxfp_gemm(gemm, shape) + print("MXFP GEMM triple-buffer 8-wave test passed!") + + if __name__ == "__main__": args = parse_args() diff --git a/examples/python/7.2_mxfp4_gemm_preshuffle_scale.py b/examples/python/7.2_mxfp4_gemm_preshuffle_scale.py new file mode 100644 index 0000000000..abb0a4603b --- /dev/null +++ b/examples/python/7.2_mxfp4_gemm_preshuffle_scale.py @@ -0,0 +1,742 @@ +""" +MXFP4 GEMM Example: Unshuffled vs Shuffled Scales + +Tests two MXFP4 GEMM implementations: +1. Unshuffled: Scales in normal [M, K//32] or [N, K//32] layout +2. Shuffled: Scales pre-shuffled using e8m0_shuffle for hardware efficiency + +Both kernels are verified against a PyTorch reference implementation. +""" + +import torch +import argparse + +import wave_lang.kernel.wave as tkw +import wave_lang.kernel.lang as tkl +from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.utils.run_utils import set_default_run_config +from wave_lang.kernel.wave.constraints import ScaledMMAType + +SCALE_GROUP_SIZE = 32 # Hardware constant: 1 scale per 32 data elements + + +def e8m0_shuffle(scale): + """ + Shuffle the scale tensor for e8m0 format. + + This particular shuffle is taken from + https://github.com/ROCm/rocm-libraries/blob/4348901528fe100a84975b89c247eece553a2a2d/shared/mxdatagenerator/lib/include/mxDataGenerator/PreSwizzle.hpp#L403 + + The e8m0_shuffle operation transforms a matrix with shape (m, n) as follows: + 1. Pads to shape ((m+255)//256*256, (n+7)//8*8) + 2. Reshapes to (sm//32, 2, 16, sn//8, 2, 4) + 3. Permutes dimensions: (0, 3, 5, 2, 4, 1) + 4. Flattens back to (sm, sn) + + Args: + scale: A 2D tensor to be shuffled + + Returns: + Shuffled tensor with the same padded shape + """ + if scale is None: + return scale + if scale.dtype == torch.float32: + return scale + assert scale.ndim == 2, "scale must be a 2D tensor" + m, n = scale.shape + scale_padded = torch.zeros( + (m + 255) // 256 * 256, + (n + 7) // 8 * 8, + dtype=scale.dtype, + device=scale.device, + ) + + scale_padded[:m, :n] = scale + scale = scale_padded + sm, sn = scale.shape + scale = scale.view(sm // 32, 2, 16, sn // 8, 2, 4) + scale = scale.permute(0, 3, 5, 2, 4, 1).contiguous() + scale = scale.view(sm, sn) + return scale + + +def generate_mxfp4_inputs( + shape: tuple[int, int, int], device: torch.device = torch.device("cpu") +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate random MXFP4 inputs for scaled GEMM.""" + M, N, K = shape + torch.manual_seed(5) + + # Generate packed MXFP4 data (2 values per byte) + x_low = torch.randint(0, 16, (M, K // 2), dtype=torch.uint8, device=device) + x_high = torch.randint(0, 16, (M, K // 2), dtype=torch.uint8, device=device) + x = x_low | (x_high << 4) + + w_low = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device=device) + w_high = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device=device) + w = w_low | (w_high << 4) + + # Generate E8M0 scales (random values near 1.0) + x_scales = torch.randint( + 124, 128, (K // SCALE_GROUP_SIZE, M), dtype=torch.uint8, device=device + ) + w_scales = torch.randint( + 124, 128, (K // SCALE_GROUP_SIZE, N), dtype=torch.uint8, device=device + ) + + x_scales = x_scales.T.contiguous() # [M, K//32] + w_scales = w_scales.T.contiguous() # [N, K//32] + + return x, w, x_scales, w_scales + + +def mxfp4_to_f32(x: torch.Tensor) -> torch.Tensor: + """Convert packed MXFP4 (e2m1fn) values to float32.""" + x = x.repeat_interleave(2, dim=1) + x[:, ::2] = x[:, ::2] & 0xF # Low nibble + x[:, 1::2] = x[:, 1::2] >> 4 # High nibble + + mxfp4_lut = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ] + mxfp4_f32 = torch.tensor(mxfp4_lut, dtype=torch.float32, device=x.device) + return mxfp4_f32[x.long()] + + +def e8m0_to_f32(x: torch.Tensor) -> torch.Tensor: + """Convert E8M0 (exponent-only) scale values to float32.""" + x_f32 = 2 ** ((x.to(torch.float32) - 127)) + x_f32[x == 255] = float("nan") + return x_f32 + + +def reference_mxfp4_gemm( + x: torch.Tensor, w: torch.Tensor, x_scales: torch.Tensor, w_scales: torch.Tensor +) -> torch.Tensor: + """PyTorch reference implementation for scaled MXFP4 GEMM: C = (x * x_scales) @ (w * w_scales)^T""" + x_f32 = mxfp4_to_f32(x) + w_f32 = mxfp4_to_f32(w) + + x_scales_expanded = x_scales.repeat_interleave(SCALE_GROUP_SIZE, dim=1) + x_scales_f32 = e8m0_to_f32(x_scales_expanded) + + w_scales_expanded = w_scales.repeat_interleave(SCALE_GROUP_SIZE, dim=1) + w_scales_f32 = e8m0_to_f32(w_scales_expanded) + + x_scaled = x_f32 * x_scales_f32 + w_scaled = w_f32 * w_scales_f32 + + return torch.mm(x_scaled, w_scaled.T) + + +def get_vanilla_kernel(): + """Return the vanilla (unshuffled) MXFP4 GEMM kernel definition.""" + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.TilingConstraint(K, BLOCK_K), + tkw.WaveConstraint(M, BLOCK_M / 2), + tkw.WaveConstraint(N, BLOCK_N / 2), + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=ScaledMMAType.F32_16x16x128_F8F6F4, + ), + ] + + @tkw.wave(constraints) + def mxfp4_gemm_vanilla( + a: tkl.Memory[M, K / 2, ADDRESS_SPACE, tkl.i8], + a_scale: tkl.Memory[M, K / 32, ADDRESS_SPACE, tkl.i8], + b: tkl.Memory[N, K / 2, ADDRESS_SPACE, tkl.i8], + b_scale: tkl.Memory[N, K / 32, ADDRESS_SPACE, tkl.i8], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a) + a_reg = tkw.bitcast(a_reg, tkl.f4e2m1fn) + a_scale_reg = tkw.read(a_scale) + a_scale_reg = tkw.bitcast(a_scale_reg, tkl.f8e8m0fnu) + + b_reg = tkw.read(b) + b_reg = tkw.bitcast(b_reg, tkl.f4e2m1fn) + b_scale_reg = tkw.read(b_scale) + b_scale_reg = tkw.bitcast(b_scale_reg, tkl.f8e8m0fnu) + + acc = tkw.scaled_mma(a_reg, a_scale_reg, b_reg, b_scale_reg, acc) + return acc + + tkw.write(repeat, c) + + return mxfp4_gemm_vanilla + + +def get_preshuffle_kernel(): + """Return the pre-shuffled MXFP4 GEMM kernel definition with IndexMapping for shuffled scales.""" + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + K_SCALE_SHUFFLED = tkl.sym.K_SCALE_SHUFFLED + + constraints: list[tkw.Constraint] = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.TilingConstraint(K, BLOCK_K), + tkw.WaveConstraint(M, BLOCK_M / 2), + tkw.WaveConstraint(N, BLOCK_N / 2), + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=ScaledMMAType.F32_16x16x128_F8F6F4, + ), + ] + + # Create IndexMapping for shuffled A scales + # The e8m0_shuffle coordinate transformation maps logical (K, M) iterators + # to physical shuffled memory layout + i = tkw.IndexMapping.iterator(0) # K iterator + j = tkw.IndexMapping.iterator(1) # M iterator + + a_scale_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={ + M: ( + ( + (j // 32) * ((K_SCALE_SHUFFLED // 8) * 256) + + (i // 8) * 256 + + ((i % 8) % 4) * 64 + + ((j % 32) % 16) * 4 + + (((i % 8) // 4) * 2) + + ((j % 32) // 16) + ) + // K_SCALE_SHUFFLED + ), + K: ( + ( + (j // 32) * ((K_SCALE_SHUFFLED // 8) * 256) + + (i // 8) * 256 + + ((i % 8) % 4) * 64 + + ((j % 32) % 16) * 4 + + (((i % 8) // 4) * 2) + + ((j % 32) // 16) + ) + % K_SCALE_SHUFFLED + ), + }, + outputs={ + K: i, + M: j, + }, + ) + + # Create IndexMapping for shuffled B scales + k = tkw.IndexMapping.iterator(0) # K iterator + n = tkw.IndexMapping.iterator(1) # N iterator + + b_scale_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={ + N: ( + ( + (n // 32) * ((K_SCALE_SHUFFLED // 8) * 256) + + (k // 8) * 256 + + ((k % 8) % 4) * 64 + + ((n % 32) % 16) * 4 + + (((k % 8) // 4) * 2) + + ((n % 32) // 16) + ) + // K_SCALE_SHUFFLED + ), + K: ( + ( + (n // 32) * ((K_SCALE_SHUFFLED // 8) * 256) + + (k // 8) * 256 + + ((k % 8) % 4) * 64 + + ((n % 32) % 16) * 4 + + (((k % 8) // 4) * 2) + + ((n % 32) // 16) + ) + % K_SCALE_SHUFFLED + ), + }, + outputs={ + K: k, + N: n, + }, + ) + + # TODO: preshuffle merge doesn't work with shared address space yet. + @tkw.wave(constraints) + def mxfp4_gemm_preshuffle( + a: tkl.Memory[M, K / 2, ADDRESS_SPACE, tkl.i8], + a_scale: tkl.Memory[M, K / 32, GLOBAL_ADDRESS_SPACE, tkl.i8], + b: tkl.Memory[N, K / 2, ADDRESS_SPACE, tkl.i8], + b_scale: tkl.Memory[N, K / 32, GLOBAL_ADDRESS_SPACE, tkl.i8], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a) + a_reg = tkw.bitcast(a_reg, tkl.f4e2m1fn) + a_scale_reg = tkw.read(a_scale, mapping=a_scale_mapping) + a_scale_reg = tkw.bitcast(a_scale_reg, tkl.f8e8m0fnu) + + b_reg = tkw.read(b) + b_reg = tkw.bitcast(b_reg, tkl.f4e2m1fn) + b_scale_reg = tkw.read(b_scale, mapping=b_scale_mapping) + b_scale_reg = tkw.bitcast(b_scale_reg, tkl.f8e8m0fnu) + + acc = tkw.scaled_mma(a_reg, a_scale_reg, b_reg, b_scale_reg, acc) + return acc + + tkw.write(repeat, c) + + return mxfp4_gemm_preshuffle + + +def run_all_tests(): + """Run both vanilla and pre-shuffled tests and compare results.""" + m, n, k = 512, 512, 2048 + block_m, block_n, block_k = 128, 128, 256 + + print("=" * 70) + print("MXFP4 GEMM COMPREHENSIVE TEST SUITE") + print("=" * 70) + print(f"Problem size: M={m}, N={n}, K={k}") + print(f"Block sizes: BLOCK_M={block_m}, BLOCK_N={block_n}, BLOCK_K={block_k}") + + # Define symbolic dimensions + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + # Calculate shuffled dimensions for pre-shuffle kernel + k_scale_shuffled = (((k // 32) + 7) // 8) * 8 + K_SCALE_SHUFFLED = tkl.sym.K_SCALE_SHUFFLED + + # Get kernel definitions + print("\nGetting kernel definitions...") + vanilla_kernel = get_vanilla_kernel() + preshuffle_kernel = get_preshuffle_kernel() + + # Set up hyperparameters (shared by both kernels) + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + BLOCK_M: block_m, + BLOCK_N: block_n, + BLOCK_K: block_k, + M: m, + N: n, + K: k, + K_SCALE_SHUFFLED: k_scale_shuffled, + } + + # Compile options (shared by both kernels) + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + use_global_to_shared=True, + # minimize_shared_allocs=True, + ) + options = set_default_run_config(options) + + # Compile both kernels + compiled_vanilla = wave_compile(options, vanilla_kernel) + + compiled_preshuffle = wave_compile(options, preshuffle_kernel) + print(compiled_preshuffle.asm) + + # Generate test data + x, w, x_scales, w_scales = generate_mxfp4_inputs( + (m, n, k), device=torch.device("cpu") + ) + + # Compute PyTorch reference + torch_result = reference_mxfp4_gemm(x, w, x_scales, w_scales) + + # Shuffle scales for pre-shuffle kernel + x_scales_shuffled = e8m0_shuffle(x_scales) + w_scales_shuffled = e8m0_shuffle(w_scales) + + # Move data to GPU + x_gpu = x.cuda() + w_gpu = w.cuda() + x_scales_gpu = x_scales.cuda() + w_scales_gpu = w_scales.cuda() + x_scales_shuffled_gpu = x_scales_shuffled.cuda() + w_scales_shuffled_gpu = w_scales_shuffled.cuda() + + # Run vanilla kernel + print("\n" + "=" * 60) + print("TEST 1: Vanilla MXFP4 GEMM") + print("=" * 60) + print("Running vanilla Wave kernel...") + c_vanilla_gpu = torch.zeros(m, n, dtype=torch.float32, device="cuda") + compiled_vanilla(x_gpu, x_scales_gpu, w_gpu, w_scales_gpu, c_vanilla_gpu) + wave_vanilla_result = c_vanilla_gpu.cpu() + + print("Verifying vanilla results...") + try: + torch.testing.assert_close( + torch_result, wave_vanilla_result, rtol=1e-3, atol=1e-3, check_dtype=False + ) + print("✓ VANILLA TEST PASSED! Results match PyTorch reference.") + except AssertionError as e: + print("✗ VANILLA TEST FAILED!") + print(f"Error: {e}") + max_diff = torch.max(torch.abs(torch_result - wave_vanilla_result)) + mean_diff = torch.mean(torch.abs(torch_result - wave_vanilla_result)) + print(f"Max difference: {max_diff}") + print(f"Mean difference: {mean_diff}") + print(f"Reference output range: [{torch_result.min()}, {torch_result.max()}]") + print( + f"Wave output range: [{wave_vanilla_result.min()}, {wave_vanilla_result.max()}]" + ) + + # Run pre-shuffle kernel + print("\n" + "=" * 60) + print("TEST 2: Pre-Shuffled MXFP4 GEMM") + print("=" * 60) + print("Running pre-shuffle Wave kernel...") + c_preshuffle_gpu = torch.zeros(m, n, dtype=torch.float32, device="cuda") + compiled_preshuffle( + x_gpu, x_scales_shuffled_gpu, w_gpu, w_scales_shuffled_gpu, c_preshuffle_gpu + ) + wave_preshuffle_result = c_preshuffle_gpu.cpu() + + print("Verifying pre-shuffle results...") + try: + torch.testing.assert_close( + torch_result, + wave_preshuffle_result, + rtol=1e-3, + atol=1e-3, + check_dtype=False, + ) + print("✓ PRE-SHUFFLE TEST PASSED! Results match PyTorch reference.") + except AssertionError as e: + print("✗ PRE-SHUFFLE TEST FAILED!") + print(f"Error: {e}") + max_diff = torch.max(torch.abs(torch_result - wave_preshuffle_result)) + mean_diff = torch.mean(torch.abs(torch_result - wave_preshuffle_result)) + print(f"Max difference: {max_diff}") + print(f"Mean difference: {mean_diff}") + print(f"Reference output range: [{torch_result.min()}, {torch_result.max()}]") + print( + f"Wave output range: [{wave_preshuffle_result.min()}, {wave_preshuffle_result.max()}]" + ) + + # Final comparison: verify both Wave results are identical + print("\n" + "=" * 60) + print("FINAL COMPARISON: Vanilla vs Pre-Shuffled Wave Results") + print("=" * 60) + try: + torch.testing.assert_close( + wave_vanilla_result, + wave_preshuffle_result, + rtol=1e-6, + atol=1e-6, + check_dtype=False, + ) + print("✓ Both Wave kernels produce IDENTICAL results!") + except AssertionError as e: + print("✗ Wave kernels produce DIFFERENT results!") + print(f"Error: {e}") + max_diff = torch.max(torch.abs(wave_vanilla_result - wave_preshuffle_result)) + mean_diff = torch.mean(torch.abs(wave_vanilla_result - wave_preshuffle_result)) + print(f"Max difference: {max_diff}") + print(f"Mean difference: {mean_diff}") + + +def run_benchmark( + kernel_type="vanilla", + m=512, + n=512, + k=2048, + block_m=128, + block_n=128, + block_k=256, + warmup_iters=10, + bench_iters=100, +): + """ + Benchmark a single kernel configuration. + + Args: + kernel_type: Either "vanilla" or "preshuffle" + m, n, k: Problem dimensions + block_m, block_n, block_k: Block dimensions + warmup_iters: Number of warmup iterations + bench_iters: Number of benchmark iterations + + Returns: + Dictionary containing benchmark results + """ + print("=" * 70) + print(f"BENCHMARKING: {kernel_type.upper()} MXFP4 GEMM") + print("=" * 70) + print(f"Problem size: M={m}, N={n}, K={k}") + print(f"Block sizes: BLOCK_M={block_m}, BLOCK_N={block_n}, BLOCK_K={block_k}") + print(f"Warmup iterations: {warmup_iters}") + print(f"Benchmark iterations: {bench_iters}") + + # Define symbolic dimensions + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + # Get the appropriate kernel + if kernel_type == "vanilla": + kernel = get_vanilla_kernel() + elif kernel_type == "preshuffle": + kernel = get_preshuffle_kernel() + else: + raise ValueError(f"Unknown kernel_type: {kernel_type}") + + # Set up hyperparameters + k_scale_shuffled = (((k // 32) + 7) // 8) * 8 + K_SCALE_SHUFFLED = tkl.sym.K_SCALE_SHUFFLED + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + BLOCK_M: block_m, + BLOCK_N: block_n, + BLOCK_K: block_k, + M: m, + N: n, + K: k, + K_SCALE_SHUFFLED: k_scale_shuffled, + } + + # Compile kernel + print("\nCompiling kernel...") + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + use_global_to_shared=True, + ) + options = set_default_run_config(options) + compiled_kernel = wave_compile(options, kernel) + + # Generate test data + print("Generating test data...") + x, w, x_scales, w_scales = generate_mxfp4_inputs( + (m, n, k), device=torch.device("cpu") + ) + + # Shuffle scales if needed + if kernel_type == "preshuffle": + x_scales = e8m0_shuffle(x_scales) + w_scales = e8m0_shuffle(w_scales) + + # Move data to GPU + x_gpu = x.cuda() + w_gpu = w.cuda() + x_scales_gpu = x_scales.cuda() + w_scales_gpu = w_scales.cuda() + c_gpu = torch.zeros(m, n, dtype=torch.float32, device="cuda") + + # Warmup + print(f"\nWarming up ({warmup_iters} iterations)...") + for _ in range(warmup_iters): + compiled_kernel(x_gpu, x_scales_gpu, w_gpu, w_scales_gpu, c_gpu) + torch.cuda.synchronize() + + # Benchmark + print(f"Benchmarking ({bench_iters} iterations)...") + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(bench_iters): + compiled_kernel(x_gpu, x_scales_gpu, w_gpu, w_scales_gpu, c_gpu) + end_event.record() + torch.cuda.synchronize() + + # Calculate metrics + elapsed_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_ms / bench_iters + flops = 2 * m * n * k # GEMM FLOPs + tflops = (flops / (avg_time_ms * 1e-3)) / 1e12 + + results = { + "kernel_type": kernel_type, + "m": m, + "n": n, + "k": k, + "block_m": block_m, + "block_n": block_n, + "block_k": block_k, + "avg_time_ms": avg_time_ms, + "total_time_ms": elapsed_ms, + "tflops": tflops, + "warmup_iters": warmup_iters, + "bench_iters": bench_iters, + } + + # Print results + print("\n" + "=" * 70) + print("BENCHMARK RESULTS") + print("=" * 70) + print(f"Average time per iteration: {avg_time_ms:.4f} ms") + print(f"Total benchmark time: {elapsed_ms:.2f} ms") + print(f"Throughput: {tflops:.2f} TFLOPS") + print("=" * 70) + + return results + + +def run_all_benchmarks(): + """Run benchmarks for both vanilla and preshuffle kernels and compare.""" + # Default configuration + m, n, k = 512, 512, 2048 + block_m, block_n, block_k = 128, 128, 256 + warmup_iters = 10 + bench_iters = 100 + + vanilla_results = run_benchmark( + kernel_type="vanilla", + m=m, + n=n, + k=k, + block_m=block_m, + block_n=block_n, + block_k=block_k, + warmup_iters=warmup_iters, + bench_iters=bench_iters, + ) + + print("\n\n") + + preshuffle_results = run_benchmark( + kernel_type="preshuffle", + m=m, + n=n, + k=k, + block_m=block_m, + block_n=block_n, + block_k=block_k, + warmup_iters=warmup_iters, + bench_iters=bench_iters, + ) + + # Comparison + print("\n\n") + print("=" * 70) + print("COMPARISON: Vanilla vs Pre-Shuffled") + print("=" * 70) + print( + f"Vanilla avg time: {vanilla_results['avg_time_ms']:.4f} ms ({vanilla_results['tflops']:.2f} TFLOPS)" + ) + print( + f"Pre-shuffle avg time: {preshuffle_results['avg_time_ms']:.4f} ms ({preshuffle_results['tflops']:.2f} TFLOPS)" + ) + + speedup = vanilla_results["avg_time_ms"] / preshuffle_results["avg_time_ms"] + if speedup > 1.0: + print( + f"\n✓ Pre-shuffle is {speedup:.2f}x FASTER than vanilla ({(speedup-1)*100:.1f}% improvement)" + ) + elif speedup < 1.0: + print( + f"\n✗ Pre-shuffle is {1/speedup:.2f}x SLOWER than vanilla ({(1-speedup)*100:.1f}% regression)" + ) + else: + print("\n= Performance is identical") + print("=" * 70) + + return vanilla_results, preshuffle_results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="MXFP4 GEMM: Testing and Benchmarking") + parser.add_argument( + "--mode", + type=str, + default="test", + choices=["test", "bench", "vanilla", "preshuffle"], + help="Mode: test (run correctness tests), bench (benchmark both kernels), vanilla (benchmark vanilla only), preshuffle (benchmark preshuffle only)", + ) + parser.add_argument("--m", type=int, default=512, help="M dimension") + parser.add_argument("--n", type=int, default=512, help="N dimension") + parser.add_argument("--k", type=int, default=2048, help="K dimension") + parser.add_argument("--block_m", type=int, default=128, help="Block M dimension") + parser.add_argument("--block_n", type=int, default=128, help="Block N dimension") + parser.add_argument("--block_k", type=int, default=256, help="Block K dimension") + parser.add_argument( + "--warmup", type=int, default=10, help="Number of warmup iterations" + ) + parser.add_argument( + "--iters", type=int, default=100, help="Number of benchmark iterations" + ) + + args = parser.parse_args() + + if args.mode == "test": + run_all_tests() + elif args.mode == "bench": + run_all_benchmarks() + elif args.mode == "vanilla": + run_benchmark( + kernel_type="vanilla", + m=args.m, + n=args.n, + k=args.k, + block_m=args.block_m, + block_n=args.block_n, + block_k=args.block_k, + warmup_iters=args.warmup, + bench_iters=args.iters, + ) + elif args.mode == "preshuffle": + run_benchmark( + kernel_type="preshuffle", + m=args.m, + n=args.n, + k=args.k, + block_m=args.block_m, + block_n=args.block_n, + block_k=args.block_k, + warmup_iters=args.warmup, + bench_iters=args.iters, + ) diff --git a/lit_tests/kernel/wave/attention/pipelined_attention.py b/lit_tests/kernel/wave/attention/pipelined_attention.py index aeb1a920ca..b342a8378a 100644 --- a/lit_tests/kernel/wave/attention/pipelined_attention.py +++ b/lit_tests/kernel/wave/attention/pipelined_attention.py @@ -372,7 +372,7 @@ def test_bshd_attention_pipelined_prefetch(): base_attention = wave_compile(options, base_attention) print(base_attention.asm) - # CHECK: func.func @base_attention + # CHECK-LABEL: func.func @base_attention # CHECK: {{.*}} = scf.for # CHECK-COUNT-16: vector.load # CHECK: arith.subf @@ -380,12 +380,10 @@ def test_bshd_attention_pipelined_prefetch(): # CHECK: math.exp2 # CHECK: arith.mulf # CHECK: arith.addf - # CHECK-COUNT-16: vector.extract - # CHECK-COUNT-16: arith.addf + # CHECK: vector.extract + # CHECK: arith.addf # CHECK: vector.broadcast # CHECK: gpu.shuffle - # CHECK: arith.addf - # CHECK: arith.addf # CHECK: arith.truncf # CHECK: arith.truncf # CHECK: vector.extract @@ -393,15 +391,14 @@ def test_bshd_attention_pipelined_prefetch(): # CHECK: arith.mulf # CHECK: arith.mulf # CHECK-COUNT-8: vector.extract_strided_slice - # CHECK-COUNT-32: amdgpu.mfma - # CHECK-COUNT-8: vector.load - # CHECK-COUNT-8: vector.extract + # CHECK: amdgpu.mfma + # CHECK: vector.load + # CHECK: vector.extract # CHECK: vector.from_elements # CHECK: vector.from_elements # CHECK: amdgpu.lds_barrier - # CHECK-COUNT-32: vector.load - # CHECK-COUNT-4: vector.load - # CHECK-COUNT-8: amdgpu.mfma + # CHECK: vector.load + # CHECK: amdgpu.mfma @run_test @@ -434,7 +431,7 @@ def test_bshd_attention_pipelined_prefetch_pingpong(): base_attention = wave_compile(options, base_attention) print(base_attention.asm) - # CHECK: func.func @base_attention + # CHECK-LABEL: func.func @base_attention # CHECK: scf.if # CHECK-NEXT: rocdl.s.barrier diff --git a/lit_tests/kernel/wave/merge_scale_reads.py b/lit_tests/kernel/wave/merge_scale_reads.py new file mode 100644 index 0000000000..d369317a56 --- /dev/null +++ b/lit_tests/kernel/wave/merge_scale_reads.py @@ -0,0 +1,195 @@ +# RUN: python %s | FileCheck %s + +""" +Test merge_contiguous_reads pass on pre-shuffled (e8m0_shuffle) scale +reads for MXFP4 GEMM. + +The e8m0_shuffle index mapping rearranges scale data so that each thread's +scale elements land in contiguous groups in physical memory. The merge pass +should combine the expanded scalar reads into wider vector loads: + + BLOCK_K=128 -> 4 scale elements -> 2 groups of 2 -> vector<2xi8> + BLOCK_K=256 -> 8 scale elements -> 2 groups of 4 -> vector<4xi8> + +The shuffle layout requires K/32 >= 64 (i.e. K >= 2048) for the groups to +land contiguously in the row-major [M, K/32] scale tensor. + +Also verifies that the opsel_scaled_mfma pass enables byte selection in +amdgpu.scaled_mfma, replacing scalar scale operands with vector operands +and scalesIdxA/scalesIdxB attributes for efficient hardware extraction. +""" + +import wave_lang.kernel.lang as tkl +import wave_lang.kernel.wave as tkw +from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.constraints import ScaledMMAType +from wave_lang.kernel.wave.utils.general_utils import ( + get_default_scheduling_params, + run_test, +) + +# Symbols shared by all tests. +M = tkl.sym.M +N = tkl.sym.N +K = tkl.sym.K +BLOCK_M = tkl.sym.BLOCK_M +BLOCK_N = tkl.sym.BLOCK_N +BLOCK_K = tkl.sym.BLOCK_K +ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE +K_SCALE_SHUFFLED = tkl.sym.K_SCALE_SHUFFLED + + +def get_preshuffle_kernel(): + """Return the pre-shuffled MXFP4 GEMM kernel with e8m0_shuffle mappings.""" + constraints: list[tkw.Constraint] = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.TilingConstraint(K, BLOCK_K), + tkw.WaveConstraint(M, BLOCK_M / 2), + tkw.WaveConstraint(N, BLOCK_N / 2), + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=ScaledMMAType.F32_16x16x128_F8F6F4, + ), + ] + + # e8m0_shuffle index mapping: logical (iter0, iter1) -> physical (row, col). + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + + shuffle_expr = ( + (j // 32) * ((K_SCALE_SHUFFLED // 8) * 256) + + (i // 8) * 256 + + ((i % 8) % 4) * 64 + + ((j % 32) % 16) * 4 + + (((i % 8) // 4) * 2) + + ((j % 32) // 16) + ) + + a_scale_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={ + M: shuffle_expr // K_SCALE_SHUFFLED, + K: shuffle_expr % K_SCALE_SHUFFLED, + }, + outputs={K: i, M: j}, + ) + + k = tkw.IndexMapping.iterator(0) + n = tkw.IndexMapping.iterator(1) + + shuffle_expr_b = ( + (n // 32) * ((K_SCALE_SHUFFLED // 8) * 256) + + (k // 8) * 256 + + ((k % 8) % 4) * 64 + + ((n % 32) % 16) * 4 + + (((k % 8) // 4) * 2) + + ((n % 32) // 16) + ) + + b_scale_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={ + N: shuffle_expr_b // K_SCALE_SHUFFLED, + K: shuffle_expr_b % K_SCALE_SHUFFLED, + }, + outputs={K: k, N: n}, + ) + + @tkw.wave(constraints) + def preshuffle_scaled_mma( + a: tkl.Memory[M, K / 2, ADDRESS_SPACE, tkl.i8], + a_scale: tkl.Memory[M, K / 32, GLOBAL_ADDRESS_SPACE, tkl.i8], + b: tkl.Memory[N, K / 2, ADDRESS_SPACE, tkl.i8], + b_scale: tkl.Memory[N, K / 32, GLOBAL_ADDRESS_SPACE, tkl.i8], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a) + a_reg = tkw.bitcast(a_reg, tkl.f4e2m1fn) + a_scale_reg = tkw.read(a_scale, mapping=a_scale_mapping) + a_scale_reg = tkw.bitcast(a_scale_reg, tkl.f8e8m0fnu) + b_reg = tkw.read(b) + b_reg = tkw.bitcast(b_reg, tkl.f4e2m1fn) + b_scale_reg = tkw.read(b_scale, mapping=b_scale_mapping) + b_scale_reg = tkw.bitcast(b_scale_reg, tkl.f8e8m0fnu) + acc = tkw.scaled_mma(a_reg, a_scale_reg, b_reg, b_scale_reg, acc) + return acc + + tkw.write(repeat, c) + + return preshuffle_scaled_mma + + +def compile_and_print(m, n, k, block_k): + """Compile the preshuffle kernel with given dimensions and print MLIR.""" + k_scale_shuffled = (((k // 32) + 7) // 8) * 8 + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + BLOCK_M: 128, + BLOCK_N: 128, + BLOCK_K: block_k, + M: m, + N: n, + K: k, + K_SCALE_SHUFFLED: k_scale_shuffled, + } + hyperparams.update(get_default_scheduling_params()) + + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + device="hip", + target="gfx950", + compile_to_mlir=True, + use_global_to_shared=True, + ) + kernel = get_preshuffle_kernel() + result = wave_compile(options, kernel) + print(result.asm) + + +@run_test +def test_preshuffle_scale_merge_block_k_128(): + # BLOCK_K=128: 4 scale elements per thread -> 2 groups of 2 -> vector<2xi8>. + compile_and_print(m=512, n=512, k=2048, block_k=128) + + # CHECK-LABEL: test_preshuffle_scale_merge_block_k_128 + + # Each scale tensor produces 2 merged vector<2xi8> loads from global. + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<2xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<2xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<2xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<2xi8> + + # No unmerged scalar scale loads from global should remain. + # CHECK-NOT: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<1xi8> + + +@run_test +def test_preshuffle_scale_merge_block_k_256(): + # BLOCK_K=256: 8 scale elements per thread -> 2 groups of 4 -> vector<4xi8>. + compile_and_print(m=512, n=512, k=2048, block_k=256) + + # CHECK-LABEL: test_preshuffle_scale_merge_block_k_256 + + # Each scale tensor produces 2 merged vector<4xi8> loads from global. + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8> + + # No unmerged scalar scale loads from global should remain. + # CHECK-NOT: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<1xi8> + + # Check that amdgpu.scaled_mfma uses opsel (indexed access into scale values) + # The key indicator is the [N] indexing syntax on f8E8M0FNU scale operands + # CHECK: amdgpu.scaled_mfma {{.*}} (%{{.*}}[{{[0-9]+}}] * %{{.*}}) * (%{{.*}}[{{[0-9]+}}] * %{{.*}}) + %{{.*}} : vector<4xf8E8M0FNU>, vector<{{.*}}xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<{{.*}}xf4E2M1FN>, vector<4xf32> + + # Verify that we're not using scalar scale extracts (the old pattern) + # If opsel is working, we should NOT see vector.extract before scaled_mfma + # CHECK-NOT: vector.extract %{{.*}}[0] : f8E8M0FNU diff --git a/lit_tests/kernel/wave/mma.py b/lit_tests/kernel/wave/mma.py index 8a1c0f3ee1..16804427fd 100644 --- a/lit_tests/kernel/wave/mma.py +++ b/lit_tests/kernel/wave/mma.py @@ -771,10 +771,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: %[[ITER_COND:.*]] = affine.apply #[[MAP1]]()[%[[ITER_ARG]]] # CHECK: %[[COND1:.*]] = arith.cmpi eq, %[[ITER_COND]], %[[C0]] : index - # CHECK: scf.if %[[COND1]] { - # CHECK-NEXT: scf.if %[[COND0]] { + # CHECK: %[[COND_ANDED:.*]] = arith.andi %[[COND1]], %[[COND0]] : i1 + # CHECK: scf.if %[[COND_ANDED]] { # CHECK-NEXT: rocdl.s.barrier.signal id = -3 - # CHECK-NEXT: } # CHECK-NEXT: } # CHECK-COUNT-4: rocdl.wmma diff --git a/lit_tests/kernel/wave/speculative_decoding.py b/lit_tests/kernel/wave/speculative_decoding.py index f7b14e96f5..096ba9e73b 100644 --- a/lit_tests/kernel/wave/speculative_decoding.py +++ b/lit_tests/kernel/wave/speculative_decoding.py @@ -56,14 +56,10 @@ def test_speculative_decoding(): # CHECK: arith.divf # CHECK: arith.cmpf # CHECK: arith.ori - # CHECK: arith.xori # CHECK: vector.extract # CHECK: scf.if - # CHECK: scf.if # CHECK: vector.load # CHECK: vector.store - # CHECK: scf.yield - # CHECK: scf.yield # --- Reduction and arithmetic patterns: # CHECK: gpu.shuffle up diff --git a/wave_lang/kernel/wave/asm/handlers_memory.py b/wave_lang/kernel/wave/asm/handlers_memory.py index bd4968b832..3591007d61 100644 --- a/wave_lang/kernel/wave/asm/handlers_memory.py +++ b/wave_lang/kernel/wave/asm/handlers_memory.py @@ -320,11 +320,15 @@ def handle_mfma_op(self, operation: amdgpu_d.MFMAOp, kernel_info: KernelInfo): def handle_scaled_mfma_op( self, operation: amdgpu_d.ScaledMFMAOp, kernel_info: KernelInfo ): - """Handle amdgpu.scaled_mfma operations - emit scaled MFMA instruction for MXFP4/FP6/FP8.""" + """Handle amdgpu.scaled_mfma operations - emit scaled MFMA instruction for MXFP4/FP6/FP8. - # Scaled MFMA format: %result = amdgpu.scaled_mfma M x N x K - # (%scaleA * %dataA) * (%scaleB * %dataB) + %acc - # where scaleA and scaleB are scalar f8E8M0FNU values (often from vector.extract) + Scaled MFMA format: %result = amdgpu.scaled_mfma M x N x K + (%scaleA * %dataA) * (%scaleB * %dataB) + %acc + + Supports both scalar f8E8M0FNU and vector<4xf8E8M0FNU> scale types. + When the scale is a vector<4xf8E8M0FNU>, the scalesIdx attribute + selects which byte within the 32-bit VGPR to use (opsel). + """ from .kernel_mfma import _MFMASupport ctx = self.walker.kernel_ctx @@ -342,13 +346,21 @@ def handle_scaled_mfma_op( cbsz = _MFMASupport._get_scaled_mfma_format_code(a_type_str) blgp = _MFMASupport._get_scaled_mfma_format_code(b_type_str) + # Extract opsel (byte index within scale VGPR) from attributes + scales_idx_a = int(operation.attributes["scalesIdxA"]) + scales_idx_b = int(operation.attributes["scalesIdxB"]) + # Get operands based on actual MLIR structure # Operand order: sourceA, sourceB, destC, scaleA, scaleB data_a_ssa = str(operation.operands[0]) # sourceA: vector<32xf4E2M1FN> data_b_ssa = str(operation.operands[1]) # sourceB: vector<32xf4E2M1FN> acc_ssa = str(operation.operands[2]) # destC: vector<4xf32> - scale_a_ssa = str(operation.operands[3]) # scaleA: f8E8M0FNU (scalar) - scale_b_ssa = str(operation.operands[4]) # scaleB: f8E8M0FNU (scalar) + scale_a_ssa = str( + operation.operands[3] + ) # scaleA: f8E8M0FNU or vector<4xf8E8M0FNU> + scale_b_ssa = str( + operation.operands[4] + ) # scaleB: f8E8M0FNU or vector<4xf8E8M0FNU> # Get registers from kernel context scale_a_reg = ctx.ssa_to_reg.get(scale_a_ssa) @@ -363,7 +375,9 @@ def handle_scaled_mfma_op( # For MXFP4: 32 elements of FP4 = 16 bytes = 4 VGPRs (4 bytes/VGPR) # vector<32xf4E2M1FN> bitcast from vector<16xi8> -> 4 VGPRs - # Scale registers should be single VGPRs (extracted from vector<1xf8E8M0FNU>) + # Scale register: either a single VGPR (scalar f8E8M0FNU) + # or a single VGPR containing 4 packed bytes (vector<4xf8E8M0FNU>). + # In both cases it maps to a single VGPR register. if isinstance(scale_a_reg, (list, tuple)): scale_a_vreg = scale_a_reg[0] if len(scale_a_reg) > 0 else None else: @@ -390,6 +404,8 @@ def handle_scaled_mfma_op( acc_regs if acc_regs and len(acc_regs) == 4 else None, cbsz=cbsz, blgp=blgp, + scales_idx_a=scales_idx_a, + scales_idx_b=scales_idx_b, ) # Track result in SSA mapping diff --git a/wave_lang/kernel/wave/asm/kernel_mfma.py b/wave_lang/kernel/wave/asm/kernel_mfma.py index 3f5605ff86..c76761d64b 100644 --- a/wave_lang/kernel/wave/asm/kernel_mfma.py +++ b/wave_lang/kernel/wave/asm/kernel_mfma.py @@ -182,6 +182,8 @@ def emit_mfma_f32_16x16x128_f8f6f4( acc_regs: Optional[Tuple[KReg, ...]] = None, cbsz: int = 4, blgp: int = 4, + scales_idx_a: int = 0, + scales_idx_b: int = 0, ) -> Tuple[KReg, ...]: """ Emit scaled MFMA instruction for MXFP4 (16x16x128 F8F6F4). @@ -194,19 +196,24 @@ def emit_mfma_f32_16x16x128_f8f6f4( Args: a_regs: Tuple of 4 VGPRs for A operand (32 x f4E2M1FN packed as i8) b_regs: Tuple of 4 VGPRs for B operand (32 x f4E2M1FN packed as i8) - a_scale_reg: Single VGPR for A scale factor (f8E8M0FNU) - b_scale_reg: Single VGPR for B scale factor (f8E8M0FNU) + a_scale_reg: Single VGPR for A scale factor (f8E8M0FNU or + 4 packed bytes with opsel byte selection) + b_scale_reg: Single VGPR for B scale factor (same as above) acc_regs: Optional tuple of 4 VGPRs for accumulator (f32x4) If None, allocates new result registers cbsz: Format code for A source data (0=FP8, 1=BF8, 2=FP6_E2M3, 3=FP6_E3M2, 4=FP4). Default 4 (FP4). blgp: Format code for B source data. Same encoding as cbsz. Default 4 (FP4). + scales_idx_a: Byte index (0-3) within the A scale VGPR. Default 0. + scales_idx_b: Byte index (0-3) within the B scale VGPR. Default 0. Returns: Tuple of 4 VGPRs containing the result """ modifiers = f"cbsz:{cbsz} blgp:{blgp}" + if scales_idx_a != 0 or scales_idx_b != 0: + modifiers += f" op_sel_hi:[0,0,0,{scales_idx_a},{scales_idx_b}]" # Build operand ranges - For FP4: 32 elements = 16 bytes = 4 VGPRs a_range = KRegRange(a_regs[0], len(a_regs), alignment=4) diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 7d97bc2721..c098a9e1b5 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -78,6 +78,7 @@ from .type_inference import infer_types from .wave_schedule import WaveSchedule from .workgroup_reordering import reorder_workgroups +from .opsel_scaled_mfma import apply_opsel_scaled_mfma # Utilities. from .utils.compile_utils import canonicalize_module, apply_transform, compile_to_vmfb @@ -652,6 +653,12 @@ def compile_launchable_to_mlir( if options.canonicalize: canonicalize_module(mb.module_op) + # Replace scalar extract+bitcast scale chains on scaled_mfma ops + # with vector-level bitcast and opsel byte selection. + apply_opsel_scaled_mfma(mb.module_op) + if options.canonicalize: + canonicalize_module(mb.module_op) + return mb, trace, exe, kernel_sig, entrypoint_name diff --git a/wave_lang/kernel/wave/opsel_scaled_mfma.py b/wave_lang/kernel/wave/opsel_scaled_mfma.py new file mode 100644 index 0000000000..0d39afe8db --- /dev/null +++ b/wave_lang/kernel/wave/opsel_scaled_mfma.py @@ -0,0 +1,214 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +MLIR pass: replace extract+bitcast chains on scale operands of +amdgpu.scaled_mfma with vector-level bitcast and opsel. + +Before (per scale operand): + %vec4 = vector.load ... : vector<4xi8> + %slice = vector.extract_strided_slice %vec4 + {offsets=[N], sizes=[1], strides=[1]} + : vector<4xi8> to vector<1xi8> + %bc1 = vector.bitcast %slice : vector<1xi8> to vector<1xf8E8M0FNU> + %scalar = vector.extract %bc1[0] : f8E8M0FNU + amdgpu.scaled_mfma ... (%scalar[0] * ...) ... + +After: + %vec4 = vector.load ... : vector<4xi8> + %bc4 = vector.bitcast %vec4 : vector<4xi8> to vector<4xf8E8M0FNU> + amdgpu.scaled_mfma ... (%bc4[N] * ...) ... + +The extract_strided_slice, per-element bitcast and vector.extract are +dead-code eliminated by a subsequent canonicalization pass. +""" + +from iree.compiler.ir import ( + Float8E8M0FNUType, + InsertionPoint, + IntegerAttr, + IntegerType, + Location, + Module, + VectorType, +) +from iree.compiler.dialects import ( + amdgpu as amdgpu_d, + vector as vector_d, +) + +from wave_lang.support.logging import get_logger + +logger = get_logger("wave.opsel_scaled_mfma") + + +def _trace_scale_chain(scale_value): + """Trace a scaled_mfma scale operand back through the extract+bitcast chain. + + Returns (source_vector_4xi8, byte_offset) if the pattern matches, + or None if it doesn't. + + Expected chain (walking backwards from the scale operand): + scale_value : f8E8M0FNU (scalar) + <- vector.extract [0] : vector<1xf8E8M0FNU> -> f8E8M0FNU + <- vector.bitcast : vector<1xi8> -> vector<1xf8E8M0FNU> + <- vector.extract_strided_slice {offsets=[N], sizes=[1], strides=[1]} + : vector<4xi8> -> vector<1xi8> + <- source : vector<4xi8> (typically a vector.load) + """ + extract_op = scale_value.owner + if not hasattr(extract_op, "name") or extract_op.name != "vector.extract": + return None + + extract_source = extract_op.operands[0] + extract_source_type = extract_source.type + if not isinstance(extract_source_type, VectorType): + return None + if extract_source_type.rank != 1 or extract_source_type.shape[0] != 1: + return None + + bitcast_op = extract_source.owner + if not hasattr(bitcast_op, "name") or bitcast_op.name != "vector.bitcast": + return None + + bitcast_source = bitcast_op.operands[0] + bitcast_source_type = bitcast_source.type + if not isinstance(bitcast_source_type, VectorType): + return None + if bitcast_source_type.rank != 1 or bitcast_source_type.shape[0] != 1: + return None + + slice_op = bitcast_source.owner + if not hasattr(slice_op, "name") or slice_op.name != "vector.extract_strided_slice": + return None + + offsets = slice_op.attributes["offsets"] + offset = IntegerAttr(offsets[0]).value + + slice_source = slice_op.operands[0] + slice_source_type = slice_source.type + if not isinstance(slice_source_type, VectorType): + return None + # Only apply opsel optimization to vector<4xi8> sources. + # The amdgpu.scaled_mfma operation requires vector<4xf8E8M0FNU> scale operands. + if slice_source_type.rank != 1 or slice_source_type.shape[0] != 4: + return None + + return (slice_source, offset) + + +def _walk_operations(op): + """Recursively yield all operations nested inside op (post-order).""" + for region in op.regions: + for block in region: + for child_op in block: + yield from _walk_operations(child_op) + yield op + + +def apply_opsel_scaled_mfma(module: Module): + """Walk the MLIR module and apply the opsel optimization to scaled_mfma ops. + + For each scaled_mfma, if a scale operand traces back through: + vector.extract[0] <- vector.bitcast(1xi8->1xf8E8M0FNU) + <- vector.extract_strided_slice(Nxi8->1xi8, offset=K) + then replace the scale with a vector.bitcast(Nxi8->Nxf8E8M0FNU) + of the source and set scales_idx to K. + """ + mlir_ctx = module.operation.context + + with mlir_ctx, Location.unknown(): + f8e8m0 = Float8E8M0FNUType.get() + + scaled_mfma_ops = [] + for op in _walk_operations(module.operation): + if hasattr(op, "name") and op.name == "amdgpu.scaled_mfma": + scaled_mfma_ops.append(op.opview) + + if not scaled_mfma_ops: + return + + logger.debug(f"Found {len(scaled_mfma_ops)} scaled_mfma ops") + + replacements = [] + + for mfma_op in scaled_mfma_ops: + idx_a = int(mfma_op.scalesIdxA) + idx_b = int(mfma_op.scalesIdxB) + + new_scale_a = None + new_idx_a = idx_a + new_scale_b = None + new_idx_b = idx_b + + chain_a = _trace_scale_chain(mfma_op.scalesA) + if chain_a is not None: + new_scale_a, new_idx_a = chain_a + + chain_b = _trace_scale_chain(mfma_op.scalesB) + if chain_b is not None: + new_scale_b, new_idx_b = chain_b + + if new_scale_a is not None or new_scale_b is not None: + replacements.append( + (mfma_op, new_scale_a, new_idx_a, new_scale_b, new_idx_b) + ) + + if not replacements: + logger.debug("No opsel optimization opportunities found") + return + + logger.debug(f"Applying opsel optimization to {len(replacements)} ops") + + i32 = IntegerType.get_signless(32) + + # Cache: defining Operation -> bitcast result Value. + # Using the Operation object identity ensures one bitcast per source load. + source_op_to_bitcast = {} + + def get_wide_bitcast(source_vec): + """Get or create a wide bitcast vector -> vector.""" + defining_op = source_vec.owner + if defining_op in source_op_to_bitcast: + return source_op_to_bitcast[defining_op] + + source_type = source_vec.type + n = source_type.shape[0] + result_type = VectorType.get([n], f8e8m0) + + with InsertionPoint(defining_op): + bc = vector_d.bitcast(result_type, source_vec) + bc.owner.move_after(defining_op) + + source_op_to_bitcast[defining_op] = bc + return bc + + for mfma_op, new_scale_a, new_idx_a, new_scale_b, new_idx_b in replacements: + actual_scale_a = mfma_op.scalesA + actual_scale_b = mfma_op.scalesB + + if new_scale_a is not None: + actual_scale_a = get_wide_bitcast(new_scale_a) + if new_scale_b is not None: + actual_scale_b = get_wide_bitcast(new_scale_b) + + with InsertionPoint(mfma_op): + new_mfma = amdgpu_d.scaled_mfma( + m=mfma_op.attributes["m"], + n=mfma_op.attributes["n"], + k=mfma_op.attributes["k"], + source_a=mfma_op.sourceA, + source_b=mfma_op.sourceB, + dest_c=mfma_op.destC, + scales_a=actual_scale_a, + scales_b=actual_scale_b, + scales_idx_a=IntegerAttr.get(i32, new_idx_a), + scales_idx_b=IntegerAttr.get(i32, new_idx_b), + ) + mfma_op.result.replace_all_uses_with(new_mfma) + mfma_op.operation.erase() + + logger.debug("opsel optimization applied successfully") diff --git a/wave_lang/kernel/wave/schedules/__init__.py b/wave_lang/kernel/wave/schedules/__init__.py index 62d4cc542c..e607d25a50 100644 --- a/wave_lang/kernel/wave/schedules/__init__.py +++ b/wave_lang/kernel/wave/schedules/__init__.py @@ -12,7 +12,11 @@ get_two_pp_cluster_schedule, get_async_two_pp_clusters, ) -from .gemm_mxfp4_double_buffer import get_mxfp4_dbuf_schedule +from .gemm_mxfp4_double_buffer import ( + get_mxfp4_dbuf_schedule, + get_mxfp4_dbuf_schedule_shuffle, +) +from .gemm_mxfp4_triple_buffer import get_mxfp4_triplebuf_schedule from .attention_prefetch import get_attention_prefetch_schedule __all__ = [ @@ -21,5 +25,7 @@ "get_two_pp_cluster_schedule", "get_async_two_pp_clusters", "get_mxfp4_dbuf_schedule", + "get_mxfp4_triplebuf_schedule", + "get_mxfp4_dbuf_schedule_shuffle", "get_attention_prefetch_schedule", ] diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index bc8459842f..434f8ed1ca 100644 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -177,6 +177,13 @@ def mxfp4_dbuf_schedule(): # Build cluster 0: first K-partition loads + bitcasts + GatherToLDS cluster_0_ops = [ + tkw.SchedulingBarrier([]), + # tkw.MemoryCounterWait(load=0), + tkw.WorkgroupBarrier(), + tkw.WorkgroupBarrier(), + loop_global_to_shared, + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWait(load=independent_global_count), loop_shared_load_a_0, loop_shared_load_a_scale_0, loop_shared_load_b_0, @@ -186,8 +193,6 @@ def mxfp4_dbuf_schedule(): loop_bitcast_b_0, loop_bitcast_b_scale_0, tkw.SchedulingBarrier([]), - loop_global_to_shared, - tkw.SchedulingBarrier([]), ] if use_stagger: cluster_0_ops.extend( @@ -207,13 +212,14 @@ def mxfp4_dbuf_schedule(): loop_scaled_mma_0, tkw.SetWavePrio(0), tkw.SchedulingBarrier([]), - tkw.MemoryCounterWaitBarrier(load=independent_global_count), + tkw.WorkgroupBarrier(), tkw.SchedulingBarrier([]), ], ), # Cluster 2: Second K-partition shared loads/bitcasts tkw.cluster( [ + tkw.SchedulingBarrier([]), loop_shared_load_a_1, loop_shared_load_a_scale_1, loop_shared_load_b_1, @@ -223,7 +229,235 @@ def mxfp4_dbuf_schedule(): loop_bitcast_b_1, loop_bitcast_b_scale_1, tkw.SchedulingBarrier([]), - tkw.MemoryCounterWaitBarrier(load=0), + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ], + ), + # Cluster 3: Second K-partition scaled_mma (high priority) + tkw.cluster( + [ + tkw.SetWavePrio(1), + loop_scaled_mma_1, + tkw.SetWavePrio(0), + tkw.SchedulingBarrier([]), + ], + ), + ] + + # Insert barriers at loop boundaries + tkw.insert_before(pipeline_loop.KERNEL, tkw.WorkgroupBarrier()) + tkw.insert_after(pipeline_loop.KERNEL, tkw.SharedMemoryBarrier()) + # tkw.insert_at_end(pipeline_loop.KERNEL, tkw.SharedMemoryBarrier()) + + # Apply the cluster-based reordering + tkw.reorder_graph(pipeline_loop.KERNEL, clusters) + + # Apply wave staggering for better overlap + if use_stagger: + tkw.stagger(pipeline_loop.KERNEL) + + return mxfp4_dbuf_schedule + + +def get_mxfp4_dbuf_schedule_shuffle(use_stagger: bool = True): + """Return a double-buffered MXFP4 schedule for wave_compile(). + + Args: + use_stagger: Enable wave staggering + WorkgroupBarrier in cluster 0. + Recommended for 8-wave configs; disable for 4-wave. + """ + K = tkl.sym.K + + @wave_schedule.wave_schedule() + def mxfp4_dbuf_schedule(): + # ===================================================================== + # Get tagged nodes from the kernel + # ===================================================================== + k_loop = tkw.get_node_by_tag("k_loop") + + # Matrix A data - GatherToLDS (global->shared) + Read (shared load) + all_read_a = tkw.get_node_by_tag("read_a") + global_to_shared_a = tkw.filter_nodes(all_read_a, node_type=tkw.GatherToLDS) + shared_load_a = tkw.filter_nodes(all_read_a, node_type=tkw.Read) + + # Matrix A scale + all_read_a_scale = tkw.get_node_by_tag("read_a_scale") + global_to_shared_a_scale = tkw.filter_nodes( + all_read_a_scale, node_type=tkw.GatherToLDS + ) + shared_load_a_scale = tkw.filter_nodes(all_read_a_scale, node_type=tkw.Read) + + # Matrix B data + all_read_b = tkw.get_node_by_tag("read_b") + global_to_shared_b = tkw.filter_nodes(all_read_b, node_type=tkw.GatherToLDS) + shared_load_b = tkw.filter_nodes(all_read_b, node_type=tkw.Read) + + # Matrix B scale + g2v_b_scale = tkw.get_node_by_tag("read_b_scale") + + # Bitcast operations (needed alongside compute) + bitcast_a = tkw.get_node_by_tag("bitcast_a") + bitcast_a_scale = tkw.get_node_by_tag("bitcast_a_scale") + bitcast_b = tkw.get_node_by_tag("bitcast_b") + bitcast_b_scale = tkw.get_node_by_tag("bitcast_b_scale") + + # Scaled MMA + scaled_mma = tkw.get_node_by_tag("scaled_mma") + + # ===================================================================== + # Create 2-stage pipeline (double buffering) + # ===================================================================== + pipeline_loop = tkw.pipeline(k_loop) + + with pipeline_loop as pl: + # Stage 0: Global-to-shared prefetch via GatherToLDS (no fusion) + pl.set_stage( + [ + ( + global_to_shared_a, + global_to_shared_b, + global_to_shared_a_scale, + ), + (), + (), + ], + ) + # Stage 1: Shared memory loads + bitcasts + compute + pl.set_stage( + [ + ( + g2v_b_scale, + shared_load_a_scale, + shared_load_a, + shared_load_b, + ), + (bitcast_a, bitcast_a_scale, bitcast_b, bitcast_b_scale), + (scaled_mma,), + ], + ) + + # ===================================================================== + # KERNEL: Main loop body with custom cluster ordering + # ===================================================================== + + # Filter nodes for KERNEL stage + loop_global_to_shared = ( + tkw.filter_nodes(global_to_shared_a, subgraph=pipeline_loop.KERNEL) + + tkw.filter_nodes(global_to_shared_b, subgraph=pipeline_loop.KERNEL) + + tkw.filter_nodes(global_to_shared_a_scale, subgraph=pipeline_loop.KERNEL) + ) + + loop_shared_load_a = tkw.filter_nodes( + shared_load_a, subgraph=pipeline_loop.KERNEL + ) + loop_shared_load_b = tkw.filter_nodes( + shared_load_b, subgraph=pipeline_loop.KERNEL + ) + loop_shared_load_a_scale = tkw.filter_nodes( + shared_load_a_scale, subgraph=pipeline_loop.KERNEL + ) + loop_g2v_b_scale = tkw.filter_nodes(g2v_b_scale, subgraph=pipeline_loop.KERNEL) + + loop_bitcast_a = tkw.filter_nodes(bitcast_a, subgraph=pipeline_loop.KERNEL) + loop_bitcast_a_scale = tkw.filter_nodes( + bitcast_a_scale, subgraph=pipeline_loop.KERNEL + ) + loop_bitcast_b = tkw.filter_nodes(bitcast_b, subgraph=pipeline_loop.KERNEL) + loop_bitcast_b_scale = tkw.filter_nodes( + bitcast_b_scale, subgraph=pipeline_loop.KERNEL + ) + loop_scaled_mma = tkw.filter_nodes(scaled_mma, subgraph=pipeline_loop.KERNEL) + + # Partition by K dimension for interleaving compute with memory ops. + # NOTE: Bitcasts MUST also be partitioned by K to match their producer + # shared loads, otherwise reorder_graph fails with + # "Cannot find producer(s)" because bitcasts in an earlier cluster + # would depend on shared loads in a later cluster. + loop_scaled_mma_0, loop_scaled_mma_1 = tkw.partition_by_dim( + loop_scaled_mma, dim=K, num_partitions=2 + ) + loop_shared_load_a_0, loop_shared_load_a_1 = tkw.partition_by_dim( + loop_shared_load_a, dim=K, num_partitions=2 + ) + loop_shared_load_b_0, loop_shared_load_b_1 = tkw.partition_by_dim( + loop_shared_load_b, dim=K, num_partitions=2 + ) + loop_shared_load_a_scale_0, loop_shared_load_a_scale_1 = tkw.partition_by_dim( + loop_shared_load_a_scale, dim=K, num_partitions=2 + ) + loop_g2v_b_scale_0, loop_g2v_b_scale_1 = tkw.partition_by_dim( + loop_g2v_b_scale, dim=K, num_partitions=2 + ) + loop_bitcast_a_0, loop_bitcast_a_1 = tkw.partition_by_dim( + loop_bitcast_a, dim=K, num_partitions=2 + ) + loop_bitcast_a_scale_0, loop_bitcast_a_scale_1 = tkw.partition_by_dim( + loop_bitcast_a_scale, dim=K, num_partitions=2 + ) + loop_bitcast_b_0, loop_bitcast_b_1 = tkw.partition_by_dim( + loop_bitcast_b, dim=K, num_partitions=2 + ) + loop_bitcast_b_scale_0, loop_bitcast_b_scale_1 = tkw.partition_by_dim( + loop_bitcast_b_scale, dim=K, num_partitions=2 + ) + + independent_global_count = len(loop_global_to_shared) + + # Build cluster 0: first K-partition loads + bitcasts + GatherToLDS + cluster_0_ops = [ + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWait(load=0), + tkw.WorkgroupBarrier(), + tkw.WorkgroupBarrier(), + loop_global_to_shared, + tkw.SchedulingBarrier([]), + # tkw.MemoryCounterWait(load=independent_global_count), + loop_g2v_b_scale_0, + loop_shared_load_a_scale_0, + loop_shared_load_a_0, + loop_shared_load_b_0, + loop_bitcast_a_0, + loop_bitcast_a_scale_0, + loop_bitcast_b_0, + loop_bitcast_b_scale_0, + tkw.SchedulingBarrier([]), + ] + if use_stagger: + cluster_0_ops.extend( + [ + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ] + ) + + clusters = [ + # Cluster 0: First K-partition shared loads/bitcasts + async GatherToLDS + tkw.cluster(cluster_0_ops), + # Cluster 1: First K-partition scaled_mma (high priority) + tkw.cluster( + [ + tkw.SetWavePrio(1), + loop_scaled_mma_0, + tkw.SetWavePrio(0), + tkw.SchedulingBarrier([]), + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ], + ), + # Cluster 2: Second K-partition shared loads/bitcasts + tkw.cluster( + [ + tkw.SchedulingBarrier([]), + loop_g2v_b_scale_1, + loop_shared_load_a_scale_1, + loop_shared_load_a_1, + loop_shared_load_b_1, + loop_bitcast_a_1, + loop_bitcast_a_scale_1, + loop_bitcast_b_1, + loop_bitcast_b_scale_1, + tkw.SchedulingBarrier([]), + tkw.WorkgroupBarrier(), tkw.SchedulingBarrier([]), ], ), @@ -238,9 +472,10 @@ def mxfp4_dbuf_schedule(): ), ] - # Insert shared memory barriers at loop boundaries - tkw.insert_before(pipeline_loop.KERNEL, tkw.SharedMemoryBarrier()) - tkw.insert_at_end(pipeline_loop.KERNEL, tkw.SharedMemoryBarrier()) + # Insert barriers at loop boundaries + tkw.insert_before(pipeline_loop.KERNEL, tkw.WorkgroupBarrier()) + tkw.insert_after(pipeline_loop.KERNEL, tkw.SharedMemoryBarrier()) + # tkw.insert_at_end(pipeline_loop.KERNEL, tkw.SharedMemoryBarrier()) # Apply the cluster-based reordering tkw.reorder_graph(pipeline_loop.KERNEL, clusters) diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_triple_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_triple_buffer.py new file mode 100644 index 0000000000..209ae185c2 --- /dev/null +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_triple_buffer.py @@ -0,0 +1,262 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +MXFP4 Scaled GEMM Double Buffer Schedule for CDNA4 (GFX950) + +Reusable 2-stage pipeline schedule for MXFP4 scaled GEMM on GFX950. +Handles 4 input tensors (A data, A scale, B data, B scale) with bitcasts. + +Stage 0: GatherToLDS async prefetch | Stage 1: shared loads + bitcasts + MMA +K-dimension partitioned into 2 halves for memory/compute interleaving. + +Required kernel tags: k_loop, read_a, read_a_scale, read_b, read_b_scale, +bitcast_a, bitcast_a_scale, bitcast_b, bitcast_b_scale, scaled_mma. +Requires use_global_to_shared=True and threads_per_wave=64. +""" + +import wave_lang.kernel.lang as tkl +import wave_lang.kernel.wave as tkw +import wave_lang.kernel.wave.wave_schedule as wave_schedule + + +def get_mxfp4_triplebuf_schedule(use_stagger: bool = True): + """Return a double-buffered MXFP4 schedule for wave_compile(). + + Args: + use_stagger: Enable wave staggering + WorkgroupBarrier in cluster 0. + Recommended for 8-wave configs; disable for 4-wave. + """ + K = tkl.sym.K + + @wave_schedule.wave_schedule() + def mxfp4_triple_buf_schedule(): + # ===================================================================== + # Get tagged nodes from the kernel + # ===================================================================== + k_loop = tkw.get_node_by_tag("k_loop") + + # Matrix A data - GatherToLDS (global->shared) + Read (shared load) + all_read_a = tkw.get_node_by_tag("read_a") + global_to_shared_a = tkw.filter_nodes(all_read_a, node_type=tkw.GatherToLDS) + shared_load_a = tkw.filter_nodes(all_read_a, node_type=tkw.Read) + + # Matrix A scale + all_read_a_scale = tkw.get_node_by_tag("read_a_scale") + global_to_shared_a_scale = tkw.filter_nodes( + all_read_a_scale, node_type=tkw.GatherToLDS + ) + shared_load_a_scale = tkw.filter_nodes(all_read_a_scale, node_type=tkw.Read) + + # Matrix B data + all_read_b = tkw.get_node_by_tag("read_b") + global_to_shared_b = tkw.filter_nodes(all_read_b, node_type=tkw.GatherToLDS) + shared_load_b = tkw.filter_nodes(all_read_b, node_type=tkw.Read) + + # Matrix B scale + all_read_b_scale = tkw.get_node_by_tag("read_b_scale") + global_to_shared_b_scale = tkw.filter_nodes( + all_read_b_scale, node_type=tkw.GatherToLDS + ) + shared_load_b_scale = tkw.filter_nodes(all_read_b_scale, node_type=tkw.Read) + + # Bitcast operations (needed alongside compute) + bitcast_a = tkw.get_node_by_tag("bitcast_a") + bitcast_a_scale = tkw.get_node_by_tag("bitcast_a_scale") + bitcast_b = tkw.get_node_by_tag("bitcast_b") + bitcast_b_scale = tkw.get_node_by_tag("bitcast_b_scale") + + # Scaled MMA + scaled_mma = tkw.get_node_by_tag("scaled_mma") + + # ===================================================================== + # Create 2-stage pipeline (double buffering) + # ===================================================================== + pipeline_loop = tkw.pipeline(k_loop) + + with pipeline_loop as pl: + # Stage 0: Global-to-shared prefetch via GatherToLDS (no fusion) + pl.set_stage( + [ + ( + global_to_shared_a, + global_to_shared_a_scale, + global_to_shared_b, + global_to_shared_b_scale, + ), + (), + (), + ], + ) + pl.set_stage( + [ + (), + (), + (), + ], + ) + # Stage 1: Shared memory loads + bitcasts + compute + pl.set_stage( + [ + ( + shared_load_a, + shared_load_b, + shared_load_a_scale, + shared_load_b_scale, + ), + (bitcast_a, bitcast_a_scale, bitcast_b, bitcast_b_scale), + (scaled_mma,), + ], + ) + + # ===================================================================== + # KERNEL: Main loop body with custom cluster ordering + # ===================================================================== + + # Filter nodes for KERNEL stage + loop_global_to_shared = ( + tkw.filter_nodes(global_to_shared_a, subgraph=pipeline_loop.KERNEL) + + tkw.filter_nodes(global_to_shared_a_scale, subgraph=pipeline_loop.KERNEL) + + tkw.filter_nodes(global_to_shared_b, subgraph=pipeline_loop.KERNEL) + + tkw.filter_nodes(global_to_shared_b_scale, subgraph=pipeline_loop.KERNEL) + ) + + loop_shared_load_a = tkw.filter_nodes( + shared_load_a, subgraph=pipeline_loop.KERNEL + ) + loop_shared_load_a_scale = tkw.filter_nodes( + shared_load_a_scale, subgraph=pipeline_loop.KERNEL + ) + loop_shared_load_b = tkw.filter_nodes( + shared_load_b, subgraph=pipeline_loop.KERNEL + ) + loop_shared_load_b_scale = tkw.filter_nodes( + shared_load_b_scale, subgraph=pipeline_loop.KERNEL + ) + + loop_bitcast_a = tkw.filter_nodes(bitcast_a, subgraph=pipeline_loop.KERNEL) + loop_bitcast_a_scale = tkw.filter_nodes( + bitcast_a_scale, subgraph=pipeline_loop.KERNEL + ) + loop_bitcast_b = tkw.filter_nodes(bitcast_b, subgraph=pipeline_loop.KERNEL) + loop_bitcast_b_scale = tkw.filter_nodes( + bitcast_b_scale, subgraph=pipeline_loop.KERNEL + ) + loop_scaled_mma = tkw.filter_nodes(scaled_mma, subgraph=pipeline_loop.KERNEL) + + # Partition by K dimension for interleaving compute with memory ops. + # NOTE: Bitcasts MUST also be partitioned by K to match their producer + # shared loads, otherwise reorder_graph fails with + # "Cannot find producer(s)" because bitcasts in an earlier cluster + # would depend on shared loads in a later cluster. + loop_scaled_mma_0, loop_scaled_mma_1 = tkw.partition_by_dim( + loop_scaled_mma, dim=K, num_partitions=2 + ) + loop_shared_load_a_0, loop_shared_load_a_1 = tkw.partition_by_dim( + loop_shared_load_a, dim=K, num_partitions=2 + ) + loop_shared_load_a_scale_0, loop_shared_load_a_scale_1 = tkw.partition_by_dim( + loop_shared_load_a_scale, dim=K, num_partitions=2 + ) + loop_shared_load_b_0, loop_shared_load_b_1 = tkw.partition_by_dim( + loop_shared_load_b, dim=K, num_partitions=2 + ) + loop_shared_load_b_scale_0, loop_shared_load_b_scale_1 = tkw.partition_by_dim( + loop_shared_load_b_scale, dim=K, num_partitions=2 + ) + loop_bitcast_a_0, loop_bitcast_a_1 = tkw.partition_by_dim( + loop_bitcast_a, dim=K, num_partitions=2 + ) + loop_bitcast_a_scale_0, loop_bitcast_a_scale_1 = tkw.partition_by_dim( + loop_bitcast_a_scale, dim=K, num_partitions=2 + ) + loop_bitcast_b_0, loop_bitcast_b_1 = tkw.partition_by_dim( + loop_bitcast_b, dim=K, num_partitions=2 + ) + loop_bitcast_b_scale_0, loop_bitcast_b_scale_1 = tkw.partition_by_dim( + loop_bitcast_b_scale, dim=K, num_partitions=2 + ) + + independent_global_count = len(loop_global_to_shared) + + # Build cluster 0: first K-partition loads + bitcasts + GatherToLDS + cluster_0_ops = [ + tkw.MemoryCounterWaitBarrier(load=independent_global_count), + loop_shared_load_a_0, + loop_shared_load_a_scale_0, + loop_shared_load_b_0, + loop_shared_load_b_scale_0, + loop_bitcast_a_0, + loop_bitcast_a_scale_0, + loop_bitcast_b_0, + loop_bitcast_b_scale_0, + tkw.SchedulingBarrier([]), + loop_global_to_shared, + tkw.SchedulingBarrier([]), + ] + if use_stagger: + cluster_0_ops.extend( + [ + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ] + ) + + clusters = [ + # Cluster 0: First K-partition shared loads/bitcasts + async GatherToLDS + tkw.cluster(cluster_0_ops), + # Cluster 1: First K-partition scaled_mma (high priority) + tkw.cluster( + [ + tkw.SetWavePrio(1), + loop_scaled_mma_0, + tkw.SetWavePrio(0), + tkw.SchedulingBarrier([]), + # tkw.MemoryCounterWaitBarrier(load=independent_global_count), + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ], + ), + # Cluster 2: Second K-partition shared loads/bitcasts + tkw.cluster( + [ + loop_shared_load_a_1, + loop_shared_load_a_scale_1, + loop_shared_load_b_1, + loop_shared_load_b_scale_1, + loop_bitcast_a_1, + loop_bitcast_a_scale_1, + loop_bitcast_b_1, + loop_bitcast_b_scale_1, + tkw.SchedulingBarrier([]), + # tkw.MemoryCounterWaitBarrier(load=0), + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ], + ), + # Cluster 3: Second K-partition scaled_mma (high priority) + tkw.cluster( + [ + tkw.SetWavePrio(1), + loop_scaled_mma_1, + tkw.SetWavePrio(0), + tkw.SchedulingBarrier([]), + ], + ), + ] + + # Insert shared memory barriers at loop boundaries + # tkw.insert_before(pipeline_loop.KERNEL, tkw.WorkgroupBarrier()) + # tkw.insert_at_end(pipeline_loop.KERNEL, tkw.SharedMemoryBarrier()) + + # Apply the cluster-based reordering + tkw.reorder_graph(pipeline_loop.KERNEL, clusters) + + # Apply wave staggering for better overlap + if use_stagger: + tkw.stagger(pipeline_loop.KERNEL) + + return mxfp4_triple_buf_schedule diff --git a/wave_lang/kernel/wave/templates/__init__.py b/wave_lang/kernel/wave/templates/__init__.py index 2a190b643e..f9637eb6b2 100644 --- a/wave_lang/kernel/wave/templates/__init__.py +++ b/wave_lang/kernel/wave/templates/__init__.py @@ -6,10 +6,11 @@ from .attention_common import AttentionShape from .tagged_attention import get_tagged_bshd_attention_kernel -from .tagged_mxfp4_gemm import get_tagged_mxfp4_gemm +from .tagged_mxfp4_gemm import get_tagged_mxfp4_gemm, get_preshuffle_kernel __all__ = [ "AttentionShape", "get_tagged_bshd_attention_kernel", "get_tagged_mxfp4_gemm", + "get_preshuffle_kernel", ] diff --git a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py index aa82f2f080..2866269262 100644 --- a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py +++ b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py @@ -113,3 +113,181 @@ def repeat( ) return gemm, options + + +def get_preshuffle_kernel( + shape: tuple[int, int, int] = (1024, 1024, 8192), + block_shape: tuple[int, int, int] = (256, 256, 256), + mfma_variant: ScaledMMAType = ScaledMMAType.F32_16x16x128_F8F6F4, + num_waves: int = 8, +): + """Return the pre-shuffled MXFP4 GEMM kernel definition with IndexMapping for shuffled scales. + + All ops are tagged for use with MXFP4 schedule functions (e.g. get_mxfp4_dbuf_schedule_shuffle). + + Args: + shape: (M, N, K) problem dimensions. + block_shape: (BLOCK_M, BLOCK_N, BLOCK_K) tile sizes. + mfma_variant: Scaled MMA instruction type. + num_waves: Waves per workgroup (4 or 8). + + Returns: + (kernel_function, WaveCompileOptions) + """ + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + K_SCALE_SHUFFLED = tkl.sym.K_SCALE_SHUFFLED + k_scale_shuffled = (((shape[2] // 32) + 7) // 8) * 8 + + constraints: list[tkw.Constraint] = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.TilingConstraint(K, BLOCK_K), + ] + + if num_waves == 8: + # 8 waves: 4 M-tiles x 2 N-tiles + constraints += [tkw.WaveConstraint(M, BLOCK_M / 4)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + else: + # 4 waves: 2 M-tiles x 2 N-tiles + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=ScaledMMAType.F32_16x16x128_F8F6F4, + ), + ] + + # Create IndexMapping for shuffled A scales + # The e8m0_shuffle coordinate transformation maps logical (K, M) iterators + # to physical shuffled memory layout + i = tkw.IndexMapping.iterator(0) # K iterator + j = tkw.IndexMapping.iterator(1) # M iterator + + a_scale_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={ + M: ( + ( + (j // 32) * ((K_SCALE_SHUFFLED // 8) * 256) + + (i // 8) * 256 + + ((i % 8) % 4) * 64 + + ((j % 32) % 16) * 4 + + (((i % 8) // 4) * 2) + + ((j % 32) // 16) + ) + // K_SCALE_SHUFFLED + ), + K: ( + ( + (j // 32) * ((K_SCALE_SHUFFLED // 8) * 256) + + (i // 8) * 256 + + ((i % 8) % 4) * 64 + + ((j % 32) % 16) * 4 + + (((i % 8) // 4) * 2) + + ((j % 32) // 16) + ) + % K_SCALE_SHUFFLED + ), + }, + outputs={ + K: i, + M: j, + }, + ) + + # Create IndexMapping for shuffled B scales + k = tkw.IndexMapping.iterator(0) # K iterator + n = tkw.IndexMapping.iterator(1) # N iterator + + b_scale_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={ + N: ( + ( + (n // 32) * ((K_SCALE_SHUFFLED // 8) * 256) + + (k // 8) * 256 + + ((k % 8) % 4) * 64 + + ((n % 32) % 16) * 4 + + (((k % 8) // 4) * 2) + + ((n % 32) // 16) + ) + // K_SCALE_SHUFFLED + ), + K: ( + ( + (n // 32) * ((K_SCALE_SHUFFLED // 8) * 256) + + (k // 8) * 256 + + ((k % 8) % 4) * 64 + + ((n % 32) % 16) * 4 + + (((k % 8) // 4) * 2) + + ((n % 32) // 16) + ) + % K_SCALE_SHUFFLED + ), + }, + outputs={ + K: k, + N: n, + }, + ) + + # TODO: preshuffle merge doesn't work with shared address space yet. + @tkw.wave(constraints) + def mxfp4_gemm_preshuffle( + a: tkl.Memory[M, K / 2, ADDRESS_SPACE, tkl.i8], + a_scale: tkl.Memory[M, K / 32, ADDRESS_SPACE, tkl.i8], + b: tkl.Memory[N, K / 2, ADDRESS_SPACE, tkl.i8], + b_scale: tkl.Memory[N, K / 32, GLOBAL_ADDRESS_SPACE, tkl.i8], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.iterate(K, init_args=[c_reg], tag="k_loop") + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, tag="read_a") + a_reg = tkw.bitcast(a_reg, tkl.f4e2m1fn, tag="bitcast_a") + a_scale_reg = tkw.read(a_scale, tag="read_a_scale") + a_scale_reg = tkw.bitcast(a_scale_reg, tkl.f8e8m0fnu, tag="bitcast_a_scale") + + b_reg = tkw.read(b, tag="read_b") + b_reg = tkw.bitcast(b_reg, tkl.f4e2m1fn, tag="bitcast_b") + b_scale_reg = tkw.read(b_scale, mapping=b_scale_mapping, tag="read_b_scale") + b_scale_reg = tkw.bitcast(b_scale_reg, tkl.f8e8m0fnu, tag="bitcast_b_scale") + + acc = tkw.scaled_mma( + a_reg, a_scale_reg, b_reg, b_scale_reg, acc, tag="scaled_mma" + ) + return acc + + tkw.write(repeat, c) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + BLOCK_M: block_shape[0], + BLOCK_N: block_shape[1], + BLOCK_K: block_shape[2], + M: shape[0], + N: shape[1], + K: shape[2], + K_SCALE_SHUFFLED: k_scale_shuffled, + } + hyperparams.update(get_default_scheduling_params()) + + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + schedule=SchedulingType.MANUAL, + use_global_to_shared=True, + minimize_shared_allocs=True, + ) + + return mxfp4_gemm_preshuffle, options