Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Python bindings and tests for Triu #3637

Draft
wants to merge 1 commit into
base: pbasu_iota_experiment
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions csrc/ops/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,9 @@ TensorView* triu(TensorView* tv, Val* offset) {

NVF_CHECK(
tv->nDims() >= 2,
"triu is only supported for 2+D tensors, but got ",
"input tensor for triu must have 2 or more dims, but got ",
tv->nDims(),
"D tensor");
" dims");

// Let's say we want a triu of a 2D tensor of shape [2, 4]
// We broadcast the iota of the outer dim
Expand Down Expand Up @@ -539,6 +539,10 @@ TensorView* triu(TensorView* tv, Val* offset) {
return where(mask, tv, IrBuilder::create<Val>(0, tv->dtype()));
}

TensorView* triu(TensorView* tv) {
return triu(tv, IrBuilder::create<Val>(0));
}

// UNARY OPERATIONS

#define NVFUSER_DEFINE_UNARY_OP(operator_name, operator_type) \
Expand Down
1 change: 1 addition & 0 deletions csrc/ops/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ NVF_API TensorView* arange(
NVF_API TensorView* eye(Val* size, DataType dtype);
NVF_API TensorView* eye(Val* rows, Val* cols, DataType dtype);
NVF_API TensorView* triu(TensorView* tv, Val* offset);
NVF_API TensorView* triu(TensorView* tv);

// UNARY OPERATIONS
// abs
Expand Down
32 changes: 32 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1693,6 +1693,38 @@ void initNvFuserPythonBindings(PyObject* module) {
NVFUSER_PYTHON_BINDING_UNARY_OP("imag", imag)
#undef NVFUSER_PYTHON_BINDING_UNARY_OP

nvf_ops.def(
"triu",
[](FusionDefinition::Operators& self,
Tensor input,
std::optional<Scalar> offset) -> Tensor {
FUSER_PERF_SCOPE("Operators.triu");
NVF_CHECK(
self.validUse(), "Attempting to add to a completed definition!");
FusionDefinition* fd = self.fusion_definition;
Tensor output = fd->defineTensor(input.dims);
if (offset.has_value()) {
fd->defineRecord(new OpRecord<TensorView*, TensorView*, Val*>(
{fd->recordingState(input()),
fd->recordingState(offset.value()())},
{fd->recordingState(output())},
("ops.triu"),
serde::RecordType::Binary_TV_VAL,
static_cast<TensorView* (*)(TensorView*, Val*)>(triu)));
} else {
fd->defineRecord(new OpRecord<TensorView*, TensorView*>(
{fd->recordingState(input())},
{fd->recordingState(output())},
("ops.triu"),
serde::RecordType::Unary_TV,
static_cast<TensorView* (*)(TensorView*)>(triu)));
}
return output;
},
py::arg("input"),
py::arg("offset") = std::nullopt,
py::return_value_policy::reference);

// overload to
nvf_ops.def(
"stride_order",
Expand Down
9 changes: 9 additions & 0 deletions csrc/serde/fusion_record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,13 @@ void RecordFunctorFactory::setupFunctionMaps() {
("ops." op_str), static_cast<TensorView* (*)(TensorView*)>(op_name)); \
unary_val.emplace(("ops." op_str), static_cast<Val* (*)(Val*)>(op_name));

#define NVFUSER_UNARY_TV_ALPHA_OP(op_str, op_name) \
unary_tv.emplace( \
("ops." op_str), static_cast<TensorView* (*)(TensorView*)>(op_name)); \
binary_tv_val.emplace( \
("ops." op_str), \
static_cast<TensorView* (*)(TensorView*, Val*)>(op_name));

#define NVFUSER_BINARY_TV_ONLY_OP(op_str, op_name) \
binary_tv.emplace( \
("ops." op_str), \
Expand Down Expand Up @@ -808,6 +815,8 @@ void RecordFunctorFactory::setupFunctionMaps() {
NVFUSER_UNARY_TV_OP("real", real)
NVFUSER_UNARY_TV_OP("imag", imag)

NVFUSER_UNARY_TV_ALPHA_OP("triu", triu)

NVFUSER_BINARY_TV_ONLY_OP("matmul", matmul)
NVFUSER_BINARY_TV_ONLY_OP("linear", linear)
NVFUSER_TERNARY_TV_ONLY_OP("linear", linear)
Expand Down
36 changes: 36 additions & 0 deletions tests/python/opinfo_input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,8 @@ def elementwise_unary_generator(

# Typical inputs
for shape in shapes:
if op.name == "triu" and len(shape) < 2:
continue
yield SampleInput(make_arg(shape))
yield SampleInput(make_arg(shape, noncontiguous=True))

Expand Down Expand Up @@ -1591,3 +1593,37 @@ def div_input_generator(
denom = torch.where(denom_is_small, denom_scaled_to_minabs, denom).detach()
denom.requires_grad_(requires_grad)
yield SampleInput(numer, denom)


def triu_input_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False):
offsets = (0, 1, -1, 2, 3, -3, 1024, -1024)

for element in elementwise_unary_generator(
op,
dtype,
requires_grad,
enable_extremal_value_testing=False,
enable_large_value_testing=False,
enable_small_value_testing=False,
):
yield element
for offset in offsets:
yield SampleInput(*element.args, offset)


def triu_error_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False):
make_arg = partial(
make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad
)

invalid_shapes = (
(),
(4,),
)
yield SampleInput(
make_arg((4, 16)), 5.6
), RuntimeError, "offset must have type Int",
for shape in invalid_shapes:
yield SampleInput(
make_arg(shape),
), RuntimeError, f"input tensor for triu must have 2 or more dims, but got {len(shape)} dims",
15 changes: 15 additions & 0 deletions tests/python/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
matmul_input_generator,
linear_input_generator,
linear_error_generator,
triu_input_generator,
triu_error_generator,
)
from utils import (
bool_int_dtypes,
Expand Down Expand Up @@ -1218,6 +1220,18 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor):
)
linear_ops.append(linear_opinfo)

tv_val_ops = []

triu_opinfo = OpInfo(
lambda fd: fd.ops.triu,
"triu",
sample_input_generator=triu_input_generator,
error_input_generator=triu_error_generator,
reference=torch.triu,
)

tv_val_ops.append(triu_opinfo)

""" End Tensor Creation """

# Puts all opinfos into the "opinfos" list
Expand All @@ -1231,3 +1245,4 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor):
opinfos.extend(tensor_creation_ops)
opinfos.extend(matmul_ops)
opinfos.extend(linear_ops)
opinfos.extend(tv_val_ops)
15 changes: 15 additions & 0 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,21 @@ def fusion_func(fd: FusionDefinition):
self.assertEqual(eager_out2, nvf_out[1])
# self.assertEqual(eager_out3, nvf_out[2])

def test_triu(self):
inputs = [
torch.randn(4, 16, device="cuda", dtype=torch.float16),
]

def fusion_func(fd: FusionDefinition) -> None:
t0 = fd.from_pytorch(inputs[0])
tt = fd.define_scalar(-1, dtype=DataType.Int)
t1 = fd.ops.triu(t0, tt)
fd.add_output(t1)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
eager_out0 = torch.triu(inputs[0], -1)
self.assertEqual(eager_out0, nvf_out[0])

def test_complex_rsqrt(self):
inputs = [
torch.randn(4, device="cuda", dtype=torch.complex64),
Expand Down
Loading