Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ Do not modify directly.*
|||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
|||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|20+|**T** = tensor(float)|
|Gemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|Gemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[9, 10]|**T** = tensor(double), tensor(float)|
|||[7, 8]|**T** = tensor(double), tensor(float)|
Expand Down Expand Up @@ -230,7 +230,7 @@ Do not modify directly.*
|||[18, 21]|**T** = tensor(float)|
|||[11, 17]|**T** = tensor(float)|
|||[2, 10]|**T** = tensor(float)|
|MatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|MatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[1, 8]|**T** = tensor(double), tensor(float)|
|MatMulInteger|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *out* Y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int32)|
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -657,8 +657,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul);
Expand Down Expand Up @@ -2336,6 +2338,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
// opset 13
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, Cast)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Clip)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double,
MatMul)>,
Expand All @@ -2346,6 +2349,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Min)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Max)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Mean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Sign)>,
Expand Down
58 changes: 58 additions & 0 deletions onnxruntime/core/providers/cpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
.TypeConstraint("T", BuildKernelDefConstraints<int64_t, uint64_t>()),
MatMul<int64_t>);

ONNX_CPU_OPERATOR_TYPED_KERNEL(
MatMul,
13,
MLFloat16,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
MatMul<MLFloat16>);

ONNX_CPU_OPERATOR_TYPED_KERNEL(
MatMul,
13,
Expand Down Expand Up @@ -133,6 +140,57 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {

return Status::OK();
}

template <>
Status MatMul<MLFloat16>::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();

const Tensor* a = ctx->Input<Tensor>(0);
const Tensor* b = ctx->Input<Tensor>(1);
const auto& b_shape = b->Shape();

// match CUDA kernel implementation, ignore transpose for vectors
const bool trans_a = false;
const bool trans_b = false;

MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, trans_a, trans_b, false, false));
Tensor* y = ctx->Output(0, helper.OutputShape());

// Bail out early if the output is going to be empty
if (y->Shape().Size() == 0)
return Status::OK();

if (helper.K() == 0) {
// When we have (M, 0, N) then the inputs are empty, but the output should
// be filled out with zeros.
memset(y->MutableDataRaw(), 0, y->SizeInBytes());
return Status::OK();
}

const auto* a_data = a->Data<MLFloat16>();
const auto* b_data = b ? b->Data<MLFloat16>() : nullptr;
auto* y_data = y->MutableData<MLFloat16>();

const size_t max_len = helper.OutputOffsets().size();
const size_t M = static_cast<size_t>(helper.M());
const size_t N = static_cast<size_t>(helper.N());
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(trans_a);
const size_t ldb = helper.Ldb(trans_b);
std::vector<MLAS_HALF_GEMM_DATA_PARAMS> data(max_len);
for (size_t i = 0; i < max_len; i++) {
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].B = b_data + helper.RightOffsets()[i];
data[i].ldb = ldb;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
}
MlasHalfGemmBatch(M, N, K, max_len, data.data(), thread_pool);
return Status::OK();
}

#if defined(__aarch64__) && defined(__linux__)
bool GemmPackBBfloat16(AllocatorPtr& alloc,
const Tensor& tensor_b,
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/providers/cpu/math/gemm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ auto get_bias_value = [](const std::vector<float>& bias_data, BiasType bias_type

} // namespace

// Only CUDA, ROCM, CoreML and XNNPack kernels have float 16 support
// Only CPU, CUDA, ROCM, CoreML and XNNPack kernels have float 16 support
TEST(GemmOpTest, GemmNoTrans_f16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
Expand Down Expand Up @@ -196,7 +196,7 @@ TEST(GemmOpTest, GemmNoTrans_f16) {
}
}

// Only CUDA, ROCM and CoreML kernels have float 16 support
// Only CPU, CUDA, ROCM and CoreML kernels have float 16 support
TEST(GemmOpTest, GemmTransB_f16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/providers/cpu/math/matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ TEST(MathOpTest, MatMulZeroKInt32Type) {
RunMatMulZeroKTest<int32_t>();
}

#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) || defined(USE_XNNPACK)
// #if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_COREML) || defined(USE_XNNPACK)
TEST(MathOpTest, MatMul_Float16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
Expand Down Expand Up @@ -481,7 +481,7 @@ TEST(MathOpTest, MatMul_Float16) {
run_test(true);
run_test(false);
}
#endif
// #endif

#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL)
TEST(MathOpTest, MatMul_bfloat16) {
Expand Down
Loading