Skip to content

Commit

Permalink
creating a C++ API for triu
Browse files Browse the repository at this point in the history
  • Loading branch information
protonu committed Dec 23, 2024
1 parent e92c575 commit 5f82c65
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 15 deletions.
20 changes: 18 additions & 2 deletions csrc/ops/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,14 +504,30 @@ TensorView* triu(TensorView* tv, Val* offset) {
tv->nDims(),
"D tensor");

// Let's say we want a triu of a 2D tensor of shape [2, 4]
// We broadcast the iota of the outer dim
// [0 [0, 0, 0, 0]
// 1] -> [1, 1, 1, 1]
// We broadcast the iota of the inner dim
// [0, 1, 2, 3] -> [0, 1, 2, 3]
// [0, 1, 2, 3]
// Using LE on the bcast tensors we get the mask
//[0, 0, 0, 0] LE [0, 1, 2, 3]
//[1, 1, 1, 1] [0, 1, 2, 3]
// Gives:
//[1, 0, 0, 0]
//[0, 1, 0, 0]
// If triu has an offset of k, we shift/subtract the iota of the columns by k
// before broadcasting and comparing with the iota of the rows.
auto dims = tv->domain()->logical().size();
auto tv_rows = iota(
tv->domain()->logical()[1]->extent(),
tv->domain()->logical()[dims - 2]->extent(),
IrBuilder::create<Val>(0, DataType::Index),
IrBuilder::create<Val>(1, DataType::Index),
DataType::Int);

auto tv_columns = iota(
tv->domain()->logical()[2]->extent(),
tv->domain()->logical()[dims - 1]->extent(),
SimplifyingIrBuilder::mulExpr(
offset, IrBuilder::create<Val>(-1, DataType::Index)),
IrBuilder::create<Val>(1, DataType::Index),
Expand Down
33 changes: 20 additions & 13 deletions tests/cpp/test_tensor_factories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,25 +231,32 @@ TEST_F(TensorFactoryTest, StandaloneIota) {
}

TEST_F(TensorFactoryTest, SimpleTriu) {
auto fusion = std::make_unique<Fusion>();

FusionGuard gf(fusion.get());
std::vector<std::vector<int64_t>> input_sizes = {
{64, 64}, {4, 16}, {16, 4}, {16, 8, 32}};
auto offsets = {0, 1, 2, -1, -2, 200, -200};

auto tv_to_triu_on = makeSymbolicTensor(3, DataType::Half);
fusion->addInput(tv_to_triu_on);
for (auto input_size : input_sizes) {
for (auto offset : offsets) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

int64_t k_factor = -2;
auto out = triu(tv_to_triu_on, IrBuilder::create<Val>(k_factor, DataType::Int));
fusion->addOutput(out);
auto tv_to_triu_on =
makeSymbolicTensor(input_size.size(), DataType::Half);
fusion->addInput(tv_to_triu_on);

auto out =
triu(tv_to_triu_on, IrBuilder::create<Val>(offset, DataType::Int));
fusion->addOutput(out);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto in_tensor = at::randn({4, 4, 8}, options);
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto in_tensor = at::randn(input_size, options);

FusionExecutorCache executor_cache(std::move(fusion));
auto cg_outputs = executor_cache.runFusionWithInputs({in_tensor});
FusionExecutorCache executor_cache(std::move(fusion));
auto cg_outputs = executor_cache.runFusionWithInputs({in_tensor});

EXPECT_TRUE(cg_outputs[0].allclose(at::triu(in_tensor, k_factor), .001, .001));
EXPECT_TRUE(at::equal(cg_outputs[0], at::triu(in_tensor, offset)));
}
}
}

TEST_F(TensorFactoryTest, StandaloneARange) {
Expand Down

0 comments on commit 5f82c65

Please sign in to comment.