From 545b2e81e3696512f0189b9c9a1572058fbbe74d Mon Sep 17 00:00:00 2001 From: Bill Varcho Date: Sat, 9 Nov 2024 23:54:18 -0800 Subject: [PATCH] [SDY] refactor propagation functions in basic_propagation.cc to utilize parameter struct in attempt to cleanup signatures. PiperOrigin-RevId: 695000954 --- .../propagation/basic_propagation.cc | 107 ++++++++++-------- 1 file changed, 58 insertions(+), 49 deletions(-) diff --git a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc index 2f220bd8..eaf6f5d3 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc @@ -120,23 +120,31 @@ void notifyShardingModified(Value value, notifyUsersModified(value, notifyOpModified); } +// Struct to hold common parameters for sharding propagation. +struct PropagationParams { + const ShardingGroupMap& shardingGroupMap; + StringRef meshName = StringRef(); + MeshAttr mesh = nullptr; + std::optional notifyOpModified = std::nullopt; +}; + // Update the sharding of `value` to the sharding in `tensorFactorShardings`. // // Returns true if it's possible to update the sharding, i.e., if strided view // isn't needed and all non-minor-most factors are divisible by sharding axes. -bool updateTensorSharding( - TensorShardingAttr oldTensorSharding, - SetTensorShardingCallback setTensorShardingCallback, - const TensorFactorShardings& tensorFactorShardings, - TensorMappingAttr tensorMapping, ArrayRef factorSizes, - StringRef meshName, MeshAttr mesh, Value modifiedValue, - const ShardingGroupMap& shardingGroupMap, - std::optional notifyOpModified) { +bool updateTensorSharding(TensorShardingAttr oldTensorSharding, + SetTensorShardingCallback setTensorShardingCallback, + const TensorFactorShardings& tensorFactorShardings, + TensorMappingAttr tensorMapping, + ArrayRef factorSizes, + const PropagationParams& params, + Value modifiedValue) { // We can assume `modifiedValue` exists since we are updating its sharding. assert(modifiedValue && "modified value should exist"); TensorShardingAttr newSharding = tensorFactorShardings.createTensorShardingAttr( - mesh.getContext(), tensorMapping, factorSizes, meshName, mesh); + params.mesh.getContext(), tensorMapping, factorSizes, params.meshName, + params.mesh); // `oldTensorSharding` may be null if there is no sharding, in which case we // check if `newSharding` is empty. // TODO(tomnatan): remove this checking if the new sharding equals the old @@ -154,19 +162,20 @@ bool updateTensorSharding( setTensorShardingCallback(newSharding); - if (notifyOpModified) { - notifyShardingModified(modifiedValue, *notifyOpModified); + if (params.notifyOpModified) { + notifyShardingModified(modifiedValue, *params.notifyOpModified); } // Set the sharding of all values in the same sharding group to be equivalent // (skipping the modified value which has already been updated). - for (Value groupValue : shardingGroupMap.getGroupMembers(modifiedValue)) { + for (Value groupValue : + params.shardingGroupMap.getGroupMembers(modifiedValue)) { if (groupValue == modifiedValue) { continue; } setSharding(groupValue, newSharding); - if (notifyOpModified) { - notifyShardingModified(groupValue, *notifyOpModified); + if (params.notifyOpModified) { + notifyShardingModified(groupValue, *params.notifyOpModified); } } @@ -185,16 +194,13 @@ void updateTensorShardings( SetShardingPerTensorCallback setTensorShardingCallback, ArrayRef tensorFactorShardings, ArrayRef tensorMappings, ArrayRef factorSizes, - BitVector& updateTensor, StringRef meshName, MeshAttr mesh, - const ShardingGroupMap& shardingGroupMap, - std::optional notifyOpModified) { + BitVector& updateTensor, const PropagationParams& params) { for (int64_t index : updateTensor.set_bits()) { if (!updateTensorSharding( tensorShardings[index], std::bind(setTensorShardingCallback, std::placeholders::_1, index), tensorFactorShardings[index], tensorMappings[index], factorSizes, - meshName, mesh, getShardableValue(tensors[index]), shardingGroupMap, - notifyOpModified)) { + params, getShardableValue(tensors[index]))) { updateTensor.reset(index); } } @@ -209,21 +215,16 @@ void updateTensorShardings( SetShardingPerTensorCallback setResultShardingCallback, OpShardingRuleAttr shardingRule, const ShardingProjection& shardingProjection, BitVector& updateOperand, - BitVector& updateResult, StringRef meshName, MeshAttr mesh, - const ShardingGroupMap& shardingGroupMap, - std::optional notifyOpModified) { + BitVector& updateResult, const PropagationParams& params) { updateTensorShardings(operands, operandShardings, setOperandShardingCallback, shardingProjection.getOperands(), shardingRule.getOperandMappings(), - shardingRule.getFactorSizes(), updateOperand, meshName, - mesh, shardingGroupMap, notifyOpModified); + shardingRule.getFactorSizes(), updateOperand, params); updateTensorShardings(results, resultsShardings, setResultShardingCallback, shardingProjection.getResults(), shardingRule.getResultMappings(), - shardingRule.getFactorSizes(), updateResult, meshName, - mesh, shardingGroupMap, notifyOpModified); + shardingRule.getFactorSizes(), updateResult, params); } - // Propagates tensor shardings of the given `operands` and `results` according // to `shardingRule`. // @@ -237,9 +238,9 @@ LogicalResult propagateTensorShardings( SetShardingPerTensorCallback setOperandShardingCallback, SetShardingPerTensorCallback setResultShardingCallback, OpShardingRuleAttr shardingRule, PropagationDirection direction, - const FactorPropagation& factorPropagation, - const ShardingGroupMap& shardingGroupMap, bool conservativePropagation, - Operation* op, const SymbolTable& symbolTable, PatternRewriter* rewriter) { + const FactorPropagation& factorPropagation, bool conservativePropagation, + Operation* op, const SymbolTable& symbolTable, PatternRewriter* rewriter, + PropagationParams params) { std::optional meshName = getCommonMeshName(operandShardings, resultsShardings, symbolTable); if (!meshName.has_value()) { @@ -276,11 +277,15 @@ LogicalResult propagateTensorShardings( } }; } + + params.meshName = meshName.value(); + params.mesh = mesh; + params.notifyOpModified = notifyOpModified; + updateTensorShardings(operands, results, operandShardings, resultsShardings, setOperandShardingCallback, setResultShardingCallback, shardingRule, shardingProjection, updateOperand, - updateResult, meshName.value(), mesh, shardingGroupMap, - notifyOpModified); + updateResult, params); bool anyUpdated = updateOperand.any() || updateResult.any(); if (rewriter && !anyUpdated) { @@ -299,8 +304,7 @@ LogicalResult propagateTensorShardings( SetTensorShardingCallback setResultShardingCallback, OpShardingRuleAttr shardingRule, Operation* op, const SymbolTable& symbolTable, PatternRewriter* rewriter, - const FactorPropagation& factorPropagation, - const ShardingGroupMap& shardingGroupMap, + const FactorPropagation& factorPropagation, const PropagationParams& params, PropagationDirection direction = PropagationDirection::BOTH, bool conservativePropagation = false) { return propagateTensorShardings( @@ -311,8 +315,8 @@ LogicalResult propagateTensorShardings( [&](TensorShardingAttr sharding, int64_t) { setResultShardingCallback(sharding); }, - shardingRule, direction, factorPropagation, shardingGroupMap, - conservativePropagation, op, symbolTable, rewriter); + shardingRule, direction, factorPropagation, conservativePropagation, op, + symbolTable, rewriter, params); // Pass params here } // Same as the overload above, except the operand and result shardings are @@ -320,8 +324,7 @@ LogicalResult propagateTensorShardings( LogicalResult propagateTensorShardings( ValueRange operands, ValueRange results, OpShardingRuleAttr shardingRule, Operation* op, const SymbolTable& symbolTable, PatternRewriter& rewriter, - const FactorPropagation& factorPropagation, - const ShardingGroupMap& shardingGroupMap, + const FactorPropagation& factorPropagation, const PropagationParams& params, PropagationDirection direction = PropagationDirection::BOTH, bool conservativePropagation = false) { return propagateTensorShardings( @@ -332,8 +335,8 @@ LogicalResult propagateTensorShardings( [&](TensorShardingAttr sharding, int64_t index) { setSharding(results[index], sharding); }, - shardingRule, direction, factorPropagation, shardingGroupMap, - conservativePropagation, op, symbolTable, &rewriter); + shardingRule, direction, factorPropagation, conservativePropagation, op, + symbolTable, &rewriter, params); } // Propagates the shardings between the operands of the `funcOp`'s terminator @@ -353,6 +356,7 @@ LogicalResult propagateFuncResults(FuncOp funcOp, // NOTE: we void the returned `LogicalResult` since function updates aren't // done through a rewriter, can ignore whether operands/results were // updated. + PropagationParams params{shardingGroupMap}; (void)propagateTensorShardings( // The operand/result function arguments are used to: // - invoke the rewriter (if specified) that a value was updated. But @@ -374,7 +378,7 @@ LogicalResult propagateFuncResults(FuncOp funcOp, // result attrs as an identity op. Create an equivalent sharding // rule. createIdentityShardingRule(tensorType), funcOp, symbolTable, - /*rewriter=*/nullptr, factorPropagation, shardingGroupMap); + /*rewriter=*/nullptr, factorPropagation, params); } return success(); } @@ -429,10 +433,11 @@ class PropagateRegisteredOp : public RewritePattern { }); } + PropagationParams params{shardingGroupMap}; return propagateTensorShardings(op->getOperands(), op->getResults(), shardingRule, op, symbolTable, rewriter, - factorPropagation, shardingGroupMap, - direction, conservativePropagation); + factorPropagation, params, direction, + conservativePropagation); } private: @@ -463,12 +468,13 @@ class PropagateDataFlowEdgeOp : public OpRewritePattern { SmallVector sources = getDataFlowSources(dataFlowEdgeOp); // The sharding of `dataFlowEdgeOp.getResult()` is the sharding of all // targets. + + PropagationParams params{shardingGroupMap}; return propagateTensorShardings( sources, dataFlowEdgeOp.getResult(), createIdentityShardingRule(cast(dataFlowEdgeOp.getType()), sources.size()), - dataFlowEdgeOp, symbolTable, rewriter, factorPropagation, - shardingGroupMap); + dataFlowEdgeOp, symbolTable, rewriter, factorPropagation, params); } private: @@ -517,6 +523,8 @@ class PropagateManualComputationOp PatternRewriter& rewriter) const override { bool updated = false; + PropagationParams params{shardingGroupMap}; + // 1. Propagate between the operands of the `ManualComputationOp` and the // block arguments (specifically the `in_shardings`, but we use the op's // block arguments as an alias for them). @@ -547,7 +555,7 @@ class PropagateManualComputationOp createIdentityShardingRule( cast(operand.getType())), manualComputationOp, symbolTable, &rewriter, factorPropagation, - shardingGroupMap) + params) .succeeded(); } @@ -575,7 +583,7 @@ class PropagateManualComputationOp createIdentityShardingRule( cast(opResult.getType())), manualComputationOp, symbolTable, &rewriter, factorPropagation, - shardingGroupMap) + params) .succeeded(); } @@ -604,12 +612,13 @@ class PropagatePropagationBarrier LogicalResult matchAndRewrite(PropagationBarrierOp propagationBarrierOp, PatternRewriter& rewriter) const override { + PropagationParams params{shardingGroupMap}; return propagateTensorShardings( propagationBarrierOp.getInput(), propagationBarrierOp.getResult(), createIdentityShardingRule( cast(propagationBarrierOp.getType())), - propagationBarrierOp, symbolTable, rewriter, factorPropagation, - shardingGroupMap, propagationBarrierOp.getAllowedDirection()); + propagationBarrierOp, symbolTable, rewriter, factorPropagation, params, + propagationBarrierOp.getAllowedDirection()); } private: