@@ -520,7 +520,7 @@ TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_t
520520 auto op = Dispatcher::singleton ().findSchema ({" _test::dummy" , " " });
521521 ASSERT_TRUE (op.has_value ());
522522
523- // catchAll now maps to Math which has higher precedence than Autograd
523+ // catchAll now maps to CompositeImplicitAutograd which has higher precedence than Autograd
524524 called_nonautograd = called_autograd = false ;
525525 op->typed <void (Tensor)>().call (dummyTensor (DispatchKey::CPU, /* requires_grad=*/ true ));
526526 EXPECT_TRUE (called_nonautograd);
@@ -1306,7 +1306,7 @@ TEST(NewOperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndCatchal
13061306
13071307 called = false ;
13081308 auto stack = callOp (*op, dummyTensor (c10::DispatchKey::CPU), " hello " );
1309- // CatchAll now maps to Math and has higher precedence than backend fallback.
1309+ // CatchAll now maps to CompositeImplicitAutograd and has higher precedence than backend fallback.
13101310 EXPECT_TRUE (called);
13111311}
13121312
@@ -1325,10 +1325,10 @@ TEST(NewOperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel
13251325 EXPECT_FALSE (called_autograd);
13261326}
13271327
1328- TEST (NewOperatorRegistrationTest, dispatchWithMathKernel ) {
1328+ TEST (NewOperatorRegistrationTest, dispatchWithCompositeImplicitAutogradKernel ) {
13291329 bool math_called = false ;
13301330 auto m = MAKE_TORCH_LIBRARY (test);
1331- m.def (" fn" , torch::dispatch (c10::DispatchKey::Math , [&](const Tensor& x) { math_called = true ; return x; }));
1331+ m.def (" fn" , torch::dispatch (c10::DispatchKey::CompositeImplicitAutograd , [&](const Tensor& x) { math_called = true ; return x; }));
13321332
13331333 auto op = Dispatcher::singleton ().findSchema ({" test::fn" , " " });
13341334 ASSERT_TRUE (op.has_value ());
@@ -1370,17 +1370,17 @@ TEST(NewOperatorRegistrationTest, dispatchWithMathKernel) {
13701370 }
13711371}
13721372
1373- TEST (NewOperatorRegistrationTest, dispatchWithMathAndAutogradKernel ) {
1373+ TEST (NewOperatorRegistrationTest, dispatchWithCompositeImplicitAutogradAndAutogradKernel ) {
13741374 bool math_called = false ;
13751375 bool autograd_called = false ;
13761376 auto m = MAKE_TORCH_LIBRARY (test);
1377- m.def (" fn" , torch::dispatch (c10::DispatchKey::Math , [&](const Tensor& x) { math_called = true ; return x; }));
1377+ m.def (" fn" , torch::dispatch (c10::DispatchKey::CompositeImplicitAutograd , [&](const Tensor& x) { math_called = true ; return x; }));
13781378 m.impl (" fn" , c10::DispatchKey::Autograd, [&](const Tensor& x) { autograd_called = true ; return x; });
13791379
13801380 auto op = Dispatcher::singleton ().findSchema ({" test::fn" , " " });
13811381 ASSERT_TRUE (op.has_value ());
13821382
1383- // Math has higher precedence than Autograd
1383+ // CompositeImplicitAutograd has higher precedence than Autograd
13841384 {
13851385 math_called = autograd_called = false ;
13861386 callOp (*op, dummyTensor (c10::DispatchKey::CPU, /* requires_grad=*/ true ));
@@ -1396,17 +1396,17 @@ TEST(NewOperatorRegistrationTest, dispatchWithMathAndAutogradKernel) {
13961396 }
13971397}
13981398
1399- TEST (NewOperatorRegistrationTest, dispatchWithMathAndCatchAllKernel ) {
1399+ TEST (NewOperatorRegistrationTest, dispatchWithCompositeImplicitAutogradAndCatchAllKernel ) {
14001400 bool math_called = false ;
14011401 bool catchall_called = false ;
14021402 auto m = MAKE_TORCH_LIBRARY (test);
1403- m.def (" fn" , torch::dispatch (c10::DispatchKey::Math , [&](const Tensor& x) { math_called = true ; return x; }));
1403+ m.def (" fn" , torch::dispatch (c10::DispatchKey::CompositeImplicitAutograd , [&](const Tensor& x) { math_called = true ; return x; }));
14041404 m.impl (" fn" , [&](const Tensor& x) { catchall_called = true ; return x; });
14051405
14061406 auto op = Dispatcher::singleton ().findSchema ({" test::fn" , " " });
14071407 ASSERT_TRUE (op.has_value ());
14081408
1409- // catchAll now maps to Math , which means we have two registrations to Math key.
1409+ // catchAll now maps to CompositeImplicitAutograd , which means we have two registrations to CompositeImplicitAutograd key.
14101410 // The last registration is used.
14111411 {
14121412 catchall_called = math_called = false ;
@@ -1423,11 +1423,11 @@ TEST(NewOperatorRegistrationTest, dispatchWithMathAndCatchAllKernel) {
14231423 }
14241424}
14251425
1426- TEST (NewOperatorRegistrationTest, AutogradBackendOverridesMathKernel ) {
1426+ TEST (NewOperatorRegistrationTest, AutogradBackendOverridesCompositeImplicitAutogradKernel ) {
14271427 bool math_called = false ;
14281428 bool autograd_called = false ;
14291429 auto m = MAKE_TORCH_LIBRARY (test);
1430- m.def (" fn" , torch::dispatch (c10::DispatchKey::Math , [&](const Tensor& x) { math_called = true ; return x; }));
1430+ m.def (" fn" , torch::dispatch (c10::DispatchKey::CompositeImplicitAutograd , [&](const Tensor& x) { math_called = true ; return x; }));
14311431 m.impl (" fn" , c10::DispatchKey::AutogradCPU, [&](const Tensor& x) { autograd_called = true ; return x; });
14321432
14331433 auto op = Dispatcher::singleton ().findSchema ({" test::fn" , " " });
@@ -1462,11 +1462,11 @@ TEST(NewOperatorRegistrationTest, AutogradBackendOverridesMathKernel) {
14621462 }
14631463}
14641464
1465- TEST (NewOperatorRegistrationTest, BackendOverridesMathKernel ) {
1465+ TEST (NewOperatorRegistrationTest, BackendOverridesCompositeImplicitAutogradKernel ) {
14661466 bool math_called = false ;
14671467 bool backend_called = false ;
14681468 auto m = MAKE_TORCH_LIBRARY (test);
1469- m.def (" fn" , torch::dispatch (c10::DispatchKey::Math , [&](const Tensor& x) { math_called = true ; return x; }));
1469+ m.def (" fn" , torch::dispatch (c10::DispatchKey::CompositeImplicitAutograd , [&](const Tensor& x) { math_called = true ; return x; }));
14701470 m.impl (" fn" , c10::DispatchKey::CPU, [&](const Tensor& x) { backend_called = true ; return x; });
14711471
14721472 auto op = Dispatcher::singleton ().findSchema ({" test::fn" , " " });
@@ -1550,12 +1550,12 @@ TEST(NewOperatorRegistrationTest, dispatchWithDefaultBackendKernel) {
15501550 }
15511551}
15521552
1553- TEST (NewOperatorRegistrationTest, dispatchWithDefaultBackendAndMathKernel ) {
1553+ TEST (NewOperatorRegistrationTest, dispatchWithDefaultBackendAndCompositeImplicitAutogradKernel ) {
15541554 bool backend_called = false ;
15551555 bool math_called = false ;
15561556 auto m = MAKE_TORCH_LIBRARY (test);
15571557 m.def (" fn" , torch::dispatch (c10::DispatchKey::DefaultBackend, [&](const Tensor& x) { backend_called = true ; return x; }));
1558- m.impl (" fn" , c10::DispatchKey::Math , [&](const Tensor& x) { math_called = true ; return x; });
1558+ m.impl (" fn" , c10::DispatchKey::CompositeImplicitAutograd , [&](const Tensor& x) { math_called = true ; return x; });
15591559
15601560 auto op = Dispatcher::singleton ().findSchema ({" test::fn" , " " });
15611561 ASSERT_TRUE (op.has_value ());
@@ -1735,7 +1735,7 @@ TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther
17351735 bool sparsecpu_called, math_called = false ;
17361736 auto m = MAKE_TORCH_LIBRARY (test);
17371737 m.def (" fn" , torch::dispatch (c10::DispatchKey::SparseCPU, [&](const Tensor& x) { sparsecpu_called = true ; return x; }));
1738- m.impl (" fn" , c10::DispatchKey::Math , [&](const Tensor& x) { math_called = true ; return x; });
1738+ m.impl (" fn" , c10::DispatchKey::CompositeImplicitAutograd , [&](const Tensor& x) { math_called = true ; return x; });
17391739
17401740 auto op = Dispatcher::singleton ().findSchema ({" test::fn" , " " });
17411741 ASSERT_TRUE (op.has_value ());
@@ -1748,7 +1748,7 @@ TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther
17481748 {
17491749 expectThrows<c10::Error>([&] {
17501750 callOp (*op, dummyTensor (c10::DispatchKey::SparseCPU, /* requires_grad=*/ true ));
1751- }, " test::fn has kernels registered to both Math and a backend mapped to AutogradOther." );
1751+ }, " test::fn has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther." );
17521752 }
17531753}
17541754
0 commit comments