Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: don't unroll Recurrence #1209

Merged
merged 5 commits into from
Feb 11, 2025
Merged

feat: don't unroll Recurrence #1209

merged 5 commits into from
Feb 11, 2025

Conversation

avik-pal
Copy link
Member

needs EnzymeAD/Reactant.jl#565.

AD doesn't seem to work (enzyme.init doesn't work with XLA) cc @wsmoses

module {
  func.func private @"diffeConst{typeof(sumabs2)}(Main.sumabs2)_autodiff"(%arg0: tensor<6x2x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<3x3xf32>, %arg3: tensor<3xf32>, %arg4: tensor<3xf32>, %arg5: tensor<2xui64>, %arg6: tensor<f32>, %arg7: tensor<2xui64>, %arg8: tensor<3x3xf32>, %arg9: tensor<3x3xf32>, %arg10: tensor<3xf32>, %arg11: tensor<3xf32>) -> (tensor<6x2x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>) {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<3x2xf32>
    %c = stablehlo.constant dense<0> : tensor<i64>
    %c_0 = stablehlo.constant dense<6> : tensor<i64>
    %c_1 = stablehlo.constant dense<2> : tensor<i64>
    %c_2 = stablehlo.constant dense<1> : tensor<i64>
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<3x2xf32>
    %0 = "enzyme.init"() : () -> !enzyme.Cache<tensor<3x2xf32>>
    %1 = "enzyme.init"() : () -> !enzyme.Cache<tensor<3x3xf32>>
    %2 = "enzyme.init"() : () -> !enzyme.Cache<tensor<2x3xf32>>
    %3 = "enzyme.init"() : () -> !enzyme.Cache<tensor<3x2xf32>>
    %4 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<6x2x3xf32>) -> tensor<3x2x6xf32>
    %5 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %6 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %7 = stablehlo.slice %4 [0:3, 0:2, 0:1] : (tensor<3x2x6xf32>) -> tensor<3x2x1xf32>
    %8 = stablehlo.transpose %7, dims = [2, 1, 0] : (tensor<3x2x1xf32>) -> tensor<1x2x3xf32>
    %9 = stablehlo.reshape %8 : (tensor<1x2x3xf32>) -> tensor<2x3xf32>
    %10 = stablehlo.broadcast_in_dim %arg4, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32>
    %11 = stablehlo.dot_general %arg1, %9, contracting_dims = [0] x [1] : (tensor<3x3xf32>, tensor<2x3xf32>) -> tensor<3x2xf32>
    %12 = stablehlo.broadcast_in_dim %arg3, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32>
    %13 = stablehlo.add %11, %12 : tensor<3x2xf32>
    %14 = stablehlo.add %10, %13 : tensor<3x2xf32>
    %15 = stablehlo.tanh %14 : tensor<3x2xf32>
    %16:10 = stablehlo.while(%iterArg = %c, %iterArg_5 = %5, %iterArg_6 = %6, %iterArg_7 = %arg3, %iterArg_8 = %arg4, %iterArg_9 = %arg5, %iterArg_10 = %c_0, %iterArg_11 = %15, %iterArg_12 = %4, %iterArg_13 = %c) : tensor<i64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<i64>, tensor<3x2xf32>, tensor<3x2x6xf32>, tensor<i64>
     cond {
      %41 = stablehlo.subtract %iterArg_10, %c_1 : tensor<i64>
      %42 = stablehlo.divide %41, %c_2 : tensor<i64>
      %43 = stablehlo.add %42, %c_2 : tensor<i64>
      %44 = stablehlo.compare  LT, %iterArg, %43 : (tensor<i64>, tensor<i64>) -> tensor<i1>
      stablehlo.return %44 : tensor<i1>
    } do {
      %41 = stablehlo.add %iterArg_13, %c_2 : tensor<i64>
      %42 = stablehlo.add %c_1, %iterArg : tensor<i64>
      %43 = stablehlo.subtract %42, %c_2 : tensor<i64>
      %44 = stablehlo.dynamic_slice %iterArg_12, %c, %c, %43, sizes = [3, 2, 1] : (tensor<3x2x6xf32>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<3x2x1xf32>
      %45 = stablehlo.transpose %44, dims = [2, 1, 0] : (tensor<3x2x1xf32>) -> tensor<1x2x3xf32>
      %46 = stablehlo.reshape %45 : (tensor<1x2x3xf32>) -> tensor<2x3xf32>
      "enzyme.push"(%1, %iterArg_6) : (!enzyme.Cache<tensor<3x3xf32>>, tensor<3x3xf32>) -> ()
      "enzyme.push"(%0, %iterArg_11) : (!enzyme.Cache<tensor<3x2xf32>>, tensor<3x2xf32>) -> ()
      %47 = stablehlo.dot_general %iterArg_6, %iterArg_11, contracting_dims = [1] x [0] : (tensor<3x3xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
      %48 = stablehlo.broadcast_in_dim %iterArg_8, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32>
      %49 = stablehlo.add %47, %48 : tensor<3x2xf32>
      "enzyme.push"(%2, %46) : (!enzyme.Cache<tensor<2x3xf32>>, tensor<2x3xf32>) -> ()
      %50 = stablehlo.dot_general %iterArg_5, %46, contracting_dims = [1] x [1] : (tensor<3x3xf32>, tensor<2x3xf32>) -> tensor<3x2xf32>
      %51 = stablehlo.broadcast_in_dim %iterArg_7, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32>
      %52 = stablehlo.add %50, %51 : tensor<3x2xf32>
      %53 = stablehlo.add %49, %52 : tensor<3x2xf32>
      "enzyme.push"(%3, %53) : (!enzyme.Cache<tensor<3x2xf32>>, tensor<3x2xf32>) -> ()
      %54 = stablehlo.tanh %53 : tensor<3x2xf32>
      %55 = stablehlo.add %iterArg, %c_2 : tensor<i64>
      stablehlo.return %55, %iterArg_5, %iterArg_6, %iterArg_7, %iterArg_8, %iterArg_9, %iterArg_10, %54, %iterArg_12, %41 : tensor<i64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<i64>, tensor<3x2xf32>, tensor<3x2x6xf32>, tensor<i64>
    }
    %17 = stablehlo.abs %16#7 : tensor<3x2xf32>
    %18 = stablehlo.transpose %16#8, dims = [2, 1, 0] : (tensor<3x2x6xf32>) -> tensor<6x2x3xf32>
    %19 = stablehlo.transpose %16#1, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %20 = stablehlo.transpose %16#2, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %21 = stablehlo.transpose %arg9, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %22 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %23 = stablehlo.broadcast_in_dim %arg6, dims = [] : (tensor<f32>) -> tensor<3x2xf32>
    %24 = stablehlo.multiply %23, %17 : tensor<3x2xf32>
    %25 = stablehlo.add %24, %24 : tensor<3x2xf32>
    %26 = stablehlo.compare  GE, %16#7, %cst_4 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xi1>
    %27 = stablehlo.negate %25 : tensor<3x2xf32>
    %28 = stablehlo.select %26, %25, %27 : tensor<3x2xi1>, tensor<3x2xf32>
    %29:6 = stablehlo.while(%iterArg = %c, %iterArg_5 = %22, %iterArg_6 = %21, %iterArg_7 = %arg10, %iterArg_8 = %arg11, %iterArg_9 = %28) : tensor<i64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3x2xf32>
     cond {
      %41 = stablehlo.compare  LT, %iterArg, %16#9 : (tensor<i64>, tensor<i64>) -> tensor<i1>
      stablehlo.return %41 : tensor<i1>
    } do {
      %41 = stablehlo.add %iterArg, %c_2 : tensor<i64>
      %42 = "enzyme.pop"(%3) : (!enzyme.Cache<tensor<3x2xf32>>) -> tensor<3x2xf32>
      %43 = stablehlo.tanh %42 : tensor<3x2xf32>
      %44 = stablehlo.multiply %43, %43 : tensor<3x2xf32>
      %45 = stablehlo.subtract %cst, %44 : tensor<3x2xf32>
      %46 = stablehlo.multiply %iterArg_9, %45 : tensor<3x2xf32>
      %47 = stablehlo.reduce(%46 init: %cst_3) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor<f32>) -> tensor<3xf32>
      %48 = stablehlo.add %iterArg_7, %47 : tensor<3xf32>
      %49 = "enzyme.pop"(%2) : (!enzyme.Cache<tensor<2x3xf32>>) -> tensor<2x3xf32>
      %50 = stablehlo.dot_general %46, %49, contracting_dims = [1] x [0] : (tensor<3x2xf32>, tensor<2x3xf32>) -> tensor<3x3xf32>
      %51 = stablehlo.add %iterArg_5, %50 : tensor<3x3xf32>
      %52 = stablehlo.add %iterArg_8, %47 : tensor<3xf32>
      %53 = "enzyme.pop"(%1) : (!enzyme.Cache<tensor<3x3xf32>>) -> tensor<3x3xf32>
      %54 = "enzyme.pop"(%0) : (!enzyme.Cache<tensor<3x2xf32>>) -> tensor<3x2xf32>
      %55 = stablehlo.dot_general %46, %54, contracting_dims = [1] x [1] : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x3xf32>
      %56 = stablehlo.add %iterArg_6, %55 : tensor<3x3xf32>
      %57 = stablehlo.dot_general %53, %46, contracting_dims = [0] x [0] : (tensor<3x3xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
      stablehlo.return %41, %51, %56, %48, %52, %57 : tensor<i64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3x2xf32>
    }
    %30 = stablehlo.multiply %15, %15 : tensor<3x2xf32>
    %31 = stablehlo.subtract %cst, %30 : tensor<3x2xf32>
    %32 = stablehlo.multiply %29#5, %31 : tensor<3x2xf32>
    %33 = stablehlo.reduce(%32 init: %cst_3) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor<f32>) -> tensor<3xf32>
    %34 = stablehlo.add %29#3, %33 : tensor<3xf32>
    %35 = stablehlo.dot_general %32, %9, contracting_dims = [1] x [0] : (tensor<3x2xf32>, tensor<2x3xf32>) -> tensor<3x3xf32>
    %36 = stablehlo.reduce(%32 init: %cst_3) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor<f32>) -> tensor<3xf32>
    %37 = stablehlo.add %29#4, %36 : tensor<3xf32>
    %38 = stablehlo.transpose %29#2, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %39 = stablehlo.add %35, %29#1 : tensor<3x3xf32>
    %40 = stablehlo.transpose %39, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    return %18, %19, %20, %16#3, %16#4, %arg5, %40, %38, %34, %37 : tensor<6x2x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>
  }
  func.func @main(%arg0: tensor<6x2x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<3x3xf32>, %arg3: tensor<3xf32>, %arg4: tensor<3xf32>, %arg5: tensor<2xui64>) -> (tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<6x2x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>) {
    %c = stablehlo.constant dense<1> : tensor<2xui64>
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<3x3xf32>
    %0:10 = call @"diffeConst{typeof(sumabs2)}(Main.sumabs2)_autodiff"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %cst, %c, %cst_1, %cst_1, %cst_0, %cst_0) : (tensor<6x2x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<f32>, tensor<2xui64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<6x2x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>)
    return %0#6, %0#7, %0#8, %0#9, %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<6x2x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>
  }
}

@avik-pal avik-pal force-pushed the ap/loop_rnn_reactant branch from 8fdb7ef to 4ae54ef Compare January 18, 2025 04:38
Copy link
Contributor

github-actions bot commented Jan 18, 2025

Benchmark Results (ASV)

main 7e7f10a... main/7e7f10a1e4a295...
basics/overhead 0.124 ± 0.0013 μs 0.137 ± 0.0011 μs 0.904
time_to_load 0.909 ± 0.011 s 0.904 ± 0.0077 s 1.01

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@avik-pal avik-pal marked this pull request as draft January 18, 2025 05:01
@wsmoses
Copy link
Contributor

wsmoses commented Jan 18, 2025

cc @Pangoraw if an ir test case is helpful!

@Pangoraw
Copy link

Pangoraw commented Jan 18, 2025

Thank you for the sample MLIR. Our current loop analysis cannot figure out the static number of iterations from this condition:

    %16:10 = stablehlo.while(%iterArg = %c, %iterArg_5 = %5, %iterArg_6 = %6, %iterArg_7 = %arg3, %iterArg_8 = %arg4, %iterArg_9 = %arg5, %iterArg_10 = %c_0, %iterArg_11 = %15, %iterArg_12 = %4, %iterArg_13 = %c) : tensor<i64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<i64>, tensor<3x2xf32>, tensor<3x2x6xf32>, tensor<i64>
     cond {
      %41 = stablehlo.subtract %iterArg_10, %c_1 : tensor<i64>
      %42 = stablehlo.divide %41, %c_2 : tensor<i64>
      %43 = stablehlo.add %42, %c_2 : tensor<i64>
      %44 = stablehlo.compare  LT, %iterArg, %43 : (tensor<i64>, tensor<i64>) -> tensor<i1>
      stablehlo.return %44 : tensor<i1>
    } do {

So we should probably update the codegen from @trace for as well or add something like EnzymeAD/Enzyme-JAX#173

@wsmoses
Copy link
Contributor

wsmoses commented Jan 18, 2025

EnzymeAD/Enzyme-JAX#173

I'm currently fixing fires on weird execution stuff. @Pangoraw if you have cycles to take/finish up the while dead code limination PR, be my guest! It would be super helpful (especially for differentiation)

@avik-pal avik-pal force-pushed the ap/loop_rnn_reactant branch from 0bbc7d9 to 6413484 Compare February 11, 2025 17:11
@avik-pal
Copy link
Member Author

Locally tested this. Works with the latest round of JLL changes. Can merge once EnzymeAD/Reactant.jl#713 lands

@avik-pal avik-pal marked this pull request as ready for review February 11, 2025 17:12
@avik-pal avik-pal force-pushed the ap/loop_rnn_reactant branch from 6413484 to 3cdd728 Compare February 11, 2025 17:17
@avik-pal avik-pal merged commit 637a9cf into main Feb 11, 2025
42 of 69 checks passed
@avik-pal avik-pal deleted the ap/loop_rnn_reactant branch February 11, 2025 23:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants