From 8d331ea39a89054b1a1fe5232134fbea3390a1ad Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 20 Dec 2024 10:18:42 -0800 Subject: [PATCH 1/9] new unit test --- csrc/ops/arith.cpp | 29 +++++++++++++++++++++++++++++ csrc/ops/arith.h | 1 + tests/cpp/test_tensor_factories.cpp | 22 ++++++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 7684e406fef..a858d723861 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -494,6 +494,35 @@ TensorView* eye(Val* size, DataType dtype) { return eye(size, size, dtype); } +TensorView* triu(TensorView* tv, Val* offset) { + NVF_CHECK( + offset->getDataType() == DataType::Int, "offset must have type Int"); + + NVF_CHECK( + tv->nDims() >= 2, + "triu is only supported for 2+D tensors, but got ", + tv->nDims(), + "D tensor"); + + auto tv_rows = iota( + tv->domain()->logical()[1]->extent(), + IrBuilder::create(0, DataType::Index), + IrBuilder::create(1, DataType::Index), + DataType::Int); + + auto tv_columns = iota( + tv->domain()->logical()[2]->extent(), + SimplifyingIrBuilder::mulExpr( + offset, IrBuilder::create(-1, DataType::Index)), + IrBuilder::create(1, DataType::Index), + DataType::Int); + + auto tv_rows_b = broadcast(tv_rows, {false, true}); + auto tv_cols_b = broadcast(tv_columns, {true, false}); + auto mask = le(tv_rows_b, tv_cols_b); + return where(mask, tv, IrBuilder::create(0, tv->dtype())); +} + // UNARY OPERATIONS #define NVFUSER_DEFINE_UNARY_OP(operator_name, operator_type) \ diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index d8ea10038ad..46134de1c95 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -251,6 +251,7 @@ NVF_API TensorView* arange( DataType dtype = DataType::Int); NVF_API TensorView* eye(Val* size, DataType dtype); NVF_API TensorView* eye(Val* rows, Val* cols, DataType dtype); +NVF_API TensorView* triu(TensorView* tv, Val* offset); // UNARY OPERATIONS // abs diff --git a/tests/cpp/test_tensor_factories.cpp b/tests/cpp/test_tensor_factories.cpp index 2eabde38b3b..b31a1506de3 100644 --- a/tests/cpp/test_tensor_factories.cpp +++ b/tests/cpp/test_tensor_factories.cpp @@ -230,6 +230,28 @@ TEST_F(TensorFactoryTest, StandaloneIota) { } } +TEST_F(TensorFactoryTest, SimpleTriu) { + auto fusion = std::make_unique(); + + FusionGuard gf(fusion.get()); + + auto tv_to_triu_on = makeSymbolicTensor(3, DataType::Half); + fusion->addInput(tv_to_triu_on); + + int64_t k_factor = -2; + auto out = triu(tv_to_triu_on, IrBuilder::create(k_factor, DataType::Int)); + fusion->addOutput(out); + + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto in_tensor = at::randn({4, 4, 8}, options); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs({in_tensor}); + + EXPECT_TRUE(cg_outputs[0].allclose(at::triu(in_tensor, k_factor), .001, .001)); +} + TEST_F(TensorFactoryTest, StandaloneARange) { auto starts_ends = {-1., 0., 10.3, 1024. * 256}; auto steps = {-1.5, 1., 2.}; From 741e167bb9dd64b8f11de6ba8ef29ca6ec12acf0 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 20 Dec 2024 22:35:46 -0800 Subject: [PATCH 2/9] creating a C++ API for triu --- csrc/ops/arith.cpp | 20 +++++++++++++++-- tests/cpp/test_tensor_factories.cpp | 33 +++++++++++++++++------------ 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index a858d723861..093f02286b5 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -504,14 +504,30 @@ TensorView* triu(TensorView* tv, Val* offset) { tv->nDims(), "D tensor"); + // Let's say we want a triu of a 2D tensor of shape [2, 4] + // We broadcast the iota of the outer dim + // [0 [0, 0, 0, 0] + // 1] -> [1, 1, 1, 1] + // We broadcast the iota of the inner dim + // [0, 1, 2, 3] -> [0, 1, 2, 3] + // [0, 1, 2, 3] + // Using LE on the bcast tensors we get the mask + //[0, 0, 0, 0] LE [0, 1, 2, 3] + //[1, 1, 1, 1] [0, 1, 2, 3] + // Gives: + //[1, 0, 0, 0] + //[0, 1, 0, 0] + // If triu has an offset of k, we shift/subtract the iota of the columns by k + // before broadcasting and comparing with the iota of the rows. + auto dims = tv->domain()->logical().size(); auto tv_rows = iota( - tv->domain()->logical()[1]->extent(), + tv->domain()->logical()[dims - 2]->extent(), IrBuilder::create(0, DataType::Index), IrBuilder::create(1, DataType::Index), DataType::Int); auto tv_columns = iota( - tv->domain()->logical()[2]->extent(), + tv->domain()->logical()[dims - 1]->extent(), SimplifyingIrBuilder::mulExpr( offset, IrBuilder::create(-1, DataType::Index)), IrBuilder::create(1, DataType::Index), diff --git a/tests/cpp/test_tensor_factories.cpp b/tests/cpp/test_tensor_factories.cpp index b31a1506de3..170c24dd6e2 100644 --- a/tests/cpp/test_tensor_factories.cpp +++ b/tests/cpp/test_tensor_factories.cpp @@ -231,25 +231,32 @@ TEST_F(TensorFactoryTest, StandaloneIota) { } TEST_F(TensorFactoryTest, SimpleTriu) { - auto fusion = std::make_unique(); - - FusionGuard gf(fusion.get()); + std::vector> input_sizes = { + {64, 64}, {4, 16}, {16, 4}, {16, 8, 32}}; + auto offsets = {0, 1, 2, -1, -2, 200, -200}; - auto tv_to_triu_on = makeSymbolicTensor(3, DataType::Half); - fusion->addInput(tv_to_triu_on); + for (auto input_size : input_sizes) { + for (auto offset : offsets) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); - int64_t k_factor = -2; - auto out = triu(tv_to_triu_on, IrBuilder::create(k_factor, DataType::Int)); - fusion->addOutput(out); + auto tv_to_triu_on = + makeSymbolicTensor(input_size.size(), DataType::Half); + fusion->addInput(tv_to_triu_on); + auto out = + triu(tv_to_triu_on, IrBuilder::create(offset, DataType::Int)); + fusion->addOutput(out); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); - auto in_tensor = at::randn({4, 4, 8}, options); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto in_tensor = at::randn(input_size, options); - FusionExecutorCache executor_cache(std::move(fusion)); - auto cg_outputs = executor_cache.runFusionWithInputs({in_tensor}); + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs({in_tensor}); - EXPECT_TRUE(cg_outputs[0].allclose(at::triu(in_tensor, k_factor), .001, .001)); + EXPECT_TRUE(at::equal(cg_outputs[0], at::triu(in_tensor, offset))); + } + } } TEST_F(TensorFactoryTest, StandaloneARange) { From b734bf6c18541eafdf71866d5c565814a20ad815 Mon Sep 17 00:00:00 2001 From: Protonu Date: Mon, 23 Dec 2024 16:49:05 -0500 Subject: [PATCH 3/9] correct comment --- csrc/ops/arith.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 093f02286b5..6fa7447d611 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -515,8 +515,8 @@ TensorView* triu(TensorView* tv, Val* offset) { //[0, 0, 0, 0] LE [0, 1, 2, 3] //[1, 1, 1, 1] [0, 1, 2, 3] // Gives: - //[1, 0, 0, 0] - //[0, 1, 0, 0] + //[1, 1, 1, 1] + //[0, 1, 1, 1] // If triu has an offset of k, we shift/subtract the iota of the columns by k // before broadcasting and comparing with the iota of the rows. auto dims = tv->domain()->logical().size(); From fdc9c5293a0297d97b75dd20c578c6722bc25ec3 Mon Sep 17 00:00:00 2001 From: protonu Date: Thu, 2 Jan 2025 12:55:02 -0800 Subject: [PATCH 4/9] moving from arith.cpp to composite.cpp --- csrc/ops/arith.cpp | 45 ------------------------------------------ csrc/ops/arith.h | 1 - csrc/ops/composite.cpp | 45 ++++++++++++++++++++++++++++++++++++++++++ csrc/ops/composite.h | 2 ++ 4 files changed, 47 insertions(+), 46 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 6fa7447d611..7684e406fef 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -494,51 +494,6 @@ TensorView* eye(Val* size, DataType dtype) { return eye(size, size, dtype); } -TensorView* triu(TensorView* tv, Val* offset) { - NVF_CHECK( - offset->getDataType() == DataType::Int, "offset must have type Int"); - - NVF_CHECK( - tv->nDims() >= 2, - "triu is only supported for 2+D tensors, but got ", - tv->nDims(), - "D tensor"); - - // Let's say we want a triu of a 2D tensor of shape [2, 4] - // We broadcast the iota of the outer dim - // [0 [0, 0, 0, 0] - // 1] -> [1, 1, 1, 1] - // We broadcast the iota of the inner dim - // [0, 1, 2, 3] -> [0, 1, 2, 3] - // [0, 1, 2, 3] - // Using LE on the bcast tensors we get the mask - //[0, 0, 0, 0] LE [0, 1, 2, 3] - //[1, 1, 1, 1] [0, 1, 2, 3] - // Gives: - //[1, 1, 1, 1] - //[0, 1, 1, 1] - // If triu has an offset of k, we shift/subtract the iota of the columns by k - // before broadcasting and comparing with the iota of the rows. - auto dims = tv->domain()->logical().size(); - auto tv_rows = iota( - tv->domain()->logical()[dims - 2]->extent(), - IrBuilder::create(0, DataType::Index), - IrBuilder::create(1, DataType::Index), - DataType::Int); - - auto tv_columns = iota( - tv->domain()->logical()[dims - 1]->extent(), - SimplifyingIrBuilder::mulExpr( - offset, IrBuilder::create(-1, DataType::Index)), - IrBuilder::create(1, DataType::Index), - DataType::Int); - - auto tv_rows_b = broadcast(tv_rows, {false, true}); - auto tv_cols_b = broadcast(tv_columns, {true, false}); - auto mask = le(tv_rows_b, tv_cols_b); - return where(mask, tv, IrBuilder::create(0, tv->dtype())); -} - // UNARY OPERATIONS #define NVFUSER_DEFINE_UNARY_OP(operator_name, operator_type) \ diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index 46134de1c95..d8ea10038ad 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -251,7 +251,6 @@ NVF_API TensorView* arange( DataType dtype = DataType::Int); NVF_API TensorView* eye(Val* size, DataType dtype); NVF_API TensorView* eye(Val* rows, Val* cols, DataType dtype); -NVF_API TensorView* triu(TensorView* tv, Val* offset); // UNARY OPERATIONS // abs diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 1db959115a2..a6f2c343d49 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -56,6 +56,51 @@ TensorView* dropout_backward(TensorView* dy, TensorView* mask, Val* scale) { return dx; } +TensorView* triu(TensorView* tv, Val* offset) { + NVF_CHECK( + offset->getDataType() == DataType::Int, "offset must have type Int"); + + NVF_CHECK( + tv->nDims() >= 2, + "triu is only supported for 2+D tensors, but got ", + tv->nDims(), + "D tensor"); + + // Let's say we want a triu of a 2D tensor of shape [2, 4] + // We broadcast the iota of the outer dim + // [0 [0, 0, 0, 0] + // 1] -> [1, 1, 1, 1] + // We broadcast the iota of the inner dim + // [0, 1, 2, 3] -> [0, 1, 2, 3] + // [0, 1, 2, 3] + // Using LE on the bcast tensors we get the mask + //[0, 0, 0, 0] LE [0, 1, 2, 3] + //[1, 1, 1, 1] [0, 1, 2, 3] + // Gives: + //[1, 1, 1, 1] + //[0, 1, 1, 1] + // If triu has an offset of k, we shift/subtract the iota of the columns by k + // before broadcasting and comparing with the iota of the rows. + auto dims = tv->domain()->logical().size(); + auto tv_rows = iota( + tv->domain()->logical()[dims - 2]->extent(), + IrBuilder::create(0, DataType::Index), + IrBuilder::create(1, DataType::Index), + DataType::Int); + + auto tv_columns = iota( + tv->domain()->logical()[dims - 1]->extent(), + SimplifyingIrBuilder::mulExpr( + offset, IrBuilder::create(-1, DataType::Index)), + IrBuilder::create(1, DataType::Index), + DataType::Int); + + auto tv_rows_b = broadcast(tv_rows, {false, true}); + auto tv_cols_b = broadcast(tv_columns, {true, false}); + auto mask = le(tv_rows_b, tv_cols_b); + return where(mask, tv, IrBuilder::create(0, tv->dtype())); +} + namespace { TensorView* newForLinear( diff --git a/csrc/ops/composite.h b/csrc/ops/composite.h index ecbbb89b5a3..b67015b994d 100644 --- a/csrc/ops/composite.h +++ b/csrc/ops/composite.h @@ -35,6 +35,8 @@ NVF_API TensorView* dropout_backward( TensorView* mask, Val* scale); +NVF_API TensorView* triu(TensorView* tv, Val* offset); + struct LstmResult { TensorView* cell = nullptr; TensorView* hidden = nullptr; From ad6021cf52e02ec65479c8cbbd1738f5063eaf1a Mon Sep 17 00:00:00 2001 From: protonu Date: Thu, 2 Jan 2025 13:19:06 -0800 Subject: [PATCH 5/9] fixes based on reviewer comments --- csrc/ops/composite.cpp | 8 ++++---- tests/cpp/test_tensor_factories.cpp | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index a6f2c343d49..6a11aceb893 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -58,7 +58,7 @@ TensorView* dropout_backward(TensorView* dy, TensorView* mask, Val* scale) { TensorView* triu(TensorView* tv, Val* offset) { NVF_CHECK( - offset->getDataType() == DataType::Int, "offset must have type Int"); + offset->getDataType() == DataType::Index, "offset must have type Index"); NVF_CHECK( tv->nDims() >= 2, @@ -81,19 +81,19 @@ TensorView* triu(TensorView* tv, Val* offset) { //[0, 1, 1, 1] // If triu has an offset of k, we shift/subtract the iota of the columns by k // before broadcasting and comparing with the iota of the rows. - auto dims = tv->domain()->logical().size(); + auto dims = TensorDomain::noReductions(tv->getLogicalDomain()).size(); auto tv_rows = iota( tv->domain()->logical()[dims - 2]->extent(), IrBuilder::create(0, DataType::Index), IrBuilder::create(1, DataType::Index), - DataType::Int); + DataType::Index); auto tv_columns = iota( tv->domain()->logical()[dims - 1]->extent(), SimplifyingIrBuilder::mulExpr( offset, IrBuilder::create(-1, DataType::Index)), IrBuilder::create(1, DataType::Index), - DataType::Int); + DataType::Index); auto tv_rows_b = broadcast(tv_rows, {false, true}); auto tv_cols_b = broadcast(tv_columns, {true, false}); diff --git a/tests/cpp/test_tensor_factories.cpp b/tests/cpp/test_tensor_factories.cpp index 170c24dd6e2..110bc91cc73 100644 --- a/tests/cpp/test_tensor_factories.cpp +++ b/tests/cpp/test_tensor_factories.cpp @@ -245,7 +245,7 @@ TEST_F(TensorFactoryTest, SimpleTriu) { fusion->addInput(tv_to_triu_on); auto out = - triu(tv_to_triu_on, IrBuilder::create(offset, DataType::Int)); + triu(tv_to_triu_on, IrBuilder::create(offset, DataType::Index)); fusion->addOutput(out); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); From c90b21cd7c5c42b0ceb0e2347205627f678f58eb Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 3 Jan 2025 08:50:18 -0800 Subject: [PATCH 6/9] reviewer comments --- csrc/ops/composite.cpp | 20 +++++++----- tests/cpp/test_tensor_factories.cpp | 47 ++++++++++++++++++----------- 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 6a11aceb893..6044013fa7f 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -79,19 +79,25 @@ TensorView* triu(TensorView* tv, Val* offset) { // Gives: //[1, 1, 1, 1] //[0, 1, 1, 1] - // If triu has an offset of k, we shift/subtract the iota of the columns by k - // before broadcasting and comparing with the iota of the rows. - auto dims = TensorDomain::noReductions(tv->getLogicalDomain()).size(); + auto tv_logical_no_reductions = + TensorDomain::noReductions(tv->getLogicalDomain()); + auto dims = tv_logical_no_reductions.size(); + auto tv_rows = iota( - tv->domain()->logical()[dims - 2]->extent(), + tv_logical_no_reductions[dims - 2]->extent(), IrBuilder::create(0, DataType::Index), IrBuilder::create(1, DataType::Index), DataType::Index); + // If triu has an offset of k, we shift/subtract the iota of the columns by k + // before broadcasting and comparing with the iota of the rows. + // So when building an iota op, instead of starting from 0 with a step of 1 + // we start from -offset (== -k) with a step of 1. + auto start_shifted_by_offset = SimplifyingIrBuilder::mulExpr( + offset, IrBuilder::create(-1, DataType::Index)); auto tv_columns = iota( - tv->domain()->logical()[dims - 1]->extent(), - SimplifyingIrBuilder::mulExpr( - offset, IrBuilder::create(-1, DataType::Index)), + tv_logical_no_reductions[dims - 1]->extent(), + start_shifted_by_offset, IrBuilder::create(1, DataType::Index), DataType::Index); diff --git a/tests/cpp/test_tensor_factories.cpp b/tests/cpp/test_tensor_factories.cpp index 110bc91cc73..99096b39a2c 100644 --- a/tests/cpp/test_tensor_factories.cpp +++ b/tests/cpp/test_tensor_factories.cpp @@ -231,30 +231,41 @@ TEST_F(TensorFactoryTest, StandaloneIota) { } TEST_F(TensorFactoryTest, SimpleTriu) { - std::vector> input_sizes = { - {64, 64}, {4, 16}, {16, 4}, {16, 8, 32}}; + std::vector> input_sizes_2d = { + {64, 64}, {4, 16}, {16, 4}}; + std::vector> input_sizes_3d = {{16, 8, 32}}; auto offsets = {0, 1, 2, -1, -2, 200, -200}; - for (auto input_size : input_sizes) { - for (auto offset : offsets) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv_to_triu_on = - makeSymbolicTensor(input_size.size(), DataType::Half); - fusion->addInput(tv_to_triu_on); + for (auto in : {input_sizes_2d, input_sizes_3d}) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); - auto out = - triu(tv_to_triu_on, IrBuilder::create(offset, DataType::Index)); - fusion->addOutput(out); + auto tv_to_triu_on = makeSymbolicTensor(in.at(0).size(), DataType::Half); + auto input_offset = IrBuilder::create(DataType::Index); + auto out = triu(tv_to_triu_on, input_offset); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); - auto in_tensor = at::randn(input_size, options); + fusion->addInput(tv_to_triu_on); + fusion->addInput(input_offset); + fusion->addOutput(out); - FusionExecutorCache executor_cache(std::move(fusion)); - auto cg_outputs = executor_cache.runFusionWithInputs({in_tensor}); + FusionExecutorCache executor_cache(std::move(fusion)); - EXPECT_TRUE(at::equal(cg_outputs[0], at::triu(in_tensor, offset))); + for (auto input_size : in) { + for (auto offset : offsets) { + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto in_tensor = at::randn(input_size, options); + + auto cg_outputs = + executor_cache.runFusionWithInputs({in_tensor, offset}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + {in_tensor, offset}, + {at::triu(in_tensor, offset)}, + __LINE__, + __FILE__); + } } } } From e40db84f48065ca745b425ae8554967c54b11d45 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 3 Jan 2025 09:33:44 -0800 Subject: [PATCH 7/9] offset should be dtype int --- csrc/ops/composite.cpp | 2 +- tests/cpp/test_tensor_factories.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 6044013fa7f..eaeb9dabc9b 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -58,7 +58,7 @@ TensorView* dropout_backward(TensorView* dy, TensorView* mask, Val* scale) { TensorView* triu(TensorView* tv, Val* offset) { NVF_CHECK( - offset->getDataType() == DataType::Index, "offset must have type Index"); + offset->getDataType() == DataType::Int, "offset must have type Int"); NVF_CHECK( tv->nDims() >= 2, diff --git a/tests/cpp/test_tensor_factories.cpp b/tests/cpp/test_tensor_factories.cpp index 99096b39a2c..3d95ad7d3c4 100644 --- a/tests/cpp/test_tensor_factories.cpp +++ b/tests/cpp/test_tensor_factories.cpp @@ -241,7 +241,7 @@ TEST_F(TensorFactoryTest, SimpleTriu) { FusionGuard fg(fusion.get()); auto tv_to_triu_on = makeSymbolicTensor(in.at(0).size(), DataType::Half); - auto input_offset = IrBuilder::create(DataType::Index); + auto input_offset = IrBuilder::create(DataType::Int); auto out = triu(tv_to_triu_on, input_offset); fusion->addInput(tv_to_triu_on); From c7fe0707d881b98eed695f26b14f603376b3ad1c Mon Sep 17 00:00:00 2001 From: protonu Date: Sat, 4 Jan 2025 07:22:00 -0800 Subject: [PATCH 8/9] modifying checks --- csrc/ops/composite.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index eaeb9dabc9b..716ab7a8c2e 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -58,13 +58,8 @@ TensorView* dropout_backward(TensorView* dy, TensorView* mask, Val* scale) { TensorView* triu(TensorView* tv, Val* offset) { NVF_CHECK( - offset->getDataType() == DataType::Int, "offset must have type Int"); - - NVF_CHECK( - tv->nDims() >= 2, - "triu is only supported for 2+D tensors, but got ", - tv->nDims(), - "D tensor"); + isIntegralType(offset->getDataType().value()), + "offset must have integral type"); // Let's say we want a triu of a 2D tensor of shape [2, 4] // We broadcast the iota of the outer dim @@ -83,6 +78,12 @@ TensorView* triu(TensorView* tv, Val* offset) { TensorDomain::noReductions(tv->getLogicalDomain()); auto dims = tv_logical_no_reductions.size(); + NVF_CHECK( + dims >= 2, + "triu is only supported for 2+D tensors, but got ", + dims, + "D tensor"); + auto tv_rows = iota( tv_logical_no_reductions[dims - 2]->extent(), IrBuilder::create(0, DataType::Index), From d94b35c90391dcc3090162c69896accb3313032e Mon Sep 17 00:00:00 2001 From: protonu Date: Mon, 6 Jan 2025 06:42:11 -0800 Subject: [PATCH 9/9] reviewer comments --- csrc/ops/composite.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 716ab7a8c2e..d2f0d9277d2 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -84,28 +84,29 @@ TensorView* triu(TensorView* tv, Val* offset) { dims, "D tensor"); + auto fusion = tv->fusion(); + auto tv_rows = iota( tv_logical_no_reductions[dims - 2]->extent(), - IrBuilder::create(0, DataType::Index), - IrBuilder::create(1, DataType::Index), + fusion->zeroVal(DataType::Index), + fusion->oneVal(DataType::Index), DataType::Index); // If triu has an offset of k, we shift/subtract the iota of the columns by k // before broadcasting and comparing with the iota of the rows. // So when building an iota op, instead of starting from 0 with a step of 1 // we start from -offset (== -k) with a step of 1. - auto start_shifted_by_offset = SimplifyingIrBuilder::mulExpr( - offset, IrBuilder::create(-1, DataType::Index)); + auto start_shifted_by_offset = SimplifyingIrBuilder::negExpr(offset); auto tv_columns = iota( tv_logical_no_reductions[dims - 1]->extent(), start_shifted_by_offset, - IrBuilder::create(1, DataType::Index), + fusion->oneVal(DataType::Index), DataType::Index); auto tv_rows_b = broadcast(tv_rows, {false, true}); auto tv_cols_b = broadcast(tv_columns, {true, false}); auto mask = le(tv_rows_b, tv_cols_b); - return where(mask, tv, IrBuilder::create(0, tv->dtype())); + return where(mask, tv, fusion->zeroVal(DataType::Index)); } namespace {