diff --git a/shardy/dialect/sdy/transforms/import/import_pipeline.cc b/shardy/dialect/sdy/transforms/import/import_pipeline.cc index 9c8420be..4bf1f923 100644 --- a/shardy/dialect/sdy/transforms/import/import_pipeline.cc +++ b/shardy/dialect/sdy/transforms/import/import_pipeline.cc @@ -25,6 +25,21 @@ limitations under the License. namespace mlir { namespace sdy { +namespace { + +GreedyRewriteConfig getCanonicalizerConfig(bool enableRegionSimplification) { + GreedyRewriteConfig config; + config.useTopDownTraversal = true; + config.enableRegionSimplification = enableRegionSimplification + ? GreedySimplifyRegionLevel::Normal + : GreedySimplifyRegionLevel::Disabled; + config.fold = false; + config.cseConstants = false; + return config; +} + +} // namespace + void addImportPipeline(OpPassManager& pm, StringRef dumpDirectory, bool skipInline) { pm.addPass(mlir::sdy::createSaveModuleOpPass(dumpDirectory, @@ -33,7 +48,10 @@ void addImportPipeline(OpPassManager& pm, StringRef dumpDirectory, // otherwise we would need to propagate shardings between call ops and callee // functions. if (!skipInline) { - pm.addPass(createInlinerPass()); + pm.addPass(createInlinerPass({}, [&](OpPassManager& pm) { + pm.addPass(createCanonicalizerPass( + getCanonicalizerConfig(/*enableRegionSimplification=*/true))); + })); } pm.addPass(createSymbolDCEPass()); pm.addPass(createLiftInlinedMeshesPass()); @@ -45,12 +63,9 @@ void addImportPipeline(OpPassManager& pm, StringRef dumpDirectory, // constraints. This ensures we can detect sharding conflicts between group // members which have pre-propagation shardings due to sharding constraints. pm.addPass(createShardingGroupImportPass()); - - GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.enableRegionSimplification = GreedySimplifyRegionLevel::Disabled; pm.addPass(createCanonicalizerPass( - /*config=*/config, /*disabledPatterns=*/{}, + getCanonicalizerConfig(/*enableRegionSimplification=*/false), + /*disabledPatterns=*/{}, /*enabledPatterns=*/{"DedupShardingGroupPattern"})); pm.addPass(mlir::sdy::createSaveModuleOpPass(dumpDirectory, "sdy_module_after_sdy_import")); diff --git a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc index 37423495..146f7a2b 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc @@ -592,8 +592,14 @@ LogicalResult BasicPropagationPassImpl::propagate( GreedyRewriteConfig config; config.useTopDownTraversal = true; config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; - if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns), - config))) { + config.fold = false; + config.cseConstants = false; + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns), config))) { + // We should always converge in 2 iterations, if we don't, something is + // wrong. + moduleOp->emitError("Failed to converge after ") + << config.maxIterations + << " iterations. please contact the Shardy team."; return failure(); }