diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index a858d723861..093f02286b5 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -504,14 +504,30 @@ TensorView* triu(TensorView* tv, Val* offset) { tv->nDims(), "D tensor"); + // Let's say we want a triu of a 2D tensor of shape [2, 4] + // We broadcast the iota of the outer dim + // [0 [0, 0, 0, 0] + // 1] -> [1, 1, 1, 1] + // We broadcast the iota of the inner dim + // [0, 1, 2, 3] -> [0, 1, 2, 3] + // [0, 1, 2, 3] + // Using LE on the bcast tensors we get the mask + //[0, 0, 0, 0] LE [0, 1, 2, 3] + //[1, 1, 1, 1] [0, 1, 2, 3] + // Gives: + //[1, 0, 0, 0] + //[0, 1, 0, 0] + // If triu has an offset of k, we shift/subtract the iota of the columns by k + // before broadcasting and comparing with the iota of the rows. + auto dims = tv->domain()->logical().size(); auto tv_rows = iota( - tv->domain()->logical()[1]->extent(), + tv->domain()->logical()[dims - 2]->extent(), IrBuilder::create(0, DataType::Index), IrBuilder::create(1, DataType::Index), DataType::Int); auto tv_columns = iota( - tv->domain()->logical()[2]->extent(), + tv->domain()->logical()[dims - 1]->extent(), SimplifyingIrBuilder::mulExpr( offset, IrBuilder::create(-1, DataType::Index)), IrBuilder::create(1, DataType::Index), diff --git a/tests/cpp/test_tensor_factories.cpp b/tests/cpp/test_tensor_factories.cpp index b31a1506de3..170c24dd6e2 100644 --- a/tests/cpp/test_tensor_factories.cpp +++ b/tests/cpp/test_tensor_factories.cpp @@ -231,25 +231,32 @@ TEST_F(TensorFactoryTest, StandaloneIota) { } TEST_F(TensorFactoryTest, SimpleTriu) { - auto fusion = std::make_unique(); - - FusionGuard gf(fusion.get()); + std::vector> input_sizes = { + {64, 64}, {4, 16}, {16, 4}, {16, 8, 32}}; + auto offsets = {0, 1, 2, -1, -2, 200, -200}; - auto tv_to_triu_on = makeSymbolicTensor(3, DataType::Half); - fusion->addInput(tv_to_triu_on); + for (auto input_size : input_sizes) { + for (auto offset : offsets) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); - int64_t k_factor = -2; - auto out = triu(tv_to_triu_on, IrBuilder::create(k_factor, DataType::Int)); - fusion->addOutput(out); + auto tv_to_triu_on = + makeSymbolicTensor(input_size.size(), DataType::Half); + fusion->addInput(tv_to_triu_on); + auto out = + triu(tv_to_triu_on, IrBuilder::create(offset, DataType::Int)); + fusion->addOutput(out); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); - auto in_tensor = at::randn({4, 4, 8}, options); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto in_tensor = at::randn(input_size, options); - FusionExecutorCache executor_cache(std::move(fusion)); - auto cg_outputs = executor_cache.runFusionWithInputs({in_tensor}); + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs({in_tensor}); - EXPECT_TRUE(cg_outputs[0].allclose(at::triu(in_tensor, k_factor), .001, .001)); + EXPECT_TRUE(at::equal(cg_outputs[0], at::triu(in_tensor, offset))); + } + } } TEST_F(TensorFactoryTest, StandaloneARange) {