-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce repeat and RepeatOp #3687
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2135,6 +2135,74 @@ std::vector<PolymorphicValue> ExpandOp::evaluate( | |
|
||
NVFUSER_DEFINE_CLONE_AND_CREATE(ExpandOp) | ||
|
||
RepeatOp::RepeatOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) | ||
: Expr(passkey) { | ||
auto in_domain = TensorDomain::noReductions(in->getLogicalDomain()); | ||
const auto& out_domain = out->getLogicalDomain(); | ||
|
||
NVF_ERROR(in_domain.size() == out_domain.size()); | ||
|
||
bool repetition_found = false; | ||
for (const auto i : c10::irange(in_domain.size())) { | ||
if (in_domain.at(i)->isBroadcast() && !out_domain.at(i)->isBroadcast()) { | ||
NVF_ERROR(!in_domain.at(i)->hasExpandedExtent()); | ||
NVF_ERROR(in_domain.at(i)->extent()->isOneInt()); | ||
repetition_found = true; | ||
} | ||
} | ||
|
||
NVF_ERROR( | ||
repetition_found, | ||
"No repetition dim found: ", | ||
out->toString(), | ||
", ", | ||
in->toString()); | ||
|
||
addOutput(out); | ||
addInput(in); | ||
} | ||
|
||
std::string RepeatOp::toString(int indent_size) const { | ||
std::stringstream ss; | ||
indent(ss, indent_size) << out()->toString() << " = repeat( " << in() | ||
<< " )\n"; | ||
return ss.str(); | ||
} | ||
|
||
std::string RepeatOp::toInlineString(int indent_size) const { | ||
NVF_CHECK(false, "Tensor op can not be printed inline"); | ||
} | ||
|
||
std::vector<PolymorphicValue> RepeatOp::evaluate( | ||
const ExpressionEvaluator& ee, | ||
const std::vector<PolymorphicValue>& inputs) const { | ||
NVF_ERROR( | ||
inputs.size() == 1, | ||
"ConcretizeOp expects exactly 1 input, but received ", | ||
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
inputs.size()); | ||
auto tensor = inputs.at(0).as<at::Tensor>(); | ||
std::vector<int64_t> sizes; | ||
sizes.reserve(out()->getLogicalDomain().size()); | ||
const auto c2p = | ||
PairwiseLogicalDomainMap(in(), out()).mapConsumerToProducer(); | ||
for (const auto i : c10::irange(out()->getLogicalDomain().size())) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Skip reduction? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is the output, it should not have a reduction ID. I added an assertion in the constructor. |
||
auto out_id = out()->getLogicalDomain().at(i); | ||
auto inp_id = c2p.at(out_id); | ||
auto out_extent = ee.evaluate(out_id->extent()).as<int64_t>(); | ||
auto inp_extent = ee.evaluate(inp_id->extent()).as<int64_t>(); | ||
NVF_ERROR( | ||
out_extent == inp_extent || out_extent % inp_extent == 0, | ||
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Invalid input and output extents: ", | ||
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
inp_extent, | ||
", ", | ||
out_extent); | ||
sizes.push_back(out_extent / inp_extent); | ||
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
return {tensor.repeat(sizes)}; | ||
} | ||
|
||
NVFUSER_DEFINE_CLONE_AND_CREATE(RepeatOp) | ||
|
||
ViewAsScalar::ViewAsScalar( | ||
IrBuilderPasskey passkey, | ||
Val* out, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -504,6 +504,10 @@ class ComputeAtLogicalDomainMapBuilder : private BackwardVisitor { | |
mapPointwiseLikeOp(op); | ||
} | ||
|
||
void handle(RepeatOp* op) override { | ||
mapPointwiseLikeOp(op); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it supposed to be mapped as pointwise? The input and the output don't even have the same extent. |
||
} | ||
|
||
void handle(PadOp* op) override { | ||
// For compute-at, padded id should be mapped | ||
mapPointwiseLikeOp(op); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question for my understanding: is it correct that only IterDomain ops can be printed inline according to our convention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like that's the case, but I just copied this from
ExpandOp::toInlineString
.