forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
RegisterMkldnnOpContextClass.cpp
122 lines (102 loc) · 5.18 KB
/
RegisterMkldnnOpContextClass.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include <ATen/Config.h>
#if AT_MKLDNN_ENABLED()
#include <ATen/Tensor.h>
#include <ATen/native/mkldnn/ConvPrepack.h>
#include <ATen/native/mkldnn/OpContext.h>
#include <ATen/native/mkldnn/Utils.h>
#include <torch/custom_class.h>
#include <torch/library.h>
namespace at {
namespace native {
namespace mkldnn {
using namespace internal::convolution;
static bool is_mkldnn_bf16_supported() {
#if defined(__aarch64__)
return mkldnn_bf16_device_check_arm();
#else
return mkldnn_bf16_device_check();
#endif
}
static bool is_mkldnn_fp16_supported() {
return mkldnn_fp16_device_check();
}
constexpr bool is_mkldnn_acl_supported() {
return AT_MKLDNN_ACL_ENABLED();
}
TORCH_LIBRARY(mkldnn, m) {
m.class_<ConvOpContext>(TORCH_SELECTIVE_CLASS("ConvOpContext"))
.def_pickle(
[](const c10::intrusive_ptr<ConvOpContext>& op_context)
-> SerializationTypeConvPrePack { // __getstate__
return op_context->unpack();
},
[](SerializationTypeConvPrePack state)
-> c10::intrusive_ptr<ConvOpContext> { // __setstate__
return createConvPrePackOpContext(
std::move(std::get<0>(state)),
std::move(std::get<1>(state)),
std::move(std::get<2>(state)),
std::move(std::get<3>(state)),
std::move(std::get<4>(state)),
std::move(std::get<5>(state)),
std::move(std::get<6>(state)),
std::move(std::get<7>(state)));
});
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_linear_pointwise(Tensor X, Tensor W, Tensor? B, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_linear_pointwise.binary(Tensor X, Tensor other, Tensor W, Tensor? B, str attr) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_convolution_pointwise(Tensor X, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_convolution_pointwise.binary(Tensor X, Tensor other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_convolution_pointwise_.binary(Tensor(a!) other, Tensor X, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor(a!) Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_convolution_transpose_pointwise(Tensor X, Tensor W, Tensor? B, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_reorder_convolution_transpose_weight(Tensor self, int[2] padding=0, int[2] output_padding=0, int[2] stride=1, int[2] dilation=1, int groups=1, int[]? input_size=None) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_reorder_linear_weight(Tensor self, int? batch_size=None) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_reorder_convolution_weight(Tensor self, int[2] padding=0, int[2] stride=1, int[2] dilation=1, int groups=1, int[]? input_size=None) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_reorder_mkldnn_rnn_layer_weight(Tensor weight0, Tensor weight1, int hidden_size, bool reverse, bool has_biases, bool batch_first, int[]? input_size=None) -> Tensor[] Y"));
m.def("_is_mkldnn_bf16_supported", &is_mkldnn_bf16_supported);
m.def("_is_mkldnn_fp16_supported", &is_mkldnn_fp16_supported);
m.def("_is_mkldnn_acl_supported", &is_mkldnn_acl_supported);
m.def("mkldnn::data_ptr(Tensor mkldnn_tensor) -> int");
m.def("mkldnn::_get_mkldnn_serialized_md (Tensor mkldnn_tensor) -> Tensor");
m.def("mkldnn::_nbytes(Tensor mkldnn_tensor) -> int");
}
TORCH_LIBRARY(mkldnn_prepacked, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn_prepacked::conv2d_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] dilation, int groups, int[4] input_size, str attr) -> __torch__.torch.classes.mkldnn.ConvOpContext"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn_prepacked::conv2d_run(Tensor X, __torch__.torch.classes.mkldnn.ConvOpContext W_prepack) -> Tensor Y"));
}
TORCH_LIBRARY_IMPL(mkldnn_prepacked, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn_prepacked::conv2d_prepack"),
TORCH_FN(createConvPrePackOpContext));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn_prepacked::conv2d_run"), TORCH_FN(conv_run));
}
} // namespace mkldnn
} // namespace native
} // namespace at
#endif // AT_MKLDNN_ENABLED()
#if AT_MKL_ENABLED() && AT_MKLDNN_ENABLED()
namespace at {
namespace native {
namespace mkl {
TORCH_LIBRARY(mkl, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"mkl::_mkl_reorder_linear_weight(Tensor X, int batch_size) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkl::_mkl_linear(Tensor X, Tensor MKL_W, Tensor ORI_W, Tensor? B, int batch_size) -> Tensor"));
}
} // namespace mkl
} // namespace native
} // namespace at
#endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED