From 153ec617ef52a72ca319da43b4dadd0320ad27dc Mon Sep 17 00:00:00 2001 From: Corentin Maravat Date: Fri, 20 Dec 2024 17:39:20 +0100 Subject: [PATCH] Add Gradient for Atan --- .../orttraining/core/graph/gradient_builder.cc | 13 +++++++++++++ .../orttraining/core/graph/gradient_builder.h | 2 ++ .../core/graph/gradient_builder_registry.cc | 1 + .../orttraining/test/gradient/gradient_ops_test.cc | 2 ++ 4 files changed, 18 insertions(+) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 76fe0ee91d4c6..9baf3e6843a6d 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -2227,5 +2227,18 @@ IMPLEMENT_GRADIENT_BUILDER(GetResizeGradient) { SrcNodeAttributes())}; } +IMPLEMENT_GRADIENT_BUILDER(GetAtanGradient) { + // dl/dx = dl/dy * (1/(1+x^2)) + NodeDef one_const_node = OneConstantNode(IElemType(0)); + ArgDef one = one_const_node.output_args[0]; + std::vector result; + result.push_back(one_const_node); + result.push_back(NodeDef("Mul", {I(0), I(0)}, {IA("Square_I0")})); + result.push_back(NodeDef("Add", {IA("Square_I0"), one}, {IA("One_Plus_Square_I0")})); + result.push_back(NodeDef("Div", {GO(0), IA("One_Plus_Square_I0")}, {GI(0)})); + return result; +} + + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 92bfae9cd83a4..6928d8ade6016 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -93,6 +93,8 @@ DECLARE_GRADIENT_BUILDER(GetReciprocalGradient) DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient) DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient) DECLARE_GRADIENT_BUILDER(GetResizeGradient) +DECLARE_GRADIENT_BUILDER(GetAtanGradient) + DECLARE_GRADIENT_BUILDER(GetExternalGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index ea56be9e6dfa3..9c9884c5d3865 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -125,6 +125,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient); REGISTER_GRADIENT_BUILDER("ConvTranspose", GetConvTransposeGradient); REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient); + REGISTER_GRADIENT_BUILDER("Atan", GetAtanGradient); REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient); }; diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 94ca96c68f2ce..b81a08e23e3cf 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -3352,6 +3352,8 @@ TEST(GradientCheckerTest, ResizeGrad) { #endif // USE_CUDA +TEST(GradientCheckerTest, AtanGrad) { UnaryOpGradientTest("Atan"); } + } // namespace test } // namespace onnxruntime