@@ -194,3 +194,60 @@ NGRAPH_TEST(${BACKEND_NAME}, layer_norm_bprop_affine)
194194 EXPECT_TRUE (test::all_close_f (exp_d_scale, read_vector<float >(d_scale)));
195195 EXPECT_TRUE (test::all_close_f (exp_d_bias, read_vector<float >(d_bias)));
196196}
197+
198+ NGRAPH_TEST (${BACKEND_NAME}, layer_norm_bprop_4d_input)
199+ {
200+ auto p_data = make_shared<op::Parameter>(element::f32 , Shape{2 , 3 , 4 , 5 });
201+ auto p_delta = make_shared<op::Parameter>(element::f32 , Shape{2 , 3 , 4 , 5 });
202+ auto p_mean = make_shared<op::Parameter>(element::f32 , Shape{2 });
203+ auto p_variance = make_shared<op::Parameter>(element::f32 , Shape{2 });
204+ auto p_scale = make_shared<op::Parameter>(element::f32 , Shape{60 });
205+ auto lnb = make_shared<op::LayerNormBackprop>(p_data, p_delta, p_mean, p_variance, p_scale);
206+
207+ auto output_data = lnb->output (0 );
208+ auto output_scale = lnb->output (1 );
209+ auto output_bias = lnb->output (2 );
210+
211+ // flatten output_scale
212+ auto output_scale_shape = output_scale.get_shape ();
213+ auto flattened_output_scale = make_shared<op::Reshape>(
214+ output_scale, get_default_order (output_scale_shape), Shape{shape_size (output_scale_shape)});
215+
216+ auto f = make_shared<Function>(OutputVector{output_data, flattened_output_scale, output_bias},
217+ ParameterVector{p_data, p_delta, p_mean, p_variance, p_scale});
218+
219+ auto backend = runtime::Backend::create (" ${BACKEND_NAME}" );
220+
221+ // Create tensors for input
222+ auto data = backend->create_tensor (element::f32 , Shape{2 , 3 , 4 , 5 });
223+ auto delta = backend->create_tensor (element::f32 , Shape{2 , 3 , 4 , 5 });
224+ auto mean = backend->create_tensor (element::f32 , Shape{2 });
225+ auto variance = backend->create_tensor (element::f32 , Shape{2 });
226+ auto scale = backend->create_tensor (element::f32 , Shape{60 });
227+ // Fill in input tensors
228+ vector<float > d_input (2 * 3 * 4 * 5 , 1 );
229+ copy_data (data, d_input);
230+ vector<float > dt_input (2 * 3 * 4 * 5 , 1 );
231+ copy_data (delta, dt_input);
232+ vector<float > m_input (2 , 1 );
233+ copy_data (mean, m_input);
234+ vector<float > v_input (2 , 1 );
235+ copy_data (variance, v_input);
236+ vector<float > s_input (60 , 1 );
237+ copy_data (scale, s_input);
238+ // Create tensors for output
239+ auto d_data = backend->create_tensor (element::f32 , Shape{2 , 3 , 4 , 5 });
240+ auto d_scale = backend->create_tensor (element::f32 , Shape{60 });
241+ auto d_bias = backend->create_tensor (element::f32 , Shape{60 });
242+
243+ auto handle = backend->compile (f);
244+ handle->call_with_validate ({d_data, d_scale, d_bias}, {data, delta, mean, variance, scale});
245+
246+ vector<float > expected_data (120 , 0 );
247+ vector<float > expected_scale (60 , 0 );
248+ vector<float > expected_bias (60 , 2 );
249+
250+ EXPECT_TRUE (test::all_close_f (expected_data, read_vector<float >(d_data), 1e-6f , 1e-6f ));
251+ EXPECT_TRUE (test::all_close_f (expected_scale, read_vector<float >(d_scale), 1e-6f , 1e-6f ));
252+ EXPECT_TRUE (test::all_close_f (expected_bias, read_vector<float >(d_bias), 1e-6f , 1e-6f ));
253+ }
0 commit comments