Skip to content

Commit 18263c2

Browse files
Fix bug in elu
Differential Revision: D81359053 Pull Request resolved: #13829
1 parent 66799e3 commit 18263c2

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

kernels/optimized/cpu/op_elu.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,9 @@ void elu(
2828
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
2929
using MathT =
3030
std::conditional_t<c10::is_reduced_floating_point_v<CTYPE>, float, CTYPE>;
31-
MathT math_alpha = 0;
32-
MathT math_scale = 0;
33-
MathT math_input_scale = 0;
34-
ET_EXTRACT_SCALAR(alpha, math_alpha);
35-
ET_EXTRACT_SCALAR(scale, math_scale);
36-
ET_EXTRACT_SCALAR(input_scale, math_input_scale);
31+
const auto math_alpha = utils::scalar_to<MathT>(alpha);
32+
const auto math_scale = utils::scalar_to<MathT>(scale);
33+
const auto math_input_scale = utils::scalar_to<MathT>(input_scale);
3734
const auto scalar_func =
3835
at::native::get_scalar_elu_elementwise_func<CTYPE, MathT>(
3936
math_alpha, math_scale, math_input_scale);

kernels/portable/cpu/op_elu.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,9 @@ Tensor& elu_out(
3737
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() {
3838
using MathT = std::
3939
conditional_t<c10::is_reduced_floating_point_v<CTYPE>, float, CTYPE>;
40-
MathT math_alpha = 0;
41-
MathT math_scale = 0;
42-
MathT math_input_scale = 0;
43-
ET_EXTRACT_SCALAR(alpha, math_alpha);
44-
ET_EXTRACT_SCALAR(scale, math_scale);
45-
ET_EXTRACT_SCALAR(input_scale, math_input_scale);
40+
const auto math_alpha = utils::scalar_to<MathT>(alpha);
41+
const auto math_scale = utils::scalar_to<MathT>(scale);
42+
const auto math_input_scale = utils::scalar_to<MathT>(input_scale);
4643
const auto negcoef = math_alpha * math_scale;
4744
utils::apply_unitensor_elementwise_fn<
4845
CTYPE,

kernels/test/op_elu_test.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class OpEluTest : public OperatorTest {
4343

4444
Tensor out = tf.zeros(sizes);
4545

46-
// Run full gelu.
46+
// Run full elu.
4747
op_elu_out(in, 1.25, 1, 1, out);
4848

4949
// Check that it matches the expected output.
@@ -93,3 +93,18 @@ TEST_F(OpEluTest, MismatchedOutputDtypeDies) {
9393

9494
ET_EXPECT_KERNEL_FAILURE(context_, op_elu_out(a, 1, 1, 1, out));
9595
}
96+
97+
TEST_F(OpEluTest, MixedScalarTypes) {
98+
TensorFactory<ScalarType::Float> tf_float;
99+
100+
const std::vector<int32_t> sizes = {2, 2};
101+
102+
Tensor in = tf_float.ones(sizes);
103+
Tensor out = tf_float.zeros(sizes);
104+
105+
op_elu_out(in, true, 1.0, 1.0, out);
106+
EXPECT_TENSOR_CLOSE(out, tf_float.ones(sizes));
107+
108+
op_elu_out(in, false, true, 3, out);
109+
EXPECT_TENSOR_CLOSE(out, tf_float.ones(sizes));
110+
}

0 commit comments

Comments
 (0)