Skip to content

Commit

Permalink
#sdy Introduce transformEdgeSharding to the `ShardableDataFlowOpInt…
Browse files Browse the repository at this point in the history
…erface` and use it during propagation.

PiperOrigin-RevId: 700081638
  • Loading branch information
bartchr808 authored and copybara-github committed Nov 25, 2024
1 parent 744ebf3 commit 1e7378f
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 3 deletions.
12 changes: 12 additions & 0 deletions shardy/dialect/sdy/ir/data_flow_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ DataFlowEdgeOp getDataFlowEdge(OpOperand& source) {
return DataFlowEdgeOp::getDataFlowEdgeUser(getDataFlowEdgeOwner(source));
}

TensorShardingAttr transformTargetSharding(
DataFlowEdgeOp dataFlowEdge, TensorShardingAttr sharding,
DataFlowShardingTransformType transformType) {
Value input = dataFlowEdge.getInput();
if (ShardableDataFlowOpInterface shardableDataFlowOp =
getOwningShardableDataFlowOp(input)) {
return shardableDataFlowOp.transformTargetSharding(input, sharding,
transformType);
}
return sharding;
}

SmallVector<Value> getDataFlowSources(DataFlowEdgeOp dataFlowEdge) {
Value input = dataFlowEdge.getInput();
if (ShardableDataFlowOpInterface shardableDataFlowOp =
Expand Down
7 changes: 7 additions & 0 deletions shardy/dialect/sdy/ir/data_flow_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ DataFlowEdgeOp getDataFlowEdge(Value target);
// `DataFlowEdgeOp`, otherwise returns `nullptr`.
DataFlowEdgeOp getDataFlowEdge(OpOperand& source);

// Transforms the `sharding` depending on `transformType`.
//
// See `DataFlowShardingTransformType` for more information.
TensorShardingAttr transformTargetSharding(
DataFlowEdgeOp dataFlowEdge, TensorShardingAttr sharding,
DataFlowShardingTransformType transformType);

// Returns all sources of the given `dataFlowEdge`.
SmallVector<Value> getDataFlowSources(DataFlowEdgeOp dataFlowEdge);

Expand Down
12 changes: 12 additions & 0 deletions shardy/dialect/sdy/ir/dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ limitations under the License.
// which cannot be inlined due to cyclic dependencies on helper functions.
namespace mlir {
namespace sdy {

// Specifies whether the dataflow edge owner sharding is being transformed
// before or after edge propagation.
enum class DataFlowShardingTransformType {
// Before edge propagation is when the value of the shardings are inspected
// for propagation.
kBeforeEdgePropagation,
// After edge propagation is when the shardings are set back on the data flow
// edge owner.
kAfterEdgePropagation
};

namespace details {

// Default implementation of the `getOpResultEdgeOwnerShardings` method of
Expand Down
16 changes: 16 additions & 0 deletions shardy/dialect/sdy/ir/op_interface.td
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,22 @@ def ShardableDataFlowOpInterface : OpInterface<"ShardableDataFlowOpInterface"> {
this->getOperation(), shardings);
}]
>,
InterfaceMethod<
/*desc=*/[{
Transforms the `sharding` of the target depending on `transformType`

See `DataFlowShardingTransformType` for more information.
}],
/*retType=*/"mlir::sdy::TensorShardingAttr",
/*methodName=*/"transformTargetSharding",
/*args=*/(ins "mlir::Value":$target,
"mlir::sdy::TensorShardingAttr":$sharding,
"mlir::sdy::DataFlowShardingTransformType":$transformType),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return sharding;
}]
>,
InterfaceMethod<
/*desc=*/[{
Gets all block argument edge owners.
Expand Down
18 changes: 15 additions & 3 deletions shardy/dialect/sdy/transforms/propagation/basic_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,11 +470,23 @@ class PropagateDataFlowEdgeOp : public OpRewritePattern<DataFlowEdgeOp> {
// The sharding of `dataFlowEdgeOp.getResult()` is the sharding of all
// targets.
return propagateTensorShardings(
sources, dataFlowEdgeOp.getResult(),
sources, dataFlowEdgeOp.getResult(), getShardings(sources),
transformTargetSharding(
dataFlowEdgeOp, dataFlowEdgeOp.getShardingAttr(),
DataFlowShardingTransformType::kBeforeEdgePropagation),
[&](TensorShardingAttr sharding, int64_t index) {
setSharding(sources[index], sharding);
},
[&](TensorShardingAttr sharding, int64_t _) {
dataFlowEdgeOp.setShardingAttr(transformTargetSharding(
dataFlowEdgeOp, sharding,
DataFlowShardingTransformType::kAfterEdgePropagation));
},
createIdentityShardingRule(cast<ShapedType>(dataFlowEdgeOp.getType()),
sources.size()),
dataFlowEdgeOp, symbolTable, rewriter, factorPropagation,
shardingGroupMap);
PropagationDirection::BOTH, factorPropagation, shardingGroupMap,
/*conservativePropagation=*/false, dataFlowEdgeOp, symbolTable,
&rewriter);
}

private:
Expand Down

0 comments on commit 1e7378f

Please sign in to comment.