Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 6fb1c5e

Browse files
baojun-nervanadiyessi
authored andcommitted
Fix layer_norm flatten issue (#4031)
* fix layernorm flatten issue * update ut * checkout output val * fix style * apply tolerance
1 parent e03289a commit 6fb1c5e

File tree

4 files changed

+62
-4
lines changed

4 files changed

+62
-4
lines changed

src/ngraph/op/fused/layer_norm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ shared_ptr<Node> op::LayerNorm::copy_with_new_args(const NodeVector& new_args) c
170170
}
171171
}
172172

173-
void op::LayerNorm::validate_and_infer_types()
173+
void op::LayerNorm::pre_validate_and_infer_types()
174174
{
175175
element::Type input_element_type = get_input_element_type(0);
176176

@@ -509,7 +509,7 @@ shared_ptr<Node> op::LayerNormBackprop::copy_with_new_args(const NodeVector& new
509509
}
510510
}
511511

512-
void op::LayerNormBackprop::validate_and_infer_types()
512+
void op::LayerNormBackprop::pre_validate_and_infer_types()
513513
{
514514
element::Type input_element_type = get_input_element_type(0);
515515

src/ngraph/op/fused/layer_norm.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ namespace ngraph
5656

5757
virtual NodeVector decompose_op() const override;
5858

59-
void validate_and_infer_types() override;
59+
void pre_validate_and_infer_types() override;
6060

6161
virtual std::shared_ptr<Node>
6262
copy_with_new_args(const NodeVector& new_args) const override;
@@ -121,7 +121,7 @@ namespace ngraph
121121

122122
virtual NodeVector decompose_op() const override;
123123

124-
void validate_and_infer_types() override;
124+
void pre_validate_and_infer_types() override;
125125

126126
virtual std::shared_ptr<Node>
127127
copy_with_new_args(const NodeVector& new_args) const override;

src/ngraph/runtime/plaidml/unit_test.manifest

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ random_uniform_dynamic_shapes
316316
layer_norm_affine_stats
317317
layer_norm_bprop_affine_stats
318318
layer_norm_bprop_affine
319+
layer_norm_bprop_4d_input
319320

320321
# Another fused op decomposition pass required after the downgrade pass
321322
model_split_equal_parts_default

test/backend/layer_norm.in.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,60 @@ NGRAPH_TEST(${BACKEND_NAME}, layer_norm_bprop_affine)
194194
EXPECT_TRUE(test::all_close_f(exp_d_scale, read_vector<float>(d_scale)));
195195
EXPECT_TRUE(test::all_close_f(exp_d_bias, read_vector<float>(d_bias)));
196196
}
197+
198+
NGRAPH_TEST(${BACKEND_NAME}, layer_norm_bprop_4d_input)
199+
{
200+
auto p_data = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4, 5});
201+
auto p_delta = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4, 5});
202+
auto p_mean = make_shared<op::Parameter>(element::f32, Shape{2});
203+
auto p_variance = make_shared<op::Parameter>(element::f32, Shape{2});
204+
auto p_scale = make_shared<op::Parameter>(element::f32, Shape{60});
205+
auto lnb = make_shared<op::LayerNormBackprop>(p_data, p_delta, p_mean, p_variance, p_scale);
206+
207+
auto output_data = lnb->output(0);
208+
auto output_scale = lnb->output(1);
209+
auto output_bias = lnb->output(2);
210+
211+
// flatten output_scale
212+
auto output_scale_shape = output_scale.get_shape();
213+
auto flattened_output_scale = make_shared<op::Reshape>(
214+
output_scale, get_default_order(output_scale_shape), Shape{shape_size(output_scale_shape)});
215+
216+
auto f = make_shared<Function>(OutputVector{output_data, flattened_output_scale, output_bias},
217+
ParameterVector{p_data, p_delta, p_mean, p_variance, p_scale});
218+
219+
auto backend = runtime::Backend::create("${BACKEND_NAME}");
220+
221+
// Create tensors for input
222+
auto data = backend->create_tensor(element::f32, Shape{2, 3, 4, 5});
223+
auto delta = backend->create_tensor(element::f32, Shape{2, 3, 4, 5});
224+
auto mean = backend->create_tensor(element::f32, Shape{2});
225+
auto variance = backend->create_tensor(element::f32, Shape{2});
226+
auto scale = backend->create_tensor(element::f32, Shape{60});
227+
// Fill in input tensors
228+
vector<float> d_input(2 * 3 * 4 * 5, 1);
229+
copy_data(data, d_input);
230+
vector<float> dt_input(2 * 3 * 4 * 5, 1);
231+
copy_data(delta, dt_input);
232+
vector<float> m_input(2, 1);
233+
copy_data(mean, m_input);
234+
vector<float> v_input(2, 1);
235+
copy_data(variance, v_input);
236+
vector<float> s_input(60, 1);
237+
copy_data(scale, s_input);
238+
// Create tensors for output
239+
auto d_data = backend->create_tensor(element::f32, Shape{2, 3, 4, 5});
240+
auto d_scale = backend->create_tensor(element::f32, Shape{60});
241+
auto d_bias = backend->create_tensor(element::f32, Shape{60});
242+
243+
auto handle = backend->compile(f);
244+
handle->call_with_validate({d_data, d_scale, d_bias}, {data, delta, mean, variance, scale});
245+
246+
vector<float> expected_data(120, 0);
247+
vector<float> expected_scale(60, 0);
248+
vector<float> expected_bias(60, 2);
249+
250+
EXPECT_TRUE(test::all_close_f(expected_data, read_vector<float>(d_data), 1e-6f, 1e-6f));
251+
EXPECT_TRUE(test::all_close_f(expected_scale, read_vector<float>(d_scale), 1e-6f, 1e-6f));
252+
EXPECT_TRUE(test::all_close_f(expected_bias, read_vector<float>(d_bias), 1e-6f, 1e-6f));
253+
}

0 commit comments

Comments
 (0)