Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve conflicts by recomputation #3625

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
67 changes: 52 additions & 15 deletions csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,6 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) {

auto resize_based_tensor_ops = ir_utils::getOpsOfType<SliceOp, PadOp>(fusion);

if (auto non_exclusive_resizes = scheduler_tools::getNonExclusiveResizeInfo(
resize_based_tensor_ops, id_model.idGraph(IdMappingMode::EXACT));
!non_exclusive_resizes.empty()) {
std::stringstream msg;
msg << "Propagation of resizes would affect fusion outputs.";
for (const auto& [tv, resize_ids] : non_exclusive_resizes) {
msg << " Resize input tv: " << tv->toString()
<< ", resize input ID groups: " << nvfuser::toString(resize_ids);
}
scheduler_debug_utils::canScheduleRejectReason(schedulerType(), msg.str());
return false;
}

// Slicing of or to a broadcast ID is not allowed yet.
for (auto tensor_op : resize_based_tensor_ops) {
TensorView* out_tv = tensor_op->output(0)->as<TensorView>();
Expand Down Expand Up @@ -133,6 +120,30 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) {
return false;
}

for (auto out_tv : ir_utils::filterByType<TensorView>(fusion->outputs())) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is needed since the non-exclusivity check is dropped. It was redundant before.

if (out_tv == ref_tv) {
continue;
}
auto exprs = ValGraphBFS::getExprGroupsBetween(
broadcast_graph,
broadcast_graph.toGroups(ref_tv->getLogicalDomain()),
broadcast_graph.toGroups(out_tv->getLogicalDomain()),
/*require_all_to_visited=*/false)
.first;
for (const auto& [expr_g, dir] : exprs) {
if (expr_g->front()->isA<Resize>()) {
std::stringstream msg;
msg << "Resize between reference and output not allowed.";
msg << " Reference: " << ref_tv->toString()
<< ". Output: " << out_tv->toString()
<< ". Resize: " << expr_g->front()->toString();
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), msg.str());
return false;
}
}
}

// Disable the scheduler if there's a squeeze op. The loop option
// may also need to be enabled in that case, but that option is not
// turned on automatically yet.
Expand Down Expand Up @@ -163,6 +174,27 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
scheduler_utils::cacheInputs(fusion, true);
scheduler_utils::cacheAndForkOutputs(fusion, true);

auto resize_based_tensor_ops = ir_utils::getOpsOfType<SliceOp, PadOp>(fusion);

IdModel id_model(fusion, /*build_graphs=*/false);
const auto& exact_graph = id_model.buildExactGraph();

// Replicate resize inputs if necessary to avoid conflicting propagations
for (const auto& [out_tv, exlusivity_info] :
scheduler_tools::getNonExclusiveResizeInfo(
resize_based_tensor_ops, exact_graph)) {
auto resize_based_op = out_tv->definition();
auto inp_tv = resize_based_op->input(0)->as<TensorView>();
// Since cacheInput may skip caching if an input is used by
// slice/pad, inp_tv may be a fusion input, in which case it is
// not necessary to recompute the tensor.
if (inp_tv->isFusionInput()) {
continue;
}
auto inp_tv_copy = RecomputeTv::recompute(inp_tv);
ir_utils::replaceValInExprInputs(resize_based_op, inp_tv, inp_tv_copy);
}

for (auto expr : fusion->exprs()) {
if (!expr->isOneOf<SliceOp, PadOp>()) {
continue;
Expand All @@ -186,9 +218,14 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
ref_tv->axis(-1)->parallelize(ParallelType::TIDx);
ref_tv->axis(-2)->parallelize(ParallelType::BIDx);

// Propagate the reference to the other tensors
// Propagate the reference to the other tensors. Note that the
// update flag is enabled so to workaround the resize propagation
// issue. This may not work if there's a tensor that is reshaped
// from the reference tensor, but that should not be the case as the
// reference is picked by the same routine used for the pointwise
// scheduler.
scheduler_tools::scheduleLoopDomainsLike(
fusion->allTvs(), ref_tv->getLoopDomain());
fusion->allTvs(), ref_tv->getLoopDomain(), true);

inlineMost();

Expand Down
13 changes: 10 additions & 3 deletions csrc/scheduler/tools/resize_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ void propagateResizeToInputs(Expr* resize_tensor_op) {
}
}

std::unordered_map<TensorView*, ValGroups> getNonExclusiveResizeInfo(
std::unordered_map<TensorView*, ResizeExclusivityInfo> getNonExclusiveResizeInfo(
const std::vector<Expr*>& ordered_resize_tensor_ops,
const ValGraph& exact_graph) {
NVF_ERROR(!ordered_resize_tensor_ops.empty());
Fusion* fusion = ordered_resize_tensor_ops[0]->fusion();

std::unordered_map<TensorView*, ValGroups> non_exclusive_resizes;
std::unordered_map<TensorView*, ResizeExclusivityInfo> non_exclusive_resizes;

std::unordered_set<Val*> inputs{
fusion->inputs().begin(), fusion->inputs().end()};
Expand All @@ -98,6 +98,8 @@ std::unordered_map<TensorView*, ValGroups> getNonExclusiveResizeInfo(
auto inp_tv = dynamic_cast<TensorView*>(resize_tensor_op->inputs().at(0));
auto out_tv = dynamic_cast<TensorView*>(resize_tensor_op->outputs().at(0));

ResizeExclusivityInfo info;

ValGroups resize_inp_ids = get_root_to_logical_resizes(out_tv);
NVF_ERROR(!resize_inp_ids.empty());

Expand Down Expand Up @@ -159,10 +161,15 @@ std::unordered_map<TensorView*, ValGroups> getNonExclusiveResizeInfo(
}

// This resize input ID is not exclusively used
non_exclusive_resizes[inp_tv].pushBack(resize_inp_id);
info.shared_tvs.push_back(dep_tv);
info.resized_ids.pushBack(resize_inp_id);
}
}

if (!info.shared_tvs.empty()) {
NVF_ERROR(non_exclusive_resizes.emplace(out_tv, info).second);
}

// Analysis of exclusiveness until in_tv is done. Following
// resize-based tensor ops do not need to check the same section
// of the fusion and can start from out_tv.
Expand Down
16 changes: 15 additions & 1 deletion csrc/scheduler/tools/resize_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,21 @@ void propagateResizeToInputs(Expr* resize_op);
// The function returns a map from tensors that are input to
// non-exclusive ops to their resize input ID groups. This map will be
// used to resolve the non-exclusiveness by replication.
std::unordered_map<TensorView*, ValGroups> getNonExclusiveResizeInfo(
struct ResizeExclusivityInfo {
std::vector<TensorView*> shared_tvs;
// std::unordered_map<TensorView*, ValGroups> resized_ids;
ValGroups resized_ids;

bool operator==(const ResizeExclusivityInfo& other) const {
return shared_tvs == other.shared_tvs && resized_ids == other.resized_ids;
}

bool operator!=(const ResizeExclusivityInfo& other) const {
return !(*this == other);
}
};

std::unordered_map<TensorView*, ResizeExclusivityInfo> getNonExclusiveResizeInfo(
const std::vector<Expr*>& ordered_resize_tensor_ops,
const ValGraph& exact_graph);

Expand Down
Loading
Loading