Skip to content

Commit

Permalink
Avoid amax roll for non-run modules (#825)
Browse files Browse the repository at this point in the history
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman authored and ptrendx committed Apr 30, 2024
1 parent 9f0a4a4 commit 3c604eb
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions transformer_engine/common/recipe/delayed_scaling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,18 @@ kernel_bulk(
const auto last_amax = ((amax_reduction_buffer != nullptr)
&& (amax_reduction_buffer[offset_in_buffer+count] != 0.0f)) ?
amax_reduction_buffer[offset_in_buffer+count] : amax_history[0];
for (size_t off = 0; off < length; off += bsize) {
const size_t i = off + tid;
float a = 0;
if (i < length) {
a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax;
amax = fmaxf(amax, a);
}
__syncthreads(); // Inplace roll
if (i < length) {
amax_history[i*stride] = (i > 0) ? a : 0;
if (last_amax != 0.0f) {
for (size_t off = 0; off < length; off += bsize) {
const size_t i = off + tid;
float a = 0;
if (i < length) {
a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax;
amax = fmaxf(amax, a);
}
__syncthreads(); // Inplace roll
if (i < length) {
amax_history[i*stride] = (i > 0) ? a : 0;
}
}
}

Expand Down

0 comments on commit 3c604eb

Please sign in to comment.