diff --git a/BUILD b/BUILD new file mode 100644 index 000000000000..d99d944f2790 --- /dev/null +++ b/BUILD @@ -0,0 +1,931 @@ +# This package imports OpenAI's Triton (https://github.com/openai/triton). +# +# There are two versions of Triton in google3 at the moment. The older version +# can be found at //third_party/py/triton. This is the MLIR-based version close +# to head. We expect to transition users to this version in the following +# weeks. +# +# There is no SLA associated with this package and it may get broken by LLVM +# imports at any time. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = [":license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # # Add your project here if you need to depend on Triton's C++ sources. + # # Add a point of contact we can reach out to when needed in the comment. + # # + # # If you need to use the Python fronted, add your project to + # # google3/third_party/py/triton/BUILD instead. + # # + # # By adding your project here, you agree to the Triton SLA: go/triton-google3-sla + # "//third_party/py/jax:__subpackages__", # cjfj@ + # "//third_party/tensorflow/compiler/xla:__subpackages__", # bchetioui@ + # "//platforms/xla/experimental/gpu:__subpackages__", # csigg@ + # # Triton-internal visibility + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end + # TODO(csigg): fix and remove + features = [ + "-parse_headers", + "-use_header_modules", + ], +) + +# copybara:uncomment_begin +# license(name = "license") +# +# licenses(["notice"]) +# +# exports_files(["LICENSE"]) +# copybara:uncomment_end + +config_setting( + name = "compiler_is_msvc", + flag_values = { + # copybara:comment_begin + "@bazel_tools" + + # copybara:comment_end + "//tools/cpp:compiler": "msvc-cl", + }, +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + ":compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +td_library( + name = "td_files", + srcs = glob(["include/triton/**/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "triton_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/Triton/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/Triton/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-enum-decls"], + "include/triton/Dialect/Triton/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/Triton/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-op-decls"], + "include/triton/Dialect/Triton/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/Triton/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/Triton/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/Triton/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=Triton", + ], + "include/triton/Dialect/Triton/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_combine_inc_gen", + # The generated file is #included without relative path. + strip_include_prefix = "lib/Dialect/Triton/Transforms", + tbl_outs = [ + ( + ["--gen-rewriters"], + "lib/Dialect/Triton/Transforms/TritonCombine.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/Dialect/Triton/Transforms/Combine.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonGPU/IR/AttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/TritonGPU/IR/AttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/TritonGPU/IR/AttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_type_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-type-interface-decls"], + "include/triton/Dialect/TritonGPU/IR/TypeInterfaces.h.inc", + ), + ( + ["--gen-type-interface-defs"], + "include/triton/Dialect/TritonGPU/IR/TypeInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPU", + ], + "include/triton/Dialect/TritonGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNvidiaGPU", + ], + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_to_triton_gpu_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonToTritonGPU", + ], + "include/triton/Conversion/TritonToTritonGPU/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonToTritonGPU/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_target_llvmir_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonLLVMIR", + ], + "include/triton/Target/LLVMIR/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Target/LLVMIR/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPUToLLVM", + ], + "include/triton/Conversion/TritonGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonGPUToLLVM/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_op_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-op-interface-decls"], + "include/triton/Dialect/Triton/IR/OpInterfaces.h.inc", + ), + ( + ["--gen-op-interface-defs"], + "include/triton/Dialect/Triton/IR/OpInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonOpInterfaces.td", + deps = ["td_files"], +) + +cc_library( + name = "TritonAnalysis", + srcs = [ + "lib/Analysis/Alias.cpp", + "lib/Analysis/Allocation.cpp", + "lib/Analysis/Membar.cpp", + # Part of TritonDialects compilation unit to avoid circular dependencies. + # "lib/Analysis/Utility.cpp", + # "lib/Analysis/AxisInfo.cpp", + ], + hdrs = [ + "include/triton/Analysis/Alias.h", + "include/triton/Analysis/Allocation.h", + "include/triton/Analysis/Membar.h", + # Part of TritonDialects compilation unit to avoid circular dependencies. + # "include/triton/Analysis/AxisInfo.h", + # "include/triton/Analysis/Utility.h", + "include/triton/Conversion/MLIRTypes.h", + "include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h", + "include/triton/Conversion/TritonGPUToLLVM/Utility.h", + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", + ], + copts = _no_unused_variable, + deps = [ + ":TritonDialects", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonDialects", + srcs = glob([ + "lib/Dialect/Triton/IR/*.cpp", + "lib/Dialect/TritonGPU/IR/*.cpp", + "lib/Dialect/TritonNvidiaGPU/IR/*.cpp", + "lib/Tools/*.cpp", + ]) + [ + "include/triton/Conversion/MLIRTypes.h", # Avoid circular dependency. + "include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h", # Avoid circular dependency. + "include/triton/Conversion/TritonGPUToLLVM/Utility.h", # Avoid circular dependency. + "include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h", # Avoid circular dependency. + "lib/Analysis/AxisInfo.cpp", # Avoid circular dependency. + "lib/Analysis/Utility.cpp", # Avoid circular dependency. + "lib/Dialect/TritonGPU/Transforms/Utility.cpp", # Avoid circular dependency. + ], + hdrs = glob([ + "include/triton/Dialect/Triton/IR/*.h", + "include/triton/Dialect/TritonGPU/IR/*.h", + "include/triton/Dialect/TritonNvidiaGPU/IR/*.h", + "include/triton/Tools/*.h", + ]) + [ + "include/triton/Analysis/AxisInfo.h", # Avoid circular dependency. + "include/triton/Analysis/Utility.h", # Avoid circular dependency. + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", # Avoid circular dependency. + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = ["include"], + deps = [ + ":triton_dialect_inc_gen", + ":triton_gpu_attr_inc_gen", + ":triton_gpu_dialect_inc_gen", + ":triton_gpu_ops_inc_gen", + ":triton_gpu_types_inc_gen", + ":triton_gpu_type_interfaces_inc_gen", + ":triton_interfaces_inc_gen", + ":triton_nvidia_gpu_attr_inc_gen", + ":triton_nvidia_gpu_dialect_inc_gen", + ":triton_nvidia_gpu_ops_inc_gen", + ":triton_nvidia_gpu_types_inc_gen", + ":triton_op_interfaces_inc_gen", + ":triton_ops_inc_gen", + ":triton_types_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:UBDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@triton//third_party/nvidia:NVGPUDialect", + # The following is added to make Utility compile + ":TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@triton//third_party/f2reduce", + ], +) + +cc_library( + name = "TritonTransforms", + srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]), + hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]), + copts = _no_unused_variable, + deps = [ + ":TritonDialects", + ":triton_combine_inc_gen", + ":triton_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + alwayslink = True, # TritonDialect uses getCanonicalizationPatterns(). +) + +cc_library( + name = "TritonGPUTransforms", + srcs = glob( + [ + "lib/Dialect/TritonGPU/Transforms/*.cpp", + "lib/Dialect/TritonGPU/Transforms/*.h", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.cpp", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.h", + ], + exclude = ["lib/Dialect/TritonGPU/Transforms/Utility.cpp"], + ), + hdrs = glob( + [ + "include/triton/Dialect/TritonGPU/Transforms/*.h", + ], + exclude = ["include/triton/Dialect/TritonGPU/Transforms/Utility.h"], + ) + [ + "include/triton/Tools/Sys/GetEnv.hpp", + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-logical-op-parentheses", + "-Wno-reorder-ctor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":TritonGPUToLLVM", + ":triton_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonGPUToLLVM", + srcs = glob([ + "lib/Conversion/TritonGPUToLLVM/*.h", + "lib/Conversion/TritonGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/triton/Tools/Sys/*.hpp", + "include/triton/Conversion/TritonGPUToLLVM/*.h", + ]), + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + ":triton_gpu_attr_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonNvidiaGPUTransforms", + srcs = glob([ + "lib/Dialect/TritonNvidiaGPU/Transforms/*.cpp", + ]), + hdrs = glob([ + "include/triton/Dialect/TritonNvidiaGPU/Transforms/*.h", + ]), + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-ctad-maybe-unsupported", + "-Wno-logical-op-parentheses", + "-Wno-non-virtual-dtor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonToTritonGPU", + srcs = glob([ + "lib/Conversion/TritonToTritonGPU/*.h", + "lib/Conversion/TritonToTritonGPU/*.cpp", + ]), + hdrs = glob(["include/triton/Conversion/TritonToTritonGPU/*.h"]), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonLLVMIR", + srcs = glob([ + "lib/Target/LLVMIR/*.cpp", + "lib/Target/LLVMIR/*.h", + ]), + hdrs = glob(["include/triton/Target/LLVMIR/*.h"]), + copts = _no_unused_variable, + deps = [ + ":TritonTransforms", + ":triton_target_llvmir_passes_inc_gen", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BinaryFormat", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLToLLVMIRTranslation", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonPTX", + srcs = glob([ + "lib/Target/PTX/*.cpp", + ]), + hdrs = glob(["include/triton/Target/PTX/*.h"]), + deps = ["@llvm-project//llvm:Support"], +) + +cc_library( + name = "TritonHSACO", + srcs = glob([ + "lib/Target/HSACO/*.cpp", + ]), + hdrs = glob(["include/triton/Target/HSACO/*.h"]), + deps = [ + ":TritonLLVMIR", + ":TritonTools", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:ExecutionEngine", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Scalar", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:TransformUtils", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + ], +) + +cc_library( + name = "TritonTools", + hdrs = ["include/triton/Tools/Sys/GetEnv.hpp"], +) + +cc_library( + name = "AllPassesAndDialects", + srcs = [ + "include/triton/Conversion/TritonToTritonGPU/Passes.h", + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h", + ], + hdrs = ["bin/RegisterTritonDialects.h"], + includes = ["."], # because it includes third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h + deps = [ + ":TritonDialects", + ":TritonGPUToLLVM", + ":TritonGPUTransforms", + ":TritonLLVMIR", + ":TritonNvidiaGPUTransforms", + ":TritonToTritonGPU", + ":TritonTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//mlir:AllPassesAndDialects", + "@triton//test:TritonTestAnalysis", + "@triton//third_party/amd:TritonAMDGPU", + "@triton//third_party/amd:TritonAMDGPUToLLVM", + "@triton//third_party/amd:TritonAMDGPUTransforms", + "@triton//third_party/nvidia:NVGPUDialect", + "@triton//third_party/nvidia:NVGPUToLLVM", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + "@triton//third_party/proton:ProtonIRDialect", + ], +) + +cc_binary( + name = "triton-opt", + srcs = [ + "bin/triton-opt.cpp", + ], + deps = [ + ":AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", + ], +) + +cc_binary( + name = "triton-llvm-opt", + srcs = [ + "bin/triton-llvm-opt.cpp", + "lib/Target/LLVMIR/LLVMPasses.h", + ], + deps = [ + ":TritonLLVMIR", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Option", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + ], +) + +# See go/triton-debug for usage. +cc_binary( + name = "triton-reduce", + srcs = ["bin/triton-reduce.cpp"], + deps = [ + ":AllPassesAndDialects", + "@llvm-project//mlir:MlirReduceLib", + "@triton//third_party/amd:TritonAMDGPU", + "@triton//third_party/amd:TritonAMDGPUToLLVM", + ], +) + +cc_binary( + name = "triton-tensor-layout", + srcs = ["bin/triton-tensor-layout.cpp"], + deps = [ + ":AllPassesAndDialects", + ":TritonDialects", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + ], +) + +filegroup( + name = "metadata-file", + srcs = ["METADATA"], +) diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index fc6a2c73befc..cab66161e412 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -937,7 +937,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl { // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n lhsDivisibility = 1; } - return std::max(1, lhsDivisibility / (1 << shift)); + return std::max(1, lhsDivisibility / (int64_t(1) << shift)); } int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, @@ -1011,6 +1011,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor>(); diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp index 06e75ee18d59..035991a5d44f 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) -> Value { + // Allows partial TTIR to TTGIR conversion by materializing a conversion for + // remaining arguments that have been converted to a new type. + // We use this to rewrite triton_xla.sparse_dot in a separate pass after + // 'convert-triton-to-tritongpu'. + return builder.create(loc, tensorType, + inputs); llvm_unreachable("Argument rematerialization should not happen in Triton " "-> TritonGPU conversion"); return {}; @@ -66,6 +72,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // convert origValue to newValue addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) -> Value { + // Allows partial TTIR to TTGIR conversion by materializing a conversion for + // remaining uses of values that have been converted to a new type. + // We use this to rewrite triton_xla.sparse_dot in a separate pass after + // 'convert-triton-to-tritongpu'. + return builder.create(loc, tensorType, + inputs); llvm_unreachable("Source rematerialization should not happen in Triton -> " "TritonGPU Conversion"); return {}; diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index e3588f587757..1f1de2c717bd 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -17,7 +17,7 @@ def CombineDotAddIPattern : Pat< [(Constraint> $c), (ConstrainthasOneUse()">, "dot result has a single use">)]>; def CombineDotAddFPattern : Pat< - (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath), + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath, $denorm), (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), [(Constraint> $c), (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), @@ -29,7 +29,7 @@ def CombineDotAddIRevPattern : Pat< [(Constraint> $c), (ConstrainthasOneUse()">, "dot result has a single use">)]>; def CombineDotAddFRevPattern : Pat< - (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath), + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath, $denorm), (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), [(Constraint> $c), (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index dec78c2e41a4..b8a8c1079b67 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -3127,6 +3127,11 @@ struct CanonicalizeConvertFromAlloc auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); + // LocalAllocOp lowering doesn't support going from DotOperandEncoding + // to SharedEncoding, so we want to keep this layout conversion. + if (mlir::isa( + convert.getSrc().getType().getEncoding())) + return failure(); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), convert.getSrc()); return mlir::success(); @@ -3189,13 +3194,13 @@ struct CanonicalizeConvertFromConvert // heuristic to accommodate fused attention. auto srcType = op.getSrc().getType(); auto dstType = op.getType(); - if (mlir::isa(dstType.getEncoding()) && - mlir::isa(srcType.getEncoding())) + if (mlir::isa_and_nonnull(dstType.getEncoding()) && + mlir::isa_and_nonnull(srcType.getEncoding())) return failure(); // for hopper MMAv3 - if (mlir::isa(dstType.getEncoding()) && - mlir::isa(srcType.getEncoding()) && + if (mlir::isa_and_nonnull(dstType.getEncoding()) && + mlir::isa_and_nonnull(srcType.getEncoding()) && llvm::any_of(op.getResult().getUsers(), [](Operation *dot) { return dot->hasTrait(); })) { diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 678d2f12a8e8..3ad108af9b76 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -20,8 +20,6 @@ namespace mlir { namespace triton { namespace gpu { -namespace { - // Get the highest version supported for the hardware and the dot. static int getMMAVersionSafe(int computeCapability, DotOp op) { // List supported mma version in order of preference. @@ -44,8 +42,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { return 0; } -SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, - int numWarps) { +SmallVector +warpsPerTileV2(Operation *dotOp, const ArrayRef shape, int numWarps) { auto rank = shape.size(); // Early exit for batched matmul if (rank == 3) @@ -109,10 +107,10 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, } SmallVector -warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, +warpsPerTileV3(Operation *dotOp, const ArrayRef shape, int numWarps, const SmallVector &instrShape) { SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); + mlir::getForwardSlice(dotOp->getResult(0), &slices); // Contains a chained dot. We prefer to assign warps to one axis // to facilitate use cases like flash attention, allowing reductions within // the same warp. @@ -167,11 +165,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, auto newType = MemDescType::get(argType.getShape(), argType.getElementType(), newLayout, SharedMemorySpace); rewriter.setInsertionPointAfterValue(arg); + + // LocalAllocOp lowering doesn't support going from DotOperandEncoding + // to SharedEncoding. + if (auto dotOpEnc = mlir::dyn_cast( + argType.getEncoding())) { + // Create a layout conversion from DotOperandEncoding to BlockedEncoding + // then pass it to the LocalAllocOp. + auto newArgType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), dotOpEnc.getParent()); + auto dotOperandToBlockedCvt = + rewriter.create(arg.getLoc(), newArgType, arg); + return rewriter.create(arg.getLoc(), newType, + dotOperandToBlockedCvt); + } + return rewriter.create(arg.getLoc(), newType, arg); } SmallVector -getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, +getWarpsPerTile(Operation* dotOp, const ArrayRef shape, int version, int numWarps, const SmallVector &instrShape) { switch (version) { case 2: @@ -184,11 +197,24 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, } } +// Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity +// extension. +namespace { + class BlockedToMMA : public mlir::OpRewritePattern { int computeCapability; mutable llvm::DenseMap dotOpInstNs; static bool bwdFilter(Operation *op) { + // Dot operand layout assignment to Predicates are not currently supported + // during lowering from TritonGPU to LLVM in Triton for MMA cases. This + // condition limits visibility of the original bit-width so that predicate + // are not considered, hence, kwidth can never be = 32. + if (isa(op)) { + Type srcType = getElementTypeOrSelf(op->getOperand(0)); + if (srcType.isInteger(1)) + return false; + } return op->getNumOperands() == 1 && (isa(op) || isPureUnaryInlineAsm(op) || @@ -196,6 +222,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { mlir::TypeID::get()); } +public: // Finds the first different bitwidth in the chain of shape-preserving // unary ops that x depends on. // There are two primary scenarios: @@ -806,6 +833,15 @@ class TritonGPUAccelerateMatmulPass } }; +// Expose helper functions from BlockedToMMA to be reused for sparse matmul. +int computeOrigBitWidth(Value x) { + return BlockedToMMA::computeOrigBitWidth(x); +} +Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter, + int opIdx, bool allowTranspose) { + return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose); +} + } // namespace gpu } // namespace triton } // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp index d9fb1d7e17b3..b6728d22b484 100644 --- a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp @@ -17,9 +17,11 @@ namespace { // dot(a, b, inputPrecision="tf32x3") -> // let aBig = f32ToTF32(a), aSmall = a - aBig; // let bBig = f32ToTF32(b), bSmall = b - bBig; -// dot(aSmall, bBig, inputPrecision="tf32") + -// dot(aBig, bSmall, inputPrecision="tf32") + -// dot(aBig, bBig, inputPrecision="tf32") +// let small = dot(aSmall, bBig, inputPrecision="tf32") + +// dot(aBig, bSmall, inputPrecision="tf32") +// let masked_nans = replaceNansWithZeros(small) +// let big = dot(aBig, bBig, inputPrecision="tf32") +// return big + masked_nans; class TF32x3 : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -62,6 +64,13 @@ class TF32x3 : public OpRewritePattern { InputPrecision::TF32, dotOp.getMaxNumImpreciseAcc()); }; + auto replaceNansWithZeros = [&](Value value) -> Value { + auto nans = rewriter.create( + dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); + auto zero = zeroLike(value); + return rewriter.create(dotOp->getLoc(), nans, zero, + value); + }; auto aBig = f32ToTF32(dotOp.getA()); auto aSmall = sub(dotOp.getA(), aBig); @@ -73,7 +82,16 @@ class TF32x3 : public OpRewritePattern { auto dot1 = dot(aSmall, bBig, zero); auto dot2 = dot(aBig, bSmall, dot1); - auto dot3 = dot(aBig, bBig, dot2); + + // If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0. + // If rhs is +infinity, we will have: + // +infinity * 1.0 = +infinity + // +infinity * 0.0 = NaN + // We would get the wrong result if we sum these partial products. Instead, + // we must override any accumulated result if the last partial product is + // non-finite. + auto dot2withZeroedNans = replaceNansWithZeros(dot2); + auto dot3 = dot(aBig, bBig, dot2withZeroedNans); auto sum = add(dot3, dotOp.getC()); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index f0fe8d43f438..fa0c8540df19 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -111,6 +111,7 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc, Value zero = builder.createWithStage( forOp.getLoc(), stage, clusterId, 0, 32); + // Replace the load with insert/extract slice. builder.setInsertionPoint(loadOp); Location loc = loadOp.getLoc(); @@ -468,7 +469,8 @@ assignMemoryLayouts(scf::ForOp &forOp, } }); - loadsToPipeline.insert(&op); + // TODO: b/381421713 - Uncomment this once pipelining is fixed. + // loadsToPipeline.insert(&op); LoadInfo loadInfo; for (auto use : users) { if (use->hasTrait()) { @@ -508,6 +510,11 @@ assignMemoryLayouts(scf::ForOp &forOp, getBlockedEncoding(loadOp, axisInfoAnalysis); } } + + // TODO: b/381421713 - Remove this once pipelining is fixed. + if (!loadInfo.sharedEncoding) continue; + loadsToPipeline.insert(&op); + loadToInfo[&op] = loadInfo; } // Make sure all loads in loadsToPipeline are in loadToInfo. diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index c11f2f8e5ee7..78220cb0f4ee 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, // opIdx: 0 => a, 1 => b auto type = cast(v.getType()); SmallVector shape{type.getShape().begin(), type.getShape().end()}; - SmallVector offset{0, 0}; + SmallVector offset(shape.size(), 0); Type elementType = type.getElementType(); // k => (prefetchWidth, k - prefetchWidth) @@ -141,8 +141,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, type.getMutableMemory(), type.getAllocShape()), v, offsetsVal); + // We need to assign kwidth to zero in the case where the parent layout is + // Blocked, otherwise the verifier emits a failure. The parent layout is + // Blocked only when Tensor Cores are disabled. + int kwidth = dyn_cast(dotEncoding) + ? 0 + : prefetchWidth / 8; auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( - builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); + builder.getContext(), opIdx, dotEncoding, kwidth); Value prefetchSlice = builder.create( v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), newSmem); @@ -191,6 +197,22 @@ LogicalResult Prefetcher::initialize() { break; if (!op->getResult(0).hasOneUse()) break; + // Similar to issues faced in HoistLayoutConversion pattern in + // OptimizeDotOperands.cpp, we can't propagate through type casts from + // predicates as they aren't supported in Triton when encoded with dot_op + // layout. + if (isa(op)) { + Type srcType = getElementTypeOrSelf(op->getOperand(0)); + if (srcType.isInteger(1)) + break; + } + // Propagation through ExpandDims is currently not supported. This blindly + // replaces the encoding with dot encoding & but ExpandDims requires a + // SliceEncoding. This could be rewritten to support it somehow, but I + // don't think it's trivial & it's currently crashing. + if (isa(op)) { + break; + } rets.push_back(op->getOperand(0)); if (auto cvt = dyn_cast(op)) { foundConvertFromShared = true; diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 117c4c65121e..8c3ec2ccdaa2 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -953,18 +953,26 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { } else { if (!isa(user)) return std::nullopt; - auto dotOpEnc = dyn_cast( - cast(user->getResult(0).getType()) - .getEncoding()); - if (!dotOpEnc) + auto enc = + cast(user->getResult(0).getType()).getEncoding(); + if (isa(enc)) { + auto srcTy = cast(val.getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy.getEncoding()); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + tempAttr = ttg::SharedEncodingAttr::get( + val.getContext(), cast(enc), + srcTy.getShape(), order, CTALayout, bitWidth, /*needTrans=*/false); + } else if (enc.getAbstractAttribute().getName().str() == + "triton.gpu.sparse_dot_meta_encoding") { + auto srcTy = cast(val.getType()); + tempAttr = ttg::SharedEncodingAttr::get( + val.getContext(), /*vec=*/1, /*perPhase=*/1, /*maxPhase=*/1, + ttg::getOrder(srcTy.getEncoding()), + ttg::getCTALayout(srcTy.getEncoding())); + } else { return std::nullopt; - auto srcTy = cast(val.getType()); - auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); - auto order = ttg::getOrder(srcTy.getEncoding()); - unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); - tempAttr = ttg::SharedEncodingAttr::get( - val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, - bitWidth, /*needTrans=*/false); + } } // Check that the shared encodings needed by the users are compatible. if (attr != nullptr && attr != tempAttr) { diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index fb0e7f6fdb18..62ed71c175fd 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -44,7 +44,7 @@ struct FenceInsertionPass return; ModuleOp mod = getOperation(); mod.walk([&](Operation *op) { - if (!isa(op)) + if (!op->hasTrait()) return WalkResult::advance(); OpBuilder builder(op); auto a = op->getOperand(0); diff --git a/python/BUILD b/python/BUILD new file mode 100644 index 000000000000..132ab8de68b5 --- /dev/null +++ b/python/BUILD @@ -0,0 +1,78 @@ +# NOTE: Do not depend on any targets from this directory, +# but use //third_party/py/triton instead. + +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__pkg__", + "@triton//python:__subpackages__", + ], +) + +cc_library( + name = "passes", + hdrs = ["src/passes.h"], + includes = ["src"], + visibility = ["@triton//third_party:__subpackages__"], +) + +pybind_extension( + name = "libtriton", + srcs = [ + "src/interpreter.cc", + "src/ir.cc", + "src/llvm.cc", + "src/main.cc", + "src/passes.cc", + ], + copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"], + deps = [ + ":passes", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBDialect", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + "//:TritonHSACO", + "//:TritonLLVMIR", + "//:TritonNvidiaGPUTransforms", + "//:TritonPTX", + "//:TritonToTritonGPU", + "//:TritonTools", + "//:TritonTransforms", + "@triton//third_party/nvidia:triton_nvidia", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["triton/**/*.py"], + ), +) diff --git a/python/test/regression/BUILD b/python/test/regression/BUILD new file mode 100644 index 000000000000..a88f4eeae1f8 --- /dev/null +++ b/python/test/regression/BUILD @@ -0,0 +1,26 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "tests", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + include = ["test_*.py"], + exclude = [ + "test_performance.py", #TODO(b/321005767): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/BUILD b/python/test/unit/BUILD new file mode 100644 index 000000000000..3140375fbec7 --- /dev/null +++ b/python/test/unit/BUILD @@ -0,0 +1,181 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests", "pytest_test") + +package( + default_applicable_licenses = ["//:license"], +) + +_requires_gpu_sm80 = [ + "config-cuda-only", + "requires-gpu-sm80", +] + +_requires_config_cuda = select( + {"@local_config_cuda//cuda:using_blaze_config_cuda": []}, + no_match_error = "Requires --config=cuda", +) + +EXCLUDE_TESTS = [ + "language/test_reproducer.py", # this is not an actual test, but a tool for running reproducers + "language/test_subprocess.py", # TODO(b/320224484): fix failing test + "runtime/test_launch.py", # TODO(b/320226169): fix failing tests + "tools/test_aot.py", # TODO(b/320224484): fix failing test + "tools/test_disasm.py", # TODO(b/320224484): fix failing test + "hopper/test_persistent_warp_specialized_gemm.py", # TODO (b/342348738): fix failing test + "runtime/test_cublas.py", # TODO(b/346755023): fix failing test + "test_debug.py", # TODO(b/374733875): fix failing test. Also see b/374733872. +] + +# Runs all python tests on H100 +pytest_multi_tests( + name = "hopper", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + "language/test_core.py", + ], + name_suffix = "_h100", + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm90", + ], + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["**/test_*.py"], + exclude = EXCLUDE_TESTS + [ + "language/test_core.py", + "language/test_pipeliner.py", # TODO(b/362458006): fix failing test + "hopper/test_experimental_tma.py", # TODO(b/362458006): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Shard test_core more, as it is otherwise very slow to run. +pytest_test( + name = "hopper/language/test_core_h100", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + ], + shard_count = 40, + tags = [ + "config-cuda-only", + "requires-gpu-sm90", + ], + target_compatible_with = _requires_config_cuda, + tests = ["language/test_core.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "language", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + "language/test_core.py", + ], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["language/**/test_*.py"], + exclude = EXCLUDE_TESTS + ["language/test_core.py"], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Shard test_core more, as it is otherwise very slow to run. +pytest_test( + name = "language/test_core", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + ], + shard_count = 40, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = ["language/test_core.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "instrumentation", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["instrumentation/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "runtime", + srcs = ["conftest.py"], + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["runtime/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "tools", + size = "large", + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["tools/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "unit", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index ac22cdee4335..fbe8ec406e4e 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2186,6 +2186,8 @@ def kernel(X, Z, BLOCK: tl.constexpr): reduce_bool = [(op, 'bool', shape, axis, False) for op in ['xor_sum'] for shape in reduce2d_shapes for axis in [0, 1]] +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] >= 9, + reason='Reduction test produces wrong results on H100, b/342347027') @pytest.mark.interpreter @pytest.mark.parametrize( "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + @@ -3929,6 +3931,25 @@ def _kernel(out): kernel[(1, )](out) assert torch.all(out == out_ref) +@pytest.mark.interpreter +def test_dot_on_broadcast(device): + @triton.jit + def _kernel(a, b, out): + a_offsets = tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :] + lhs = tl.load(a + a_offsets, mask=a_offsets < 32 * 64) + rhs = tl.load(b) + rhs_bc = tl.broadcast_to(rhs, [32, 32]) + c = tl.dot(lhs, rhs_bc) + out_ptr = out + tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(out_ptr, c) + + a = torch.ones((64, 32), dtype=getattr(torch, 'float32'), device=device) + b = torch.tensor([1.0], dtype=getattr(torch, 'float32'), device=device) + out_ref = torch.matmul(a, torch.broadcast_to(b, (32, 32))) + out = torch.zeros((64, 32), dtype=getattr(torch, 'float32'), device=device) + _kernel[(1, )](a, b, out, num_stages=1, num_warps=4) + assert torch.all(out == out_ref) + # --------------- # test arange diff --git a/python/triton/_C/include b/python/triton/_C/include index b85a409837d1..8a5dba6c4b56 120000 --- a/python/triton/_C/include +++ b/python/triton/_C/include @@ -1 +1 @@ -../../../include/ \ No newline at end of file +../../../include \ No newline at end of file diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index 92ba144ba97b..f9bab523bf6c 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -46,5 +46,8 @@ def _discover_backends(): _find_concrete_subclasses(driver, DriverBase)) return backends - -backends = _discover_backends() +from triton.backends.nvidia.driver import CudaDriver +from triton.backends.nvidia.compiler import CUDABackend +backends = { + "nvidia": Backend(CUDABackend, CudaDriver) +} diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 31b19754c63c..df467ac3b81b 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -793,7 +793,7 @@ def __str__(self) -> str: @builtin def __add__(self, other, _builder=None): - return add(self, other, sanitize_overflow=True, _builder=_builder) + return add(self, other, sanitize_overflow=False, _builder=_builder) @builtin def __radd__(self, other, _builder=None): @@ -809,7 +809,7 @@ def __rsub__(self, other, _builder=None): @builtin def __mul__(self, other, _builder=None): - return mul(self, other, sanitize_overflow=True, _builder=_builder) + return mul(self, other, sanitize_overflow=False, _builder=_builder) @builtin def __rmul__(self, other, _builder=None): @@ -2154,7 +2154,7 @@ def where(condition, x, y, _builder=None): @builtin -def add(x, y, sanitize_overflow: constexpr = True, _builder=None): +def add(x, y, sanitize_overflow: constexpr = False, _builder=None): x = _unwrap_if_constexpr(x) y = _unwrap_if_constexpr(y) return semantic.add(x, y, sanitize_overflow, _builder) @@ -2168,7 +2168,7 @@ def sub(x, y, sanitize_overflow: constexpr = True, _builder=None): @builtin -def mul(x, y, sanitize_overflow: constexpr = True, _builder=None): +def mul(x, y, sanitize_overflow: constexpr = False, _builder=None): x = _unwrap_if_constexpr(x) y = _unwrap_if_constexpr(y) return semantic.mul(x, y, sanitize_overflow, _builder) diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py deleted file mode 100644 index 20da2bc25790..000000000000 --- a/python/triton/runtime/build.py +++ /dev/null @@ -1,80 +0,0 @@ -import contextlib -import sys -import io -import sysconfig -import os -import shutil -import subprocess -import setuptools - - -@contextlib.contextmanager -def quiet(): - old_stdout, old_stderr = sys.stdout, sys.stderr - sys.stdout, sys.stderr = io.StringIO(), io.StringIO() - try: - yield - finally: - sys.stdout, sys.stderr = old_stdout, old_stderr - - -def _build(name, src, srcdir, library_dirs, include_dirs, libraries): - suffix = sysconfig.get_config_var('EXT_SUFFIX') - so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) - # try to avoid setuptools if possible - cc = os.environ.get("CC") - if cc is None: - # TODO: support more things here. - clang = shutil.which("clang") - gcc = shutil.which("gcc") - cc = gcc if gcc is not None else clang - if cc is None: - raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") - # This function was renamed and made public in Python 3.10 - if hasattr(sysconfig, 'get_default_scheme'): - scheme = sysconfig.get_default_scheme() - else: - scheme = sysconfig._get_default_scheme() - # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install - # path changes to include 'local'. This change is required to use triton with system-wide python. - if scheme == 'posix_local': - scheme = 'posix_prefix' - py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] - custom_backend_dirs = set(os.getenv(var) for var in ('TRITON_CUDACRT_PATH', 'TRITON_CUDART_PATH')) - include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] - # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 - cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] - cc_cmd += [f'-l{lib}' for lib in libraries] - cc_cmd += [f"-L{dir}" for dir in library_dirs] - cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] - ret = subprocess.check_call(cc_cmd) - if ret == 0: - return so - # fallback on setuptools - extra_compile_args = [] - # extra arguments - extra_link_args = [] - # create extension module - ext = setuptools.Extension( - name=name, - language='c', - sources=[src], - include_dirs=include_dirs, - extra_compile_args=extra_compile_args + ['-O3'], - extra_link_args=extra_link_args, - library_dirs=library_dirs, - libraries=libraries, - ) - # build extension module - args = ['build_ext'] - args.append('--build-temp=' + srcdir) - args.append('--build-lib=' + srcdir) - args.append('-q') - args = dict( - name=name, - ext_modules=[ext], - script_args=args, - ) - with quiet(): - setuptools.setup(**args) - return so diff --git a/test/BUILD b/test/BUILD new file mode 100644 index 000000000000..6d0d853423f3 --- /dev/null +++ b/test/BUILD @@ -0,0 +1,63 @@ +# copybara:uncomment_begin +# load("//third_party/llvm/build_defs:lit.bzl", "glob_lit_tests") +# load("//tools/build_defs/build_test:build_test.bzl", "build_test") +# +# package( +# default_applicable_licenses = ["//:license"], +# default_compatible_with = ["//buildenv/target:non_prod"], +# default_visibility = ["//:__subpackages__"], +# ) +# +# glob_lit_tests( +# name = "all_tests", +# data = [ +# "@llvm-project//llvm:FileCheck", +# "//:triton-llvm-opt", +# "//:triton-opt", +# "//:triton-tensor-layout", +# ], +# driver = "@llvm-project//mlir:run_lit.sh", +# exclude = [ +# "Conversion/amd/dedup-by-constancy.mlir", # AMD-specific, broken +# "TritonGPU/amd/amd-instruction-sched.mlir", # AMD-specific, broken with -debug-only. +# "TritonGPU/optimize_epilogue.mlir", # TODO: b/346283526 - AMD-specific, triggering UBSAN +# ], +# test_file_exts = [ +# "mlir", +# "ll", +# ], +# ) +# +# build_test( +# name = "build_test", +# allow_empty_target = False, +# targets = [ +# "//:TritonAnalysis", +# "//:TritonDialects", +# "//:TritonGPUToLLVM", +# "//:TritonGPUTransforms", +# "//:TritonLLVMIR", +# "//:TritonPTX", +# "//:TritonToTritonGPU", +# "//:TritonTools", +# "//:TritonTransforms", +# "//:triton-opt", +# ], +# ) +# copybara:uncomment_end + +cc_library( + name = "TritonTestAnalysis", + srcs = glob(["lib/Analysis/*.cpp"]), + deps = [ + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index a97ac476cbad..fcccbfd024a7 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -2044,3 +2044,17 @@ tt.func @upcast_mxfp(%arg0: tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #m } } + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=4}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @f16_to_f8_dot_operand(%f16_inp: tensor<32x32xf16, #dot_operand>) { + // CHECK-LABEL: @f16_to_f8_dot_operand + + %f8 = tt.fp_to_fp %f16_inp, rounding = rtne : tensor<32x32xf16, #dot_operand> -> tensor<32x32xf8E5M2, #dot_operand> + tt.return + } +} + diff --git a/test/Conversion/tritongpu_to_llvm_ampere.mlir b/test/Conversion/tritongpu_to_llvm_ampere.mlir new file mode 100644 index 000000000000..0bda37caed17 --- /dev/null +++ b/test/Conversion/tritongpu_to_llvm_ampere.mlir @@ -0,0 +1,23 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=80 2>&1 | FileCheck %s + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @ampere_s8_to_fp16_conversion_opIdx1(%1 : tensor<16x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>) attributes {noinline = false} { + // CHECK-LABEL: ampere_s8_to_fp16_conversion_opIdx1 + // CHECK: llvm.sitofp %{{.*}} : i8 to f16 + %2 = arith.sitofp %1 : tensor<16x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @ampere_s8_to_fp16_conversion_opIdx0(%1 : tensor<32x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>) attributes {noinline = false} { + // CHECK-LABEL: @ampere_s8_to_fp16_conversion_opIdx0 + // CHECK: llvm.sitofp %{{.*}} : i8 to f16 + %2 = arith.sitofp %1 : tensor<32x16xi8, #ttg.dot_op<{opIdx = 0 , parent = #mma, kWidth = 4}>> to tensor<32x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + tt.return + } +} diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 17180a392440..b9207593436c 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -261,3 +261,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } + +// ----- + +// CHECK-DAG: #[[$BLOCKED:.*]] = #ttg.blocked +// CHECK-DAG: #mma = #ttg.nvidia_mma<{versionMajor = 3 +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func @local_alloc_dot_operand(%in0: tensor<64x256xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> {tt.divisibility = 16 : i32}, %in1: f32, %in2: tensor<64x32xf32, #blocked>) -> (tensor<64x32xf32, #blocked>) { + // CHECK-LABEL: local_alloc_dot_operand + %splat_in1 = tt.splat %in1 : f32 -> tensor<256x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK: %[[LHS_LOCAL_ALLOC:.*]] = ttg.local_alloc + // CHECK: %[[RHS_CVT:.*]] = ttg.convert_layout {{.*}} #ttg.dot_op<{{.*}}> -> {{.*}} #[[$BLOCKED]] + // CHECK: %[[RHS_LOCAL_ALLOC:.*]] = ttg.local_alloc %[[RHS_CVT]] + // CHECK: ttng.warp_group_dot %[[LHS_LOCAL_ALLOC]], %[[RHS_LOCAL_ALLOC]] + %res = tt.dot %in0, %splat_in1, %in2, inputPrecision = tf32 : tensor<64x256xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<256x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x32xf32, #blocked> + tt.return %res : tensor<64x32xf32, #blocked> + } +} diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index b47005b56978..5b8dc7516b3b 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -137,3 +137,20 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.return %2 : !ttg.memdesc<16x16xf16, #shared, #smem> } } // end module + +// ----- + +// CHECK: #[[$BLOCKED:.*]] = #ttg.blocked +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func @cvt_from_dot_op_into_local_allow_not_canonicalized(%in: tensor<256x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) -> !ttg.memdesc<256x32xf32, #shared1, #smem> { + // CHECK-LABEL: cvt_from_dot_op_into_local_allow_not_canonicalized + %cvt_in = ttg.convert_layout %in : tensor<256x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<256x32xf32, #blocked> + %alloc = ttg.local_alloc %cvt_in : (tensor<256x32xf32, #blocked>) -> !ttg.memdesc<256x32xf32, #shared1, #smem> + // CHECK: %[[ALLOC:.*]] = ttg.local_alloc {{.*}} (tensor<{{.*}}, #[[$BLOCKED]]{{.*}}>) -> + tt.return %alloc : !ttg.memdesc<256x32xf32, #shared1, #smem> + } +} // end module + diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index 208516b3bfab..481e982cd8cb 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -244,3 +244,23 @@ tt.func @matmul_loop_mixed_amd(%lb : index, %ub : index, %step : index, %A : !tt tt.return %loop#4 : tensor<128x128xf32, #C> } } // end module + + // ----- + +// CHECK: tt.func @matmul_loop_on_blocked_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func @matmul_loop_on_blocked_layout(%arg_lhs: !ttg.memdesc<16x512xf32, #shared, #smem, mutable>, %arg_rhs: !ttg.memdesc<512x32xf32, #shared, #smem, mutable>, %arg_init: tensor<16x32xf32, #blocked>, %itr_val : i32) -> (tensor<16x32xf32, #blocked>) { + %loop:3 = scf.for %itr = %itr_val to %itr_val step %itr_val iter_args(%init = %arg_init, %lhs = %arg_lhs, %rhs = %arg_rhs) -> (tensor<16x32xf32, #blocked>, !ttg.memdesc<16x512xf32, #shared, #smem, mutable>, !ttg.memdesc<512x32xf32, #shared, #smem, mutable>) : i32 { + %lhs_ll = ttg.local_load %lhs : !ttg.memdesc<16x512xf32, #shared, #smem, mutable> -> tensor<16x512xf32, #blocked> + %lhs_ll_cvt = ttg.convert_layout %lhs_ll : tensor<16x512xf32, #blocked> -> tensor<16x512xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %rhs_ll = ttg.local_load %rhs : !ttg.memdesc<512x32xf32, #shared, #smem, mutable> -> tensor<512x32xf32, #blocked> + %rhs_ll_cvt = ttg.convert_layout %rhs_ll : tensor<512x32xf32, #blocked> -> tensor<512x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %res = tt.dot %lhs_ll_cvt, %rhs_ll_cvt, %init : tensor<16x512xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<512x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x32xf32, #blocked> + scf.yield %res, %lhs, %rhs : tensor<16x32xf32, #blocked>, !ttg.memdesc<16x512xf32, #shared, #smem, mutable>, !ttg.memdesc<512x32xf32, #shared, #smem, mutable> + } + tt.return %loop#0 : tensor<16x32xf32, #blocked> + } +} // end module diff --git a/third_party/amd/BUILD b/third_party/amd/BUILD new file mode 100644 index 000000000000..5574fcafc9af --- /dev/null +++ b/third_party/amd/BUILD @@ -0,0 +1,248 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/service/gpu/fusions/triton:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + "//:compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +cc_library( + name = "TritonAMDGPUTransforms", + srcs = glob([ + "lib/TritonAMDGPUTransforms/**/*.h", + "lib/TritonAMDGPUTransforms/**/*.cpp", + ]) + [ + "include/TritonAMDGPUToLLVM/TargetUtils.h", + "lib/TritonAMDGPUToLLVM/SchedInstructions.h", + ], + hdrs = glob([ + "include/TritonAMDGPUTransforms/**/*.h", + ]), + copts = _no_unused_variable, + includes = [ + "include", + "lib/TritonAMDGPUTransforms", + ], + deps = [ + ":TritonAMDGPU", + ":triton_conversion_amdgpu_transforms_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConvertToLLVM", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + ], +) + +cc_library( + name = "TritonAMDGPU", + srcs = glob([ + "lib/Dialect/TritonAMDGPU/**/*.h", + "lib/Dialect/TritonAMDGPU/**/*.cpp", + ]), + hdrs = glob([ + "include/Dialect/TritonAMDGPU/**/*.h", + ]), + includes = [ + "..", + "include", + ], + deps = [ + ":triton_amdgpu_attr_def_inc_gen", + ":triton_amdgpu_dialect_inc_gen", + ":triton_amdgpu_ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:TensorDialect", + "//:TritonDialects", + "//:TritonGPUToLLVM", + ], +) + +cc_library( + name = "TritonAMDGPUToLLVM", + srcs = glob([ + "lib/TritonAMDGPUToLLVM/**/*.h", + "lib/TritonAMDGPUToLLVM/**/*.cpp", + # TritonAMDGPUToLLVM and TritonAMDGPUDialectToLLVM have interdependencies, easiest way to + # deal with circular dependencies is to just compile both in a single unit. + "lib/TritonAMDGPUDialectToLLVM/**/*.h", + "lib/TritonAMDGPUDialectToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/TritonAMDGPUToLLVM/**/*.h", + ]), + copts = _no_unused_variable, + includes = [ + "include", + "lib/TritonAMDGPUToLLVM", + ], + deps = [ + ":TritonAMDGPU", + ":TritonAMDGPUTransforms", + ":triton_conversion_amdgpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:AMDGPUDialect", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:ConvertToLLVM", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUToROCDLTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBToLLVM", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + ], +) + +td_library( + name = "td_files", + srcs = glob(["include/**/*.td"]), + includes = ["include"], + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_ops_inc_gen", + tbl_outs = [ + ( + [ + "--gen-llvmir-conversions", + ], + "include/Dialect/TritonAMDGPU/IR/OpsConversions.inc", + ), + ( + [ + "--gen-op-decls", + ], + "include/Dialect/TritonAMDGPU/IR/Ops.h.inc", + ), + ( + [ + "--gen-op-defs", + ], + "include/Dialect/TritonAMDGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_dialect_inc_gen", + tbl_outs = [ + ( + [ + "--gen-dialect-decls", + "--dialect=amdgpu", + ], + "include/Dialect/TritonAMDGPU/IR/Dialect.h.inc", + ), + ( + [ + "--gen-dialect-defs", + "--dialect=amdgpu", + ], + "include/Dialect/TritonAMDGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_attr_def_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_to_llvm_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPUToLLVM", + ], + "include/TritonAMDGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonAMDGPUToLLVM/Passes.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_transforms_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPU", + ], + "include/TritonAMDGPUTransforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonAMDGPUTransforms/Passes.td", + deps = [":td_files"], +) diff --git a/third_party/f2reduce/BUILD b/third_party/f2reduce/BUILD new file mode 100644 index 000000000000..93829539e1b9 --- /dev/null +++ b/third_party/f2reduce/BUILD @@ -0,0 +1,31 @@ +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +# copybara:uncomment_begin +# license( +# name = "license", +# license_text = "LICENCE.txt", +# ) +# +# licenses(["notice"]) +# +# exports_files(["LICENCE.txt"]) +# copybara:uncomment_end + +cc_library( + name = "f2reduce", + srcs = ["f2reduce.cpp"], + hdrs = ["f2reduce.h"], + # copybara:uncomment strip_include_prefix = "/third_party/triton", +) diff --git a/third_party/nvidia/BUILD b/third_party/nvidia/BUILD new file mode 100644 index 000000000000..8b654176d6cd --- /dev/null +++ b/third_party/nvidia/BUILD @@ -0,0 +1,308 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@pybind11_bazel//:build_defs.bzl", "pybind_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/service/gpu:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +pybind_library( + name = "cublas_headers", + hdrs = glob([ + "include/*.h", + ]), + deps = ["@local_config_cuda//cuda:cuda_headers"], +) + +pybind_library( + name = "triton_nvidia", + srcs = [ + "triton_nvidia.cc", + ], + compatible_with = [], + # copybara:uncomment_begin + # visibility = [ + # "@triton//python:__subpackages__", + # ], + # copybara:uncomment_end + deps = [ + ":NVGPUDialect", + ":NVGPUToLLVM", + ":TritonNVIDIAGPUToLLVM", + ":cublas_headers", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonNvidiaGPUTransforms", + "@triton//python:passes", + ], +) + +cc_library( + name = "NVGPUToLLVM", + srcs = glob([ + "lib/NVGPUToLLVM/*.cpp", + ]), + hdrs = glob([ + "include/NVGPUToLLVM/*.h", + ]), + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + ], + deps = [ + ":NVGPUDialect", + ":TritonNVIDIAGPUToLLVM", + ":triton_conversion_nvgpu_to_llvm_passes_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + ], +) + +cc_library( + name = "TritonNVIDIAGPUToLLVM", + srcs = glob([ + "lib/TritonNVIDIAGPUToLLVM/*.h", + "lib/TritonNVIDIAGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/TritonNVIDIAGPUToLLVM/*.h", + ]) + [ + "lib/TritonNVIDIAGPUToLLVM/TargetInfo.h", + "lib/TritonNVIDIAGPUToLLVM/Utility.h", + ], + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + "lib/TritonNVIDIAGPUToLLVM", + ], + deps = [ + ":NVGPUDialect", + ":triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBToLLVM", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:triton_gpu_attr_inc_gen", + ], +) + +gentbl_cc_library( + name = "triton_conversion_nvgpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=NVGPUToLLVM", + ], + "include/NVGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/NVGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNVIDIAGPUToLLVM", + ], + "include/TritonNVIDIAGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonNVIDIAGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +td_library( + name = "td_files", + srcs = glob(["include/Dialect/NVGPU/IR/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "nvgpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-llvmir-conversions"], + "include/Dialect/NVGPU/IR/OpsConversions.inc", + ), + ( + ["--gen-op-decls"], + "include/Dialect/NVGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/Dialect/NVGPU/IR/Ops.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/Dialect/NVGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/Dialect/NVGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "nvgpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "nvgpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/Dialect/NVGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/Dialect/NVGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUDialect.td", + deps = ["td_files"], +) + +cc_library( + name = "NVGPUDialect", + srcs = glob([ + "lib/Dialect/NVGPU/IR/*.cpp", + ]), + hdrs = glob([ + "include/Dialect/NVGPU/IR/*.h", + ]), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = [ + "..", # because nvidia/include/Dialect/NVGPU/IR/Dialect.h.inc + "../..", # because third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h + "include", + ], + deps = [ + ":nvgpu_attr_inc_gen", + ":nvgpu_dialect_inc_gen", + ":nvgpu_ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + # The following is added to make Utility compile + "//:TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/third_party/nvidia/backend/BUILD b/third_party/nvidia/backend/BUILD new file mode 100644 index 000000000000..a5b34aa5c29b --- /dev/null +++ b/third_party/nvidia/backend/BUILD @@ -0,0 +1,30 @@ +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__subpackages__", + ], +) + +pybind_extension( + name = "cuda_utils", + srcs = ["cuda_utils.cc"], + visibility = [ + "//learning/deepmind/jax/triton/ops:__subpackages__", + "//third_party/py/triton:__subpackages__", + ], + deps = [ + "//platforms/gpus/cuda/dynamic_libcuda", + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cuda_runtime", + "@llvm-project//llvm:Support", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["**/*.py"], + ), +) diff --git a/third_party/nvidia/backend/cuda_utils.cc b/third_party/nvidia/backend/cuda_utils.cc new file mode 100644 index 000000000000..abbbdb44d701 --- /dev/null +++ b/third_party/nvidia/backend/cuda_utils.cc @@ -0,0 +1,896 @@ +#define PY_SSIZE_T_CLEAN +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cuda.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +namespace { + +struct UniquePyObjectDeleter { + void operator()(PyObject* obj) { Py_DECREF(obj); } +}; +// A unique_ptr for PyObjects that automatically calls Py_DECREF once it goes +// out of scope. +using UniquePyObjectPtr = std::unique_ptr; + +// Raise a python exception if the CUDA result code is not CUDA_SUCCESS. +// Can be called even on threads that do not hold Python's Global Interpreter +// Lock (GIL), as the function will acquire one if needed. +inline bool gpuAssert(CUresult code, const char* file, int line) { + if (code == CUDA_SUCCESS) + return true; + const char* error = nullptr; + cuGetErrorString(code, &error); + PyGILState_STATE gil_state = PyGILState_Ensure(); + PyErr_Format(PyExc_RuntimeError, "Triton Error [CUDA]: %s", error); + PyGILState_Release(gil_state); + return false; +} + +// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. +#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + +#define CUDA_CHECK_AND_RETURN_NULL(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) \ + return NULL; \ + } while (0) + +// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. +#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) { \ + PyEval_RestoreThread(_save); \ + return NULL; \ + } \ + } while (0) + +// Used to check if functions exist in old CUDA driver versions. +#define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \ + do { \ + if ((funcPointer) == NULL) { \ + (funcPointer) = (initializerFunction)(); \ + if ((funcPointer) == NULL) { \ + return NULL; \ + } \ + } \ + } while (0) + +using cuLaunchKernelEx_t = CUresult (*)(const CUlaunchConfig* config, + CUfunction f, void** kernelParams, + void** extra); + +// Dynamically load the handle to cuLaunchKernelEx. +cuLaunchKernelEx_t getLaunchKernelExHandle() { + // Open the shared library + void* handle = dlopen("libcuda.so.1", RTLD_LAZY); + if (!handle) { + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so"); + return nullptr; + } + // Clear any existing error + dlerror(); + auto cuLaunchKernelExHandle = + reinterpret_cast(dlsym(handle, "cuLaunchKernelEx")); + // Check for errors + if (const char* dlsym_error = dlerror()) { + PyErr_Format(PyExc_RuntimeError, + "Failed to retrieve cuLaunchKernelEx from libcuda.so: %s", + dlsym_error); + return nullptr; + } + return cuLaunchKernelExHandle; +} + +// Configuration with all the information necessary to launch a compiled +// Triton kernel using the CUDA driver API. +struct TritonLaunchConfig { + // Represents CUDA's 3D ID structure of grids and clusters + struct Dim { + int x; + int y; + int z; + constexpr int size() const { return x * y * z; } + }; + Dim grid; // Number of clusters per grid + Dim cluster; // Number of blocks per cluster + int num_warps; // number of warps per block + int shared_memory; // Size of shared memory in bytes to allocate + CUstream stream; // CUDA Stream on which to launch the kernel + CUfunction function; // Pointer to the kernel to launch + void** params; // Parameters to pass to the kernel +}; + +// Launch a CUDA kernel with the given parameters. Raises a Python exception +// if the kernel launch fails. +PyObject* launchKernel(const TritonLaunchConfig& config) { + const auto& grid = config.grid; + const auto& cluster = config.cluster; + if (grid.size() == 0) { + Py_RETURN_NONE; + } + if (cluster.size() == 1) { + CUDA_CHECK_AND_RETURN_NULL(cuLaunchKernel( + config.function, grid.x, grid.y, grid.z, 32 * config.num_warps, 1, 1, + config.shared_memory, config.stream, config.params, 0)); + } else { + CUlaunchAttribute launchAttr[2]; + launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = cluster.x; + launchAttr[0].value.clusterDim.y = cluster.y; + launchAttr[0].value.clusterDim.z = cluster.z; + launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + launchAttr[1].value.clusterSchedulingPolicyPreference = + CU_CLUSTER_SCHEDULING_POLICY_SPREAD; + CUlaunchConfig cu_config; + cu_config.gridDimX = grid.x * cluster.x; + cu_config.gridDimY = grid.y * cluster.y; + cu_config.gridDimZ = grid.z * cluster.z; + cu_config.blockDimX = 32 * config.num_warps; + cu_config.blockDimY = 1; + cu_config.blockDimZ = 1; + cu_config.sharedMemBytes = config.shared_memory; + cu_config.hStream = config.stream; + cu_config.attrs = launchAttr; + cu_config.numAttrs = 2; + // cuLaunchKernelEx was added in CUDA 12, so load it dynamically to be + // able to link on CUDA 11 and earlier. + static cuLaunchKernelEx_t cuLaunchKernelExHandle = + getLaunchKernelExHandle(); + CUDA_CHECK_AND_RETURN_NULL( + cuLaunchKernelExHandle(&cu_config, config.function, config.params, 0)); + } + Py_RETURN_NONE; +} + +// Interface used by various PyObject extractors to extract obj into a memory +// location pointed by ptr. Returns true if extraction succeeded, and false +// otherwise. +using ExtractorType = bool (*)(PyObject* obj, void* ptr); + +// Extract a CUDA device pointer from a pointer-like PyObject obj, and store +// it to the memory location pointed by ptr. +bool extractPointer(PyObject* obj, void* ptr) { + auto dev_ptr = static_cast(ptr); + if (obj == Py_None) { + *dev_ptr = static_cast(0); // valid nullptr + return true; + } + if (PyLong_Check(obj)) { + *dev_ptr = PyLong_AsUnsignedLongLong(obj); + return true; + } + UniquePyObjectPtr ret(PyObject_CallMethod(obj, "data_ptr", nullptr)); + if (!ret.get()) { + PyErr_Format(PyExc_TypeError, + "Pointer argument must be either uint64 or have data_ptr " + "method, but got %R", + obj); + return false; + } + if (!PyLong_Check(ret.get())) { + PyErr_SetString(PyExc_TypeError, + "data_ptr method of Pointer object must return 64-bit int"); + return false; + } + *dev_ptr = PyLong_AsUnsignedLongLong(ret.get()); + if (PyErr_Occurred()) { + return false; + } + if (*dev_ptr == 0) { + return true; // valid nullptr + } + CUresult status = cuPointerGetAttribute( + dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, *dev_ptr); + if (status == CUDA_ERROR_INVALID_VALUE) { + PyErr_Format(PyExc_ValueError, + "Pointer argument cannot be accessed from Triton " + "(cpu tensor?)"); + return false; + } else if (status != CUDA_SUCCESS) { + CUDA_CHECK(status); + return false; + } + return true; +} + +// For a given type T, maps to the Python API with signature `U (*)(PyObject*)` +// that can extract values of that type from a PyObject. Note that the return +// type U is not guaranteed to be the same as T, but it can be explicitly casted +// to T. +template +constexpr auto kValueFunction = nullptr; +template +constexpr auto + kValueFunction && + std::is_signed_v && sizeof(T) <= 4>> = + PyLong_AsLong; +template <> +constexpr auto kValueFunction = PyLong_AsLongLong; +template +constexpr auto kValueFunction< + T, std::enable_if_t && std::is_unsigned_v && + sizeof(T) <= 4>> = PyLong_AsUnsignedLong; +template <> +constexpr auto kValueFunction = PyLong_AsUnsignedLongLong; +template +constexpr auto + kValueFunction>> = + PyFloat_AsDouble; + +// Extract a value of type T from obj and store it into memory pointed by ptr. +// Returns true if extraction succeeded, and false otherwise. +template +bool extractValue(PyObject* obj, void* ptr) { + *static_cast(ptr) = static_cast(kValueFunction(obj)); + return PyErr_Occurred() == nullptr; +} + +// Contains information necessary for extracting a certain type from a PyObject. +struct ExtractionInfo { + // Prefixes of types reprs supported by the extractor. + llvm::SmallVector supported_type_repr_prefixes; + std::size_t size; // Size required by the extracted value. + ExtractorType extractor; // Function to call to extract the value. + + // Builds an ExtractionInfo for a given type T and a list of type reprs that + // are backed by that type. + template + static ExtractionInfo build( + std::initializer_list supported_type_reprs, + ExtractorType extractor = extractValue) { + return {supported_type_reprs, sizeof(T), extractor}; + } + + // Checks if the extractor supports extracting a given type repr. + bool supports(llvm::StringRef type_repr) const { + return llvm::any_of( + supported_type_repr_prefixes, + [&](llvm::StringRef prefix) { return type_repr.starts_with(prefix); }); + } +}; + +// Array of supported extractors +const ExtractionInfo kExtractionInfos[]{ + ExtractionInfo::build({"'i8'"}), + ExtractionInfo::build({"'i16'"}), + ExtractionInfo::build({"'i1'", "'i32'"}), + ExtractionInfo::build({"'i64'"}), + ExtractionInfo::build({"'u8'"}), + ExtractionInfo::build({"'u16'"}), + ExtractionInfo::build({"'u1'", "'u32'"}), + ExtractionInfo::build({"'u64'"}), + ExtractionInfo::build({"'fp16'", "'bf16'", "'fp32'", "'f32'"}), + ExtractionInfo::build({"'fp64'"}), + ExtractionInfo::build({"'*"}, extractPointer), + ExtractionInfo{{"None"}, 0, nullptr}, // Represent constexprs as None +}; + +// Finds an extractor that supports a given type_repr in the extractor list. +// Returns nullopt if no such extractor is found. +std::optional findExtractor(llvm::StringRef type_repr) { + constexpr std::size_t kNumExtractors = std::size(kExtractionInfos); + static_assert(kNumExtractors < std::numeric_limits::max(), + "Not enough bits in a byte to store the extractor index"); + for (const auto& [idx, info] : llvm::enumerate(kExtractionInfos)) { + if (info.supports(type_repr)) return idx; + } + return std::nullopt; +} + +PyDoc_STRVAR(buildSignatureMetadata__doc__, + R"(buildSignatureMetadata(signature_iterator) -> bytes + +Build a metadata object describing the signature of a kernel. + +This can then be passed as the signature_metadata parameter to the launch() +function. + +:param signature: list of types describing the signature of a kernel, + specialized parameters should be represented with None +:type signature: sequence or iterable +:return: an opaque metadata object which can then be passed to launch() +:rtype: bytes +)"); +PyObject* buildSignatureMetadata(PyObject* self, PyObject* args) { + PyObject* signature = nullptr; + if (!PyArg_ParseTuple(args, "O", &signature)) { + return nullptr; + } + if (!PyIter_Check(signature)) { + PyErr_Format(PyExc_TypeError, + "expected signature to be an iterable, got %R", signature); + return nullptr; + } + + llvm::SmallVector signature_metadata; + while (UniquePyObjectPtr obj_type{PyIter_Next(signature)}) { + UniquePyObjectPtr repr(PyObject_Repr(obj_type.get())); + if (!repr) { + return nullptr; + } + UniquePyObjectPtr repr_str( + PyUnicode_AsEncodedString(repr.get(), "utf-8", "~E~")); + if (!repr_str) { + return nullptr; + } + const char* repr_bytes = PyBytes_AsString(repr_str.get()); + if (!repr_bytes) { + return nullptr; + } + std::optional extractor_idx = findExtractor(repr_bytes); + if (!extractor_idx.has_value()) { + PyErr_Format(PyExc_TypeError, + "unexpected type %R in kernel signature, dir: %R", + obj_type.get(), PyObject_Dir(obj_type.get())); + return nullptr; + } + signature_metadata.push_back(extractor_idx.value()); + } + if (PyErr_Occurred()) { + return nullptr; + } + + return PyBytes_FromStringAndSize(signature_metadata.data(), + signature_metadata.size()); +} + +// Launch a Python callable hook with metadata passed as parameters. +bool launchHook(PyObject* hook, PyObject* metadata) { + if (hook == Py_None) { + return true; + } + UniquePyObjectPtr args(Py_BuildValue("(O)", metadata)); + if (!args) { + return false; + } + UniquePyObjectPtr ret(PyObject_CallObject(hook, args.get())); + return static_cast(ret); +} + +static void ensureCudaContext() { + CUcontext pctx; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) { + // Ensure device context. + CUdevice device; + CUDA_CHECK(cuDeviceGet(&device, 0)); + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + } +} + +PyDoc_STRVAR( + launch__doc__, + R"(launch(gridDimX, gridDimY, gridDimZ, stream, kernel, packed_metadata, launch_metadata, launch_enter_hook, launch_exit_hook, kernel_arg_types, global_scratch, kernel_args) + +Launch a kernel on an Nvidia GPU. + +:param gridDimX: X dimension of the grid +:type gridDimX: signed integer +:param gridDimY: Y dimension of the grid +:type gridDimY: signed integer +:param gridDimZ: Z dimension of the grid +:type gridDimZ: signed integer +:param stream: CUDA Stream to launch on +:type stream: unsigned long integer (pointer) +:param kernel: CUDA kernel to launch +:type kernel: unsigned long integer (pointer) +:param packed_metadata: Kernel metadata, including in sequence: + number of warps, number of CTAs, required bytes of shared memory, + cluster dimensions x, y, and z +:type packed_metadata: 6-tuple +:param hook_args: arguments to pass to the enter and exit hooks +:type hook_args: object +:param launch_enter_hook: hook to call just before launching the kernel +:type launch_enter_hook: callable +:param launch_exit_hook: hook to call just after launching the kernel +:type launch_exit_hook: callable +:param signature_metadata: matadata built from build_signature_metadata +:type signature_metadata: bytes +:param global_scratch: pointer to global scratch memory +:type global_scratch: pointer +:param kernel_args: kernel parameters +:type kernel_args: tuple + +:raises RuntimeError: on kernel launch failure +)"); +PyObject* launch(PyObject* self, PyObject* args) { + ensureCudaContext(); + TritonLaunchConfig config{}; + auto& grid = config.grid; + auto& cluster = config.cluster; + // PyObject* kernel_metadata = nullptr; + PyObject* hook_args = nullptr; + PyObject* launch_enter_hook = nullptr; + PyObject* launch_exit_hook = nullptr; + PyBytesObject* signature_metadata_bytes = nullptr; + PyObject* kernel_args = nullptr; + PyObject* global_scratch = nullptr; + int num_ctas = 0; + if (!PyArg_ParseTuple(args, "iiiKK(iiiiii)OOOSOO", &grid.x, &grid.y, &grid.z, + &config.stream, &config.function, &config.num_warps, + &num_ctas, &config.shared_memory, &cluster.x, + &cluster.y, &cluster.z, &hook_args, &launch_enter_hook, + &launch_exit_hook, &signature_metadata_bytes, + &global_scratch, &kernel_args)) { + return nullptr; + } + if (num_ctas != cluster.size()) { + PyErr_Format( + PyExc_ValueError, + "Expected cluster dimensions (%d, %d, %d) to have a total size of %d", + cluster.x, cluster.y, cluster.z, num_ctas); + return nullptr; + } + llvm::ArrayRef signature_metadata( + PyBytes_AS_STRING(signature_metadata_bytes), + PyBytes_GET_SIZE(signature_metadata_bytes)); + UniquePyObjectPtr fast_kernel_args(PySequence_Fast( + kernel_args, "Expected kernel_args to be a sequence or iterable")); + if (!fast_kernel_args) { + return nullptr; + } + llvm::ArrayRef kernel_args_data( + PySequence_Fast_ITEMS(fast_kernel_args.get()), + PySequence_Fast_GET_SIZE(fast_kernel_args.get())); + + if (signature_metadata.size() != kernel_args_data.size()) { + PyErr_Format(PyExc_TypeError, + "Expected kernel to have %d parameters, but got %d", + signature_metadata.size(), kernel_args_data.size()); + return nullptr; + } + + // +1 for the global scratch pointer. + std::size_t num_params = signature_metadata.size() + 1; + // Use alloca to set up kernel parameters on the stack and avoid dynamic + // memory allocations. + config.params = static_cast(alloca(num_params * sizeof(void*))); + // This loop has to stay in the same function that owns params, since we are + // using alloca to allocate pointers to it on the stack of the function. + std::size_t params_idx = 0; + for (const auto& [converter_idx, arg] : + llvm::zip(signature_metadata, kernel_args_data)) { + if (converter_idx >= std::size(kExtractionInfos)) { + PyErr_SetString(PyExc_ValueError, "corrupted signature metadata"); + return nullptr; + } + const ExtractionInfo& extraction_info = kExtractionInfos[converter_idx]; + if (extraction_info.size == 0) { + continue; // skip adding constexpr parameters + } + config.params[params_idx] = alloca(extraction_info.size); + if (!extraction_info.extractor(arg, config.params[params_idx])) { + return nullptr; + } + ++params_idx; + } + config.params[params_idx] = alloca(sizeof(void*)); + if (!extractPointer(global_scratch, config.params[params_idx])) { + return nullptr; + } + + if (!launchHook(launch_enter_hook, hook_args)) { + return nullptr; + } + + // Launching the kernel might take a while and does not use Python APIs, so + // we can release the Global Interpreter Lock so other threads can use Python + // APIs if needed. + PyObject* result = nullptr; + Py_BEGIN_ALLOW_THREADS; + result = launchKernel(config); + Py_END_ALLOW_THREADS; + if (!result) { + return nullptr; + } + + if (!launchHook(launch_exit_hook, hook_args)) { + return nullptr; + } + + Py_RETURN_NONE; +} + +} // namespace + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + // Get device handle + CUdevice device; + cuDeviceGet(&device, device_id); + + // create a struct to hold device properties + int max_shared_mem; + int max_num_regs; + int multiprocessor_count; + int warp_size; + int sm_clock_rate; + int mem_clock_rate; + int mem_bus_width; + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); + CUDA_CHECK_AND_RETURN_NULL( + cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + max_shared_mem, "max_num_regs", max_num_regs, + "multiprocessor_count", multiprocessor_count, "warpSize", + warp_size, "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, "mem_bus_width", + mem_bus_width); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + CUdevice device; + if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, + &device)) { + return NULL; + } + CUfunction fun; + CUmodule mod; + int32_t n_regs = 0; + int32_t n_spills = 0; + // create driver handles + CUcontext pctx = 0; + + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx)); + if (!pctx) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGet(&device, 0)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx)); + } + + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuModuleGetFunction(&fun, mod, name)); + // get allocated registers and spilled registers from the function + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + n_spills /= 4; + // set dynamic shared memory if necessary + int shared_optin; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( + &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + if (shared > 49152 && shared_optin > 49152) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); + int shared_total, shared_static; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( + &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, + device)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute( + &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_optin - shared_static)); + } + Py_END_ALLOW_THREADS; + + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, + n_spills); +} + +typedef CUresult (*cuOccupancyMaxActiveClusters_t)( + int *numClusters, CUfunction func, const CUlaunchConfig *config); + +#if CUDA_VERSION >= 12000 +typedef CUresult (*cuTensorMapEncodeTiled_t)( + CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, + const cuuint64_t *globalStrides, const cuuint32_t *boxDim, + const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill); +#endif + +#define defineGetFunctionHandle(name, symbolName) \ + static symbolName##_t name() { \ + /* Open the shared library */ \ + void *libHandle = dlopen("libcuda.so.1", RTLD_LAZY); \ + if (!libHandle) { \ + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); \ + return NULL; \ + } \ + /* Clear any existing error */ \ + dlerror(); \ + symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \ + /* Check for errors */ \ + const char *err = dlerror(); \ + if (err) { \ + PyErr_SetString(PyExc_RuntimeError, \ + "Failed to retrieve " #symbolName " from libcuda.so.1"); \ + dlclose(libHandle); \ + return NULL; \ + } \ + return funcHandle; \ + } + +defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, + cuOccupancyMaxActiveClusters); + +#if CUDA_VERSION >= 12000 +defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, + cuTensorMapEncodeTiled); +#endif + +static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { + int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, + maxActiveClusters = -1; + int shared = 0; + CUfunction func; + + if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX, + &clusterDimY, &clusterDimZ)) { + return NULL; + } + + // Let each SM have one block + int maxActiveBlocks = 1; + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( + func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared)); + Py_END_ALLOW_THREADS; + + CUlaunchAttribute launchAttr[1]; + launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = clusterDimX; + launchAttr[0].value.clusterDim.y = clusterDimY; + launchAttr[0].value.clusterDim.z = clusterDimZ; + CUlaunchConfig config; + config.gridDimX = clusterDimX; + config.gridDimY = maxActiveBlocks * clusterDimY; + config.gridDimZ = clusterDimZ; + config.blockDimX = 128; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared; + config.hStream = 0; + config.numAttrs = 1; + config.attrs = launchAttr; + + static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL; + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuOccupancyMaxActiveClusters, + getCuOccupancyMaxActiveClustersHandle); + + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( + func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config)); + Py_END_ALLOW_THREADS; + return PyLong_FromLong(maxActiveClusters); +} + +static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { + long size; + if (!PyArg_ParseTuple(args, "l", &size)) { + return NULL; + } + if (size < 0) { + PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS; + + // Ensure we have an active context. + CUcontext ctx = NULL; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&ctx)); + if (!ctx) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuDevicePrimaryCtxRetain(&ctx, /*device=*/0)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(ctx)); + } + + // We can't set the fifo size after running a kernel that calls printf. This + // is true even if the set() call is a nop and the new size is the same as the + // old size. + // + // This is unfriendly, so check if the old size matches the new size, and skip + // the set() call if so. + size_t oldSize = 0; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuCtxGetLimit(&oldSize, CU_LIMIT_PRINTF_FIFO_SIZE)); + if (oldSize != size) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, size)); + } + + Py_END_ALLOW_THREADS; + Py_INCREF(Py_None); + return Py_None; +} + +// Simple helper to experiment creating TMA descriptors on the host. +// This is a useful to test TMA operations independently. +static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { +#if CUDA_VERSION < 12000 + return NULL; +#else + unsigned long long global_address; + uint64_t dim; + uint32_t tensorDim; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim, + &elementSize, &desc_address)) { + return NULL; + } + uint64_t dims[1] = {dim}; + uint64_t globalStrides[1] = {dim * elementSize}; + uint32_t boxDim[1] = {tensorDim}; + uint32_t elementStrides[1] = {1}; + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); + return NULL; + } + assert((elementSize * tensorDim) >= 32 && "block size too small."); + int rank = 1; + static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, + getCuTensorMapEncodeTiledHandle); + CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( + (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, + globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + Py_INCREF(Py_None); + return Py_None; +#endif +} + +// Simple helper to experiment creating TMA descriptors on the host. +// This is a useful to test TMA operations independently. +static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { +#if CUDA_VERSION < 12000 + return NULL; +#else + unsigned long long global_address; + uint64_t dims[2]; + uint32_t tensorDims[2]; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKKiiiK", &global_address, &dims[1], &dims[0], + &tensorDims[1], &tensorDims[0], &elementSize, + &desc_address)) { + return NULL; + } + uint64_t globalStrides[2] = {dims[0] * elementSize, + dims[0] * dims[1] * elementSize}; + uint32_t elementStrides[2] = {1, 1}; + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); + } + int rank = 2; + // Swizzling should be picked in codegen but since we need to set it on the + // descriptor we rely on a convention between this function and codegen. + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; + if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else { + assert(false && "block size too small."); + } + // The bounding box inner dimension must be less than or equal to the swizzle + // size. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + // We clamp the block size and the codegen will emit multiple copy operations. + if (contigDimSizeInByte > 128) { + tensorDims[0] = 128 / elementSize; + } + static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, + getCuTensorMapEncodeTiledHandle); + CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( + (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, + globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + Py_INCREF(Py_None); + return Py_None; +#endif +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided cubin into CUDA driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS, + "Python interface for cuOccupancyMaxActiveClusters function"}, + {"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS, + "Python interface for cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, x), which " + "controls how many bytes can be streamed from kernels before data starts " + "being dropped. This inherits all the limitations of this call; in " + "particular it's an error to change this value after launching any kernel " + "that calls printf()."}, + {"fill_1d_tma_descriptor", fill1DTMADescriptor, METH_VARARGS, "doc"}, + {"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"}, + {"build_signature_metadata", buildSignatureMetadata, METH_VARARGS, + buildSignatureMetadata__doc__}, + {"launch", launch, METH_VARARGS, launch__doc__}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_cuda_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + + PyModule_AddFunctions(m, ModuleMethods); + + return m; +} diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c deleted file mode 100644 index 12deb0d1e7a3..000000000000 --- a/third_party/nvidia/backend/driver.c +++ /dev/null @@ -1,421 +0,0 @@ -#include "cuda.h" -#include -#include -#define PY_SSIZE_T_CLEAN -#include - -// Raises a Python exception and returns false if code is not CUDA_SUCCESS. -static bool gpuAssert(CUresult code, const char *file, int line) { - if (code == CUDA_SUCCESS) - return true; - - const char *prefix = "Triton Error [CUDA]: "; - const char *str; - cuGetErrorString(code, &str); - char err[1024] = {0}; - strcat(err, prefix); - strcat(err, str); - PyGILState_STATE gil_state; - gil_state = PyGILState_Ensure(); - PyErr_SetString(PyExc_RuntimeError, err); - PyGILState_Release(gil_state); - return false; -} - -// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. -#define CUDA_CHECK_AND_RETURN_NULL(ans) \ - do { \ - if (!gpuAssert((ans), __FILE__, __LINE__)) \ - return NULL; \ - } while (0) - -// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. -#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ - do { \ - if (!gpuAssert((ans), __FILE__, __LINE__)) { \ - PyEval_RestoreThread(_save); \ - return NULL; \ - } \ - } while (0) - -// Used to check if functions exist in old CUDA driver versions. -#define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \ - do { \ - if ((funcPointer) == NULL) { \ - (funcPointer) = (initializerFunction)(); \ - if ((funcPointer) == NULL) { \ - return NULL; \ - } \ - } \ - } while (0) - -static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { - int device_id; - if (!PyArg_ParseTuple(args, "i", &device_id)) - return NULL; - // Get device handle - CUdevice device; - cuDeviceGet(&device, device_id); - - // create a struct to hold device properties - int max_shared_mem; - int max_num_regs; - int multiprocessor_count; - int warp_size; - int sm_clock_rate; - int mem_clock_rate; - int mem_bus_width; - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, - device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); - CUDA_CHECK_AND_RETURN_NULL( - cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); - - return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", - max_shared_mem, "max_num_regs", max_num_regs, - "multiprocessor_count", multiprocessor_count, "warpSize", - warp_size, "sm_clock_rate", sm_clock_rate, - "mem_clock_rate", mem_clock_rate, "mem_bus_width", - mem_bus_width); -} - -static PyObject *loadBinary(PyObject *self, PyObject *args) { - const char *name; - const char *data; - Py_ssize_t data_size; - int shared; - int device; - if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, - &device)) { - return NULL; - } - CUfunction fun; - CUmodule mod; - int32_t n_regs = 0; - int32_t n_spills = 0; - // create driver handles - CUcontext pctx = 0; - - Py_BEGIN_ALLOW_THREADS; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx)); - if (!pctx) { - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuDevicePrimaryCtxRetain(&pctx, device)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx)); - } - - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuModuleGetFunction(&fun, mod, name)); - // get allocated registers and spilled registers from the function - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); - n_spills /= 4; - // set dynamic shared memory if necessary - int shared_optin; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( - &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, - device)); - if (shared > 49152 && shared_optin > 49152) { - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); - int shared_total, shared_static; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( - &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, - device)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute( - &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - shared_optin - shared_static)); - } - Py_END_ALLOW_THREADS; - - if (PyErr_Occurred()) { - return NULL; - } - return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, - n_spills); -} - -typedef CUresult (*cuOccupancyMaxActiveClusters_t)( - int *numClusters, CUfunction func, const CUlaunchConfig *config); - -typedef CUresult (*cuTensorMapEncodeTiled_t)( - CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, - const cuuint64_t *globalStrides, const cuuint32_t *boxDim, - const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, - CUtensorMapFloatOOBfill oobFill); - -#define defineGetFunctionHandle(name, symbolName) \ - static symbolName##_t name() { \ - /* Open the shared library */ \ - void *libHandle = dlopen("libcuda.so.1", RTLD_LAZY); \ - if (!libHandle) { \ - PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); \ - return NULL; \ - } \ - /* Clear any existing error */ \ - dlerror(); \ - symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \ - /* Check for errors */ \ - const char *err = dlerror(); \ - if (err) { \ - PyErr_SetString(PyExc_RuntimeError, \ - "Failed to retrieve " #symbolName " from libcuda.so.1"); \ - dlclose(libHandle); \ - return NULL; \ - } \ - return funcHandle; \ - } - -defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, - cuOccupancyMaxActiveClusters); - -defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, - cuTensorMapEncodeTiled); - -static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { - int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, - maxActiveClusters = -1; - int shared = 0; - CUfunction func; - - if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX, - &clusterDimY, &clusterDimZ)) { - return NULL; - } - - // Let each SM have one block - int maxActiveBlocks = 1; - Py_BEGIN_ALLOW_THREADS; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( - func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared)); - Py_END_ALLOW_THREADS; - - CUlaunchAttribute launchAttr[1]; - launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launchAttr[0].value.clusterDim.x = clusterDimX; - launchAttr[0].value.clusterDim.y = clusterDimY; - launchAttr[0].value.clusterDim.z = clusterDimZ; - CUlaunchConfig config; - config.gridDimX = clusterDimX; - config.gridDimY = maxActiveBlocks * clusterDimY; - config.gridDimZ = clusterDimZ; - config.blockDimX = 128; - config.blockDimY = 1; - config.blockDimZ = 1; - config.sharedMemBytes = shared; - config.hStream = 0; - config.numAttrs = 1; - config.attrs = launchAttr; - - static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL; - INITIALIZE_FUNCTION_POINTER_IF_NULL(cuOccupancyMaxActiveClusters, - getCuOccupancyMaxActiveClustersHandle); - - Py_BEGIN_ALLOW_THREADS; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( - func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config)); - Py_END_ALLOW_THREADS; - return PyLong_FromLong(maxActiveClusters); -} - -static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { - long size; - if (!PyArg_ParseTuple(args, "l", &size)) { - return NULL; - } - if (size < 0) { - PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative"); - return NULL; - } - - Py_BEGIN_ALLOW_THREADS; - - // Ensure we have an active context. - CUcontext ctx = NULL; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&ctx)); - if (!ctx) { - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuDevicePrimaryCtxRetain(&ctx, /*device=*/0)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(ctx)); - } - - // We can't set the fifo size after running a kernel that calls printf. This - // is true even if the set() call is a nop and the new size is the same as the - // old size. - // - // This is unfriendly, so check if the old size matches the new size, and skip - // the set() call if so. - size_t oldSize = 0; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuCtxGetLimit(&oldSize, CU_LIMIT_PRINTF_FIFO_SIZE)); - if (oldSize != size) { - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, size)); - } - - Py_END_ALLOW_THREADS; - Py_INCREF(Py_None); - return Py_None; -} - -// Simple helper to experiment creating TMA descriptors on the host. -// This is a useful to test TMA operations independently. -static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { - unsigned long long global_address; - uint64_t dim; - uint32_t tensorDim; - int elementSize; - unsigned long long desc_address; - if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim, - &elementSize, &desc_address)) { - return NULL; - } - uint64_t dims[1] = {dim}; - uint64_t globalStrides[1] = {dim * elementSize}; - uint32_t boxDim[1] = {tensorDim}; - uint32_t elementStrides[1] = {1}; - CUtensorMapDataType type; - switch (elementSize) { - case 1: - type = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - case 2: - type = CU_TENSOR_MAP_DATA_TYPE_UINT16; - break; - case 4: - type = CU_TENSOR_MAP_DATA_TYPE_UINT32; - break; - default: - PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); - return NULL; - } - assert((elementSize * tensorDim) >= 32 && "block size too small."); - int rank = 1; - static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; - INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, - getCuTensorMapEncodeTiledHandle); - CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( - (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, - globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, - CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); - Py_INCREF(Py_None); - return Py_None; -} - -// Simple helper to experiment creating TMA descriptors on the host. -// This is a useful to test TMA operations independently. -static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { - unsigned long long global_address; - uint64_t dims[2]; - uint32_t tensorDims[2]; - int elementSize; - unsigned long long desc_address; - if (!PyArg_ParseTuple(args, "KKKiiiK", &global_address, &dims[1], &dims[0], - &tensorDims[1], &tensorDims[0], &elementSize, - &desc_address)) { - return NULL; - } - uint64_t globalStrides[2] = {dims[0] * elementSize, - dims[0] * dims[1] * elementSize}; - uint32_t elementStrides[2] = {1, 1}; - CUtensorMapDataType type; - switch (elementSize) { - case 1: - type = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - case 2: - type = CU_TENSOR_MAP_DATA_TYPE_UINT16; - break; - case 4: - type = CU_TENSOR_MAP_DATA_TYPE_UINT32; - break; - default: - PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); - } - int rank = 2; - // Swizzling should be picked in codegen but since we need to set it on the - // descriptor we rely on a convention between this function and codegen. - CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; - if (contigDimSizeInByte >= 128) { - swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - } else if (contigDimSizeInByte >= 64) { - swizzle = CU_TENSOR_MAP_SWIZZLE_64B; - } else if (contigDimSizeInByte >= 32) { - swizzle = CU_TENSOR_MAP_SWIZZLE_32B; - } else { - assert(false && "block size too small."); - } - // The bounding box inner dimension must be less than or equal to the swizzle - // size. - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 - // We clamp the block size and the codegen will emit multiple copy operations. - if (contigDimSizeInByte > 128) { - tensorDims[0] = 128 / elementSize; - } - static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; - INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, - getCuTensorMapEncodeTiledHandle); - CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( - (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, - globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, - swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); - Py_INCREF(Py_None); - return Py_None; -} - -static PyMethodDef ModuleMethods[] = { - {"load_binary", loadBinary, METH_VARARGS, - "Load provided cubin into CUDA driver"}, - {"get_device_properties", getDeviceProperties, METH_VARARGS, - "Get the properties for a given device"}, - {"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS, - "Python interface for cuOccupancyMaxActiveClusters function"}, - {"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS, - "Python interface for cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, x), which " - "controls how many bytes can be streamed from kernels before data starts " - "being dropped. This inherits all the limitations of this call; in " - "particular it's an error to change this value after launching any kernel " - "that calls printf()."}, - {"fill_1d_tma_descriptor", fill1DTMADescriptor, METH_VARARGS, "doc"}, - {"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"}, - - {NULL, NULL, 0, NULL} // sentinel -}; - -static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils", - NULL, // documentation - -1, // size - ModuleMethods}; - -PyMODINIT_FUNC PyInit_cuda_utils(void) { - PyObject *m = PyModule_Create(&ModuleDef); - if (m == NULL) { - return NULL; - } - - PyModule_AddFunctions(m, ModuleMethods); - - return m; -} diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index ee440bd4f633..7a1b6a34741e 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -1,21 +1,16 @@ +from collections.abc import Callable import functools import os -import sysconfig -import hashlib import subprocess -import tempfile -from pathlib import Path -from triton.runtime.build import _build -from triton.runtime.cache import get_cache_manager from triton.runtime import _allocation from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver from triton._utils import parse_list_string +from ._C import cuda_utils dirname = os.path.dirname(os.path.realpath(__file__)) include_dir = [os.path.join(dirname, "include")] libdevice_dir = os.path.join(dirname, "lib") -libraries = ['cuda'] @functools.lru_cache() @@ -48,26 +43,6 @@ def library_dirs(): return [libdevice_dir, *libcuda_dirs()] -def compile_module_from_src(src, name): - key = hashlib.sha256(src.encode("utf-8")).hexdigest() - cache = get_cache_manager(key) - ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1] - cache_path = cache.get_file(f"{name}.{ext}") - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "main.c") - with open(src_path, "w") as f: - f.write(src) - so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries) - with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True) - import importlib.util - spec = importlib.util.spec_from_file_location(name, cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - return mod - - # ------------------------ # Utils # ------------------------ @@ -81,13 +56,12 @@ def __new__(cls): return cls.instance def __init__(self): - mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils") - self.load_binary = mod.load_binary - self.get_device_properties = mod.get_device_properties - self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters - self.set_printf_fifo_size = mod.set_printf_fifo_size - self.fill_1d_tma_descriptor = mod.fill_1d_tma_descriptor - self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor + self.load_binary = cuda_utils.load_binary + self.get_device_properties = cuda_utils.get_device_properties + self.cuOccupancyMaxActiveClusters = cuda_utils.cuOccupancyMaxActiveClusters + self.set_printf_fifo_size = cuda_utils.set_printf_fifo_size + self.fill_1d_tma_descriptor = cuda_utils.fill_1d_tma_descriptor + self.fill_2d_tma_descriptor = cuda_utils.fill_2d_tma_descriptor # ------------------------ @@ -118,357 +92,41 @@ def ty_to_cpp(ty): }[ty] -def make_launcher(constants, signature, ids): - - def _extracted_type(ty): - if ty[0] == '*' or ty == "none": - return "PyObject*" - if ty == "nvTmaDesc": - return "PyObject*" - if ty[0] == '[': - if ty == "[]": - return "[]" - tys = parse_list_string(ty) - val = ','.join(map(_extracted_type, tys)) - return f"[{val}]" - return ty_to_cpp(ty) - - def format_of(ty): - if ty == "CUdeviceptr": - return "O" - if ty[0] == "[": - if ty == "[]": - return "()" - tys = parse_list_string(ty) - val = ''.join(map(format_of, tys)) - return f"({val})" - return { - "PyObject*": "O", - "float": "f", - "double": "d", - "long": "l", - "int8_t": "b", - "int16_t": "h", - "int32_t": "i", - "int64_t": "L", - "uint8_t": "B", - "uint16_t": "H", - "uint32_t": "I", - "uint64_t": "K", - }[ty] - - signature = {k: v for k, v in signature.items() if v != 'constexpr'} - args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) - format = "iiiKKOOOOO" + args_format - signature = ','.join(signature.values()).replace('[', '').replace(']', '') - signature = list(filter(bool, signature.split(','))) - signature = {i: s for i, s in enumerate(signature)} - args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - internal_args_list = [] - for i, ty in signature.items(): - if ty[0] == "*" or ty == "none": - internal_args_list.append(f"ptr_info{i}.dev_ptr") - elif ty == "nvTmaDesc": - # Note: we have to dereference the pointer - internal_args_list.append(f"*tma_ptr{i}") - else: - internal_args_list.append(f"_arg{i}") - params = range(len(signature)) - - # generate glue code - params = [f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none"] - params.append("&global_scratch") - src = f""" -#include \"cuda.h\" -#include -#include -#include - -static inline void gpuAssert(CUresult code, const char *file, int line) -{{ - if (code != CUDA_SUCCESS) - {{ - const char* prefix = "Triton Error [CUDA]: "; - const char* str; - cuGetErrorString(code, &str); - char err[1024] = {{0}}; - strcat(err, prefix); - strcat(err, str); - PyGILState_STATE gil_state; - gil_state = PyGILState_Ensure(); - PyErr_SetString(PyExc_RuntimeError, err); - PyGILState_Release(gil_state); - }} -}} - -#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} - -typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra); - -static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ - // Open the shared library - void* handle = dlopen("libcuda.so.1", RTLD_LAZY); - if (!handle) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); - return NULL; - }} - // Clear any existing error - dlerror(); - cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx"); - // Check for errors - const char *dlsym_error = dlerror(); - if (dlsym_error) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1"); - return NULL; - }} - return cuLaunchKernelExHandle; -}} - -static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ - void *params[] = {{ {', '.join(params)} }}; - if (gridX*gridY*gridZ > 0) {{ - if (num_ctas == 1) {{ - CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); - }} else {{ - CUlaunchAttribute launchAttr[2]; - launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launchAttr[0].value.clusterDim.x = clusterDimX; - launchAttr[0].value.clusterDim.y = clusterDimY; - launchAttr[0].value.clusterDim.z = clusterDimZ; - launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; - launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; - CUlaunchConfig config; - config.gridDimX = gridX * clusterDimX; - config.gridDimY = gridY * clusterDimY; - config.gridDimZ = gridZ * clusterDimZ; - config.blockDimX = 32 * num_warps; - config.blockDimY = 1; - config.blockDimZ = 1; - config.sharedMemBytes = shared_memory; - config.hStream = stream; - config.attrs = launchAttr; - config.numAttrs = 2; - static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; - if (cuLaunchKernelExHandle == NULL) {{ - cuLaunchKernelExHandle = getLaunchKernelExHandle(); - }} - CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); - }} - }} -}} - -typedef struct _DevicePtrInfo {{ - CUdeviceptr dev_ptr; - bool valid; -}} DevicePtrInfo; - -static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); - return ptr_info; - }} - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - if(ptr){{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); - if(!ptr_info.dev_ptr) - return ptr_info; - uint64_t dev_ptr; - int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); - if (status == CUDA_ERROR_INVALID_VALUE) {{ - PyErr_Format(PyExc_ValueError, - "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); - ptr_info.valid = false; - }} else if (status != CUDA_SUCCESS) {{ - CUDA_CHECK(status); // Catch any other cuda API errors - ptr_info.valid = false; - }} - ptr_info.dev_ptr = dev_ptr; - Py_DECREF(ret); // Thanks ChatGPT! - return ptr_info; - }} - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - ptr_info.valid = false; - return ptr_info; -}} - -static inline CUtensorMap* getTmaDesc(PyObject *obj) {{ - if (sizeof(CUtensorMap*) != 8) {{ - PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation"); - return NULL; - }} - - PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_cpu_ptr"); - if (!method_handle) {{ - PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist"); - return NULL; - }} - - PyObject *empty_tuple = PyTuple_New(0); - if (!empty_tuple) {{ - Py_DECREF(method_handle); - PyErr_SetString(PyExc_SystemError, "Internal Python error!"); - return NULL; - }} - PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(method_handle); - if (!method_ret) {{ - PyErr_SetString(PyExc_SystemError, "Internal Python error!"); - return NULL; - }} - - if (!PyLong_Check(method_ret)) {{ - PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() must return 64-bit int"); - Py_DECREF(method_ret); - return NULL; - }} - - uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret); - Py_DECREF(method_ret); - if (!ptr_as_uint) {{ - PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_cpu_ptr()"); - return NULL; - }} - if (ptr_as_uint % 64 != 0) {{ - PyErr_SetString(PyExc_ValueError, "tma_desc_cpu_ptr() must be 64-byte aligned"); - return NULL; - }} - - return (CUtensorMap*)(ptr_as_uint); -}} - -static void ensureCudaContext() {{ - CUcontext pctx; - CUDA_CHECK(cuCtxGetCurrent(&pctx)); - if (!pctx) {{ - // Ensure device context. - CUdevice device; - CUDA_CHECK(cuDeviceGet(&device, 0)); - CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); - CUDA_CHECK(cuCtxSetCurrent(pctx)); - }} -}} - -static PyObject* launch(PyObject* self, PyObject* args) {{ - // ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes - ensureCudaContext(); - - int gridX, gridY, gridZ; - uint64_t _stream; - uint64_t _function; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *kernel_metadata = NULL; - PyObject *launch_metadata = NULL; - PyObject *global_scratch_obj = NULL; - {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, - &_stream, &_function, &global_scratch_obj, - &kernel_metadata, &launch_metadata, - &launch_enter_hook, &launch_exit_hook{args_list})) {{ - return NULL; - }} - - int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; - if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ - PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); - return NULL; - }} - - // extract launch metadata - if (launch_enter_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_enter_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - }} - - CUdeviceptr global_scratch = 0; - if (global_scratch_obj != Py_None) {{ - DevicePtrInfo global_scratch_info = getPointer(global_scratch_obj, -1); - if (!global_scratch_info.valid) {{ - return NULL; - }} - global_scratch = global_scratch_info.dev_ptr; - }} - - // raise exception asap - {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])}; - {"".join([f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" if ty == "nvTmaDesc" else "" for i, ty in signature.items()])}; - Py_BEGIN_ALLOW_THREADS; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); - Py_END_ALLOW_THREADS; - if (PyErr_Occurred()) {{ - return NULL; - }} - - if(launch_exit_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_exit_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - - }} - - Py_RETURN_NONE; -}} - -static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel -}}; - -static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_launcher\", - NULL, //documentation - -1, //size - ModuleMethods -}}; - -PyMODINIT_FUNC PyInit___triton_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; -}} -""" - return src +def make_launcher(constants : dict[int, str], signature : dict[int, any]) -> Callable[..., None]: + # We seem to have 3 categories of arguments: + # 1. arguments listed in signature + # 2. arguments listed in constants + # 3. those present in both signature and constants + # TODO(gflegar): why is that? + # Category (2) does not get passed to the launcher, but category (3) does. + # However, the launcher is supposed to ignore the arguments passed from + # category (3). The generic C++ launcher currently does not do that, so we + # are doing it in the python wrapper. + signature_metadata = cuda_utils.build_signature_metadata( + ty if arg_id not in constants else None + for arg_id, ty in signature.items()) + def wrapper(grid_dim_x: int, grid_dim_y: int, grid_dim_z: int, + stream: int, kernel: int, global_scratch: any, + packed_metadata: tuple[int, int, int, int, int, int], + hook_args: any, + launch_enter_hook: Callable[..., None], + launch_exit_hook: Callable[..., None], + *args: any) -> None: + cuda_utils.launch(grid_dim_x, grid_dim_y, grid_dim_z, stream, kernel, + packed_metadata, hook_args, launch_enter_hook, + launch_exit_hook, signature_metadata, global_scratch, + args) + return wrapper class CudaLauncher(object): def __init__(self, src, metadata): - ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} - constants = src.constants if hasattr(src, "constants") else dict() - constants = {idx: value for idx, value in constants.items()} - signature = {idx: value for idx, value in src.signature.items()} - src = make_launcher(constants, signature, ids) - mod = compile_module_from_src(src, "__triton_launcher") - self.launch = mod.launch + constants = getattr(src, "constants", dict()) + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + self.launch = make_launcher(constants, signature) self.global_scratch_size = metadata.global_scratch_size self.global_scratch_align = metadata.global_scratch_align @@ -506,6 +164,7 @@ def get_device_interface(self): @staticmethod def is_active(): + return True import torch return torch.cuda.is_available() and (torch.version.hip is None) diff --git a/third_party/nvidia/language/cuda/BUILD b/third_party/nvidia/language/cuda/BUILD new file mode 100644 index 000000000000..55e6ec8795c1 --- /dev/null +++ b/third_party/nvidia/language/cuda/BUILD @@ -0,0 +1,13 @@ +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__subpackages__", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["**/*.py"], + ), +) diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 8de0efefca84..637071275e39 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -291,10 +291,36 @@ class WGMMAWaitGroupOpPattern : public OpRewritePattern { Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const { auto outputStructType = cast(op.getType()); - uint32_t numOutputRegs = outputStructType.getBody().size(); - std::string output = - outputStructType.getBody().front().isF32() ? "=f" : "=r"; - return Constraints(numOutputRegs, output); + std::vector outputConstraints; + outputConstraints.reserve(outputStructType.getBody().size()); + for (mlir::Type type : outputStructType.getBody()) { + if (type.isF32()) { + outputConstraints.push_back("=f"); + continue; + } else if (type.isF64()) { + outputConstraints.push_back("=d"); + continue; + } + unsigned bitwidth = isa(type) ? + 64 : type.getIntOrFloatBitWidth(); + switch (bitwidth) { + case 1: + outputConstraints.push_back("=b"); + break; + case 16: + outputConstraints.push_back("=h"); + break; + case 32: + outputConstraints.push_back("=r"); + break; + case 64: + outputConstraints.push_back("=l"); + break; + default: + assert(false && "unsupported bitwidth"); + } + } + return outputConstraints; } OperandsAndConstraints diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 85f7da2cb5b3..6b560d966927 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -89,8 +89,8 @@ int64_t getSwizzlingFromLayout(const SharedEncodingAttr &layout, return swizzlingByteWidth; } -static Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc, - int64_t swizzling, uint32_t stride) { +Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc, + int64_t swizzling, uint32_t stride) { // Create descriptor based on the format described in the spec: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor union WGMMADescriptor { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index ef69b96fce1e..1f755929ada9 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -265,7 +265,8 @@ static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType, outVecWidthBits](Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) -> SmallVector { int numElements = v.size(); - assert(numElements == 4 || numElements == 2 && "invalid vector size"); + assert(numElements == 8 || numElements == 4 || + numElements == 2 && "invalid vector size"); auto ctx = rewriter.getContext(); int inBitwidth = inType.getIntOrFloatBitWidth(); @@ -387,10 +388,10 @@ struct FpToFpOpConversion ptx = "cvt.rz.f16.f32"; break; default: - llvm::errs() << "WARNING: unsupported rounding mode for f32->f16 " - "conversion: " - << stringifyRoundingMode(rounding) << "\n"; - llvm_unreachable(""); + llvm::report_fatal_error( + "WARNING: unsupported rounding mode for f32->f16 " + "conversion: " + stringifyRoundingMode(rounding) + + "\n"); } auto &cvt = *builder.create(ptx.str()); auto res = builder.newOperand("=h"); @@ -447,10 +448,10 @@ struct FpToFpOpConversion } if (computeCapability < 89 && (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) { - llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " - "compute capability >= 89" - << "\n"; - llvm_unreachable(""); + llvm::report_fatal_error( + "Conversion from/to f8e4m3nv is only supported on " + "compute capability >= 89" + "\n"); } auto convDesc = srcMap.lookup(key); return {makeConverterFromPtx( @@ -475,9 +476,9 @@ struct FpToFpOpConversion // For now only RTNE is supported for conversions from fp16 to fp8 if (!srcElementType.isF32() && roundingMode.value() != RoundingMode::RTNE) { - llvm::errs() << "Unsupported rounding mode for conversion to fp8: " - << stringifyRoundingMode(roundingMode.value()) << "\n"; - llvm_unreachable(""); + llvm::report_fatal_error( + "Unsupported rounding mode for conversion to fp8: " + + stringifyRoundingMode(roundingMode.value()) + "\n"); } } @@ -671,6 +672,114 @@ struct SIToFPOpConversion : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), computeCapability(computeCapability) {} + LogicalResult matchAndRewrite( + arith::SIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (succeeded(matchAndRewriteInt4ToBf16Conversion(op, rewriter))) { + return success(); + } + return Base::matchAndRewrite(op, adaptor, rewriter); + } + + // Matches subgraph of convert 8xi4 to 8xbf16 and rewrites it to inline PTX. + LogicalResult matchAndRewriteInt4ToBf16Conversion( + arith::SIToFPOp op, ConversionPatternRewriter &rewriter) const { + if (computeCapability < 90) return failure(); + Type inElemTy = getElementType(op.getIn()); + Type outElemTy = getElementType(op.getOut()); + if (!inElemTy.isInteger(8) || !outElemTy.isBF16()) return failure(); + FailureOr unpack = matchInt4Unpack(op.getIn()); + if (failed(unpack)) return failure(); + + Location loc = op.getLoc(); + Value src = rewriter.getRemappedValue(unpack.value()); + auto structTy = dyn_cast(src.getType()); + if (!structTy || structTy.getBody().size() % 4 != 0) return failure(); + auto isInt8 = [](Type type) { return type.isInteger(8); }; + if (!all_of(structTy.getBody(), isInt8)) return failure(); + + const LLVMTypeConverter *typeConverter = getTypeConverter(); + assert(inElemTy == typeConverter->convertType(inElemTy)); + assert(outElemTy == typeConverter->convertType(outElemTy)); + + const std::string S4_to_Bf16_sm90 = R"({ + .reg .b32 r<4>, mi, mf; + mov.b32 mi, 0x43404340 - 0x00080008; + mov.b32 mf, 0x43404340; + // Shift 4-bit inputs to 16-bit boundary. + shr.u32 r1, $4, 4; + shr.u32 r2, $4, 8; + shr.u32 r3, $4, 12; + // Sign-extend from 4 bits is equivalent to (x ^ 0x8) - 0x8. + lop3.b32 r0, $4, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; + lop3.b32 r1, r1, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; + lop3.b32 r2, r2, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; + lop3.b32 r3, r3, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; + // Interger-add magic number (minus bias from sign-extend above). + add.s16x2 r0, r0, mi; + add.s16x2 r1, r1, mi; + add.s16x2 r2, r2, mi; + add.s16x2 r3, r3, mi; + // Float-subtract magic number. + sub.bf16x2 r0, r0, mf; + sub.bf16x2 r1, r1, mf; + sub.bf16x2 r2, r2, mf; + sub.bf16x2 r3, r3, mf; + // Shuffle results into correct order. + prmt.b32 $0, r1, r0, 0x5410; + prmt.b32 $1, r3, r2, 0x5410; + prmt.b32 $2, r1, r0, 0x7632; + prmt.b32 $3, r3, r2, 0x7632; + })"; + + SmallVector resultVals; + SmallVector unpackedVals = unpackLLElements(loc, src, rewriter); + auto cvtFunc = makeConverterFromPtx(S4_to_Bf16_sm90, inElemTy, outElemTy); + for (ValueRange operands = unpackedVals; !operands.empty(); + operands = operands.drop_front(4)) { + SmallVector inVals = { + operands[0], operands[1], operands[2], operands[3], + // Repeat operands so that cvtFunc produces 8 outputs. + operands[0], operands[1], operands[2], operands[3]}; + auto outVals = cvtFunc(loc, rewriter, inVals); + assert(inVals.size() == outVals.size()); + resultVals.append(outVals.begin(), outVals.end()); + } + + resultVals = maybeDeduplicate(op, resultVals); + Value view = + packLLElements(loc, typeConverter, resultVals, rewriter, op.getType()); + rewriter.replaceOp(op, view); + + return success(); + } + + // Returns the source if value is the result of an 2xi4 -> 2xi8 unpack + // sequence. + static FailureOr matchInt4Unpack(Value value) { + auto reshape = value.getDefiningOp(); + if (!reshape) return failure(); + auto join = reshape.getSrc().getDefiningOp(); + if (!join) return failure(); + auto shrHi = join.getLhs().getDefiningOp(); + if (!shrHi || !isConst4(shrHi.getRhs())) return failure(); + auto shrLo = join.getRhs().getDefiningOp(); + if (!shrLo || !isConst4(shrLo.getRhs())) return failure(); + auto shlLo = shrLo.getLhs().getDefiningOp(); + if (!shlLo || !isConst4(shlLo.getRhs())) return failure(); + if (shrHi.getLhs() != shlLo.getLhs()) return failure(); + return shrHi.getLhs(); + } + + // Returns true if the value is equal to 4. + static bool isConst4(Value v) { + auto constOp = v.getDefiningOp(); + if (!constOp) return false; + auto attr = mlir::dyn_cast(constOp.getValue()); + if (!attr || !attr.isSplat()) return false; + return attr.getSplatValue().getLimitedValue() == 4; + }; + SmallVector createDestOps(arith::SIToFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, diff --git a/third_party/proton/BUILD b/third_party/proton/BUILD new file mode 100644 index 000000000000..4f85a5b62e18 --- /dev/null +++ b/third_party/proton/BUILD @@ -0,0 +1,107 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +td_library( + name = "td_files", + srcs = glob(["dialect/include/Dialect/Proton/IR/*.td"]), + includes = ["dialect/include"], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "//:td_files", + ], +) + +gentbl_cc_library( + name = "proton_ir_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "dialect/include/Dialect/Proton/IR/ProtonAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "dialect/include/Dialect/Proton/IR/ProtonAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "proton_ir_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "dialect/include/Dialect/Proton/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "dialect/include/Dialect/Proton/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "dialect/include/Dialect/Proton/IR/ProtonDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "proton_ir_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-enum-decls"], + "dialect/include/Dialect/Proton/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "dialect/include/Dialect/Proton/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-op-decls"], + "dialect/include/Dialect/Proton/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "dialect/include/Dialect/Proton/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "dialect/include/Dialect/Proton/IR/ProtonOps.td", + deps = ["td_files"], +) + +cc_library( + name = "ProtonIRDialect", + srcs = glob([ + "dialect/lib/Dialect/Proton/IR/*.cpp", + ]), + hdrs = glob([ + "dialect/include/Dialect/Proton/IR/*.h", + ]), + includes = [ + "..", # because proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc + "dialect/include", + ], + deps = [ + ":proton_ir_attr_inc_gen", + ":proton_ir_dialect_inc_gen", + ":proton_ir_ops_inc_gen", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", + "//:TritonDialects", + ], +) diff --git a/third_party/proton/proton/_C/include b/third_party/proton/proton/_C/include index fe4f4a1aa9bd..4400934bdf78 120000 --- a/third_party/proton/proton/_C/include +++ b/third_party/proton/proton/_C/include @@ -1 +1 @@ -../../csrc/include/ \ No newline at end of file +../../csrc/include \ No newline at end of file diff --git a/unittest/BUILD b/unittest/BUILD new file mode 100644 index 000000000000..4cbadcfa4655 --- /dev/null +++ b/unittest/BUILD @@ -0,0 +1,144 @@ +load("//tools/build_defs/build_test:build_test.bzl", "build_test") + +package( + default_applicable_licenses = ["//:license"], + default_compatible_with = ["//buildenv/target:non_prod"], + default_visibility = ["//:__subpackages__"], +) + +cc_test( + name = "AnalysisTest", + srcs = glob(["Analysis/*.cpp"]), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTestCatchAll", + srcs = glob( + [ + "Dialect/**/*.cpp", + ], + exclude = [ + "Dialect/TritonGPU/DialectTest.cpp", + "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp", + "Dialect/TritonGPU/SwizzleTest.cpp", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTest", + srcs = [ + "Dialect/TritonGPU/DialectTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "LinearLayoutConversionsTest", + srcs = [ + "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "SwizzleTest", + srcs = [ + "Dialect/TritonGPU/SwizzleTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "ConversionTest", + srcs = glob( + [ + "Conversion/**/*.cpp", + "Conversion/**/*.h", + ], + exclude = [ + "Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.h", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "//:TritonDialects", + "//:TritonNvidiaGPUTransforms", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +build_test( + name = "build_test", + allow_empty_target = False, + targets = [ + ":ConversionTest", + ":AnalysisTest", + ":DialectTest", + ], +)