Skip to content

Commit

Permalink
[mlir][scf] Track replacements using a listener in TileAndFuse (#120999)
Browse files Browse the repository at this point in the history
This PR makes TileAndFuse explicitly track replacements using a listener
instead of assuming that the results always come from the outer most
tiling loop. scf::tileUsingInterface can introduce merge operations
whose results are the actual replacements to use, instead of the outer
most loop results.
  • Loading branch information
Groverkss authored Dec 24, 2024
1 parent 852feea commit 6e3631d
Showing 1 changed file with 59 additions and 21 deletions.
80 changes: 59 additions & 21 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
Expand Down Expand Up @@ -1467,6 +1468,47 @@ void SliceTrackingListener::notifyOperationReplaced(Operation *op,
ValueRange replacement) {
removeOp(op);
}

//===----------------------------------------------------------------------===//
// ReplacementListener
//===----------------------------------------------------------------------===//

/// Listener that tracks updates replacements for values which can be mutated.
/// This listener runs on top of the existing listener for the rewriter,
/// to make sure external users can still run listeners.
class ReplacementListener : public RewriterBase::ForwardingListener {
public:
ReplacementListener(DenseMap<Value, Value> &replacements,
OpBuilder::Listener *listener)
: ForwardingListener(listener), replacements(replacements) {}

void updateReplacementValues(ValueRange origValues,
ValueRange replaceValues) {
// This can probably be written better, but just iterates over the map
// and the new replacements for now.
for (auto &[key, val] : replacements) {
for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
if (val == orig) {
val = replace;
}
}
}
}

void notifyOperationReplaced(Operation *op, Operation *newOp) override {
ForwardingListener::notifyOperationReplaced(op, newOp);
updateReplacementValues(op->getResults(), newOp->getResults());
}

void notifyOperationReplaced(Operation *op, ValueRange values) override {
ForwardingListener::notifyOperationReplaced(op, values);
updateReplacementValues(op->getResults(), values);
}

private:
DenseMap<Value, Value> &replacements;
};

} // namespace

/// Implementation of tile consumer and fuse producer greedily.
Expand All @@ -1493,26 +1535,27 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
for (auto *tiledOp : tilingResult->tiledOps)
tiledAndFusedOps.insert(tiledOp);

DenseMap<Value, Value> replacements;
for (auto [origVal, replacement] : llvm::zip_equal(
consumer->getResults(), tilingResult->mergeResult.replacements)) {
replacements[origVal] = replacement;
}

// If there are no loops generated, fusion is immaterial.
auto &loops = tilingResult->loops;
if (loops.empty()) {
DenseMap<Value, Value> replacements;
for (auto [origVal, replacement] : llvm::zip_equal(
consumer->getResults(), tilingResult->mergeResult.replacements)) {
replacements[origVal] = replacement;
}
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
replacements};
}

// To keep track of replacements for now just record the map from the
// original untiled value to the result number of the for loop. Since the
// loop gets potentially replaced during fusion, keeping the value directly
// wont work.
DenseMap<Value, size_t> origValToResultNumber;
for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
origValToResultNumber[result] = index;
}
// Since the loop gets potentially replaced during fusion, we need to track
// the mutation of replacement values. To do this, we attach a listener to
// update the replacements as they happen.
OpBuilder::Listener *previousListener = rewriter.getListener();
auto resetListener =
llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
ReplacementListener replaceListener(replacements, previousListener);
rewriter.setListener(&replaceListener);

// 2. Typically, the operands of the tiled operation are slices of the
// operands of the untiled operation. These are expressed in IR using
Expand Down Expand Up @@ -1581,9 +1624,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
worklistCandidates.append(newSlices.value());
for (auto [index, result] :
llvm::enumerate(fusableProducerOp->getResults())) {
origValToResultNumber[result] = loops.front()->getNumResults() -
fusableProducerOp->getNumResults() +
index;
replacements[result] = loops.front()->getResult(
loops.front()->getNumResults() -
fusableProducerOp->getNumResults() + index);
}
}
if (Operation *tiledAndFusedOp =
Expand All @@ -1597,11 +1640,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
}
}

DenseMap<Value, Value> replacements;
for (auto [origVal, resultNumber] : origValToResultNumber) {
replacements[origVal] = loops.front()->getResult(resultNumber);
}

return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
replacements};
}
Expand Down

0 comments on commit 6e3631d

Please sign in to comment.