From 1e7378ffef655516c6297f34b0e35c2208af6ecf Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Mon, 25 Nov 2024 13:24:09 -0800 Subject: [PATCH] #sdy Introduce `transformEdgeSharding` to the `ShardableDataFlowOpInterface` and use it during propagation. PiperOrigin-RevId: 700081638 --- shardy/dialect/sdy/ir/data_flow_utils.cc | 12 ++++++++++++ shardy/dialect/sdy/ir/data_flow_utils.h | 7 +++++++ shardy/dialect/sdy/ir/dialect.h | 12 ++++++++++++ shardy/dialect/sdy/ir/op_interface.td | 16 ++++++++++++++++ .../propagation/basic_propagation.cc | 18 +++++++++++++++--- 5 files changed, 62 insertions(+), 3 deletions(-) diff --git a/shardy/dialect/sdy/ir/data_flow_utils.cc b/shardy/dialect/sdy/ir/data_flow_utils.cc index 08f863e8..1245d80b 100644 --- a/shardy/dialect/sdy/ir/data_flow_utils.cc +++ b/shardy/dialect/sdy/ir/data_flow_utils.cc @@ -113,6 +113,18 @@ DataFlowEdgeOp getDataFlowEdge(OpOperand& source) { return DataFlowEdgeOp::getDataFlowEdgeUser(getDataFlowEdgeOwner(source)); } +TensorShardingAttr transformTargetSharding( + DataFlowEdgeOp dataFlowEdge, TensorShardingAttr sharding, + DataFlowShardingTransformType transformType) { + Value input = dataFlowEdge.getInput(); + if (ShardableDataFlowOpInterface shardableDataFlowOp = + getOwningShardableDataFlowOp(input)) { + return shardableDataFlowOp.transformTargetSharding(input, sharding, + transformType); + } + return sharding; +} + SmallVector getDataFlowSources(DataFlowEdgeOp dataFlowEdge) { Value input = dataFlowEdge.getInput(); if (ShardableDataFlowOpInterface shardableDataFlowOp = diff --git a/shardy/dialect/sdy/ir/data_flow_utils.h b/shardy/dialect/sdy/ir/data_flow_utils.h index 1c65529b..87a30aeb 100644 --- a/shardy/dialect/sdy/ir/data_flow_utils.h +++ b/shardy/dialect/sdy/ir/data_flow_utils.h @@ -55,6 +55,13 @@ DataFlowEdgeOp getDataFlowEdge(Value target); // `DataFlowEdgeOp`, otherwise returns `nullptr`. DataFlowEdgeOp getDataFlowEdge(OpOperand& source); +// Transforms the `sharding` depending on `transformType`. +// +// See `DataFlowShardingTransformType` for more information. +TensorShardingAttr transformTargetSharding( + DataFlowEdgeOp dataFlowEdge, TensorShardingAttr sharding, + DataFlowShardingTransformType transformType); + // Returns all sources of the given `dataFlowEdge`. SmallVector getDataFlowSources(DataFlowEdgeOp dataFlowEdge); diff --git a/shardy/dialect/sdy/ir/dialect.h b/shardy/dialect/sdy/ir/dialect.h index 31c41249..69383874 100644 --- a/shardy/dialect/sdy/ir/dialect.h +++ b/shardy/dialect/sdy/ir/dialect.h @@ -57,6 +57,18 @@ limitations under the License. // which cannot be inlined due to cyclic dependencies on helper functions. namespace mlir { namespace sdy { + +// Specifies whether the dataflow edge owner sharding is being transformed +// before or after edge propagation. +enum class DataFlowShardingTransformType { + // Before edge propagation is when the value of the shardings are inspected + // for propagation. + kBeforeEdgePropagation, + // After edge propagation is when the shardings are set back on the data flow + // edge owner. + kAfterEdgePropagation +}; + namespace details { // Default implementation of the `getOpResultEdgeOwnerShardings` method of diff --git a/shardy/dialect/sdy/ir/op_interface.td b/shardy/dialect/sdy/ir/op_interface.td index feee81cc..5bd98e8d 100644 --- a/shardy/dialect/sdy/ir/op_interface.td +++ b/shardy/dialect/sdy/ir/op_interface.td @@ -109,6 +109,22 @@ def ShardableDataFlowOpInterface : OpInterface<"ShardableDataFlowOpInterface"> { this->getOperation(), shardings); }] >, + InterfaceMethod< + /*desc=*/[{ + Transforms the `sharding` of the target depending on `transformType` + + See `DataFlowShardingTransformType` for more information. + }], + /*retType=*/"mlir::sdy::TensorShardingAttr", + /*methodName=*/"transformTargetSharding", + /*args=*/(ins "mlir::Value":$target, + "mlir::sdy::TensorShardingAttr":$sharding, + "mlir::sdy::DataFlowShardingTransformType":$transformType), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return sharding; + }] + >, InterfaceMethod< /*desc=*/[{ Gets all block argument edge owners. diff --git a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc index 563d0cb2..c7cf2dbb 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc @@ -470,11 +470,23 @@ class PropagateDataFlowEdgeOp : public OpRewritePattern { // The sharding of `dataFlowEdgeOp.getResult()` is the sharding of all // targets. return propagateTensorShardings( - sources, dataFlowEdgeOp.getResult(), + sources, dataFlowEdgeOp.getResult(), getShardings(sources), + transformTargetSharding( + dataFlowEdgeOp, dataFlowEdgeOp.getShardingAttr(), + DataFlowShardingTransformType::kBeforeEdgePropagation), + [&](TensorShardingAttr sharding, int64_t index) { + setSharding(sources[index], sharding); + }, + [&](TensorShardingAttr sharding, int64_t _) { + dataFlowEdgeOp.setShardingAttr(transformTargetSharding( + dataFlowEdgeOp, sharding, + DataFlowShardingTransformType::kAfterEdgePropagation)); + }, createIdentityShardingRule(cast(dataFlowEdgeOp.getType()), sources.size()), - dataFlowEdgeOp, symbolTable, rewriter, factorPropagation, - shardingGroupMap); + PropagationDirection::BOTH, factorPropagation, shardingGroupMap, + /*conservativePropagation=*/false, dataFlowEdgeOp, symbolTable, + &rewriter); } private: