Skip to content

Commit e4b72e7

Browse files
src: cpu: aarch64: Enable matmul static quantisation.
1 parent 3a06411 commit e4b72e7

File tree

3 files changed

+463
-3
lines changed

3 files changed

+463
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
/*******************************************************************************
2+
* Copyright 2024 Arm Ltd. and affiliates
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include "cpu/aarch64/matmul/acl_lowp_matmul_sq.hpp"
18+
19+
namespace dnnl {
20+
namespace impl {
21+
namespace cpu {
22+
namespace aarch64 {
23+
namespace matmul {
24+
25+
status_t acl_lowp_matmul_sq_resource_t::configure(
26+
const acl_lowp_matmul_sq_conf_t &almc) {
27+
28+
if (!acl_obj_) return status::out_of_memory;
29+
30+
acl_obj_->src_tensor.allocator()->init(almc.src_tensor_info);
31+
acl_obj_->wei_tensor.allocator()->init(almc.wei_tensor_info);
32+
if (almc.with_bias) {
33+
acl_obj_->bia_tensor.allocator()->init(almc.bia_tensor_info);
34+
}
35+
acl_obj_->dst_tensor.allocator()->init(almc.dst_tensor_info);
36+
37+
arm_compute::QuantizationInfo qi {1.0, 0, true};
38+
acl_obj_->src_tensor.info()->set_quantization_info(qi);
39+
acl_obj_->wei_tensor.info()->set_quantization_info(qi);
40+
acl_obj_->dst_tensor.info()->set_quantization_info(qi);
41+
42+
acl_obj_->gemm.configure(&acl_obj_->src_tensor, &acl_obj_->wei_tensor,
43+
almc.with_bias ? &acl_obj_->bia_tensor : nullptr,
44+
&acl_obj_->dst_tensor, almc.gemm_info);
45+
46+
return status::success;
47+
}
48+
49+
status_t acl_lowp_matmul_sq_t::pd_t::init(engine_t *engine) {
50+
VDISPATCH_MATMUL(set_default_formats(), "failed to set default formats");
51+
using smask_t = primitive_attr_t::skip_mask_t;
52+
VDISPATCH_MATMUL(
53+
attr()->has_default_values(smask_t::scales_runtime
54+
| smask_t::zero_points_runtime | smask_t::post_ops),
55+
"only scale, zero point and post-ops attrs supported");
56+
57+
VDISPATCH_MATMUL(attr()->scales_.get(DNNL_ARG_SRC).mask_ == 0
58+
&& attr()->zero_points_.get(DNNL_ARG_SRC) == 0
59+
&& attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0
60+
&& attr()->zero_points_.get(DNNL_ARG_WEIGHTS) == 0
61+
&& attr()->scales_.get(DNNL_ARG_DST).mask_ == 0
62+
&& attr()->zero_points_.get(DNNL_ARG_DST) == 0,
63+
"common scales and zero points only");
64+
65+
VDISPATCH_MATMUL(
66+
!has_runtime_dims_or_strides(), VERBOSE_RUNTIMEDIM_UNSUPPORTED);
67+
68+
const memory_desc_wrapper src_d(src_md_);
69+
const memory_desc_wrapper wei_d(weights_md_);
70+
const memory_desc_wrapper bia_d(bias_md_);
71+
const memory_desc_wrapper dst_d(dst_md_);
72+
73+
using namespace data_type;
74+
VDISPATCH_MATMUL(utils::one_of(src_d.data_type(), s8, u8)
75+
&& wei_d.data_type() == s8
76+
&& src_d.data_type() == s8
77+
? dst_d.data_type() == s8
78+
: dst_d.data_type() == u8
79+
&& utils::one_of(bia_d.data_type(), f32, undef),
80+
VERBOSE_UNSUPPORTED_DT_CFG);
81+
82+
VDISPATCH_MATMUL(src_d.matches_tag(format_tag::ab)
83+
&& wei_d.matches_tag(format_tag::ab)
84+
&& dst_d.matches_tag(format_tag::ab),
85+
VERBOSE_UNSUPPORTED_TAG);
86+
87+
VDISPATCH_MATMUL_SC(
88+
memory_desc_init_by_tag(bias_md_, bias_md_.ndims, bias_md_.dims,
89+
bias_md_.data_type, format_tag::ab),
90+
VERBOSE_UNSUPPORTED_BIAS_CFG);
91+
92+
// We set the QuantizationInfo to be dynamic because it is re-set in run()
93+
almc_.src_tensor_info
94+
= arm_compute::TensorInfo(arm_compute::TensorShape(K(), M()), 1,
95+
acl_utils::get_acl_data_t(src_d.data_type(), true),
96+
arm_compute::QuantizationInfo(1.0, 0, true));
97+
almc_.src_tensor_info.set_are_values_constant(false);
98+
99+
almc_.wei_tensor_info
100+
= arm_compute::TensorInfo(arm_compute::TensorShape(N(), K()), 1,
101+
acl_utils::get_acl_data_t(wei_d.data_type(), true),
102+
arm_compute::QuantizationInfo(1.0, 0, true));
103+
almc_.wei_tensor_info.set_are_values_constant(false);
104+
105+
almc_.dst_tensor_info
106+
= arm_compute::TensorInfo(arm_compute::TensorShape(N(), M()), 1,
107+
acl_utils::get_acl_data_t(dst_d.data_type(), true),
108+
arm_compute::QuantizationInfo(1.0, 0, true));
109+
110+
almc_.bia_tensor_info = arm_compute::TensorInfo(
111+
arm_compute::TensorShape(), 1, arm_compute::DataType::S32);
112+
almc_.with_bias = bia_d.format_kind() != format_kind::undef;
113+
if (almc_.with_bias) {
114+
// This is not currently guarded in ACL
115+
VDISPATCH_MATMUL(bia_d.ndims() == 2 && bia_d.dims()[0] == 1
116+
&& bia_d.dims()[1] == N(),
117+
"Only 1xN bias is supported");
118+
almc_.bia_tensor_info.set_tensor_shape(
119+
arm_compute::TensorShape(bia_d.dims()[1], bia_d.dims()[0]));
120+
}
121+
122+
arm_compute::GEMMLowpOutputStageInfo info;
123+
124+
info.type = arm_compute::GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
125+
info.gemmlowp_multiplier = 1073741824;
126+
info.gemmlowp_shift = -1;
127+
info.gemmlowp_offset = 0;
128+
info.gemmlowp_min_bound = -128;
129+
info.gemmlowp_max_bound = 127;
130+
info.output_data_type = almc_.dst_tensor_info.data_type();
131+
132+
almc_.gemm_info.set_gemmlowp_output_stage(info);
133+
134+
auto scratchpad = scratchpad_registry().registrar();
135+
const dnnl::impl::memory_desc_t dst_md_ {desc_.dst_desc};
136+
arm_compute::ActivationLayerInfo act_info;
137+
CHECK(init_scratchpad(engine, scratchpad, acl_post_ops, attr_.post_ops_,
138+
act_info, dst_md_));
139+
140+
almc_.gemm_info.set_activation_info(act_info);
141+
ACL_CHECK_VALID(arm_compute::NEGEMMLowpMatrixMultiplyCore::validate(
142+
&almc_.src_tensor_info, &almc_.wei_tensor_info,
143+
almc_.with_bias ? &almc_.bia_tensor_info : nullptr,
144+
&almc_.dst_tensor_info, almc_.gemm_info));
145+
146+
return status::success;
147+
}
148+
149+
status_t acl_lowp_matmul_sq_t::pd_t::init_scratchpad(engine_t *engine,
150+
memory_tracking::registrar_t &scratchpad, acl_post_ops_t &post_ops,
151+
dnnl::impl::post_ops_t &attr_post_ops,
152+
arm_compute::ActivationLayerInfo &act_info,
153+
const dnnl::impl::memory_desc_t &dst_md) {
154+
CHECK(post_ops.init(engine, attr_post_ops, dst_md, act_info));
155+
// ACL only accepts s32 bias for quantization and since
156+
// the current bias vector is f32 we need to convert.
157+
if (almc_.with_bias) {
158+
const memory_desc_wrapper bias_d(&bias_md_);
159+
// printf("*** init_scratchpad: bias_data_type=%d\n", bias_md_.data_type);
160+
scratchpad.book(memory_tracking::names::key_conv_bias_s32_convert,
161+
bias_d.nelems(), bias_d.data_type_size());
162+
}
163+
return status::success;
164+
}
165+
166+
status_t acl_lowp_matmul_sq_t::create_resource(
167+
engine_t *engine, resource_mapper_t &mapper) const {
168+
169+
if (mapper.has_resource(this)) return status::success;
170+
171+
auto r = utils::make_unique<acl_lowp_matmul_sq_resource_t>();
172+
if (!r) return status::out_of_memory;
173+
174+
CHECK(r->configure(pd()->almc_));
175+
176+
mapper.add(this, std::move(r));
177+
178+
return status::success;
179+
}
180+
181+
status_t acl_lowp_matmul_sq_t::execute(const exec_ctx_t &ctx) const {
182+
std::lock_guard<std::mutex> _lock {this->mtx};
183+
184+
bool with_bias = pd()->almc_.with_bias;
185+
186+
acl_lowp_matmul_sq_obj_t &acl_obj
187+
= ctx.get_resource_mapper()
188+
->get<acl_lowp_matmul_sq_resource_t>(this)
189+
->get_acl_obj();
190+
191+
auto src = CTX_IN_MEM(const int8_t *, DNNL_ARG_SRC);
192+
auto wei = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
193+
auto dst = CTX_OUT_MEM(const int8_t *, DNNL_ARG_DST);
194+
195+
acl_obj.src_tensor.allocator()->import_memory(const_cast<int8_t *>(src));
196+
acl_obj.wei_tensor.allocator()->import_memory(const_cast<int8_t *>(wei));
197+
acl_obj.dst_tensor.allocator()->import_memory(const_cast<int8_t *>(dst));
198+
199+
DEFINE_ARG_SCALES_BUFFER(src_scale, DNNL_ARG_SRC);
200+
DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC);
201+
DEFINE_ARG_SCALES_BUFFER(wei_scale, DNNL_ARG_WEIGHTS);
202+
DEFINE_ZERO_POINT_VALUE(wei_zero_point, DNNL_ARG_WEIGHTS);
203+
DEFINE_ARG_SCALES_BUFFER(dst_scale, DNNL_ARG_DST);
204+
DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST);
205+
206+
if (with_bias) {
207+
const auto scratchpad = ctx.get_scratchpad_grantor();
208+
auto bia_s32_base = scratchpad.get<uint32_t>(
209+
memory_tracking::names::key_conv_bias_s32_convert);
210+
auto bia_f32_base = CTX_IN_MEM(const float32_t *, DNNL_ARG_BIAS);
211+
const float bias_scale = 1 / (*src_scale * (*wei_scale));
212+
const int num_elements
213+
= acl_obj.bia_tensor.info()->total_size() / sizeof(float32_t);
214+
parallel_nd(num_elements, [&](dim_t e) {
215+
const auto b = int32_t(std::round(bia_f32_base[e] * bias_scale));
216+
bia_s32_base[e] = b;
217+
});
218+
acl_obj.bia_tensor.allocator()->init(*acl_obj.bia_tensor.info());
219+
acl_obj.bia_tensor.allocator()->import_memory(bia_s32_base);
220+
}
221+
222+
acl_obj.src_tensor.info()->set_quantization_info(
223+
arm_compute::QuantizationInfo(*src_scale, -src_zero_point, true));
224+
acl_obj.wei_tensor.info()->set_quantization_info(
225+
arm_compute::QuantizationInfo(*wei_scale, -wei_zero_point, true));
226+
227+
// for efficiency reasons, OneDNN saves the inverse of the destination
228+
acl_obj.dst_tensor.info()->set_quantization_info(
229+
arm_compute::QuantizationInfo(
230+
1.0 / (*dst_scale), dst_zero_point, true));
231+
232+
acl_obj.gemm.update_quantization_parameters();
233+
234+
acl_obj.gemm.run();
235+
236+
// free() here tells ACL it can no longer use it, it does not deallocate
237+
acl_obj.src_tensor.allocator()->free();
238+
acl_obj.wei_tensor.allocator()->free();
239+
if (with_bias) { acl_obj.bia_tensor.allocator()->free(); }
240+
acl_obj.dst_tensor.allocator()->free();
241+
242+
return status::success;
243+
};
244+
245+
} // namespace matmul
246+
} // namespace aarch64
247+
} // namespace cpu
248+
} // namespace impl
249+
} // namespace dnnl

0 commit comments

Comments
 (0)