Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Copies supersede OptimizationBarrier #20440

Open
stephen-huan opened this issue Dec 11, 2024 · 1 comment
Open

Copies supersede OptimizationBarrier #20440

stephen-huan opened this issue Dec 11, 2024 · 1 comment
Assignees

Comments

@stephen-huan
Copy link

stephen-huan commented Dec 11, 2024

Consider the JAX function

@partial(jit, donate_argnums=0)
def f(x: Array) -> tuple[Array, Array]:
    y = x[0, 0]
    x = x.at[0, 0].add(1)
    return x, y

Since XLA has control over scheduling, for efficiency it should schedule the slice first and then the in-place update, to avoid an unnecessary copy. However, on specifically the CPU backend it chooses to copy twice instead, generating

ENTRY %main.13 (Arg_0.1: f32[10000,10000]) -> (f32[10000,10000], f32[]) {
  %Arg_0.1 = f32[10000,10000]{1,0} parameter(0), metadata={op_name="x"}
  %copy.1 = f32[10000,10000]{1,0} copy(f32[10000,10000]{1,0} %Arg_0.1)
  %copy = f32[10000,10000]{1,0} copy(f32[10000,10000]{1,0} %copy.1)
  %add_dynamic-update-slice_fusion = f32[10000,10000]{1,0} fusion(f32[10000,10000]{1,0} %copy), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(g)/jit(main)/scatter-add" source_file="..." source_line=30}
  %slice_bitcast_fusion = f32[] fusion(f32[10000,10000]{1,0} %copy.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(g)/jit(main)/squeeze" source_file="..." source_line=29}
  ROOT %tuple.4 = (f32[10000,10000]{1,0}, f32[]) tuple(f32[10000,10000]{1,0} %add_dynamic-update-slice_fusion, f32[] %slice_bitcast_fusion)
}

(I'm not sure why it needs to make two copies here instead of just one, but the important part is that it copies at all.)

By the semantics of lax.optimization_barrier, I would expect that introducing an explicit dependency of x on y would force the slice to happen first, and then the liveliness analysis will kick in and remove the copies.

@partial(jit, donate_argnums=0)
def f(x: Array) -> tuple[Array, Array]:
    y = x[0, 0]
    x, y = lax.optimization_barrier((x, y))
    x = x.at[0, 0].add(1)
    return x, y

However, what ends up happening is XLA still introduces copies and re-orders the calls, so the generated code is the same as the one shown above. This seems to violate the scheduling control one expects from optimization_barrier.

Note that for this particular example, setting the XLA flag --xla_cpu_copy_insertion_use_region_analysis=true removes the copy and generates

ENTRY %main.13 (Arg_0.1: f32[10000,10000]) -> (f32[10000,10000], f32[]) {
  %Arg_0.1 = f32[10000,10000]{1,0} parameter(0), sharding={replicated}, metadata={op_name="x"}
  %slice_bitcast_fusion = f32[] fusion(f32[10000,10000]{1,0} %Arg_0.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(g)/jit(main)/squeeze" source_file="..." source_line=28}
  %add_dynamic-update-slice_fusion = f32[10000,10000]{1,0} fusion(f32[10000,10000]{1,0} %Arg_0.1), kind=kLoop, calls=%fused_computation.1, control-predecessors={%slice_bitcast_fusion}, metadata={op_name="jit(g)/jit(main)/scatter-add" source_file="..." source_line=30}
  ROOT %tuple.4 = (f32[10000,10000]{1,0}, f32[]) tuple(f32[10000,10000]{1,0} %add_dynamic-update-slice_fusion, f32[] %slice_bitcast_fusion)
}

as expected, with or without optimization_barrier. Also, using a GPU device generates the copyless

ENTRY %main.13 (Arg_0.1.0: f32[10000,10000]) -> (f32[10000,10000], f32[]) {
  %Arg_0.1.0 = f32[10000,10000]{1,0} parameter(0), metadata={op_name="x"}
  %wrapped_slice = f32[1,1]{1,0} fusion(f32[10000,10000]{1,0} %Arg_0.1.0), kind=kLoop, calls=%wrapped_slice_computation
  %bitcast.43.0 = f32[] bitcast(f32[1,1]{1,0} %wrapped_slice)
  %loop_dynamic_update_slice_fusion = f32[10000,10000]{1,0} fusion(f32[10000,10000]{1,0} %Arg_0.1.0), kind=kLoop, calls=%fused_dynamic_update_slice, control-predecessors={%wrapped_slice}, metadata={op_name="jit(g)/jit(main)/scatter-add" source_file="..." source_line=30}
  ROOT %tuple.5 = (f32[10000,10000]{1,0}, f32[]) tuple(f32[10000,10000]{1,0} %loop_dynamic_update_slice_fusion, f32[] %bitcast.43.0)
}

also with or without optimization_barrier. Finall, the reverse explicit schedule

@partial(jit, donate_argnums=0)
def f(x: Array) -> tuple[Array, Array]:
    z = x.at[0, 0].add(1)
    z, x = lax.optimization_barrier((z, x))
    y = x[0, 0]
    return x, y

which should introduce a copy does not introduce a copy with --xla_cpu_copy_insertion_use_region_analysis=true.

I'm a bit confused why the flag workaround works now, since region analysis was introduced more than 3 years ago in 92292d1. The core logic of RemoveUnnecessaryCopies and TryElideCopy hasn't seemed to change much in that time either. Rather, what has recently changed is the flag xla_cpu_copy_insertion_use_region_analysis was added to CPU (disabled by default) (#18521) and region analysis was disabled on GPU (#14680). Is there some context I'm missing?

(originally reported in the discussion jax-ml/jax#19165 and JAX issue jax-ml/jax#25399.)

@stephen-huan
Copy link
Author

For an example which is not fixed with --xla_cpu_copy_insertion_use_region_analysis=true, consider the toy implementation of jax.numpy.roll(a, shift=1) shown below.

import os
from functools import partial

os.environ["XLA_FLAGS"] = "--xla_cpu_copy_insertion_use_region_analysis=true"

import jax
import jax.numpy as jnp
from jax import Array, jit, lax


@partial(jit, donate_argnums=0)
def roll1(a: Array) -> Array:
    """Roll with shift=1."""
    n = a.size
    x = a[n - 1]
    a, x = lax.optimization_barrier((a, x))
    a = lax.fori_loop(1, n, lambda i, a: a.at[n - i].set(a[n - 1 - i]), a)
    a = a.at[0].set(x)
    return a


if __name__ == "__main__":
    x = jnp.arange(100)
    print(jax.make_jaxpr(roll1)(x))
    lowered = roll1.lower(x)
    compiled = lowered.compile()
    print(compiled.as_text())

Running on cpu copies the input twice (%copy.2 and %copy.9 below).

ENTRY %main.62 (Arg_0.1: s32[100]) -> s32[100] {
  %Arg_0.1 = s32[100]{0} parameter(0), metadata={op_name="a"}
  %copy.2 = s32[100]{0} copy(s32[100]{0} %Arg_0.1)
  %constant.3 = s32[] constant(0)
  %copy.10 = s32[] copy(s32[] %constant.3)
  %tuple.28 = (s32[], s32[100]{0}) tuple(s32[] %copy.10, s32[100]{0} %copy.2)
  %while.4 = (s32[], s32[100]{0}) while((s32[], s32[100]{0}) %tuple.28), condition=%region_2.47.clone.clone, body=%region_0.36.clone.clone.clone, metadata={op_name="jit(roll1)/jit(main)/while" source_file="..." source_line=17}, backend_config={"known_trip_count":{"n":"99"}}
  %get-tuple-element.85 = s32[100]{0} get-tuple-element((s32[], s32[100]{0}) %while.4), index=1, metadata={op_name="jit(roll1)/jit(main)/while" source_file="..." source_line=17}
  %bitcast_dynamic-update-slice_fusion = s32[100]{0} fusion(s32[100]{0} %get-tuple-element.85, s32[100]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation.2, metadata={op_name="jit(roll1)/jit(main)/scatter" source_file="..." source_line=18}
  ROOT %copy.9 = s32[100]{0} copy(s32[100]{0} %bitcast_dynamic-update-slice_fusion)
}

Running on gpu does not copy.

ENTRY %main.62 (Arg_0.1.0: s32[100]) -> s32[100] {
  %constant_3_0 = s32[] constant(0)
  %Arg_0.1.0 = s32[100]{0} parameter(0), metadata={op_name="a"}
  %copy.11 = s32[] copy(s32[] %constant_3_0)
  %wrapped_slice = s32[1]{0} fusion(s32[100]{0} %Arg_0.1.0), kind=kLoop, calls=%wrapped_slice_computation, metadata={op_name="jit(roll1)/jit(main)/slice" source_file="..." source_line=15}
  %bitcast.214.0 = s32[] bitcast(s32[1]{0} %wrapped_slice)
  %tuple.7.0 = (s32[100]{0}, s32[]) tuple(s32[100]{0} %Arg_0.1.0, s32[] %bitcast.214.0), metadata={op_name="jit(roll1)/jit(main)/optimization_barrier" source_file="..." source_line=16}
  %get-tuple-element.10.0 = s32[] get-tuple-element((s32[100]{0}, s32[]) %tuple.7.0), index=1, metadata={op_name="jit(roll1)/jit(main)/optimization_barrier" source_file="..." source_line=16}
  %get-tuple-element.91 = s32[100]{0} get-tuple-element((s32[100]{0}, s32[]) %tuple.7.0), index=0, metadata={op_name="jit(roll1)/jit(main)/optimization_barrier" source_file="..." source_line=16}
  %tuple.30 = (s32[], s32[100]{0}) tuple(s32[] %copy.11, s32[100]{0} %get-tuple-element.91)
  %while.3.0 = (s32[], s32[100]{0}) while((s32[], s32[100]{0}) %tuple.30), condition=%region_2.47.clone.clone, body=%region_0.36.clone.clone.sunk.clone, metadata={op_name="jit(roll1)/jit(main)/while" source_file="..." source_line=17}, backend_config={"known_trip_count":{"n":"99"}}
  %get-tuple-element.74.0 = s32[100]{0} get-tuple-element((s32[], s32[100]{0}) %while.3.0), index=1, metadata={op_name="jit(roll1)/jit(main)/while" source_file="..." source_line=17}
  ROOT %loop_dynamic_update_slice_fusion.1 = s32[100]{0} fusion(s32[100]{0} %get-tuple-element.74.0, s32[] %get-tuple-element.10.0), kind=kLoop, calls=%fused_dynamic_update_slice.1, metadata={op_name="jit(roll1)/jit(main)/scatter" source_file="..." source_line=18}
}

@NaiyerRizz NaiyerRizz self-assigned this Dec 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants