@@ -1559,15 +1559,15 @@ void IndexLowering::handleCpAsyncBulkStore(const LoadStoreOp* ldst) {
1559
1559
}
1560
1560
1561
1561
static DataType getMmaInputAType (MmaMacro macro) {
1562
- int warp_group_size = isHopper (macro) ? 128 : 32 ;
1563
- int size = getM (macro) * getK (macro) / warp_group_size /
1564
- 2 /* halves per 32bit register */ ;
1562
+ int64_t warp_group_size = isHopper (macro) ? 128L : 32L ;
1563
+ int64_t size = getM (macro) * getK (macro) / warp_group_size /
1564
+ 2L /* halves per 32bit register */ ;
1565
1565
return ArrayType{std::make_shared<DataType>(DataType::UInt32 ), (size_t )size};
1566
1566
}
1567
1567
1568
1568
static DataType getMmaInputBType (MmaMacro macro) {
1569
- int size = getN (macro) * getK (macro) / 32 /* threads per warp */ /
1570
- 2 /* halves per 32bit register */ ;
1569
+ int64_t size = getN (macro) * getK (macro) / 32L /* threads per warp */ /
1570
+ 2L /* halves per 32bit register */ ;
1571
1571
return ArrayType{std::make_shared<DataType>(DataType::UInt32 ), (size_t )size};
1572
1572
}
1573
1573
@@ -1842,8 +1842,8 @@ Val* hardCodedIndexGenerationForStMatrix(
1842
1842
// To account for the threadIdx.y, we have to add it to the offset:
1843
1843
// offset_from_tdy = threadIdx.y * tma_m * tma_n * 2 (half)
1844
1844
//
1845
- // Now, lets apply stmatrix tile to the TMA Box.
1846
- // [NO(2), MO(4), MI(16), NIO(4), NII(16)].
1845
+ // Now, lets apply stmatrix tile (16, 16) to the TMA Box [NO(2), M(64), NI(64)] .
1846
+ // [NO(2), MO(4), MI(16), NIO(4), NII(16)].
1847
1847
//
1848
1848
// A warp group of 128 threads contains four warps. StMatrix is a warp-level
1849
1849
// operation, so four StMatrix operations can be issued simultaneously by the
@@ -1865,6 +1865,7 @@ Val* hardCodedIndexGenerationForStMatrix(
1865
1865
// domain is scheduled as [NO(2), M(64), NI(64)]. Therefore, we must store the
1866
1866
// data in shared memory in [M(64), NI(64)] contiguous tiles.
1867
1867
//
1868
+ // NOTE: This offset is skipped if for-loop is trivial
1868
1869
// To account for the outer_index, we have to add it to the offset:
1869
1870
// offset_from_outer_index = outer_index * tma_m * NI(64) * 2 (half)
1870
1871
//
@@ -1928,8 +1929,13 @@ Val* hardCodedIndexGenerationForStMatrix(
1928
1929
// with the 8 rows of the matrix to avoid bank conflicts. This swizzle pattern
1929
1930
// is repeated along the rows of the TMA box.
1930
1931
//
1932
+ // The number of distinct swizzle rows is number of bytes for swizzle divided by
1933
+ // size of megabank (16B). The number of times a swizzle pattern is repeated to
1934
+ // fill core (8, 8) matrix is number of swizzle rows (8) divided by number of
1935
+ // distinct rows.
1936
+ //
1931
1937
// Swizzle column
1932
- // row_in_swizzle_pattern = row % swizzle_row_size(8)
1938
+ // row_in_swizzle_pattern = ( row % swizzle_row_size(8)) / swizzle_repetitions
1933
1939
// swizzle_col = column XOR row_in_swizzle_pattern
1934
1940
//
1935
1941
// Calculate Tile Offset
@@ -1939,7 +1945,7 @@ Val* hardCodedIndexGenerationForStMatrix(
1939
1945
//
1940
1946
// Get shared memory offset
1941
1947
// smem_offset = offset_from_tdy + offset_from_outer_index + tile_offset
1942
- Val* hardCodedIndexGenerationForStMatrix128BSwizzle (
1948
+ Val* hardCodedIndexGenerationForStMatrixSwizzle (
1943
1949
const LoadStoreOp* ldst,
1944
1950
ForLoop* loop,
1945
1951
const int64_t stsm_m_tile,
@@ -1958,16 +1964,19 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
1958
1964
1959
1965
NVF_ERROR (ldst->out ()->isA <TensorView>());
1960
1966
TensorView* out_tv = ldst->out ()->as <TensorView>();
1961
- NVF_ERROR (getSwizzle (out_tv) == MmaInputSmemSwizzle::B128);
1967
+ MmaInputSmemSwizzle swizzle = getSwizzle (out_tv);
1968
+ int64_t swizzle_bytes = getBytesFromSwizzle (swizzle);
1962
1969
1963
1970
// Constants
1964
1971
constexpr int64_t dtype_size = 2 ;
1965
1972
constexpr int64_t warp_size = 32 ;
1966
1973
constexpr int64_t swizzle_row_size = 8 ;
1967
1974
constexpr int64_t stsm_column_size = 8 ;
1968
- constexpr int64_t swizzle_n_tile = 64 ;
1975
+ constexpr int64_t megabank_size_bytes = 16 ;
1969
1976
1970
1977
// Derived constants
1978
+ const int64_t swizzle_n_tile = swizzle_bytes / dtype_size;
1979
+ const int64_t distinct_swizzle_row_size = swizzle_bytes / megabank_size_bytes;
1971
1980
constexpr int64_t stsm_column_stride = stsm_column_size * dtype_size;
1972
1981
const int64_t swizzle_n_iter = swizzle_n_tile / stsm_n_tile;
1973
1982
const int64_t swizzle_n_tile_stride = swizzle_n_tile * dtype_size;
@@ -2000,8 +2009,6 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
2000
2009
Val* warp_id = SimplifyingIrBuilder::divExpr (TDX, warp_size_val);
2001
2010
Val* lane_id = SimplifyingIrBuilder::modExpr (TDX, warp_size_val);
2002
2011
2003
- Val* outer_index =
2004
- SimplifyingIrBuilder::divExpr (loop->index (), swizzle_n_iter_val);
2005
2012
Val* inner_index =
2006
2013
SimplifyingIrBuilder::modExpr (loop->index (), swizzle_n_iter_val);
2007
2014
@@ -2021,6 +2028,17 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
2021
2028
// Swizzle Column
2022
2029
Val* row_in_swizzle_pattern =
2023
2030
SimplifyingIrBuilder::modExpr (row, swizzle_row_size_val);
2031
+
2032
+ // The swizzle pattern is repeated to fill (8, 8) matrix for 64B and 32B
2033
+ // swizzles. swizzle_row_iter is the number of repetitions to fill 8 rows
2034
+ // with distict swizzle rows.
2035
+ const int64_t swizzle_row_iter = swizzle_row_size / distinct_swizzle_row_size;
2036
+ if (swizzle_row_iter > 1 ) {
2037
+ Val* swizzle_row_iter_val =
2038
+ IrBuilder::create<Val>(swizzle_row_iter, DataType::Index);
2039
+ row_in_swizzle_pattern = SimplifyingIrBuilder::divExpr (
2040
+ row_in_swizzle_pattern, swizzle_row_iter_val);
2041
+ }
2024
2042
Val* swizzle_col = bitwise_xor (col, row_in_swizzle_pattern);
2025
2043
2026
2044
// Calculate Tile Offset
@@ -2031,16 +2049,22 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
2031
2049
Val* offset = SimplifyingIrBuilder::addExpr (row_offset, col_offset);
2032
2050
2033
2051
// Calculate Tile offset
2034
- Val* tile_offset = IrBuilder::mulExpr (outer_index, tile_stride_val);
2052
+ // Skip tile offset if loop is trivial.
2053
+ if (!loop->stop ()->isOneInt ()) {
2054
+ Val* outer_index =
2055
+ SimplifyingIrBuilder::divExpr (loop->index (), swizzle_n_iter_val);
2056
+ Val* tile_offset =
2057
+ SimplifyingIrBuilder::mulExpr (outer_index, tile_stride_val);
2058
+ offset = SimplifyingIrBuilder::addExpr (tile_offset, offset);
2059
+ }
2035
2060
2036
2061
// Calculate TDY offset
2037
- Val* tdy_offset = IrBuilder::mulExpr (TDY, tdy_stride_val);
2062
+ Val* tdy_offset = SimplifyingIrBuilder::mulExpr (TDY, tdy_stride_val);
2063
+ offset = SimplifyingIrBuilder::addExpr (tdy_offset, offset);
2038
2064
2039
2065
// Create shared memory TensorIndex
2040
2066
Val* out_index = SimplifyingIrBuilder::addExpr (
2041
- IrBuilder::baseAddressExpr (ir_utils::getTvOutput (ldst)),
2042
- SimplifyingIrBuilder::addExpr (
2043
- tdy_offset, SimplifyingIrBuilder::addExpr (tile_offset, offset)));
2067
+ IrBuilder::baseAddressExpr (ir_utils::getTvOutput (ldst)), offset);
2044
2068
Val* out = IrBuilder::create<kir::TensorIndex>(
2045
2069
dynamic_cast <TensorView*>(ldst->out ()), out_index);
2046
2070
return out;
@@ -2092,11 +2116,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
2092
2116
ldst, for_loops_[0 ], m_tile, n_tile, m, n);
2093
2117
break ;
2094
2118
case MmaInputSmemSwizzle::B128:
2095
- out = hardCodedIndexGenerationForStMatrix128BSwizzle (
2119
+ case MmaInputSmemSwizzle::B64:
2120
+ case MmaInputSmemSwizzle::B32:
2121
+ out = hardCodedIndexGenerationForStMatrixSwizzle (
2096
2122
ldst, for_loops_[0 ], m_tile, n_tile, m, n);
2097
2123
break ;
2098
- case MmaInputSmemSwizzle::B32:
2099
- case MmaInputSmemSwizzle::B64:
2100
2124
default :
2101
2125
NVF_ERROR (" Unsupported Swizzle Type for StMatrix" );
2102
2126
}
0 commit comments