Skip to content

XLA Sharded Tensor retrieval of true-sharded outputs may lead to distorted mesh shape based on the output sharding, rather than the global mesh #9733

@jameszianxuTT

Description

@jameszianxuTT

Torch-xla collects outputs via this mechanism

For some reason, for an output sharded on a 2x4 mesh {batch, model} as {{}, {model}}, leading to a gspmd shard spec of {devices=[1,4,2]<=[8] last_tile_dim_replicate}, the above collection mechanism is compiled on a 4x2 mesh. I believe the mesh shape calculated for this dummy replicate execute mechanism is somehow tied to the actual sharding of the tensor being collected.

The result is this failure in MLIR:

// -----// IR Dump Before AnalyzeMeshPass (analyze-mesh) ('builtin.module' operation: @ReplicateShardedData.6) //----- //
module @ReplicateShardedData.6 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  sdy.mesh @mesh = <["_axis_0"=4, "_axis_1"=2]>
  func.func @main(%arg0: tensor<32x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"_axis_0"}]>, ttcore.argument_type = #ttcore.argument_type<input>, ttcore.shard_status = #ttcore.shard_status<presharded>}) -> (tensor<32x32xf32> {ttcore.shard_status = #ttcore.shard_status<presharded>}) {
    %0 = sdy.sharding_constraint %arg0 <@mesh, [{}, {}]> : tensor<32x32xf32>
    return %0 : tensor<32x32xf32>
  }
}


error: Mesh is not valid
// -----// IR Dump After AnalyzeMeshPass Failed (analyze-mesh) ('builtin.module' operation: @ReplicateShardedData.6) //----- //
module @ReplicateShardedData.6 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  sdy.mesh @mesh = <["_axis_0"=4, "_axis_1"=2]>
  func.func @main(%arg0: tensor<32x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"_axis_0"}]>, ttcore.argument_type = #ttcore.argument_type<input>, ttcore.shard_status = #ttcore.shard_status<presharded>}) -> (tensor<32x32xf32> {ttcore.shard_status = #ttcore.shard_status<presharded>}) {
    %0 = sdy.sharding_constraint %arg0 <@mesh, [{}, {}]> : tensor<32x32xf32>
    return %0 : tensor<32x32xf32>
  }
}

Due to how the 4x2 mesh shape is not supported: Error: Mesh is not valid -> Check valid mesh only accepts some specific mesh configurations.

This could be fixed in 2 ways:

  • legalize the 4x2 mesh shape
  • make torch_xla respect the original mesh shape in its dummy replicate execution mechanism, and investigate how the mesh shape is inferred from the sharding of the tensor being replicated

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions