diff --git a/BUILD b/BUILD new file mode 100644 index 000000000000..f74ef119cd09 --- /dev/null +++ b/BUILD @@ -0,0 +1,900 @@ +# This package imports OpenAI's Triton (https://github.com/openai/triton). +# +# There are two versions of Triton in google3 at the moment. The older version +# can be found at //third_party/py/triton. This is the MLIR-based version close +# to head. We expect to transition users to this version in the following +# weeks. +# +# There is no SLA associated with this package and it may get broken by LLVM +# imports at any time. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = [":license"], + # default_compatible_with = ["//buildenv/target:gce"], + # default_visibility = [ + # # Add your project here if you need to depend on Triton's C++ sources. + # # Add a point of contact we can reach out to when needed in the comment. + # # + # # If you need to use the Python fronted, add your project to + # # google3/third_party/py/triton/BUILD instead. + # # + # # By adding your project here, you agree to the Triton SLA: go/triton-google3-sla + # "//third_party/py/jax:__subpackages__", # cjfj@ + # "//third_party/tensorflow/compiler/xla:__subpackages__", # bchetioui@ + # "//platforms/xla/experimental/gpu:__subpackages__", # csigg@ + # # Triton-internal visibility + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end + # TODO(csigg): fix and remove + features = [ + "-parse_headers", + "-use_header_modules", + ], +) + +# copybara:uncomment_begin +# license(name = "license") +# +# licenses(["notice"]) +# +# exports_files(["LICENSE"]) +# copybara:uncomment_end + +config_setting( + name = "compiler_is_msvc", + flag_values = { + # copybara:comment_begin + "@bazel_tools" + + # copybara:comment_end + "//tools/cpp:compiler": "msvc-cl", + }, +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + ":compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +td_library( + name = "td_files", + srcs = glob(["include/triton/**/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "triton_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/Triton/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/Triton/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-enum-decls"], + "include/triton/Dialect/Triton/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/Triton/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-op-decls"], + "include/triton/Dialect/Triton/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/Triton/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/Triton/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/Triton/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=Triton", + ], + "include/triton/Dialect/Triton/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_combine_inc_gen", + # The generated file is #included without relative path. + strip_include_prefix = "lib/Dialect/Triton/Transforms", + tbl_outs = [ + ( + ["--gen-rewriters"], + "lib/Dialect/Triton/Transforms/TritonCombine.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/Dialect/Triton/Transforms/Combine.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPU", + ], + "include/triton/Dialect/TritonGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNvidiaGPU", + ], + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_to_triton_gpu_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonToTritonGPU", + ], + "include/triton/Conversion/TritonToTritonGPU/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonToTritonGPU/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_target_llvmir_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonLLVMIR", + ], + "include/triton/Target/LLVMIR/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Target/LLVMIR/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPUToLLVM", + ], + "include/triton/Conversion/TritonGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonGPUToLLVM/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_type_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-type-interface-decls"], + "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.h.inc", + ), + ( + ["--gen-type-interface-defs"], + "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td", + deps = ["td_files"], +) + +cc_library( + name = "TritonAnalysis", + srcs = [ + "lib/Analysis/Alias.cpp", + "lib/Analysis/Allocation.cpp", + "lib/Analysis/Membar.cpp", + # Part of TritonDialects compilation unit to avoid circular dependencies. + # "lib/Analysis/Utility.cpp", + # "lib/Analysis/AxisInfo.cpp", + ], + hdrs = [ + "include/triton/Analysis/Alias.h", + "include/triton/Analysis/Allocation.h", + "include/triton/Analysis/Membar.h", + # Part of TritonDialects compilation unit to avoid circular dependencies. + # "include/triton/Analysis/AxisInfo.h", + # "include/triton/Analysis/Utility.h", + "include/triton/Conversion/MLIRTypes.h", + "include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h", + "include/triton/Conversion/TritonGPUToLLVM/Utility.h", + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", + ], + copts = _no_unused_variable, + deps = [ + ":TritonDialects", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonDialects", + srcs = glob([ + "lib/Dialect/Triton/IR/*.cpp", + "lib/Dialect/TritonGPU/IR/*.cpp", + "lib/Dialect/TritonNvidiaGPU/IR/*.cpp", + "lib/Tools/*.cpp", + ]) + [ + "include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h", # Avoid circular dependency. + "lib/Analysis/AxisInfo.cpp", # Avoid circular dependency. + "lib/Analysis/Utility.cpp", # Avoid circular dependency. + "lib/Dialect/TritonGPU/Transforms/Utility.cpp", # Avoid circular dependency. + ], + hdrs = glob([ + "include/triton/Dialect/Triton/IR/*.h", + "include/triton/Dialect/TritonGPU/IR/*.h", + "include/triton/Dialect/TritonNvidiaGPU/IR/*.h", + "include/triton/Tools/*.h", + ]) + [ + "include/triton/Analysis/AxisInfo.h", # Avoid circular dependency. + "include/triton/Analysis/Utility.h", # Avoid circular dependency. + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", # Avoid circular dependency. + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = ["include"], + deps = [ + ":triton_dialect_inc_gen", + ":triton_gpu_attr_inc_gen", + ":triton_gpu_dialect_inc_gen", + ":triton_gpu_ops_inc_gen", + ":triton_gpu_types_inc_gen", + ":triton_interfaces_inc_gen", + ":triton_nvidia_gpu_attr_inc_gen", + ":triton_nvidia_gpu_dialect_inc_gen", + ":triton_nvidia_gpu_ops_inc_gen", + ":triton_nvidia_gpu_types_inc_gen", + ":triton_ops_inc_gen", + ":triton_types_inc_gen", + ":triton_type_interfaces_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@triton//third_party/nvidia:NVGPUDialect", + # The following is added to make Utility compile + ":TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@triton//third_party/f2reduce", + ], +) + +cc_library( + name = "TritonTransforms", + srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]), + hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]), + copts = _no_unused_variable, + deps = [ + ":TritonDialects", + ":triton_combine_inc_gen", + ":triton_transforms_inc_gen", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + alwayslink = True, # TritonDialect uses getCanonicalizationPatterns(). +) + +cc_library( + name = "TritonGPUTransforms", + srcs = glob( + [ + "lib/Dialect/TritonGPU/Transforms/*.cpp", + "lib/Dialect/TritonGPU/Transforms/*.h", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.cpp", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.h", + ], + exclude = ["lib/Dialect/TritonGPU/Transforms/Utility.cpp"], + ), + hdrs = glob( + [ + "include/triton/Dialect/TritonGPU/Transforms/*.h", + ], + exclude = ["include/triton/Dialect/TritonGPU/Transforms/Utility.h"], + ) + [ + "include/triton/Tools/Sys/GetEnv.hpp", + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":triton_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonGPUToLLVM", + srcs = glob([ + "lib/Conversion/TritonGPUToLLVM/*.h", + "lib/Conversion/TritonGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/triton/Tools/Sys/*.hpp", + "include/triton/Conversion/TritonGPUToLLVM/*.h", + ]), + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + ":triton_gpu_attr_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonNvidiaGPUTransforms", + srcs = glob([ + "lib/Dialect/TritonNvidiaGPU/Transforms/*.cpp", + ]), + hdrs = glob([ + "include/triton/Dialect/TritonNvidiaGPU/Transforms/*.h", + ]), + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-ctad-maybe-unsupported", + "-Wno-logical-op-parentheses", + "-Wno-non-virtual-dtor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonToTritonGPU", + srcs = glob([ + "lib/Conversion/TritonToTritonGPU/*.h", + "lib/Conversion/TritonToTritonGPU/*.cpp", + ]), + hdrs = glob(["include/triton/Conversion/TritonToTritonGPU/*.h"]), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonLLVMIR", + srcs = glob([ + "lib/Target/LLVMIR/*.cpp", + "lib/Target/LLVMIR/*.h", + ]), + hdrs = glob(["include/triton/Target/LLVMIR/*.h"]), + copts = _no_unused_variable, + deps = [ + ":TritonTransforms", + ":triton_target_llvmir_passes_inc_gen", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BinaryFormat", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLToLLVMIRTranslation", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + # copybara:uncomment "//third_party/py/triton/google:find_cuda", + ], +) + +cc_library( + name = "TritonPTX", + srcs = glob([ + "lib/Target/PTX/*.cpp", + ]), + hdrs = glob(["include/triton/Target/PTX/*.h"]), + deps = ["@llvm-project//llvm:Support"], +) + +cc_library( + name = "TritonHSACO", + srcs = glob([ + "lib/Target/HSACO/*.cpp", + ]), + hdrs = glob(["include/triton/Target/HSACO/*.h"]), + deps = [ + ":TritonLLVMIR", + ":TritonTools", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:ExecutionEngine", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Scalar", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:TransformUtils", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + ], +) + +cc_library( + name = "TritonTools", + hdrs = ["include/triton/Tools/Sys/GetEnv.hpp"], +) + +cc_library( + name = "AllPassesAndDialects", + srcs = [ + "include/triton/Conversion/TritonToTritonGPU/Passes.h", + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h", + ], + hdrs = ["bin/RegisterTritonDialects.h"], + includes = ["."], # because it includes third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h + deps = [ + ":TritonDialects", + ":TritonGPUToLLVM", + ":TritonGPUTransforms", + ":TritonLLVMIR", + ":TritonNvidiaGPUTransforms", + ":TritonToTritonGPU", + ":TritonTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//mlir:AllPassesAndDialects", + "@triton//test:TritonTestAnalysis", + "@triton//third_party/amd:TritonAMDGPUToLLVM", + "@triton//third_party/amd:TritonAMDGPUTransforms", + "@triton//third_party/nvidia:NVGPUDialect", + "@triton//third_party/nvidia:NVGPUToLLVM", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +cc_binary( + name = "triton-opt", + srcs = [ + "bin/triton-opt.cpp", + ], + deps = [ + ":AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", + ], +) + +cc_binary( + name = "triton-llvm-opt", + srcs = [ + "bin/triton-llvm-opt.cpp", + "lib/Target/LLVMIR/LLVMPasses.h", + ], + deps = [ + ":TritonLLVMIR", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Option", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + ], +) + +# See go/triton-debug for usage. +cc_binary( + name = "triton-reduce", + srcs = ["bin/triton-reduce.cpp"], + deps = [ + ":AllPassesAndDialects", + "@llvm-project//mlir:MlirReduceLib", + ], +) + +cc_binary( + name = "triton-tensor-layout", + srcs = ["bin/triton-tensor-layout.cpp"], + deps = [ + ":AllPassesAndDialects", + ":TritonDialects", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + ], +) + +filegroup( + name = "metadata-file", + srcs = ["METADATA"], +) diff --git a/include/triton/Analysis/Alias.h b/include/triton/Analysis/Alias.h index a06df5ae2106..f92d2200acba 100644 --- a/include/triton/Analysis/Alias.h +++ b/include/triton/Analysis/Alias.h @@ -85,10 +85,9 @@ class SharedMemoryAliasAnalysis } /// Computes if the alloc set of the results are changed. - void - visitOperation(Operation *op, - ArrayRef *> operands, - ArrayRef *> results) override; + LogicalResult visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) override; }; } // namespace mlir diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td index 6ceb4bc47665..4c709cd4420b 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypes.td +++ b/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -15,7 +15,7 @@ class TritonTypeDef traits = []> } // Floating-point Type -def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; def TT_FloatTensor : RankedTensorOf<[TT_Float]>; def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; diff --git a/include/triton/Dialect/Triton/IR/Utility.h b/include/triton/Dialect/Triton/IR/Utility.h index 0ef59714733d..1ff63697ec0d 100644 --- a/include/triton/Dialect/Triton/IR/Utility.h +++ b/include/triton/Dialect/Triton/IR/Utility.h @@ -31,7 +31,11 @@ template Int ceil(Int m, Int n) { return (m + n - 1) / n; } /// Get the highest power of 2 divisor of an integer. template T highestPowOf2Divisor(T n) { - if (n == 0) { + // When n is 0 or min, return the highest power of 2. The min case is handled + // separately to avoid underflow when T is a signed integer. Technically + // in that case the correct divisor is -n, but this value is outside the + // range of possible values, so we take the next best alternative. + if (n == 0 || n == std::numeric_limits::min()) { return (static_cast(1) << (sizeof(T) * 8 - 2)); } return (n & (~(n - 1))); diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp index 52082ddf7021..1d20ab780cef 100644 --- a/lib/Analysis/Alias.cpp +++ b/lib/Analysis/Alias.cpp @@ -21,7 +21,7 @@ AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) { return ret; } -void SharedMemoryAliasAnalysis::visitOperation( +LogicalResult SharedMemoryAliasAnalysis::visitOperation( Operation *op, ArrayRef *> operands, ArrayRef *> results) { AliasInfo aliasInfo; @@ -31,7 +31,7 @@ void SharedMemoryAliasAnalysis::visitOperation( if (auto memdescTy = dyn_cast(result.getType())) { if (!isa_and_nonnull( memdescTy.getMemorySpace())) - return; + return mlir::success(); } // Only LocalAllocOp creates a new buffer. @@ -49,11 +49,13 @@ void SharedMemoryAliasAnalysis::visitOperation( } if (pessimistic) { - return setAllToEntryStates(results); + setAllToEntryStates(results); + return mlir::success(); } // Join all lattice elements for (auto *result : results) propagateIfChanged(result, result->join(aliasInfo)); + return mlir::success(); } AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) { diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 1f3368d0d481..a63792d980b6 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -195,9 +195,9 @@ class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis< dataflow::Lattice>::getLatticeElement; using FuncAxisInfoMapT = DenseMap; - void visitOperation(Operation *op, - ArrayRef *> operands, - ArrayRef *> results) override; + LogicalResult visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) override; void visitForOpInductionVar(scf::ForOp op, ArrayRef *> argLattices); @@ -1039,7 +1039,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) visitors.append(); } -void AxisInfoAnalysis::visitOperation( +LogicalResult AxisInfoAnalysis::visitOperation( Operation *op, ArrayRef *> operands, ArrayRef *> results) { // TODO: For sure not the right way to do this @@ -1048,8 +1048,10 @@ void AxisInfoAnalysis::visitOperation( if (op->getValue().getRank() == 0) setToEntryState((dataflow::Lattice *)op); AxisInfo curr = visitors.apply(op, operands); - if (curr.getRank() == 0) - return setAllToEntryStates(results); + if (curr.getRank() == 0) { + setAllToEntryStates(results); + return mlir::success(); + } // override with hint auto newContiguity = curr.getContiguity(); auto newDivisibility = curr.getDivisibility(); @@ -1071,6 +1073,7 @@ void AxisInfoAnalysis::visitOperation( // join all lattice elements for (auto *result : results) propagateIfChanged(result, result->join(curr)); + return mlir::success(); } void AxisInfoAnalysis::visitForOpInductionVar( diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 933f062d8191..0d3d9fd35b81 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -425,6 +425,7 @@ bool supportMFMATypes(Type a, Type b) { if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth()) return false; + auto F8E4M3FN = TypeID::get(); auto F8E5M2 = TypeID::get(); auto F8E4M3FNUZ = TypeID::get(); auto F8E5M2FNUZ = TypeID::get(); @@ -436,6 +437,7 @@ bool supportMFMATypes(Type a, Type b) { {F32, F32}, {F16, F16}, {BF16, BF16}, + {F8E4M3FN, F8E4M3FN}, {F8E5M2, F8E5M2}, {F8E4M3FNUZ, F8E4M3FNUZ}, {F8E4M3FNUZ, F8E5M2FNUZ}, @@ -495,14 +497,14 @@ bool supportMMA(triton::DotOp op, int version) { return false; if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && retShapePerCTA[rank - 1] % 8 == 0 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() || + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() || aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32()))) { return false; } // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. if (op.getMaxNumImpreciseAcc() < 32 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) && + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) && cast(op.getType()).getElementType().isF32()) { return false; } diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 0287207be51a..7aca67d97e90 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -40,7 +40,8 @@ SmallVector reorderValues(const SmallVector &values, Type inType, auto ouEltTy = ouTensorTy.getElementType(); if (inBitWidth == ouBitWidth) return values; - if (inBitWidth == 16 && ouBitWidth == 32) { + if ((inBitWidth == 16 && ouBitWidth == 32) || + (inBitWidth == 32 && ouBitWidth == 16)) { SmallVector ret; for (unsigned i = 0; i < values.size(); i += 8) { ret.push_back(values[i]); diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 53705c3b78b9..c0371cfe1d1b 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -34,6 +34,9 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional { return IntegerType::get(type.getContext(), 8); }); + addConversion([&](mlir::Float8E4M3FNType type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); addConversion([&](mlir::Float8E5M2Type type) -> std::optional { return IntegerType::get(type.getContext(), 8); }); diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index d0988be3bc95..dab9f3fdf068 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -87,8 +87,9 @@ struct ArithConstantSplatOpConversion // LLVM IR. if (type::isFloat8(elemType)) elemType = rewriter.getIntegerType(8); - auto constOp = rewriter.create(loc, elemType, val); auto typeConverter = getTypeConverter(); + auto constOp = rewriter.create( + loc, typeConverter->convertType(elemType), val); auto llStruct = SplatOpConversion::convertSplatLikeOp( elemType, op.getType(), constOp, typeConverter, rewriter, loc); rewriter.replaceOp(op, llStruct); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d5b5d459a910..6620e92d3c9b 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2721,6 +2721,11 @@ struct CanonicalizeConvertFromAlloc auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); + // LocalAllocOp lowering doesn't support going from DotOperandEncoding + // to SharedEncoding, so we want to keep this layout conversion. + if (mlir::isa( + convert.getSrc().getType().getEncoding())) + return failure(); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), convert.getSrc()); return mlir::success(); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 39c043695bc6..72f73682bd00 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, auto newType = MemDescType::get(argType.getShape(), argType.getElementType(), newLayout, SharedMemorySpace); rewriter.setInsertionPointAfterValue(arg); + + // LocalAllocOp lowering doesn't support going from DotOperandEncoding + // to SharedEncoding. + if (auto dotOpEnc = mlir::dyn_cast( + argType.getEncoding())) { + // Create a layout conversion from DotOperandEncoding to BlockedEncoding + // then pass it to the LocalAllocOp. + auto newArgType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), dotOpEnc.getParent()); + auto dotOperandToBlockedCvt = + rewriter.create(arg.getLoc(), newArgType, arg); + return rewriter.create(arg.getLoc(), newType, + dotOperandToBlockedCvt); + } + return rewriter.create(arg.getLoc(), newType, arg); } @@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern { mutable llvm::DenseMap dotOpInstNs; static bool bwdFilter(Operation *op) { + // Dot operand layout assignment to Predicates are not currently supported + // during lowering from TritonGPU to LLVM in Triton for MMA cases. This + // condition limits visibility of the original bit-width so that predicate + // are not considered, hence, kwidth can never be = 32. + if (isa(op)) { + Type srcType = getElementTypeOrSelf(op->getOperand(0)); + if (srcType.isInteger(1)) + return false; + } return op->getNumOperands() == 1 && (isa(op) || isPureUnaryInlineAsm(op) || @@ -357,7 +381,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { NvidiaMmaEncodingAttr mmaLayout = dyn_cast(D.getType().getEncoding()); if (mmaLayout) { - bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ(); + bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN(); // promote operands for sm < 89 since fp8 mma is not natively supported // promote operands for sm >= 90 when mma is not v3 if (!isNativeFP8 || diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 6d8279795209..e6e0ec8d7cef 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -111,7 +111,8 @@ class HoistLayoutConversion : public OpRewritePattern { PatternRewriter &rewriter) const override { // Only consider conversions to dot operand. auto cvtTy = cast(cvt.getType()); - if (!isa(cvtTy.getEncoding())) + auto dotOpEnc = dyn_cast(cvtTy.getEncoding()); + if (!dotOpEnc) return failure(); auto src = cvt.getSrc().getDefiningOp(); @@ -126,6 +127,12 @@ class HoistLayoutConversion : public OpRewritePattern { [](Type ty) { return isa(ty); })) return failure(); + // Quick handling to fix loading issues when computing the original + // bitwidth is unable to realize that there is a mixed-precision dot + // (hence kWidth = 1) but wants to hoist through the type conversion. + if (isa(src) && dotOpEnc.getKWidth() == 1) + return failure(); + // Only consider custom conversions or arith ops. // TODO(jlebar): Is this too restrictive? if (!isa(src) && !isPureUnaryInlineAsm(src) && @@ -138,6 +145,14 @@ class HoistLayoutConversion : public OpRewritePattern { if (isa(src)) return failure(); + // Don't hoist through u1 -> fp casts as they aren't supported in + // ElementwiseOpToLLVM::reorderValues(). + if (isa(src)) { + Type srcType = getElementTypeOrSelf(src->getOperand(0)); + if (srcType.isInteger(1)) + return failure(); + } + // Check that the conversion is transitively dependent on a load, and all // operations between the load and the conversion are layout preserving. // diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 02994e1ac059..cd6fc806928d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, type.getMemorySpace()), v, offsetsVal); + // We need to assign kwidth to zero in the case where the parent layout is + // Blocked, otherwise the verifier emits a failure. The parent layout is + // Blocked only when Tensor Cores are disabled. + int kwidth = dyn_cast(dotEncoding) + ? 0 + : prefetchWidth / 8; auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( - builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); + builder.getContext(), opIdx, dotEncoding, kwidth); Value prefetchSlice = builder.create( v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), newSmem); @@ -187,6 +193,15 @@ LogicalResult Prefetcher::initialize() { break; if (!op->getResult(0).hasOneUse()) break; + // Similar to issues faced in HoistLayoutConversion pattern in + // OptimizeDotOperands.cpp, we can't propagate through type casts from + // predicates as they aren't supported in Triton when encoded with dot_op + // layout. + if (isa(op)) { + Type srcType = getElementTypeOrSelf(op->getOperand(0)); + if (srcType.isInteger(1)) + break; + } rets.push_back(op->getOperand(0)); if (auto cvt = dyn_cast(op)) { foundConvertFromShared = true; diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index b6d855a05388..984b8e6eb058 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -138,11 +138,6 @@ class LayoutRematerialization { ConvertLayoutOp convertOp); private: - void updateRematMapping(SmallVector> &values); - // Existing tuples of (value, layout) that needs to be updated when recreating - // scf ops. This prevents keeping track of Values that have been delete when - // rewriting slices. - DenseMap mappedValues; // map of the values remat based on encoding. DenseMap, Value> rematMapping; // DenseMap, Operation*> @@ -154,7 +149,6 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding, Value newV) { LDBG("addRematValue " << old << " encoding " << encoding << " " << newV); rematMapping[{old, encoding}] = newV; - mappedValues[old] = encoding; } // Remove unneeded values now that we are done with the rematMapping. @@ -813,31 +807,6 @@ bool canBeRemat(Operation *op) { return true; } -void LayoutRematerialization::updateRematMapping( - SmallVector> &values) { - for (auto [old, newV] : values) { - auto it = mappedValues.find(old); - if (it != mappedValues.end()) { - Attribute encoding = it->second; - auto rematIt = rematMapping.find({old, it->second}); - assert(rematIt != rematMapping.end()); - Value replacedValue = rematIt->second; - rematMapping.erase(rematIt); - mappedValues.erase(it); - // Loop through the replacement value to find the new version of remat - // value. This should be okay as the number of values should be small. - for (auto [before, after] : values) { - if (before == replacedValue) { - replacedValue = after; - break; - } - } - rematMapping[{newV, encoding}] = replacedValue; - mappedValues[newV] = encoding; - } - } -} - void LayoutRematerialization::rewriteSlice(SetVector &slice, DenseMap &layout, ConvertLayoutOp convertOp, @@ -850,14 +819,6 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, // for/yield to fall out of sync SetVector valuesWithExistingRemat; for (Value v : slice) { - auto layoutIt = layout.find(v); - assert(layoutIt != layout.end()); - // If we already have a remat value for this value, use it. - if (hasRematValue(v, layoutIt->second)) { - mapping.map(v, getRematValue(v, layoutIt->second)); - valuesWithExistingRemat.insert(v); - continue; - } if (v.getDefiningOp()) { opsToRewrite.insert(v.getDefiningOp()); if (auto ifOp = v.getDefiningOp()) { @@ -947,8 +908,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, if (slice.count(res)) { // Why can't we use res instead of ifOp.getResult(oldIdx)? mapping.map(ifOp.getResult(oldIdx), newIfOp.getResult(newIdx)); - addRematValue(ifOp.getResult(oldIdx), layout[res], - newIfOp.getResult(newIdx)); + addRematValue(res, layout[res], newIfOp.getResult(newIdx)); ++newIdx; } ++oldIdx; @@ -979,8 +939,6 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, auto cvt = builder.create(op->getLoc(), newType, newOp->getResult(0)); mapping.map(op->getResult(0), cvt.getResult()); - addRematValue(op->getResult(0), layout[op->getResult(0)], - cvt.getResult()); continue; } Operation *newOp = builder.clone(*op, mapping); @@ -992,14 +950,12 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, cast(old.getType()).getShape(), cast(old.getType()).getElementType(), it->second); newV.setType(newType); - addRematValue(old, it->second, newV); } } // Check mapping and see if there are existing convertOps on the old Argument convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getSrc())); opToDelete.insert(convertOp); - updateRematMapping(replacements); for (auto &kv : replacements) { builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 57e41e55ff4f..6f4480d6fe93 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -45,8 +45,9 @@ SmallVector mmaVersionToInstrShape(int version, SmallVector validN; // MMAv3 with larger instruction shape is preferred. - if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FNUZ() || - eltType.isF16() || eltType.isBF16() || eltType.isF32()) { + if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() || + eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() || + eltType.isF32()) { validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); diff --git a/python/BUILD b/python/BUILD new file mode 100644 index 000000000000..334dd4aec41a --- /dev/null +++ b/python/BUILD @@ -0,0 +1,77 @@ +# NOTE: Do not depend on any targets from this directory, +# but use //third_party/py/triton instead. + +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__pkg__", + "@triton//python:__subpackages__", + ], +) + +cc_library( + name = "passes", + hdrs = ["src/passes.h"], + includes = ["src"], + visibility = ["@triton//third_party:__subpackages__"], +) + +pybind_extension( + name = "libtriton", + srcs = [ + "src/interpreter.cc", + "src/ir.cc", + "src/llvm.cc", + "src/main.cc", + "src/passes.cc", + ], + copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"], + deps = [ + ":passes", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + "//:TritonHSACO", + "//:TritonLLVMIR", + "//:TritonNvidiaGPUTransforms", + "//:TritonPTX", + "//:TritonToTritonGPU", + "//:TritonTools", + "//:TritonTransforms", + "@triton//third_party/nvidia:triton_nvidia", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["triton/**/*.py"], + ), +) diff --git a/python/src/ir.cc b/python/src/ir.cc index f018d3d7bb52..7575cf87a5de 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -747,10 +747,8 @@ void init_triton_ir(py::module &&m) { return self.getBuilder().getI64Type(); }) .def("get_fp8e4nv_ty", - // TODO: fp8e4nv is using Float8E4M3FNUZType, which - // does not seem right. It should use FloatE4M3FNType [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getType(); + return self.getBuilder().getType(); }) .def("get_fp8e4b8_ty", [](TritonOpBuilder &self) -> Type { diff --git a/python/test/regression/BUILD b/python/test/regression/BUILD new file mode 100644 index 000000000000..a88f4eeae1f8 --- /dev/null +++ b/python/test/regression/BUILD @@ -0,0 +1,26 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "tests", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + include = ["test_*.py"], + exclude = [ + "test_performance.py", #TODO(b/321005767): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/regression/conftest.py b/python/test/regression/conftest.py new file mode 100644 index 000000000000..7a02d322b49f --- /dev/null +++ b/python/test/regression/conftest.py @@ -0,0 +1,12 @@ +# content of conftest.py + +import pytest + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default='cuda') + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") diff --git a/python/test/unit/BUILD b/python/test/unit/BUILD new file mode 100644 index 000000000000..6997c23e0635 --- /dev/null +++ b/python/test/unit/BUILD @@ -0,0 +1,179 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests", "pytest_test") + +package( + default_applicable_licenses = ["//:license"], +) + +_requires_gpu_sm80 = [ + "config-cuda-only", + "requires-gpu-sm80", +] + +_requires_config_cuda = select( + {"@local_config_cuda//cuda:using_clang_allow_exec": []}, + no_match_error = "Requires --config=cuda", +) + +EXCLUDE_TESTS = [ + "language/test_pipeliner_h100.py", # TODO(b/362458006): fix failing test + "language/test_reproducer.py", # this is not an actual test, but a tool for running reproducers + "language/test_subprocess.py", # TODO(b/320224484): fix failing test + "runtime/test_launch.py", # TODO(b/320226169): fix failing tests + "tools/test_aot.py", # TODO(b/320224484): fix failing test + "tools/test_disasm.py", # TODO(b/320224484): fix failing test + "hopper/test_experimental_tma_h100.py", # TODO(b/362458006): fix failing test + "hopper/test_persistent_warp_specialized_gemm.py", # TODO (b/342348738): fix failing test + "hopper/test_tma_descriptor.py", # TODO (b/358060133): fix failing test + "runtime/test_cublas.py", # TODO(b/346755023): fix failing test +] + +# Runs all python tests on H100 +pytest_multi_tests( + name = "hopper", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + "language/test_core.py", + ], + name_suffix = "_h100", + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm90", + ], + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["**/test_*.py"], + exclude = EXCLUDE_TESTS + ["language/test_core.py"], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Shard test_core more, as it is otherwise very slow to run. +pytest_test( + name = "hopper/language/test_core_h100", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + ], + shard_count = 40, + tags = [ + "config-cuda-only", + "requires-gpu-sm90", + ], + target_compatible_with = _requires_config_cuda, + tests = ["language/test_core.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "language", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + "language/test_core.py", + ], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["language/**/test_*.py"], + exclude = EXCLUDE_TESTS + ["language/test_core.py"], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Shard test_core more, as it is otherwise very slow to run. +pytest_test( + name = "language/test_core", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + ], + shard_count = 40, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = ["language/test_core.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "instrumentation", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["instrumentation/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "runtime", + srcs = ["conftest.py"], + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["runtime/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "tools", + size = "large", + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["tools/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "unit", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 9e5ff8a2ce37..60ad26871a53 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2144,6 +2144,8 @@ def kernel(X, Z, BLOCK: tl.constexpr): reduce_bool = [(op, 'bool', shape, axis, False) for op in ['xor_sum'] for shape in reduce2d_shapes for axis in [0, 1]] +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] >= 9, + reason='Reduction test produces wrong results on H100, b/342347027') @pytest.mark.interpreter @pytest.mark.parametrize( "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + @@ -3637,6 +3639,25 @@ def _kernel(out): kernel[(1, )](out) assert torch.all(out == out_ref) +@pytest.mark.interpreter +def test_dot_on_broadcast(device): + @triton.jit + def _kernel(a, b, out): + a_offsets = tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :] + lhs = tl.load(a + a_offsets, mask=a_offsets < 32 * 64) + rhs = tl.load(b) + rhs_bc = tl.broadcast_to(rhs, [32, 32]) + c = tl.dot(lhs, rhs_bc) + out_ptr = out + tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(out_ptr, c) + + a = torch.ones((64, 32), dtype=getattr(torch, 'float32'), device=device) + b = torch.tensor([1.0], dtype=getattr(torch, 'float32'), device=device) + out_ref = torch.matmul(a, torch.broadcast_to(b, (32, 32))) + out = torch.zeros((64, 32), dtype=getattr(torch, 'float32'), device=device) + _kernel[(1, )](a, b, out, num_stages=1, num_warps=4) + assert torch.all(out == out_ref) + # --------------- # test arange diff --git a/python/triton/_C/include b/python/triton/_C/include index b85a409837d1..8a5dba6c4b56 120000 --- a/python/triton/_C/include +++ b/python/triton/_C/include @@ -1 +1 @@ -../../../include/ \ No newline at end of file +../../../include \ No newline at end of file diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index 92ba144ba97b..f9bab523bf6c 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -46,5 +46,8 @@ def _discover_backends(): _find_concrete_subclasses(driver, DriverBase)) return backends - -backends = _discover_backends() +from triton.backends.nvidia.driver import CudaDriver +from triton.backends.nvidia.compiler import CUDABackend +backends = { + "nvidia": Backend(CUDABackend, CudaDriver) +} diff --git a/test/BUILD b/test/BUILD new file mode 100644 index 000000000000..ca98d32da7b0 --- /dev/null +++ b/test/BUILD @@ -0,0 +1,63 @@ +# copybara:uncomment_begin +# load("//third_party/llvm/build_defs:lit.bzl", "glob_lit_tests") +# load("//tools/build_defs/build_test:build_test.bzl", "build_test") +# +# package( +# default_applicable_licenses = ["//:license"], +# default_compatible_with = ["//buildenv/target:gce"], +# default_visibility = ["//:__subpackages__"], +# ) +# +# glob_lit_tests( +# name = "all_tests", +# data = [ +# "@llvm-project//llvm:FileCheck", +# "//:triton-llvm-opt", +# "//:triton-opt", +# "//:triton-tensor-layout", +# ], +# driver = "@llvm-project//mlir:run_lit.sh", +# exclude = [ +# "Conversion/amd/dedup-by-constancy.mlir", # AMD-specific, broken +# "TritonGPU/combine.mlir", # TODO: b/338346821 - needs cse or something. +# "TritonGPU/dot-operands.mlir", # TODO: b/283035396 - broken by cl536931041.patch +# "TritonGPU/optimize_epilogue.mlir", # TODO: b/346283526 - AMD-specific, triggering UBSAN +# ], +# test_file_exts = [ +# "mlir", +# "ll", +# ], +# ) +# +# build_test( +# name = "build_test", +# allow_empty_target = False, +# targets = [ +# "//:TritonAnalysis", +# "//:TritonDialects", +# "//:TritonGPUToLLVM", +# "//:TritonGPUTransforms", +# "//:TritonLLVMIR", +# "//:TritonPTX", +# "//:TritonToTritonGPU", +# "//:TritonTools", +# "//:TritonTransforms", +# "//:triton-opt", +# ], +# ) +# copybara:uncomment_end + +cc_library( + name = "TritonTestAnalysis", + srcs = glob(["lib/Analysis/*.cpp"]), + deps = [ + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + ], +) diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 7ecee2eba11b..3c16ea0260ef 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -129,24 +129,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: test_fp8_to_f16_conversion tt.func @test_fp8_to_f16_conversion( - %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FNUZ, #blocked>, + %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FN, #blocked>, %in2: tensor<128xf16, #blocked>, %in3: tensor<128xf32, #blocked>) { // CHECK-COUNT-2: cvt.rn.f16x2.e5m2x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> %out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked> // CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> - %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked> + %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FN, #blocked> -> tensor<128xf16, #blocked> // CHECK-COUNT-2: mul.rn.bf16x2 %out2 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xbf16, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> %out3 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> - %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> + %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FN, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> %out5 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> - %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> + %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FN, #blocked> tt.return } } diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 923324616940..50f39c1bf970 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -142,3 +142,21 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : tt.return } } + +// ----- + +// CHECK-DAG: #[[$BLOCKED:.*]] = #triton_gpu.blocked +// CHECK-DAG: #mma = #triton_gpu.nvidia_mma<{versionMajor = 3 +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @local_alloc_dot_operand(%in0: tensor<64x256xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> {tt.divisibility = 16 : i32}, %in1: f32, %in2: tensor<64x32xf32, #blocked>) -> (tensor<64x32xf32, #blocked>) { + // CHECK-LABEL: local_alloc_dot_operand + %splat_in1 = tt.splat %in1 : f32 -> tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK: %[[LHS_LOCAL_ALLOC:.*]] = triton_gpu.local_alloc + // CHECK: %[[RHS_CVT:.*]] = triton_gpu.convert_layout {{.*}} #triton_gpu.dot_op<{{.*}}> -> {{.*}} #[[$BLOCKED]] + // CHECK: %[[RHS_LOCAL_ALLOC:.*]] = triton_gpu.local_alloc %[[RHS_CVT]] + // CHECK: triton_nvidia_gpu.warp_group_dot %[[LHS_LOCAL_ALLOC]], %[[RHS_LOCAL_ALLOC]] + %res = tt.dot %in0, %splat_in1, %in2, inputPrecision = tf32 : tensor<64x256xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x32xf32, #blocked> + tt.return %res : tensor<64x32xf32, #blocked> + } +} diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index ecee359cb19a..f015f9651065 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -133,3 +133,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.return %2 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> } } // end module + +// ----- + +// CHECK: #[[$BLOCKED:.*]] = #triton_gpu.blocked +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @cvt_from_dot_op_into_local_allow_not_canonicalized(%in: tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> !tt.memdesc<256x32xf32, #shared1> { + // CHECK-LABEL: cvt_from_dot_op_into_local_allow_not_canonicalized + %cvt_in = triton_gpu.convert_layout %in : tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<256x32xf32, #blocked> + %alloc = triton_gpu.local_alloc %cvt_in : (tensor<256x32xf32, #blocked>) -> !tt.memdesc<256x32xf32, #shared1> + // CHECK: %[[ALLOC:.*]] = triton_gpu.local_alloc {{.*}} (tensor<{{.*}}, #[[$BLOCKED]]{{.*}}>) -> + tt.return %alloc : !tt.memdesc<256x32xf32, #shared1> + } +} // end module + diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index cf93c37b840d..16550b98259c 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -131,3 +131,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war tt.return } } + +// ----- + + +module { + tt.func @int_min_does_not_underflow_in_analysis() -> i64 { + %int_min = arith.constant -9223372036854775808 : i64 + tt.return %int_min : i64 + } +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 1e47e1449f0f..04281fe64d14 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -47,11 +47,14 @@ tt.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { %5 = triton_gpu.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> %6 = arith.addi %3, %5 : tensor<1024xi32, #layout1> tt.return %6: tensor<1024xi32, #layout1> - // CHECK: %[[A:.+]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]> - // CHECK: %[[B:.+]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]> - // CHECK: %[[C:.+]] = arith.muli %[[A]], %[[B]] : tensor<1024xi32, [[$target_layout]]> - // CHECK: %[[D:.+]] = arith.addi %[[C]], %[[C]] : tensor<1024xi32, [[$target_layout]]> - // CHECK: tt.return %[[D]] : tensor<1024xi32, [[$target_layout]]> + // CHECK: %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]> + // CHECK: %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]> + // CHECK: %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]> + // CHECK: %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]> + // CHECK: %4 = arith.muli %0, %2 : tensor<1024xi32, [[$target_layout]]> + // CHECK: %5 = arith.muli %1, %3 : tensor<1024xi32, [[$target_layout]]> + // CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[$target_layout]]> + // CHECK: tt.return %6 : tensor<1024xi32, [[$target_layout]]> } // Always rematerialize single value loads diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index 3a9e80c04b0a..848e2b88f342 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -173,3 +173,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +// CHECK: tt.func @matmul_loop_on_blocked_layout +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @matmul_loop_on_blocked_layout(%arg_lhs: !tt.memdesc<16x512xf32, #shared, mutable>, %arg_rhs: !tt.memdesc<512x32xf32, #shared, mutable>, %arg_init: tensor<16x32xf32, #blocked>, %itr_val : i32) -> (tensor<16x32xf32, #blocked>) { + %loop:3 = scf.for %itr = %itr_val to %itr_val step %itr_val iter_args(%init = %arg_init, %lhs = %arg_lhs, %rhs = %arg_rhs) -> (tensor<16x32xf32, #blocked>, !tt.memdesc<16x512xf32, #shared, mutable>, !tt.memdesc<512x32xf32, #shared, mutable>) : i32 { + %lhs_ll = triton_gpu.local_load %lhs : !tt.memdesc<16x512xf32, #shared, mutable> -> tensor<16x512xf32, #blocked> + %lhs_ll_cvt = triton_gpu.convert_layout %lhs_ll : tensor<16x512xf32, #blocked> -> tensor<16x512xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %rhs_ll = triton_gpu.local_load %rhs : !tt.memdesc<512x32xf32, #shared, mutable> -> tensor<512x32xf32, #blocked> + %rhs_ll_cvt = triton_gpu.convert_layout %rhs_ll : tensor<512x32xf32, #blocked> -> tensor<512x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %res = tt.dot %lhs_ll_cvt, %rhs_ll_cvt, %init : tensor<16x512xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<512x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x32xf32, #blocked> + scf.yield %res, %lhs, %rhs : tensor<16x32xf32, #blocked>, !tt.memdesc<16x512xf32, #shared, mutable>, !tt.memdesc<512x32xf32, #shared, mutable> + } + tt.return %loop#0 : tensor<16x32xf32, #blocked> + } +} // end module diff --git a/third_party/amd/BUILD b/third_party/amd/BUILD new file mode 100644 index 000000000000..d32c944f37ed --- /dev/null +++ b/third_party/amd/BUILD @@ -0,0 +1,142 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:gce"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/service/gpu/fusions/triton:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + "//:compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +cc_library( + name = "TritonAMDGPUTransforms", + srcs = glob([ + "lib/TritonAMDGPUTransforms/**/*.h", + "lib/TritonAMDGPUTransforms/**/*.cpp", + ]) + ["include/TritonAMDGPUToLLVM/TargetUtils.h"], + hdrs = glob([ + "include/TritonAMDGPUTransforms/**/*.h", + ]), + copts = _no_unused_variable, + includes = [ + "include", + "lib/TritonAMDGPUTransforms", + ], + deps = [ + ":triton_conversion_amdgpu_transforms_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ConvertToLLVM", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + ], +) + +cc_library( + name = "TritonAMDGPUToLLVM", + srcs = glob([ + "lib/TritonAMDGPUToLLVM/**/*.h", + "lib/TritonAMDGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/TritonAMDGPUToLLVM/**/*.h", + ]), + copts = _no_unused_variable, + includes = [ + "include", + "lib/TritonAMDGPUToLLVM", + ], + deps = [ + ":TritonAMDGPUTransforms", + ":triton_conversion_amdgpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:ConvertToLLVM", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUToROCDLTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + ], +) + +td_library( + name = "td_files", + srcs = glob(["include/**/*.td"]), + includes = ["include"], + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_to_llvm_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPUToLLVM", + ], + "include/TritonAMDGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonAMDGPUToLLVM/Passes.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_transforms_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPU", + ], + "include/TritonAMDGPUTransforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonAMDGPUTransforms/Passes.td", + deps = [":td_files"], +) diff --git a/third_party/f2reduce/BUILD b/third_party/f2reduce/BUILD new file mode 100644 index 000000000000..a1a4f3a7ca02 --- /dev/null +++ b/third_party/f2reduce/BUILD @@ -0,0 +1,31 @@ +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:gce"], + # default_visibility = [ + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +# copybara:uncomment_begin +# license( +# name = "license", +# license_text = "LICENCE.txt", +# ) +# +# licenses(["notice"]) +# +# exports_files(["LICENCE.txt"]) +# copybara:uncomment_end + +cc_library( + name = "f2reduce", + srcs = ["f2reduce.cpp"], + hdrs = ["f2reduce.h"], + # copybara:uncomment strip_include_prefix = "/third_party/triton", +) diff --git a/third_party/nvidia/BUILD b/third_party/nvidia/BUILD new file mode 100644 index 000000000000..6af127c11ec6 --- /dev/null +++ b/third_party/nvidia/BUILD @@ -0,0 +1,306 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@pybind11_bazel//:build_defs.bzl", "pybind_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:gce"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/service/gpu:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +pybind_library( + name = "cublas_headers", + hdrs = glob([ + "include/*.h", + ]), + deps = ["@local_config_cuda//cuda:cuda_headers"], +) + +pybind_library( + name = "triton_nvidia", + srcs = [ + "triton_nvidia.cc", + ], + compatible_with = [], + # copybara:uncomment_begin + # visibility = [ + # "@triton//python:__subpackages__", + # ], + # copybara:uncomment_end + deps = [ + ":NVGPUDialect", + ":NVGPUToLLVM", + ":TritonNVIDIAGPUToLLVM", + ":cublas_headers", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonNvidiaGPUTransforms", + "@triton//python:passes", + ], +) + +cc_library( + name = "NVGPUToLLVM", + srcs = glob([ + "lib/NVGPUToLLVM/*.cpp", + ]), + hdrs = glob([ + "include/NVGPUToLLVM/*.h", + ]), + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:gce"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + ], + deps = [ + ":NVGPUDialect", + ":TritonNVIDIAGPUToLLVM", + ":triton_conversion_nvgpu_to_llvm_passes_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + ], +) + +cc_library( + name = "TritonNVIDIAGPUToLLVM", + srcs = glob([ + "lib/TritonNVIDIAGPUToLLVM/*.h", + "lib/TritonNVIDIAGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/TritonNVIDIAGPUToLLVM/*.h", + ]) + [ + "lib/TritonNVIDIAGPUToLLVM/Utility.h", + ], + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:gce"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + "lib/TritonNVIDIAGPUToLLVM", + ], + deps = [ + ":NVGPUDialect", + ":triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:triton_gpu_attr_inc_gen", + ], +) + +gentbl_cc_library( + name = "triton_conversion_nvgpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:gce"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=NVGPUToLLVM", + ], + "include/NVGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/NVGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:gce"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNVIDIAGPUToLLVM", + ], + "include/TritonNVIDIAGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonNVIDIAGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +td_library( + name = "td_files", + srcs = glob(["include/Dialect/NVGPU/IR/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "nvgpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-llvmir-conversions"], + "include/Dialect/NVGPU/IR/OpsConversions.inc", + ), + ( + ["--gen-op-decls"], + "include/Dialect/NVGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/Dialect/NVGPU/IR/Ops.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/Dialect/NVGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/Dialect/NVGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "nvgpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "nvgpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/Dialect/NVGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/Dialect/NVGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUDialect.td", + deps = ["td_files"], +) + +cc_library( + name = "NVGPUDialect", + srcs = glob([ + "lib/Dialect/NVGPU/IR/*.cpp", + ]), + hdrs = glob([ + "include/Dialect/NVGPU/IR/*.h", + ]), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = [ + "..", # because nvidia/include/Dialect/NVGPU/IR/Dialect.h.inc + "../..", # because third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h + "include", + ], + deps = [ + ":nvgpu_attr_inc_gen", + ":nvgpu_dialect_inc_gen", + ":nvgpu_ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + # The following is added to make Utility compile + "//:TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/third_party/nvidia/backend/BUILD b/third_party/nvidia/backend/BUILD new file mode 100644 index 000000000000..a5b34aa5c29b --- /dev/null +++ b/third_party/nvidia/backend/BUILD @@ -0,0 +1,30 @@ +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__subpackages__", + ], +) + +pybind_extension( + name = "cuda_utils", + srcs = ["cuda_utils.cc"], + visibility = [ + "//learning/deepmind/jax/triton/ops:__subpackages__", + "//third_party/py/triton:__subpackages__", + ], + deps = [ + "//platforms/gpus/cuda/dynamic_libcuda", + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cuda_runtime", + "@llvm-project//llvm:Support", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["**/*.py"], + ), +) diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index f476f5bffd73..ee36757bef2e 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -154,6 +154,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { typedef CUresult (*cuOccupancyMaxActiveClusters_t)( int *numClusters, CUfunction func, const CUlaunchConfig *config); +#if CUDA_VERSION >= 12000 typedef CUresult (*cuTensorMapEncodeTiled_t)( CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, @@ -161,6 +162,7 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)( const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); +#endif #define defineGetFunctionHandle(name, symbolName) \ static symbolName##_t name() { \ @@ -187,8 +189,10 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)( defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, cuOccupancyMaxActiveClusters); +#if CUDA_VERSION >= 12000 defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, cuTensorMapEncodeTiled); +#endif static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, @@ -280,6 +284,9 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { // Simple helper to experiment creating TMA descriptors on the host. // This is a useful to test TMA operations independently. static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { +#if CUDA_VERSION < 12000 + return NULL; +#else unsigned long long global_address; uint64_t dim; uint32_t tensorDim; @@ -318,11 +325,15 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); return Py_None; +#endif } // Simple helper to experiment creating TMA descriptors on the host. // This is a useful to test TMA operations independently. static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { +#if CUDA_VERSION < 12000 + return NULL; +#else unsigned long long global_address; uint64_t dims[2]; uint32_t tensorDims[2]; @@ -380,6 +391,7 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); return Py_None; +#endif } static PyMethodDef ModuleMethods[] = { diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index b075ca31a407..b474996f11c2 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -291,10 +291,36 @@ class WGMMAWaitGroupOpPattern : public OpRewritePattern { Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const { auto outputStructType = cast(op.getType()); - uint32_t numOutputRegs = outputStructType.getBody().size(); - std::string output = - outputStructType.getBody().front().isF32() ? "=f" : "=r"; - return Constraints(numOutputRegs, output); + std::vector outputConstraints; + outputConstraints.reserve(outputStructType.getBody().size()); + for (mlir::Type type : outputStructType.getBody()) { + if (type.isF32()) { + outputConstraints.push_back("=f"); + continue; + } else if (type.isF64()) { + outputConstraints.push_back("=d"); + continue; + } + unsigned bitwidth = isa(type) ? + 64 : type.getIntOrFloatBitWidth(); + switch (bitwidth) { + case 1: + outputConstraints.push_back("=b"); + break; + case 16: + outputConstraints.push_back("=h"); + break; + case 32: + outputConstraints.push_back("=r"); + break; + case 64: + outputConstraints.push_back("=l"); + break; + default: + assert(false && "unsupported bitwidth"); + } + } + return outputConstraints; } OperandsAndConstraints diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index d1086c189d33..86aa5c3c5ee6 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -496,6 +496,8 @@ Type getSharedMemTy(Type argType) { return type::f32Ty(ctx); else if (argType.getIntOrFloatBitWidth() == 8) return type::i8Ty(ctx); + else if (argType.getIntOrFloatBitWidth() == 16) + return type::i16Ty(ctx); else llvm::report_fatal_error("mma16816 data type not supported"); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 2d16dc19b3b3..af897ef546dd 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -81,9 +81,9 @@ enum class TensorCoreType : uint8_t { FP32_TF32_TF32_FP32, FP16_FP16_FP16_FP16, FP32_FP8E5M2_FP8E5M2_FP32, - FP32_FP8E5M2_FP8E4M3FNUZ_FP32, - FP32_FP8E4M3FNUZ_FP8E5M2_FP32, - FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32, + FP32_FP8E5M2_FP8E4M3FN_FP32, + FP32_FP8E4M3FN_FP8E5M2_FP32, + FP32_FP8E4M3FN_FP8E4M3FN_FP32, // integer tensor core instr INT32_INT1_INT1_INT32, // Not implemented INT32_INT4_INT4_INT32, // Not implemented @@ -112,9 +112,9 @@ Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) { case TensorCoreType::FP16_FP16_FP16_FP16: return fp16x2Pack2Ty; case TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32: - case TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32: - case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32: - case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32: + case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32: + case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32: + case TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32: return fp32x4Ty; case TensorCoreType::INT32_INT8_INT8_INT32: return i32x4Ty; @@ -140,14 +140,14 @@ TensorCoreType getMmaType(triton::DotOp op) { bTy.getElementType().isFloat8E5M2()) return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32; if (aTy.getElementType().isFloat8E5M2() && - bTy.getElementType().isFloat8E4M3FNUZ()) - return TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32; - if (aTy.getElementType().isFloat8E4M3FNUZ() && + bTy.getElementType().isFloat8E4M3FN()) + return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32; + if (aTy.getElementType().isFloat8E4M3FN() && bTy.getElementType().isFloat8E5M2()) - return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32; - if (aTy.getElementType().isFloat8E4M3FNUZ() && - bTy.getElementType().isFloat8E4M3FNUZ()) - return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32; + return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32; + if (aTy.getElementType().isFloat8E4M3FN() && + bTy.getElementType().isFloat8E4M3FN()) + return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32; if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && op.getInputPrecision() == InputPrecision::TF32) return TensorCoreType::FP32_TF32_TF32_FP32; @@ -193,11 +193,11 @@ inline static const std::map mmaInstrPtxAmpere = { {TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32, "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32"}, - {TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32, + {TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32, "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32"}, - {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32, + {TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32, "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32"}, - {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32, + {TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32, "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32"}, }; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index baed96a29704..41e36503f593 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -58,7 +58,7 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { return triton::nvgpu::WGMMAEltType::s8; } else if (aTy.isFloat8E5M2()) { return triton::nvgpu::WGMMAEltType::e5m2; - } else if (aTy.isFloat8E4M3FNUZ()) { + } else if (aTy.isFloat8E4M3FN()) { return triton::nvgpu::WGMMAEltType::e4m3; } else { llvm::report_fatal_error("Unsupported mma operand type found"); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index 2fabb598e99d..80072edf3bdf 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -386,7 +386,7 @@ struct FpToFpOpConversion std::pair getConversionFunc(Type srcTy, Type dstTy, std::optional roundingMode) const { - auto F8E4M3TyID = TypeID::get(); + auto F8E4M3TyID = TypeID::get(); auto F8E5M2TyID = TypeID::get(); auto F16TyID = TypeID::get(); auto BF16TyID = TypeID::get(); @@ -430,7 +430,7 @@ struct FpToFpOpConversion llvm::report_fatal_error("Unsupported rounding mode for conversion."); } if (computeCapability < 89 && - (srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) { + (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) { llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " "compute capability >= 89" << "\n"; @@ -452,7 +452,7 @@ struct FpToFpOpConversion auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); - if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FNUZ()) { + if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) { assert(roundingMode.has_value() && "Rounding mode must be specified for convertsions to fp8"); @@ -489,7 +489,7 @@ struct FpToFpOpConversion bool useFP16IntermediateSrc = srcElementType.isF32() && - (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FNUZ() || + (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() || dstElementType.isFloat8E5M2())) || roundingMode.value() == RoundingMode::RTZ); bool isDstFP32 = dstElementType.isF32(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 9ee532992d01..cc700e7186f4 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -272,6 +272,12 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b"); if (other) { + if (otherIsSplatConstInt) { + for (size_t s = valueElemNBits; s < movWidth; s += valueElemNBits) { + splatVal |= splatVal << valueElemNBits; + } + } + for (size_t ii = 0; ii < nWords; ++ii) { // PTX doesn't support mov.u8, so we need to use mov.u16 PTXInstr &mov = @@ -292,8 +298,6 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, PTXInstr::Operand *opr{}; if (otherIsSplatConstInt) { - for (size_t s = 0; s < 32; s += valueElemNBits) - splatVal |= splatVal << valueElemNBits; opr = ptxBuilder.newConstantOperand(splatVal); } else opr = ptxBuilder.newOperand(v, readConstraint); diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc index 1269dcda00aa..3cccc5fb6a1c 100644 --- a/third_party/nvidia/triton_nvidia.cc +++ b/third_party/nvidia/triton_nvidia.cc @@ -1,4 +1,4 @@ -#include "Dialect/NVGPU/IR/Dialect.h" +#include "Dialect/NVGPU/IR/Dialect.h" #include "NVGPUToLLVM/NVGPUToLLVMPass.h" #include "TritonNVIDIAGPUToLLVM/Passes.h" #include "cublas_instance.h" diff --git a/third_party/proton/proton/_C/include b/third_party/proton/proton/_C/include index fe4f4a1aa9bd..4400934bdf78 120000 --- a/third_party/proton/proton/_C/include +++ b/third_party/proton/proton/_C/include @@ -1 +1 @@ -../../csrc/include/ \ No newline at end of file +../../csrc/include \ No newline at end of file diff --git a/unittest/BUILD b/unittest/BUILD new file mode 100644 index 000000000000..cb885459e19c --- /dev/null +++ b/unittest/BUILD @@ -0,0 +1,144 @@ +load("//tools/build_defs/build_test:build_test.bzl", "build_test") + +package( + default_applicable_licenses = ["//:license"], + default_compatible_with = ["//buildenv/target:gce"], + default_visibility = ["//:__subpackages__"], +) + +cc_test( + name = "AnalysisTest", + srcs = glob(["Analysis/*.cpp"]), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTestCatchAll", + srcs = glob( + [ + "Dialect/**/*.cpp", + ], + exclude = [ + "Dialect/TritonGPU/DialectTest.cpp", + "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp", + "Dialect/TritonGPU/SwizzleTest.cpp", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTest", + srcs = [ + "Dialect/TritonGPU/DialectTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "LinearLayoutConversionsTest", + srcs = [ + "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "SwizzleTest", + srcs = [ + "Dialect/TritonGPU/SwizzleTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "ConversionTest", + srcs = glob( + [ + "Conversion/**/*.cpp", + "Conversion/**/*.h", + ], + exclude = [ + "Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.h", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "//:TritonDialects", + "//:TritonNvidiaGPUTransforms", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +build_test( + name = "build_test", + allow_empty_target = False, + targets = [ + ":ConversionTest", + ":AnalysisTest", + ":DialectTest", + ], +)