Skip to content

Commit

Permalink
[Schema] Support TensorOrTupleTensor (#18)
Browse files Browse the repository at this point in the history
* remove

* schema

* refactor cast

* comment
  • Loading branch information
comaniac authored Apr 16, 2022
1 parent 0fbbeeb commit 13e787c
Show file tree
Hide file tree
Showing 13 changed files with 36 additions and 363 deletions.
116 changes: 0 additions & 116 deletions python/raf/_op/imp_arg.py

This file was deleted.

42 changes: 0 additions & 42 deletions python/raf/_op/imp_ret.py

This file was deleted.

143 changes: 0 additions & 143 deletions python/raf/_op/sym_arg.py

This file was deleted.

9 changes: 0 additions & 9 deletions python/raf/_op/sym_ret.py

This file was deleted.

32 changes: 0 additions & 32 deletions python/raf/_op/typing.py

This file was deleted.

6 changes: 1 addition & 5 deletions python/raf/_tvm_op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,8 @@ def embedding_dx_compute(attrs, inputs, output_type):

@register_compute("raf.op.tvm.group_cast")
def group_cast_compute(attrs, inputs, output_type):
tensor_list = inputs
dtype = attrs.dtype
out = []
for item in tensor_list:
out.append(_topi.cast(item, dtype))
return out
return [_topi.cast(item, dtype) for item in inputs]


_reg.register_injective_schedule("raf.op.tvm.group_cast")
2 changes: 1 addition & 1 deletion scripts/src_codegen/main_cxx_reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"ToDouble": "Double",
"ToIntTuple": "IntOrTupleInt",
"ToIntArray": "IntArray",
"ToTensorTuple": "TupleTensor",
"ToTensorTuple": "TensorOrTupleTensor",
}


Expand Down
10 changes: 7 additions & 3 deletions src/op/declare/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -893,8 +893,6 @@ RAF_OP_DECLARE("raf.op.group_cast", [](const CallValues& call) {
const auto* args = call->args.as<GroupCastArgs>();
CHECK(args != nullptr);

const DLTensor* first_tensor = args->tensor_list[0];

std::vector<TensorValue> ret;
std::string dtype = args->dtype;
for (int i = 0; i < args->tensor_list.size(); ++i) {
Expand All @@ -903,9 +901,15 @@ RAF_OP_DECLARE("raf.op.group_cast", [](const CallValues& call) {
ret.push_back(TensorValue::Assemble(/*dev=*/x->device,
/*dtype=*/String2DLDataType(dtype),
/*shape=*/oshape));
if (i == 0) {
call->device = x->device;
} else {
raf::Device device = x->device;
CHECK_EQ(call->device.device_type(), device.device_type()) << "Device type mismatch";
CHECK_EQ(call->device.device_id(), device.device_id()) << "Device id mismatch";
}
}
call->out = TupleValue::make(ir::Array<Value>(ret.begin(), ret.end()));
call->device = first_tensor->device;
});

RAF_OP_DECLARE("raf.op.gather", [](const CallValues& call) {
Expand Down
Loading

0 comments on commit 13e787c

Please sign in to comment.