diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 38e71b74de..de48a53ebf 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -197,16 +197,18 @@ kernel_bulk( const auto last_amax = ((amax_reduction_buffer != nullptr) && (amax_reduction_buffer[offset_in_buffer+count] != 0.0f)) ? amax_reduction_buffer[offset_in_buffer+count] : amax_history[0]; - for (size_t off = 0; off < length; off += bsize) { - const size_t i = off + tid; - float a = 0; - if (i < length) { - a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax; - amax = fmaxf(amax, a); - } - __syncthreads(); // Inplace roll - if (i < length) { - amax_history[i*stride] = (i > 0) ? a : 0; + if (last_amax != 0.0f) { + for (size_t off = 0; off < length; off += bsize) { + const size_t i = off + tid; + float a = 0; + if (i < length) { + a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax; + amax = fmaxf(amax, a); + } + __syncthreads(); // Inplace roll + if (i < length) { + amax_history[i*stride] = (i > 0) ? a : 0; + } } }