-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce repeat and RepeatOp #3687
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2135,6 +2135,86 @@ std::vector<PolymorphicValue> ExpandOp::evaluate( | |
|
||
NVFUSER_DEFINE_CLONE_AND_CREATE(ExpandOp) | ||
|
||
RepeatOp::RepeatOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) | ||
: Expr(passkey) { | ||
auto in_domain = TensorDomain::noReductions(in->getLogicalDomain()); | ||
const auto& out_domain = out->getLogicalDomain(); | ||
|
||
NVF_ERROR(in_domain.size() == out_domain.size()); | ||
|
||
NVF_ERROR( | ||
std::none_of( | ||
out->getLogicalDomain().begin(), | ||
out->getLogicalDomain().end(), | ||
[](IterDomain* out_logical_id) { | ||
return out_logical_id->isReduction(); | ||
}), | ||
"Output should not have reduction IDs."); | ||
|
||
bool repetition_found = false; | ||
for (const auto i : c10::irange(in_domain.size())) { | ||
if (in_domain.at(i)->isBroadcast() && !out_domain.at(i)->isBroadcast()) { | ||
NVF_ERROR(!in_domain.at(i)->hasExpandedExtent()); | ||
NVF_ERROR(in_domain.at(i)->extent()->isOneInt()); | ||
repetition_found = true; | ||
} | ||
} | ||
|
||
NVF_ERROR( | ||
repetition_found, | ||
"No repetition dim found: ", | ||
out->toString(), | ||
", ", | ||
in->toString()); | ||
|
||
addOutput(out); | ||
addInput(in); | ||
} | ||
|
||
std::string RepeatOp::toString(int indent_size) const { | ||
std::stringstream ss; | ||
indent(ss, indent_size) << out()->toString() << " = repeat( " << in() | ||
<< " )\n"; | ||
return ss.str(); | ||
} | ||
|
||
std::string RepeatOp::toInlineString(int indent_size) const { | ||
NVF_CHECK(false, "Tensor op can not be printed inline"); | ||
} | ||
|
||
std::vector<PolymorphicValue> RepeatOp::evaluate( | ||
const ExpressionEvaluator& ee, | ||
const std::vector<PolymorphicValue>& inputs) const { | ||
NVF_ERROR( | ||
inputs.size() == 1, | ||
"RepeatOp expects exactly 1 input, but received ", | ||
inputs.size()); | ||
auto tensor = inputs.at(0).as<at::Tensor>(); | ||
std::vector<int64_t> multipliers; | ||
multipliers.reserve(out()->getLogicalDomain().size()); | ||
const auto c2p = | ||
PairwiseLogicalDomainMap(in(), out()).mapConsumerToProducer(); | ||
for (const auto i : c10::irange(out()->getLogicalDomain().size())) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Skip reduction? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is the output, it should not have a reduction ID. I added an assertion in the constructor. |
||
auto out_id = out()->getLogicalDomain().at(i); | ||
auto inp_id = c2p.at(out_id); | ||
auto out_extent = ee.evaluate(out_id->extent()).as<int64_t>(); | ||
auto inp_extent = ee.evaluate(inp_id->extent()).as<int64_t>(); | ||
NVF_ERROR( | ||
out_extent % inp_extent == 0, | ||
"For dimension ", | ||
i, | ||
", the output extent (", | ||
out_extent, | ||
" should be a multiple of the input extent (", | ||
inp_extent, | ||
")."); | ||
multipliers.push_back(out_extent / inp_extent); | ||
} | ||
return {tensor.repeat(multipliers)}; | ||
} | ||
|
||
NVFUSER_DEFINE_CLONE_AND_CREATE(RepeatOp) | ||
|
||
ViewAsScalar::ViewAsScalar( | ||
IrBuilderPasskey passkey, | ||
Val* out, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -504,6 +504,10 @@ class ComputeAtLogicalDomainMapBuilder : private BackwardVisitor { | |
mapPointwiseLikeOp(op); | ||
} | ||
|
||
void handle(RepeatOp* op) override { | ||
mapPointwiseLikeOp(op); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it supposed to be mapped as pointwise? The input and the output don't even have the same extent. |
||
} | ||
|
||
void handle(PadOp* op) override { | ||
// For compute-at, padded id should be mapped | ||
mapPointwiseLikeOp(op); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question for my understanding: is it correct that only IterDomain ops can be printed inline according to our convention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like that's the case, but I just copied this from
ExpandOp::toInlineString
.