Skip to content

Commit

Permalink
OpenXLA-specific changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jax-triton-dev authored and karupayun committed Aug 28, 2024
1 parent b2de88f commit 7a5940c
Show file tree
Hide file tree
Showing 45 changed files with 2,187 additions and 114 deletions.
900 changes: 900 additions & 0 deletions BUILD

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions include/triton/Analysis/Alias.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,9 @@ class SharedMemoryAliasAnalysis
}

/// Computes if the alloc set of the results are changed.
void
visitOperation(Operation *op,
ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
ArrayRef<dataflow::Lattice<AliasInfo> *> results) override;
LogicalResult visitOperation(
Operation *op, ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
ArrayRef<dataflow::Lattice<AliasInfo> *> results) override;
};

} // namespace mlir
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/TritonTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TritonTypeDef<string name, string _mnemonic, list<Trait> traits = []>
}

// Floating-point Type
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : RankedTensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;

Expand Down
6 changes: 5 additions & 1 deletion include/triton/Dialect/Triton/IR/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }

/// Get the highest power of 2 divisor of an integer.
template <typename T> T highestPowOf2Divisor(T n) {
if (n == 0) {
// When n is 0 or min, return the highest power of 2. The min case is handled
// separately to avoid underflow when T is a signed integer. Technically
// in that case the correct divisor is -n, but this value is outside the
// range of possible values, so we take the next best alternative.
if (n == 0 || n == std::numeric_limits<T>::min()) {
return (static_cast<T>(1) << (sizeof(T) * 8 - 2));
}
return (n & (~(n - 1)));
Expand Down
8 changes: 5 additions & 3 deletions lib/Analysis/Alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) {
return ret;
}

void SharedMemoryAliasAnalysis::visitOperation(
LogicalResult SharedMemoryAliasAnalysis::visitOperation(
Operation *op, ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
ArrayRef<dataflow::Lattice<AliasInfo> *> results) {
AliasInfo aliasInfo;
Expand All @@ -31,7 +31,7 @@ void SharedMemoryAliasAnalysis::visitOperation(
if (auto memdescTy = dyn_cast<triton::MemDescType>(result.getType())) {
if (!isa_and_nonnull<triton::gpu::SharedMemorySpaceAttr>(
memdescTy.getMemorySpace()))
return;
return mlir::success();
}

// Only LocalAllocOp creates a new buffer.
Expand All @@ -49,11 +49,13 @@ void SharedMemoryAliasAnalysis::visitOperation(
}

if (pessimistic) {
return setAllToEntryStates(results);
setAllToEntryStates(results);
return mlir::success();
}
// Join all lattice elements
for (auto *result : results)
propagateIfChanged(result, result->join(aliasInfo));
return mlir::success();
}

AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) {
Expand Down
15 changes: 9 additions & 6 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
dataflow::Lattice<AxisInfo>>::getLatticeElement;
using FuncAxisInfoMapT = DenseMap<FunctionOpInterface, AxisInfo>;

void visitOperation(Operation *op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
ArrayRef<dataflow::Lattice<AxisInfo> *> results) override;
LogicalResult visitOperation(
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
ArrayRef<dataflow::Lattice<AxisInfo> *> results) override;
void
visitForOpInductionVar(scf::ForOp op,
ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices);
Expand Down Expand Up @@ -1039,7 +1039,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
visitors.append<LoadOpAxisInfoVisitor>();
}

void AxisInfoAnalysis::visitOperation(
LogicalResult AxisInfoAnalysis::visitOperation(
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
// TODO: For sure not the right way to do this
Expand All @@ -1048,8 +1048,10 @@ void AxisInfoAnalysis::visitOperation(
if (op->getValue().getRank() == 0)
setToEntryState((dataflow::Lattice<AxisInfo> *)op);
AxisInfo curr = visitors.apply(op, operands);
if (curr.getRank() == 0)
return setAllToEntryStates(results);
if (curr.getRank() == 0) {
setAllToEntryStates(results);
return mlir::success();
}
// override with hint
auto newContiguity = curr.getContiguity();
auto newDivisibility = curr.getDivisibility();
Expand All @@ -1071,6 +1073,7 @@ void AxisInfoAnalysis::visitOperation(
// join all lattice elements
for (auto *result : results)
propagateIfChanged(result, result->join(curr));
return mlir::success();
}

void AxisInfoAnalysis::visitForOpInductionVar(
Expand Down
6 changes: 4 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ bool supportMFMATypes(Type a, Type b) {
if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth())
return false;

auto F8E4M3FN = TypeID::get<Float8E4M3FNType>();
auto F8E5M2 = TypeID::get<Float8E5M2Type>();
auto F8E4M3FNUZ = TypeID::get<Float8E4M3FNUZType>();
auto F8E5M2FNUZ = TypeID::get<Float8E5M2FNUZType>();
Expand All @@ -436,6 +437,7 @@ bool supportMFMATypes(Type a, Type b) {
{F32, F32},
{F16, F16},
{BF16, BF16},
{F8E4M3FN, F8E4M3FN},
{F8E5M2, F8E5M2},
{F8E4M3FNUZ, F8E4M3FNUZ},
{F8E4M3FNUZ, F8E5M2FNUZ},
Expand Down Expand Up @@ -495,14 +497,14 @@ bool supportMMA(triton::DotOp op, int version) {
return false;
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
retShapePerCTA[rank - 1] % 8 == 0 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() ||
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
aElemTy.isF32()))) {
return false;
}
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
if (op.getMaxNumImpreciseAcc() < 32 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
return false;
}
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
auto ouEltTy = ouTensorTy.getElementType();
if (inBitWidth == ouBitWidth)
return values;
if (inBitWidth == 16 && ouBitWidth == 32) {
if ((inBitWidth == 16 && ouBitWidth == 32) ||
(inBitWidth == 32 && ouBitWidth == 16)) {
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 8) {
ret.push_back(values[i]);
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E4M3FNType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ struct ArithConstantSplatOpConversion
// LLVM IR.
if (type::isFloat8(elemType))
elemType = rewriter.getIntegerType(8);
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
auto typeConverter = getTypeConverter();
auto constOp = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(elemType), val);
auto llStruct = SplatOpConversion::convertSplatLikeOp(
elemType, op.getType(), constOp, typeConverter, rewriter, loc);
rewriter.replaceOp(op, llStruct);
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2721,6 +2721,11 @@ struct CanonicalizeConvertFromAlloc
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding, so we want to keep this layout conversion.
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
convert.getSrc().getType().getEncoding()))
return failure();
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
op, op->getResult(0).getType(), convert.getSrc());
return mlir::success();
Expand Down
26 changes: 25 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
newLayout, SharedMemorySpace);
rewriter.setInsertionPointAfterValue(arg);

// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding.
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
argType.getEncoding())) {
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
// then pass it to the LocalAllocOp.
auto newArgType = RankedTensorType::get(
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
auto dotOperandToBlockedCvt =
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
dotOperandToBlockedCvt);
}

return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
}

Expand All @@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;

static bool bwdFilter(Operation *op) {
// Dot operand layout assignment to Predicates are not currently supported
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
// condition limits visibility of the original bit-width so that predicate
// are not considered, hence, kwidth can never be = 32.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
return false;
}
return op->getNumOperands() == 1 &&
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
isPureUnaryInlineAsm(op) ||
Expand Down Expand Up @@ -357,7 +381,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
NvidiaMmaEncodingAttr mmaLayout =
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
if (mmaLayout) {
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ();
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
// promote operands for sm < 89 since fp8 mma is not natively supported
// promote operands for sm >= 90 when mma is not v3
if (!isNativeFP8 ||
Expand Down
17 changes: 16 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
PatternRewriter &rewriter) const override {
// Only consider conversions to dot operand.
auto cvtTy = cast<RankedTensorType>(cvt.getType());
if (!isa<DotOperandEncodingAttr>(cvtTy.getEncoding()))
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
if (!dotOpEnc)
return failure();

auto src = cvt.getSrc().getDefiningOp();
Expand All @@ -126,6 +127,12 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
[](Type ty) { return isa<RankedTensorType>(ty); }))
return failure();

// Quick handling to fix loading issues when computing the original
// bitwidth is unable to realize that there is a mixed-precision dot
// (hence kWidth = 1) but wants to hoist through the type conversion.
if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1)
return failure();

// Only consider custom conversions or arith ops.
// TODO(jlebar): Is this too restrictive?
if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) &&
Expand All @@ -138,6 +145,14 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
return failure();

// Don't hoist through u1 -> fp casts as they aren't supported in
// ElementwiseOpToLLVM::reorderValues().
if (isa<arith::UIToFPOp>(src)) {
Type srcType = getElementTypeOrSelf(src->getOperand(0));
if (srcType.isInteger(1))
return failure();
}

// Check that the conversion is transitively dependent on a load, and all
// operations between the load and the conversion are layout preserving.
//
Expand Down
17 changes: 16 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
type.getMemorySpace()),
v, offsetsVal);

// We need to assign kwidth to zero in the case where the parent layout is
// Blocked, otherwise the verifier emits a failure. The parent layout is
// Blocked only when Tensor Cores are disabled.
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
? 0
: prefetchWidth / 8;
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
builder.getContext(), opIdx, dotEncoding, kwidth);
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
Expand Down Expand Up @@ -187,6 +193,15 @@ LogicalResult Prefetcher::initialize() {
break;
if (!op->getResult(0).hasOneUse())
break;
// Similar to issues faced in HoistLayoutConversion pattern in
// OptimizeDotOperands.cpp, we can't propagate through type casts from
// predicates as they aren't supported in Triton when encoded with dot_op
// layout.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
break;
}
rets.push_back(op->getOperand(0));
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
foundConvertFromShared = true;
Expand Down
Loading

0 comments on commit 7a5940c

Please sign in to comment.