Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating an array with multiple indices crashes the compilation pipeline #1446

Open
positr0nium opened this issue Jan 12, 2025 · 5 comments
Open
Labels
documentation Improvements or additions to documentation

Comments

@positr0nium
Copy link

Hi,

when using the .at methods with multiple indices the MLIR pipeline is crashed. Code example:

from catalyst import qjit
from jax import jit
import pennylane as qml
import jax.numpy as jnp

def test_function():
    
    N = 10
    
    jax_array = jnp.zeros(N, dtype = jnp.int64)
    idx_array = jnp.arange(N, dtype = jnp.int64)
    init_val = jnp.zeros(N, dtype = jnp.int64)
    
    jax_array = jax_array.at[idx_array].set(init_val)
    return jax_array

jitted_test_function = jit(test_function)
# This works
print(jitted_test_function())

qml_function = qml.qnode(qml.device("lightning.qubit", wires=0))(test_function)
# This crashes
qjitted_test_function = qjit(qml_function)
@dime10 dime10 added the bug Something isn't working label Jan 13, 2025
@dime10
Copy link
Contributor

dime10 commented Jan 13, 2025

Thanks for submitting the report @positr0nium! The case in your example is actually known to not work, which is when all possible array indices are passed to the at[] method. At the time we decided it wasn't worth spending time to fix since in that case there is not much reason to use the at[] method to begin with since one can just deal with the entire array.

Could you confirm whether you find it useful to use at[] with all indices? Is it an edge case where the code normally would provide fewer indices but sometimes provides all of them, or some other scenario?

I will say though that the compiler shouldn't crash even if this is the case, so better messaging would be nice for sure.

@positr0nium
Copy link
Author

positr0nium commented Jan 14, 2025

This seems to crash too :(

from catalyst import qjit
from jax import jit
import pennylane as qml
import jax.numpy as jnp

def test_function():
    
    N = 10
    
    jax_array = jnp.zeros(N, dtype = jnp.int64)
    idx_array = jnp.arange(N-1, dtype = jnp.int64)
    init_val = jnp.zeros(N-1, dtype = jnp.int64)
    
    jax_array = jax_array.at[idx_array].set(init_val)
    return jax_array

jitted_test_function = jit(test_function)
# This works
print(jitted_test_function())

qml_function = qml.qnode(qml.device("lightning.qubit", wires=0))(test_function)
# This crashes
qjitted_test_function = qjit(qml_function)

@josh146
Copy link
Member

josh146 commented Jan 14, 2025

Just to confirm @positr0nium, do you get the same CompileError I get here?

CompileError: Compilation failed:
test_function:28:13: error: Indices are not unique and/or not sorted, unique boolean: 0, sorted boolean :0
      %10 = "stablehlo.scatter"(%0, %8, %9) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
            ^
test_function:28:13: note: see current operation: 
%15 = "mhlo.scatter"(%2, %14, %1) <{indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  "mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<10xi64>, tensor<9x1xi32>, tensor<9xi64>) -> tensor<10xi64>
test_function:28:13: error: Indices are not unique and/or not sorted, unique boolean: 0, sorted boolean :0
      %10 = "stablehlo.scatter"(%0, %8, %9) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
            ^
test_function:28:13: note: see current operation: 
%15 = "mhlo.scatter"(%2, %14, %1) <{indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  "mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<10xi64>, tensor<9x1xi32>, tensor<9xi64>) -> tensor<10xi64>
test_function:28:13: error: failed to legalize operation 'mhlo.scatter'
      %10 = "stablehlo.scatter"(%0, %8, %9) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
            ^
test_function:28:13: note: see current operation: 
%16 = "mhlo.scatter"(%1, %15, %3) <{indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  "mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<10xi64>, tensor<9x1xi32>, tensor<9xi64>) -> tensor<10xi64>
While processing 'FinalizingBufferize' pass While processing 'mlir::detail::OpToOpPassAdaptor' pass Failed to lower MLIR module

@positr0nium
Copy link
Author

Yes, this is the message I also receive.

@dime10
Copy link
Contributor

dime10 commented Jan 16, 2025

Actually this use case is fully supported, but it has some restrictions that are difficult for us to check or document since they relate to a JAX library function. Catalyst only supports slices with the .at[] method if the indices are sorted and unique, which is the case in your example. Since that is impossible for us to check at compile time (in general), we rely on the user to provide us with this guarantee like so:

from catalyst import qjit
import jax.numpy as jnp

@qjit
def test_function():
    
    N = 10
    
    jax_array = jnp.zeros(N, dtype = jnp.int64)
    idx_array = jnp.arange(N-1, dtype = jnp.int64)
    init_val = jnp.ones(N-1, dtype = jnp.int64)
    
    return jax_array.at[idx_array].set(init_val, indices_are_sorted=True, unique_indices=True)

>>> test_function()
Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], dtype=int64)

@josh146 Do you know how we could document this better? Maybe here? We could also try improving the error message to explicitly instruct the user to use these flags.

@dime10 dime10 added documentation Improvements or additions to documentation and removed bug Something isn't working labels Jan 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

3 participants