Skip to content

Commit

Permalink
Count the sharding as compatible if the factors that need replication…
Browse files Browse the repository at this point in the history
… are unsharded across all operands and results.

PiperOrigin-RevId: 717871606
  • Loading branch information
Google-ML-Automation authored and copybara-github committed Jan 21, 2025
1 parent bca5b40 commit 1766517
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 11 deletions.
36 changes: 25 additions & 11 deletions shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,18 @@ bool hasOverflowAxes(const ShardingProjection& projection) {

// Checks if factor sharding is compatible, that is, it satisfies:
// 1. Factors are sharded the same way across operands and results.
// 2. Factors that need replication are unsharded.
//
// Assumes factor shardings do not have overflow axes.
// TODO(enver): Handle the case when some factor shardings have overflow axes.
bool hasCompatibleFactorShardings(const ShardingProjection& projection) {
bool hasCompatibleFactorShardings(const ShardingProjection& projection,
OpShardingRuleAttr shardingRule) {
FactorIndexToSharding factorIndexToCommonSharding;
// Factors that need replication should be unsharded across all operands and
// results in order for it to have a compatible sharding.
for (int64_t factorIndex : shardingRule.getNeedReplicationFactors()) {
factorIndexToCommonSharding[factorIndex] = FactorSharding{};
}
for (const TensorFactorShardings& tensorFactorSharding :
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
projection.getResults())) {
Expand Down Expand Up @@ -535,15 +542,12 @@ struct InsertExplicitReshardsPass
// Return without inserting reshards for operations with special
// dimensions.
// TODO(enver): Insert explicit reshards if special dimensions are
// unsharded, or all speical dimensions need replication and annotated as
// such on the sharding rule.
if (isa<stablehlo::CholeskyOp, stablehlo::ReverseOp,
stablehlo::BitcastConvertOp, stablehlo::BroadcastInDimOp,
stablehlo::ConcatenateOp, stablehlo::DynamicSliceOp,
stablehlo::DynamicUpdateSliceOp, stablehlo::PadOp,
stablehlo::SliceOp, stablehlo::SortOp, stablehlo::TransposeOp,
stablehlo::TriangularSolveOp, stablehlo::FftOp,
stablehlo::ReduceWindowOp, stablehlo::ScatterOp,
// unsharded.
// TODO(enver): Add need replication factors to fft.
if (isa<stablehlo::ReverseOp, stablehlo::BroadcastInDimOp,
stablehlo::DynamicSliceOp, stablehlo::DynamicUpdateSliceOp,
stablehlo::PadOp, stablehlo::SliceOp, stablehlo::TransposeOp,
stablehlo::FftOp, stablehlo::ReduceWindowOp, stablehlo::ScatterOp,
stablehlo::SelectAndScatterOp, stablehlo::GatherOp,
stablehlo::ReshapeOp, stablehlo::ConvolutionOp,
stablehlo::CustomCallOp, stablehlo::ReduceOp,
Expand All @@ -554,7 +558,17 @@ struct InsertExplicitReshardsPass
}

// Checks if factors are sharded the same way across operands and results.
if (hasCompatibleFactorShardings(shardingProjection)) {
if (hasCompatibleFactorShardings(shardingProjection, shardingRule)) {
return;
}

// Return without inserting reshards for operations with factors that need
// replication.
// TODO(enver): Insert explicit reshards also for the case that the
// factors that need replication are sharded.
if (isa<stablehlo::CholeskyOp, stablehlo::BitcastConvertOp,
stablehlo::ConcatenateOp, stablehlo::SortOp,
stablehlo::TriangularSolveOp>(op)) {
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,53 @@ func.func @sort_all_other_dims_size_one(%arg0: tensor<1x4x1xi32> {sdy.sharding =
return %0 : tensor<1x4x1xi32>
}

// CHECK-LABEL: func @sort_single_input_output
func.func @sort_single_input_output(%arg0: tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}, {}]>}) -> (tensor<4x32x8xi32>) {
// CHECK-NOT: sdy.reshard
%0 = "stablehlo.sort"(%arg0) ({
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
%1 = stablehlo.compare GT, %arg2, %arg3 : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %1 : tensor<i1>
}) {dimension = 0 : i64, is_stable = true} : (tensor<4x32x8xi32>) -> (tensor<4x32x8xi32>)
return %0 : tensor<4x32x8xi32>
}

// CHECK-LABEL: func @sort_compatible
func.func @sort_compatible(%arg0: tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}, {}]>}) -> (tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}, {}]>}) {
// CHECK-NOT: sdy.reshard
%0 = "stablehlo.sort"(%arg0) <{dimension = 0 : i64, is_stable = true}> ({
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
%1 = stablehlo.compare GT, %arg2, %arg3 : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %1 : tensor<i1>
}) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}, {}]>]>} : (tensor<4x32x8xi32>) -> (tensor<4x32x8xi32>)
return %0 : tensor<4x32x8xi32>
}


// CHECK-LABEL: func @sort_input_and_output_shardings_are_same_on_sorting_dimension
func.func @sort_input_and_output_shardings_are_same_on_sorting_dimension(%arg0: tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}, {}]>}) -> (tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}, {}]>}) {
// CHECK-NOT: sdy.reshard
// TODO(enver): Support cases that factors need replication and sharded in the same way, which still requires resharding since the sorting dimension is fully replicated.
%0 = "stablehlo.sort"(%arg0) <{dimension = 0 : i64, is_stable = true}> ({
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
%1 = stablehlo.compare GT, %arg2, %arg3 : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %1 : tensor<i1>
}) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}, {}]>]>} : (tensor<4x32x8xi32>) -> (tensor<4x32x8xi32>)
return %0 : tensor<4x32x8xi32>
}


// CHECK-LABEL: func @sort_input_and_output_shardings_are_different_on_sorting_dimension
func.func @sort_input_and_output_shardings_are_different_on_sorting_dimension(%arg0: tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh_xyz, [{"x"}, {"z"}, {}]>}) -> (tensor<4x32x8xi32> {sdy.sharding = #sdy.sharding<@mesh_xyz, [{"y"}, {"z"}, {}]>}) {
// CHECK-NOT: sdy.reshard
%0 = "stablehlo.sort"(%arg0) <{dimension = 0 : i64, is_stable = true}> ({
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
%1 = stablehlo.compare GT, %arg2, %arg3 : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %1 : tensor<i1>
}) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xyz, [{"y"}, {"z"}, {}]>]>} : (tensor<4x32x8xi32>) -> (tensor<4x32x8xi32>)
return %0 : tensor<4x32x8xi32>
}

// CHECK-LABEL: func @transpose
func.func @transpose(%arg0: tensor<256x32x64x100xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}, {}, {}]>}) -> tensor<100x32x256x64xf32> {
// CHECK-NOT: sdy.reshard
Expand Down

0 comments on commit 1766517

Please sign in to comment.