Skip to content

Commit

Permalink
[SDY] refactor propagation functions in basic_propagation.cc to utili…
Browse files Browse the repository at this point in the history
…ze parameter struct in attempt to cleanup signatures.

PiperOrigin-RevId: 695000954
  • Loading branch information
Varcho authored and copybara-github committed Nov 10, 2024
1 parent 36a65f3 commit 545b2e8
Showing 1 changed file with 58 additions and 49 deletions.
107 changes: 58 additions & 49 deletions shardy/dialect/sdy/transforms/propagation/basic_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,23 +120,31 @@ void notifyShardingModified(Value value,
notifyUsersModified(value, notifyOpModified);
}

// Struct to hold common parameters for sharding propagation.
struct PropagationParams {
const ShardingGroupMap& shardingGroupMap;
StringRef meshName = StringRef();
MeshAttr mesh = nullptr;
std::optional<NotifyOpModifiedCallback> notifyOpModified = std::nullopt;
};

// Update the sharding of `value` to the sharding in `tensorFactorShardings`.
//
// Returns true if it's possible to update the sharding, i.e., if strided view
// isn't needed and all non-minor-most factors are divisible by sharding axes.
bool updateTensorSharding(
TensorShardingAttr oldTensorSharding,
SetTensorShardingCallback setTensorShardingCallback,
const TensorFactorShardings& tensorFactorShardings,
TensorMappingAttr tensorMapping, ArrayRef<int64_t> factorSizes,
StringRef meshName, MeshAttr mesh, Value modifiedValue,
const ShardingGroupMap& shardingGroupMap,
std::optional<NotifyOpModifiedCallback> notifyOpModified) {
bool updateTensorSharding(TensorShardingAttr oldTensorSharding,
SetTensorShardingCallback setTensorShardingCallback,
const TensorFactorShardings& tensorFactorShardings,
TensorMappingAttr tensorMapping,
ArrayRef<int64_t> factorSizes,
const PropagationParams& params,
Value modifiedValue) {
// We can assume `modifiedValue` exists since we are updating its sharding.
assert(modifiedValue && "modified value should exist");
TensorShardingAttr newSharding =
tensorFactorShardings.createTensorShardingAttr(
mesh.getContext(), tensorMapping, factorSizes, meshName, mesh);
params.mesh.getContext(), tensorMapping, factorSizes, params.meshName,
params.mesh);
// `oldTensorSharding` may be null if there is no sharding, in which case we
// check if `newSharding` is empty.
// TODO(tomnatan): remove this checking if the new sharding equals the old
Expand All @@ -154,19 +162,20 @@ bool updateTensorSharding(

setTensorShardingCallback(newSharding);

if (notifyOpModified) {
notifyShardingModified(modifiedValue, *notifyOpModified);
if (params.notifyOpModified) {
notifyShardingModified(modifiedValue, *params.notifyOpModified);
}

// Set the sharding of all values in the same sharding group to be equivalent
// (skipping the modified value which has already been updated).
for (Value groupValue : shardingGroupMap.getGroupMembers(modifiedValue)) {
for (Value groupValue :
params.shardingGroupMap.getGroupMembers(modifiedValue)) {
if (groupValue == modifiedValue) {
continue;
}
setSharding(groupValue, newSharding);
if (notifyOpModified) {
notifyShardingModified(groupValue, *notifyOpModified);
if (params.notifyOpModified) {
notifyShardingModified(groupValue, *params.notifyOpModified);
}
}

Expand All @@ -185,16 +194,13 @@ void updateTensorShardings(
SetShardingPerTensorCallback setTensorShardingCallback,
ArrayRef<TensorFactorShardings> tensorFactorShardings,
ArrayRef<TensorMappingAttr> tensorMappings, ArrayRef<int64_t> factorSizes,
BitVector& updateTensor, StringRef meshName, MeshAttr mesh,
const ShardingGroupMap& shardingGroupMap,
std::optional<NotifyOpModifiedCallback> notifyOpModified) {
BitVector& updateTensor, const PropagationParams& params) {
for (int64_t index : updateTensor.set_bits()) {
if (!updateTensorSharding(
tensorShardings[index],
std::bind(setTensorShardingCallback, std::placeholders::_1, index),
tensorFactorShardings[index], tensorMappings[index], factorSizes,
meshName, mesh, getShardableValue(tensors[index]), shardingGroupMap,
notifyOpModified)) {
params, getShardableValue(tensors[index]))) {
updateTensor.reset(index);
}
}
Expand All @@ -209,21 +215,16 @@ void updateTensorShardings(
SetShardingPerTensorCallback setResultShardingCallback,
OpShardingRuleAttr shardingRule,
const ShardingProjection& shardingProjection, BitVector& updateOperand,
BitVector& updateResult, StringRef meshName, MeshAttr mesh,
const ShardingGroupMap& shardingGroupMap,
std::optional<NotifyOpModifiedCallback> notifyOpModified) {
BitVector& updateResult, const PropagationParams& params) {
updateTensorShardings(operands, operandShardings, setOperandShardingCallback,
shardingProjection.getOperands(),
shardingRule.getOperandMappings(),
shardingRule.getFactorSizes(), updateOperand, meshName,
mesh, shardingGroupMap, notifyOpModified);
shardingRule.getFactorSizes(), updateOperand, params);
updateTensorShardings(results, resultsShardings, setResultShardingCallback,
shardingProjection.getResults(),
shardingRule.getResultMappings(),
shardingRule.getFactorSizes(), updateResult, meshName,
mesh, shardingGroupMap, notifyOpModified);
shardingRule.getFactorSizes(), updateResult, params);
}

// Propagates tensor shardings of the given `operands` and `results` according
// to `shardingRule`.
//
Expand All @@ -237,9 +238,9 @@ LogicalResult propagateTensorShardings(
SetShardingPerTensorCallback setOperandShardingCallback,
SetShardingPerTensorCallback setResultShardingCallback,
OpShardingRuleAttr shardingRule, PropagationDirection direction,
const FactorPropagation& factorPropagation,
const ShardingGroupMap& shardingGroupMap, bool conservativePropagation,
Operation* op, const SymbolTable& symbolTable, PatternRewriter* rewriter) {
const FactorPropagation& factorPropagation, bool conservativePropagation,
Operation* op, const SymbolTable& symbolTable, PatternRewriter* rewriter,
PropagationParams params) {
std::optional<StringRef> meshName =
getCommonMeshName(operandShardings, resultsShardings, symbolTable);
if (!meshName.has_value()) {
Expand Down Expand Up @@ -276,11 +277,15 @@ LogicalResult propagateTensorShardings(
}
};
}

params.meshName = meshName.value();
params.mesh = mesh;
params.notifyOpModified = notifyOpModified;

updateTensorShardings(operands, results, operandShardings, resultsShardings,
setOperandShardingCallback, setResultShardingCallback,
shardingRule, shardingProjection, updateOperand,
updateResult, meshName.value(), mesh, shardingGroupMap,
notifyOpModified);
updateResult, params);

bool anyUpdated = updateOperand.any() || updateResult.any();
if (rewriter && !anyUpdated) {
Expand All @@ -299,8 +304,7 @@ LogicalResult propagateTensorShardings(
SetTensorShardingCallback setResultShardingCallback,
OpShardingRuleAttr shardingRule, Operation* op,
const SymbolTable& symbolTable, PatternRewriter* rewriter,
const FactorPropagation& factorPropagation,
const ShardingGroupMap& shardingGroupMap,
const FactorPropagation& factorPropagation, const PropagationParams& params,
PropagationDirection direction = PropagationDirection::BOTH,
bool conservativePropagation = false) {
return propagateTensorShardings(
Expand All @@ -311,17 +315,16 @@ LogicalResult propagateTensorShardings(
[&](TensorShardingAttr sharding, int64_t) {
setResultShardingCallback(sharding);
},
shardingRule, direction, factorPropagation, shardingGroupMap,
conservativePropagation, op, symbolTable, rewriter);
shardingRule, direction, factorPropagation, conservativePropagation, op,
symbolTable, rewriter, params); // Pass params here
}

// Same as the overload above, except the operand and result shardings are
// extracted using `getSharding` and set using `setSharding`.
LogicalResult propagateTensorShardings(
ValueRange operands, ValueRange results, OpShardingRuleAttr shardingRule,
Operation* op, const SymbolTable& symbolTable, PatternRewriter& rewriter,
const FactorPropagation& factorPropagation,
const ShardingGroupMap& shardingGroupMap,
const FactorPropagation& factorPropagation, const PropagationParams& params,
PropagationDirection direction = PropagationDirection::BOTH,
bool conservativePropagation = false) {
return propagateTensorShardings(
Expand All @@ -332,8 +335,8 @@ LogicalResult propagateTensorShardings(
[&](TensorShardingAttr sharding, int64_t index) {
setSharding(results[index], sharding);
},
shardingRule, direction, factorPropagation, shardingGroupMap,
conservativePropagation, op, symbolTable, &rewriter);
shardingRule, direction, factorPropagation, conservativePropagation, op,
symbolTable, &rewriter, params);
}

// Propagates the shardings between the operands of the `funcOp`'s terminator
Expand All @@ -353,6 +356,7 @@ LogicalResult propagateFuncResults(FuncOp funcOp,
// NOTE: we void the returned `LogicalResult` since function updates aren't
// done through a rewriter, can ignore whether operands/results were
// updated.
PropagationParams params{shardingGroupMap};
(void)propagateTensorShardings(
// The operand/result function arguments are used to:
// - invoke the rewriter (if specified) that a value was updated. But
Expand All @@ -374,7 +378,7 @@ LogicalResult propagateFuncResults(FuncOp funcOp,
// result attrs as an identity op. Create an equivalent sharding
// rule.
createIdentityShardingRule(tensorType), funcOp, symbolTable,
/*rewriter=*/nullptr, factorPropagation, shardingGroupMap);
/*rewriter=*/nullptr, factorPropagation, params);
}
return success();
}
Expand Down Expand Up @@ -429,10 +433,11 @@ class PropagateRegisteredOp : public RewritePattern {
});
}

PropagationParams params{shardingGroupMap};
return propagateTensorShardings(op->getOperands(), op->getResults(),
shardingRule, op, symbolTable, rewriter,
factorPropagation, shardingGroupMap,
direction, conservativePropagation);
factorPropagation, params, direction,
conservativePropagation);
}

private:
Expand Down Expand Up @@ -463,12 +468,13 @@ class PropagateDataFlowEdgeOp : public OpRewritePattern<DataFlowEdgeOp> {
SmallVector<Value> sources = getDataFlowSources(dataFlowEdgeOp);
// The sharding of `dataFlowEdgeOp.getResult()` is the sharding of all
// targets.

PropagationParams params{shardingGroupMap};
return propagateTensorShardings(
sources, dataFlowEdgeOp.getResult(),
createIdentityShardingRule(cast<ShapedType>(dataFlowEdgeOp.getType()),
sources.size()),
dataFlowEdgeOp, symbolTable, rewriter, factorPropagation,
shardingGroupMap);
dataFlowEdgeOp, symbolTable, rewriter, factorPropagation, params);
}

private:
Expand Down Expand Up @@ -517,6 +523,8 @@ class PropagateManualComputationOp
PatternRewriter& rewriter) const override {
bool updated = false;

PropagationParams params{shardingGroupMap};

// 1. Propagate between the operands of the `ManualComputationOp` and the
// block arguments (specifically the `in_shardings`, but we use the op's
// block arguments as an alias for them).
Expand Down Expand Up @@ -547,7 +555,7 @@ class PropagateManualComputationOp
createIdentityShardingRule(
cast<RankedTensorType>(operand.getType())),
manualComputationOp, symbolTable, &rewriter, factorPropagation,
shardingGroupMap)
params)
.succeeded();
}

Expand Down Expand Up @@ -575,7 +583,7 @@ class PropagateManualComputationOp
createIdentityShardingRule(
cast<RankedTensorType>(opResult.getType())),
manualComputationOp, symbolTable, &rewriter, factorPropagation,
shardingGroupMap)
params)
.succeeded();
}

Expand Down Expand Up @@ -604,12 +612,13 @@ class PropagatePropagationBarrier

LogicalResult matchAndRewrite(PropagationBarrierOp propagationBarrierOp,
PatternRewriter& rewriter) const override {
PropagationParams params{shardingGroupMap};
return propagateTensorShardings(
propagationBarrierOp.getInput(), propagationBarrierOp.getResult(),
createIdentityShardingRule(
cast<RankedTensorType>(propagationBarrierOp.getType())),
propagationBarrierOp, symbolTable, rewriter, factorPropagation,
shardingGroupMap, propagationBarrierOp.getAllowedDirection());
propagationBarrierOp, symbolTable, rewriter, factorPropagation, params,
propagationBarrierOp.getAllowedDirection());
}

private:
Expand Down

0 comments on commit 545b2e8

Please sign in to comment.