Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce repeat and RepeatOp #3687

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions csrc/device_lower/pass/fusion_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ class LoadStoreOpInserter : private kir::ExprMutator {
container, LoadStoreOpType::Set, out, in));
}

void handle(RepeatOp* op) final {
auto out = op->out();
auto in = op->in();
auto container = out->container();
registerReplaceAndPropagate(
op,
IrBuilder::createInContainer<LoadStoreOp>(
container, LoadStoreOpType::Set, out, in));
}

void handle(ViewOp* vop) final {
auto out = vop->out();
auto in = vop->in();
Expand Down
1 change: 1 addition & 0 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ bool isTvOp(const Expr* expr) {
BroadcastOp,
SqueezeOp,
ExpandOp,
RepeatOp,
ViewAsScalar,
ViewOp,
PadOp,
Expand Down
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class Val;
f(BroadcastOp); \
f(SqueezeOp); \
f(ExpandOp); \
f(RepeatOp); \
f(ViewAsScalar); \
f(ViewOp); \
f(CatOp); \
Expand Down
2 changes: 1 addition & 1 deletion csrc/id_model/predicate_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ std::vector<IterDomain*> getPredicateDomains(
: consumer_tv->getLogicalDomain();

// Broadcast domains should not need to be predicated. Note that
// unlike indexing for TensorIndex, reduction doamins do need to be
// unlike indexing for TensorIndex, reduction domains do need to be
// indexed to guard the access to the producer tensor
predicate_domains.erase(
std::remove_if(
Expand Down
36 changes: 36 additions & 0 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,42 @@ class ExpandOp : public Expr {
const std::vector<PolymorphicValue>& inputs) const override;
};

// Represents a repetition of broadcast IDs. Repetitions of
// non-broadcast IDs are represented using the broadcast, expand and
// reshape pattern. See the repeat op implementation in ops/alias.cpp
// as well as the TranslateRepeatToExpand preseg pass.
class RepeatOp : public Expr {
public:
using Expr::Expr;

// in: Input tensor that have broadcast logical IDs.
// out: Output tensor where some of the input broadcast logical IDs
// are converted to concrete IDs. Their extents represent the
// repetition factor of each ID.
RepeatOp(IrBuilderPasskey, TensorView* out, TensorView* in);

NVFUSER_DECLARE_CLONE_AND_CREATE

const char* getOpString() const override {
return "RepeatOp";
}

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;

TensorView* out() const {
return output(0)->as<TensorView>();
}

TensorView* in() const {
return input(0)->as<TensorView>();
}

std::vector<PolymorphicValue> evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const override;
};

class ViewAsScalar : public Expr {
public:
using Expr::Expr;
Expand Down
68 changes: 68 additions & 0 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2135,6 +2135,74 @@ std::vector<PolymorphicValue> ExpandOp::evaluate(

NVFUSER_DEFINE_CLONE_AND_CREATE(ExpandOp)

RepeatOp::RepeatOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in)
: Expr(passkey) {
auto in_domain = TensorDomain::noReductions(in->getLogicalDomain());
const auto& out_domain = out->getLogicalDomain();

NVF_ERROR(in_domain.size() == out_domain.size());

bool repetition_found = false;
for (const auto i : c10::irange(in_domain.size())) {
if (in_domain.at(i)->isBroadcast() && !out_domain.at(i)->isBroadcast()) {
NVF_ERROR(!in_domain.at(i)->hasExpandedExtent());
NVF_ERROR(in_domain.at(i)->extent()->isOneInt());
repetition_found = true;
}
}

NVF_ERROR(
repetition_found,
"No repetition dim found: ",
out->toString(),
", ",
in->toString());

addOutput(out);
addInput(in);
}

std::string RepeatOp::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << out()->toString() << " = repeat( " << in()
<< " )\n";
return ss.str();
}

std::string RepeatOp::toInlineString(int indent_size) const {
NVF_CHECK(false, "Tensor op can not be printed inline");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question for my understanding: is it correct that only IterDomain ops can be printed inline according to our convention?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like that's the case, but I just copied this from ExpandOp::toInlineString.

}

std::vector<PolymorphicValue> RepeatOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
NVF_ERROR(
inputs.size() == 1,
"ConcretizeOp expects exactly 1 input, but received ",
naoyam marked this conversation as resolved.
Show resolved Hide resolved
inputs.size());
auto tensor = inputs.at(0).as<at::Tensor>();
std::vector<int64_t> sizes;
sizes.reserve(out()->getLogicalDomain().size());
const auto c2p =
PairwiseLogicalDomainMap(in(), out()).mapConsumerToProducer();
for (const auto i : c10::irange(out()->getLogicalDomain().size())) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skip reduction?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is the output, it should not have a reduction ID. I added an assertion in the constructor.

auto out_id = out()->getLogicalDomain().at(i);
auto inp_id = c2p.at(out_id);
auto out_extent = ee.evaluate(out_id->extent()).as<int64_t>();
auto inp_extent = ee.evaluate(inp_id->extent()).as<int64_t>();
NVF_ERROR(
out_extent == inp_extent || out_extent % inp_extent == 0,
naoyam marked this conversation as resolved.
Show resolved Hide resolved
"Invalid input and output extents: ",
naoyam marked this conversation as resolved.
Show resolved Hide resolved
inp_extent,
", ",
out_extent);
sizes.push_back(out_extent / inp_extent);
naoyam marked this conversation as resolved.
Show resolved Hide resolved
}
return {tensor.repeat(sizes)};
}

NVFUSER_DEFINE_CLONE_AND_CREATE(RepeatOp)

ViewAsScalar::ViewAsScalar(
IrBuilderPasskey passkey,
Val* out,
Expand Down
4 changes: 4 additions & 0 deletions csrc/logical_domain_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,10 @@ class ComputeAtLogicalDomainMapBuilder : private BackwardVisitor {
mapPointwiseLikeOp(op);
}

void handle(RepeatOp* op) override {
mapPointwiseLikeOp(op);
Copy link
Collaborator

@wujingyue wujingyue Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it supposed to be mapped as pointwise? The input and the output don't even have the same extent.

}

void handle(PadOp* op) override {
// For compute-at, padded id should be mapped
mapPointwiseLikeOp(op);
Expand Down
84 changes: 84 additions & 0 deletions csrc/ops/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1124,4 +1124,88 @@ TensorView* expand_as(TensorView* inp, TensorView* other) {
return out_tensor;
}

TensorView* repeat(TensorView* inp_tv, std::vector<int64_t> repeat_times) {
naoyam marked this conversation as resolved.
Show resolved Hide resolved
const auto ndims =
TensorDomain::noReductions(inp_tv->getLogicalDomain()).size();

// Handle repetitions of non-broadcast IDs first. Each ID is
// individully repeated by:
//
// Step 1. Insert a broadcast ID immediately outside of the
// repeated ID
// Step 2. Expand the broadcast ID by the repetition factor
// Step 3. Flatten the expanded ID and the repeated ID

bool has_repetition_of_broadcast = false;
auto intermediate_tv = inp_tv;
for (const auto i : c10::irange(ndims)) {
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
if (repeat_times.at(i) == 1) {
continue;
}

auto inp_id = intermediate_tv->getLogicalDomain().at(i);

// Broadcast is handled after this
if (inp_id->isBroadcast()) {
has_repetition_of_broadcast = true;
continue;
}

// Step 1: Insert a broadcast ID
std::vector<bool> bcast_flags(ndims + 1, false);
bcast_flags.at(i) = true;
auto broadcast_tv = broadcast(intermediate_tv, bcast_flags);

// Step 2: Expand the broadcast ID for the repetition factor
std::vector<Val*> expanded_sizes(
bcast_flags.size(), IrBuilder::create<Val>(-1L));
expanded_sizes.at(i) = IrBuilder::create<Val>(repeat_times.at(i));
auto expanded_tv = expand(broadcast_tv, expanded_sizes);

// Step 3: Reshape to merge the broadcast ID and the repeated ID
intermediate_tv = flatten(expanded_tv, (int64_t)i, (int64_t)i + 1);
}

if (!has_repetition_of_broadcast) {
return intermediate_tv;
}

// Repeat broadcast IDs. The expand approach doesn't work as reshape
// would just squeeze repeated IDs and thus there would be no
// merge. Expanded IDs would remain to be expanded broadcast IDs. To
// concretize them, use RepeatOp
std::vector<IterDomain*> new_domain;
new_domain.reserve(ndims);
std::vector<std::optional<bool>> new_contiguity;
new_contiguity.reserve(ndims);

for (const auto i : c10::irange(ndims)) {
auto inp_id = intermediate_tv->getLogicalDomain().at(i);
IterDomain* new_id = nullptr;

if (repeat_times.at(i) > 1 && inp_id->isBroadcast()) {
new_id = IterDomainBuilder(inp_id)
.extent(IrBuilder::create<Val>(
repeat_times.at(i), DataType::Index))
.iter_type(IterType::Iteration)
.build();
} else {
new_id = inp_id->cloneWithoutRFactor();
}

new_domain.push_back(new_id);
new_contiguity.push_back(
new_id->isBroadcast() ? std::optional<bool>(std::nullopt)
: std::optional<bool>(true));
}

auto out_tv = IrBuilder::create<TensorView>(
IrBuilder::create<TensorDomain>(new_domain, new_contiguity),
inp_tv->dtype());

IrBuilder::create<RepeatOp>(out_tv, intermediate_tv);

return out_tv;
}

} // namespace nvfuser
5 changes: 5 additions & 0 deletions csrc/ops/alias.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,9 @@ NVF_API TensorView* expand(
// non broadcasted iter domain, inp will be expanded to other's size.
NVF_API TensorView* expand_as(TensorView* inp, TensorView* other);

// Repeat each dimension for a given time. The repeat_times parameter
// must have the same number of elements as the dimensionality of the
// input tensor (excluding reduction IDs).
NVF_API TensorView* repeat(TensorView* inp, std::vector<int64_t> repeat_times);

} // namespace nvfuser
48 changes: 18 additions & 30 deletions csrc/preseg_passes/translate_repeat_to_expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,11 @@ class RepeatToExpandTranslator {
}
}

// For each detected repetition:
//
// Step 1. Insert a broadcast ID immediately outside of the
// repeated ID
// Step 2. Expand the broadcast ID by the repetition factor
// Step 3. Flatten the expanded ID and the repeated ID
// For each detected repetition, replace the output with a repeat
// output.
void translate() {
FusionGuard fg(fusion_);

const auto exprs = fusion_->exprs();
// Apply the translation in a reverse topological order. Since the
// output of the repetition is replaced, the use exprs of the
Expand All @@ -145,36 +143,26 @@ class RepeatToExpandTranslator {

const auto& info = repeat_info_map_it->second;

if (info.cat_inp_tvs.size() < 2) {
const auto num_repetitions = (int64_t)info.cat_inp_tvs.size();

if (num_repetitions < 2) {
naoyam marked this conversation as resolved.
Show resolved Hide resolved
continue;
}

auto original_out_tv = expr->output(0)->as<TensorView>();

// Step 1
auto inp_domain =
const auto inp_domain =
TensorDomain::noReductions(info.input_tv->getLogicalDomain());
std::vector<bool> bcast_flags(inp_domain.size() + 1, false);
auto repeated_id_offset = std::distance(
inp_domain.begin(),
std::find(inp_domain.begin(), inp_domain.end(), info.repeated_id));
bcast_flags.at(repeated_id_offset) = true;
auto broadcast_tv = broadcast(info.input_tv, bcast_flags);
NVF_ERROR((size_t)broadcast_tv->nDims() == inp_domain.size() + 1);

// Step 2
std::vector<Val*> expanded_sizes(
bcast_flags.size(), IrBuilder::create<Val>(-1L));
expanded_sizes.at(repeated_id_offset) =
IrBuilder::create<Val>((int64_t)info.cat_inp_tvs.size());
auto expanded_tv = expand(broadcast_tv, expanded_sizes);

// Step 3
auto flattened_tv =
flatten(expanded_tv, repeated_id_offset, repeated_id_offset + 1);

std::vector<int64_t> repeated_times(inp_domain.size(), 1);
auto repeated_id_it =
std::find(inp_domain.begin(), inp_domain.end(), info.repeated_id);
NVF_ERROR(repeated_id_it != inp_domain.end());
auto repeated_dim = std::distance(inp_domain.begin(), repeated_id_it);
repeated_times.at(repeated_dim) = num_repetitions;

TensorView* replacement_tv = repeat(info.input_tv, repeated_times);

ir_utils::replaceValInAllExprInputsAndFusionOutputs(
original_out_tv, flattened_tv);
expr->output(0), replacement_tv);
}
}

Expand Down
Loading
Loading