-
Notifications
You must be signed in to change notification settings - Fork 562
Open
Description
Torch-xla collects outputs via this mechanism
For some reason, for an output sharded on a 2x4 mesh {batch, model} as {{}, {model}}, leading to a gspmd shard spec of {devices=[1,4,2]<=[8] last_tile_dim_replicate}, the above collection mechanism is compiled on a 4x2 mesh. I believe the mesh shape calculated for this dummy replicate execute mechanism is somehow tied to the actual sharding of the tensor being collected.
The result is this failure in MLIR:
// -----// IR Dump Before AnalyzeMeshPass (analyze-mesh) ('builtin.module' operation: @ReplicateShardedData.6) //----- //
module @ReplicateShardedData.6 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
sdy.mesh @mesh = <["_axis_0"=4, "_axis_1"=2]>
func.func @main(%arg0: tensor<32x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"_axis_0"}]>, ttcore.argument_type = #ttcore.argument_type<input>, ttcore.shard_status = #ttcore.shard_status<presharded>}) -> (tensor<32x32xf32> {ttcore.shard_status = #ttcore.shard_status<presharded>}) {
%0 = sdy.sharding_constraint %arg0 <@mesh, [{}, {}]> : tensor<32x32xf32>
return %0 : tensor<32x32xf32>
}
}
error: Mesh is not valid
// -----// IR Dump After AnalyzeMeshPass Failed (analyze-mesh) ('builtin.module' operation: @ReplicateShardedData.6) //----- //
module @ReplicateShardedData.6 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
sdy.mesh @mesh = <["_axis_0"=4, "_axis_1"=2]>
func.func @main(%arg0: tensor<32x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"_axis_0"}]>, ttcore.argument_type = #ttcore.argument_type<input>, ttcore.shard_status = #ttcore.shard_status<presharded>}) -> (tensor<32x32xf32> {ttcore.shard_status = #ttcore.shard_status<presharded>}) {
%0 = sdy.sharding_constraint %arg0 <@mesh, [{}, {}]> : tensor<32x32xf32>
return %0 : tensor<32x32xf32>
}
}Due to how the 4x2 mesh shape is not supported: Error: Mesh is not valid -> Check valid mesh only accepts some specific mesh configurations.
This could be fixed in 2 ways:
- legalize the 4x2 mesh shape
- make torch_xla respect the original mesh shape in its dummy replicate execution mechanism, and investigate how the mesh shape is inferred from the sharding of the tensor being replicated
Metadata
Metadata
Assignees
Labels
No labels