Skip to content

Commit e92c575

Browse files
committed
new unit test
1 parent 49b0862 commit e92c575

File tree

3 files changed

+52
-0
lines changed

3 files changed

+52
-0
lines changed

csrc/ops/arith.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,35 @@ TensorView* eye(Val* size, DataType dtype) {
494494
return eye(size, size, dtype);
495495
}
496496

497+
TensorView* triu(TensorView* tv, Val* offset) {
498+
NVF_CHECK(
499+
offset->getDataType() == DataType::Int, "offset must have type Int");
500+
501+
NVF_CHECK(
502+
tv->nDims() >= 2,
503+
"triu is only supported for 2+D tensors, but got ",
504+
tv->nDims(),
505+
"D tensor");
506+
507+
auto tv_rows = iota(
508+
tv->domain()->logical()[1]->extent(),
509+
IrBuilder::create<Val>(0, DataType::Index),
510+
IrBuilder::create<Val>(1, DataType::Index),
511+
DataType::Int);
512+
513+
auto tv_columns = iota(
514+
tv->domain()->logical()[2]->extent(),
515+
SimplifyingIrBuilder::mulExpr(
516+
offset, IrBuilder::create<Val>(-1, DataType::Index)),
517+
IrBuilder::create<Val>(1, DataType::Index),
518+
DataType::Int);
519+
520+
auto tv_rows_b = broadcast(tv_rows, {false, true});
521+
auto tv_cols_b = broadcast(tv_columns, {true, false});
522+
auto mask = le(tv_rows_b, tv_cols_b);
523+
return where(mask, tv, IrBuilder::create<Val>(0, tv->dtype()));
524+
}
525+
497526
// UNARY OPERATIONS
498527

499528
#define NVFUSER_DEFINE_UNARY_OP(operator_name, operator_type) \

csrc/ops/arith.h

+1
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ NVF_API TensorView* arange(
251251
DataType dtype = DataType::Int);
252252
NVF_API TensorView* eye(Val* size, DataType dtype);
253253
NVF_API TensorView* eye(Val* rows, Val* cols, DataType dtype);
254+
NVF_API TensorView* triu(TensorView* tv, Val* offset);
254255

255256
// UNARY OPERATIONS
256257
// abs

tests/cpp/test_tensor_factories.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,28 @@ TEST_F(TensorFactoryTest, StandaloneIota) {
230230
}
231231
}
232232

233+
TEST_F(TensorFactoryTest, SimpleTriu) {
234+
auto fusion = std::make_unique<Fusion>();
235+
236+
FusionGuard gf(fusion.get());
237+
238+
auto tv_to_triu_on = makeSymbolicTensor(3, DataType::Half);
239+
fusion->addInput(tv_to_triu_on);
240+
241+
int64_t k_factor = -2;
242+
auto out = triu(tv_to_triu_on, IrBuilder::create<Val>(k_factor, DataType::Int));
243+
fusion->addOutput(out);
244+
245+
246+
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
247+
auto in_tensor = at::randn({4, 4, 8}, options);
248+
249+
FusionExecutorCache executor_cache(std::move(fusion));
250+
auto cg_outputs = executor_cache.runFusionWithInputs({in_tensor});
251+
252+
EXPECT_TRUE(cg_outputs[0].allclose(at::triu(in_tensor, k_factor), .001, .001));
253+
}
254+
233255
TEST_F(TensorFactoryTest, StandaloneARange) {
234256
auto starts_ends = {-1., 0., 10.3, 1024. * 256};
235257
auto steps = {-1.5, 1., 2.};

0 commit comments

Comments
 (0)