Skip to content

Commit

Permalink
x64: brgemm matmul: support arbitrary K on AMX
Browse files Browse the repository at this point in the history
  • Loading branch information
ankalinin committed Jan 2, 2025
1 parent 22527fe commit 4bd35fa
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 17 deletions.
34 changes: 28 additions & 6 deletions src/cpu/x64/matmul/brgemm_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,16 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
: (is_int8 ? avx512_core_vnni
: avx512_core)))
: isa;
for_(int i_bs = 0; i_bs < 2; i_bs++)
for_(int i_init = 0; i_init < 2; i_init++)

const int i_bs_end = bgmmc_.brgemm_batch_tail_size ? 2 : 1;
const int i_init_start = bgmmc_.K_blk != bgmmc_.K ? 0 : 1;
const int i_K_end = bgmmc_.K_tail ? 2 : 1;

for_(int i_bs = 0; i_bs < i_bs_end; i_bs++)
for_(int i_init = i_init_start; i_init < 2; i_init++)
for_(int i_M = 0; i_M < max_m_ker_idx; i_M++)
for_(int i_N = 0; i_N < max_n_ker_idx; i_N++)
for (int i_K = 0; i_K < 2; i_K++) {
for (int i_K = 0; i_K < i_K_end; i_K++) {
auto vbeta = (i_init) ? beta_init : beta;
auto vM = (i_M) == 0 ? bgmmc_.M_blk
: (bgmmc_.is_runtime_M ? dynamic_m_tails[i_M - 1]
Expand Down Expand Up @@ -219,6 +224,8 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
brgattr.use_uker = true;
brgattr.use_interleave_stores = true;
brgattr.max_bs = bs;
brgattr.wary_A_k_tail_read = bgmmc_.extendable_k;
brgattr.extendable_k = bgmmc_.extendable_k;
// TODO: change expected sizes to local chunks wrt L2 blocking
brgattr.hint_expected_A_size = vM * vK * bs;
brgattr.hint_expected_B_size = vN * vK * bs;
Expand Down Expand Up @@ -249,11 +256,16 @@ status_t brgemm_matmul_t<isa>::init(engine_t *engine) {
= bgmmc.is_runtime_M ? max_num_dynamic_m_tails + 1 : 2;
const int max_n_ker_idx
= bgmmc.is_runtime_N ? max_num_dynamic_n_tails + 1 : 2;
for_(int i_bs = 0; i_bs < 2; i_bs++)

const int i_bs_end = bgmmc.brgemm_batch_tail_size ? 2 : 1;
const int i_init_start = bgmmc.K_blk != bgmmc.K ? 0 : 1;
const int i_K_end = bgmmc.K_tail ? 2 : 1;

for_(int i_bs = 0; i_bs < i_bs_end; i_bs++)
for_(int i_M = 0; i_M < max_m_ker_idx; i_M++)
for_(int i_N = 0; i_N < max_n_ker_idx; i_N++)
for_(int i_K = 0; i_K < 2; i_K++)
for (int i_init = 0; i_init < 2; i_init++) {
for_(int i_K = 0; i_K < i_K_end; i_K++)
for (int i_init = i_init_start; i_init < 2; i_init++) {
int idx = pd()->get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K);
if (idx < 0) continue;

Expand Down Expand Up @@ -809,6 +821,8 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
ctx.current_K_start = k;
ctx.current_K_iters = nstl::min(bgmmc.K_blk, bgmmc.K);
ctx.current_K_pad = brgmm_ctx.get_current_K_pad(ctx.current_K_iters);

ctx.scales_ptr = (void *)brgmm_ctx.get_oscales_ptr(n, k);
if (bgmmc.blocked_B && !bgmmc.is_f16_with_int_wei
&& isa == avx512_core_fp16) {
Expand All @@ -827,6 +841,7 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
ctx.current_K_start = k;
ctx.current_K_iters = bgmmc.K % bgmmc.K_blk;
ctx.current_K_pad = brgmm_ctx.get_current_K_pad(ctx.current_K_iters);
ctx.scales_ptr = (void *)brgmm_ctx.get_oscales_ptr(n, k);
if (bgmmc.blocked_B && !bgmmc.is_f16_with_int_wei
&& isa == avx512_core_fp16) {
Expand Down Expand Up @@ -1727,6 +1742,13 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {

bool packed_sparse_weights() const { return bgmmc_.packed_sparse_weights; }

int get_current_K_pad(int current_K_iters) const {
return bgmmc_.extendable_k ? bgmmc_.wei_k_blk
- rnd_up(
current_K_iters % bgmmc_.wei_k_blk, vnni_factor)
: 0;
}

private:
struct tail_processing_t {
// dimension index kernel is applied to
Expand Down
47 changes: 36 additions & 11 deletions src/cpu/x64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,8 @@ struct matmul_amx_blocking_params_t : public brgemm_matmul_conf_t {
, k_blk_(0)
, k_chunk_size_(0)
, k_chunk_elems_(0)
, use_buffer_a_(false)
, extendable_k_(false)
, current_lda_(0)
, need_buf_c_(false)
, blocking_chunk_mem_size_(0)
Expand All @@ -531,6 +533,8 @@ struct matmul_amx_blocking_params_t : public brgemm_matmul_conf_t {
, k_blk_(K_blk)
, k_chunk_size_(brgemm_batch_size)
, k_chunk_elems_(k_blk_ * k_chunk_size_)
, use_buffer_a_(use_buffer_a)
, extendable_k_(extendable_k)
, current_lda_(LDA)
, need_buf_c_(use_buffer_c)
, blocking_chunk_mem_size_(0)
Expand All @@ -553,6 +557,9 @@ struct matmul_amx_blocking_params_t : public brgemm_matmul_conf_t {
dim_t m_blk_, m_chunk_size_, m_chunk_elems_;
dim_t k_blk_, k_chunk_size_, k_chunk_elems_;

bool use_buffer_a_;
bool extendable_k_;

dim_t current_lda_;
bool need_buf_c_;
size_t blocking_chunk_mem_size_;
Expand Down Expand Up @@ -1146,6 +1153,7 @@ status_t compute_blocking_heuristic(brgemm_matmul_conf_t &bgmmc,
// AMX BRGEMM kernel requires (K_brgemm % 64 == 0 || K_brgemm < 64)
// for K_brgemm reduction value to avoid AMX tiles re-configuration.
// To satisfy this condition K_tail value is fixed to K % wei_k_blk here.

const bool fixed_K_tail_size
= bgmmc.K % bgmmc.wei_k_blk > 0 && bgmmc.K > bgmmc.wei_k_blk;
bgmmc.K_blk = bgmmc.K < bgmmc.wei_k_blk
Expand Down Expand Up @@ -1527,10 +1535,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
if (bgmmc.is_amx && bm_conf_utils.is_int8())
prefer_copy_a &= bgmmc.N >= 256;

const bool is_copy_a_required
= (bgmmc.is_amx
&& ((bgmmc.K % bgmmc.required_k_granularity != 0)
|| bm_conf_utils.is_bf32()))
const bool is_copy_a_required = (bgmmc.is_amx && bm_conf_utils.is_bf32())
|| ((bm_conf_utils.is_f16() || bm_conf_utils.is_f16_with_int_wei())
&& isa == avx512_core_fp16)
|| (bgmmc.wei_zp_type != brgemm_broadcast_t::none
Expand Down Expand Up @@ -1601,7 +1606,9 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.M_tail = bgmmc.is_runtime_M ? 0 : bgmmc.M % bgmmc.M_blk;
bgmmc.N_tail = bgmmc.is_runtime_N ? 0 : bgmmc.N % bgmmc.N_blk;
bgmmc.K_tail = bgmmc.K > bgmmc.K_blk
? rnd_up(bgmmc.K % bgmmc.K_blk, bgmmc.required_k_granularity)
? (bgmmc.extendable_k ? bgmmc.K % bgmmc.K_blk
: rnd_up(bgmmc.K % bgmmc.K_blk,
bgmmc.required_k_granularity))
: 0;

bgmmc.LDB = bm_conf_utils.get_actual_LDB();
Expand Down Expand Up @@ -1979,16 +1986,32 @@ void matmul_amx_blocking_params_t::set_blocking_parameters(
: kc1;
}

k_chunk_elems_ = k_blk_ * k_chunk_size_;

const dim_t current_k_tail = K % k_blk_;
if (current_k_tail == 0 && K % (k_blk_ * k_chunk_size_) == 0) {
k_blk_ *= k_chunk_size_;

extendable_k_
= !use_buffer_a && K % wei_k_blk && k_chunk_elems_ > wei_k_blk;

if (extendable_k_) {
if (k_chunk_elems_ >= K) {
k_blk_ = K;
k_chunk_size_ = 1;
} else {
k_blk_ = k_chunk_elems_;
k_chunk_size_ = 1;
}
} else if (current_k_tail == 0 && K % (k_blk_ * k_chunk_size_) == 0) {
k_blk_ = k_chunk_elems_;
k_chunk_size_ = 1;
} else if (nthr_k_ == 1
&& K == k_blk_ * k_chunk_size_ + current_k_tail) {
k_blk_ *= k_chunk_size_;
k_blk_ = k_chunk_elems_;
k_chunk_size_ = 2;
}
}
use_buffer_a_
= use_buffer_a || (!extendable_k_ && K % required_k_granularity);

blocking_chunk_mem_size_ = calculate_chunk_memory_size();

Expand Down Expand Up @@ -2031,7 +2054,7 @@ float matmul_amx_blocking_params_t::get_copied_data_reusage_scores() {
const dim_t desired_M_chunk_size = is_runtime_M
? effective_m_chunk_sz
: nstl::min(M, effective_m_chunk_sz);
const dim_t effective_n_chunk_sz = 64 * (use_buffer_a ? 4 : 1);
const dim_t effective_n_chunk_sz = 64 * (use_buffer_a_ ? 4 : 1);
const dim_t desired_N_chunk_size = is_runtime_N
? effective_n_chunk_sz
: nstl::min(N, effective_n_chunk_sz);
Expand Down Expand Up @@ -2091,10 +2114,12 @@ void matmul_amx_blocking_params_t::update_configuration(

bgmmc.use_buffer_c = need_buf_c_;
bgmmc.LDA = current_lda_;
bgmmc.use_buffer_a = use_buffer_a_;
bgmmc.extendable_k = extendable_k_;
}

dim_t matmul_amx_blocking_params_t::get_actual_lda() {
if (!use_buffer_a)
if (!use_buffer_a_)
return treat_transposed_A_as_plain
? K
: A_strides[1 - transposed_A] / a_dt_sz;
Expand All @@ -2118,7 +2143,7 @@ size_t matmul_amx_blocking_params_t::calculate_chunk_memory_size() {
update_k_blocking_dependent_params();

size_t A_chunk_sz = a_dt_sz * k_chunk_elems_ * m_chunk_elems_;
size_t A_buf_sz = use_buffer_a
size_t A_buf_sz = use_buffer_a_
? tr_a_dt_sz * current_lda_ * k_chunk_size_ * m_chunk_elems_
: 0;
size_t B_chunk_sz = b_dt_sz * k_chunk_elems_ * n_chunk_elems_;
Expand Down
2 changes: 2 additions & 0 deletions src/cpu/x64/matmul/brgemm_matmul_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ struct brgemm_matmul_conf_t {
bool is_oscale_per_n = false;
bool is_oscale_per_k = false;
bool apply_scales_in_buffer_b = false;
bool extendable_k = false;

inline bool lda_big_pow2() const {
const dim_t big_stride_threshold_in_bytes = 8192;
const dim_t big_K_threshold = big_stride_threshold_in_bytes / a_dt_sz;
Expand Down

0 comments on commit 4bd35fa

Please sign in to comment.