Skip to content

Commit e344273

Browse files
committed
Integrate ChainRulesCore
1 parent eb4f6c6 commit e344273

File tree

5 files changed

+178
-0
lines changed

5 files changed

+178
-0
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
99
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1010

1111
[weakdeps]
12+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1213
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
1314
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
1415

1516
[extensions]
17+
MuscleChainRulesCoreExt = "ChainRulesCore"
1618
MuscleDaggerExt = "Dagger"
1719
MuscleReactantExt = "Reactant"
1820

ext/MuscleChainRulesCoreExt.jl

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
module MuscleChainRulesCoreExt
2+
3+
using Muscle
4+
using ChainRulesCore
5+
6+
function ChainRulesCore.frule((_, _, ȧ, _), ::typeof(einsum), ic, a, ia)
7+
return einsum(ic, a, ia), einsum(ic, ȧ, ia)
8+
end
9+
10+
function ChainRulesCore.frule((_, _, ȧ, _, ḃ, _), ::typeof(einsum), ic, a, ia, b, ib)
11+
c = einsum(ic, a, ia, b, ib)
12+
= einsum(ic, ȧ, ia, b, ib) + einsum(ic, a, ia, ḃ, ib)
13+
return c, ċ
14+
end
15+
16+
function ChainRulesCore.rrule(::typeof(einsum), ic, a, ia)
17+
c = einsum(ic, a, ia)
18+
proj = ProjectTo(a)
19+
20+
function einsum_pullback(c̄)
21+
c_shape_with_singletons = map(ia) do i
22+
loc = findfirst(==(i), ic)
23+
isnothing(loc) ? 1 : size(c̄, loc)
24+
end
25+
26+
dims_to_repeat = map(zip(size(a), c_shape_with_singletons .== 1)) do (dₐ, issingleton)
27+
issingleton ? dₐ : 1
28+
end
29+
::typeof(a) = proj(repeat(reshape(c̄, c_shape_with_singletons...), dims_to_repeat...))
30+
31+
return (NoTangent(), NoTangent(), ā, NoTangent())
32+
end
33+
einsum_pullback(c̄::AbstractThunk) = einsum_pullback(unthunk(c̄))
34+
35+
return c, einsum_pullback
36+
end
37+
38+
function ChainRulesCore.rrule(::typeof(einsum), ic, a, ia, b, ib)
39+
c = einsum(ic, a, ia, b, ib)
40+
proj_a = ProjectTo(a)
41+
proj_b = ProjectTo(b)
42+
43+
function einsum_pullback(c̄)
44+
= @thunk proj_a(einsum(ia, c̄, ic, conj(b), ib))
45+
= @thunk proj_b(einsum(ib, conj(a), ia, c̄, ic))
46+
return (NoTangent(), NoTangent(), ā, NoTangent(), b̄, NoTangent())
47+
end
48+
einsum_pullback(c̄::AbstractThunk) = einsum_pullback(unthunk(c̄))
49+
50+
return c, einsum_pullback
51+
end
52+
53+
end

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
34
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
45
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
56
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

test/integration/ChainRules_test.jl

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
@testset "ChainRules" begin
2+
using ChainRulesTestUtils
3+
4+
@testset "einsum" begin
5+
@testset "unary" begin
6+
@testset "real" begin
7+
x = fill(1.0)
8+
test_frule(einsum, Char[], x, Char[])
9+
test_rrule(einsum, Char[], x, Char[]; check_inferred=false)
10+
11+
x = ones(2)
12+
test_frule(einsum, Char['i'], x, Char['i'])
13+
test_rrule(einsum, Char['i'], x, Char['i']; check_inferred=false)
14+
15+
x = ones(2, 3)
16+
test_frule(einsum, Char['i', 'j'], x, Char['i', 'j'])
17+
test_rrule(einsum, Char['i', 'j'], x, Char['i', 'j']; check_inferred=false)
18+
19+
x = ones(2, 3)
20+
test_frule(einsum, Char['j'], x, Char['i', 'j'])
21+
test_rrule(einsum, Char['j'], x, Char['i', 'j']; check_inferred=false)
22+
end
23+
24+
@testset "complex" begin
25+
x = fill(1.0 + 1.0im)
26+
test_frule(einsum, Char[], x, Char[])
27+
test_rrule(einsum, Char[], x, Char[]; check_inferred=false)
28+
29+
x = fill(1.0 + 1.0im, 2)
30+
test_frule(einsum, Char['i'], x, Char['i'])
31+
test_rrule(einsum, Char['i'], x, Char['i']; check_inferred=false)
32+
33+
x = fill(1.0 + 1.0im, 2, 3)
34+
test_frule(einsum, Char['i', 'j'], x, Char['i', 'j'])
35+
test_rrule(einsum, Char['i', 'j'], x, Char['i', 'j']; check_inferred=false)
36+
37+
x = fill(1.0 + 1.0im, 2, 3)
38+
test_frule(einsum, Char['j'], x, Char['i', 'j'])
39+
test_rrule(einsum, Char['j'], x, Char['i', 'j']; check_inferred=false)
40+
end
41+
end
42+
43+
@testset "binary" begin
44+
@testset "real" begin
45+
# scalar-scalar product
46+
a = ones()
47+
b = 2.0 * ones()
48+
test_frule(einsum, Char[], a, Char[], b, Char[]; check_inferred=false, testset_name="scalar-scalar product - frule")
49+
test_rrule(einsum, Char[], a, Char[], b, Char[]; check_inferred=false, testset_name="scalar-scalar product - rrule")
50+
51+
# vector-vector inner product
52+
a = ones(2)
53+
b = 2.0 .* ones(2)
54+
test_frule(einsum, Char[], a, Char['i'], b, Char['i']; check_inferred=false, testset_name="vector-vector inner product - frule")
55+
test_rrule(einsum, Char[], a, Char['i'], b, Char['i']; check_inferred=false, testset_name="vector-vector inner product - rrule")
56+
57+
# vector-vector outer product
58+
a = ones(2)
59+
b = 2.0 .* ones(3)
60+
test_frule(einsum, Char['i', 'j'], a, Char['i'], b, Char['j']; check_inferred=false, testset_name="vector-vector outer product - frule")
61+
test_rrule(einsum, Char['i', 'j'], a, Char['i'], b, Char['j']; check_inferred=false, testset_name="vector-vector outer product - rrule")
62+
63+
# matrix-vector product
64+
a = ones(2, 3)
65+
b = 2.0 .* ones(3)
66+
test_frule(einsum, Char['i'], a, Char['i', 'j'], b, Char['j']; check_inferred=false, testset_name="matrix-vector product - frule")
67+
test_rrule(einsum, Char['i'], a, Char['i', 'j'], b, Char['j']; check_inferred=false, testset_name="matrix-vector product - rrule")
68+
69+
# matrix-matrix product
70+
a = ones(4, 2)
71+
b = 2.0 .* ones(2, 3)
72+
test_frule(einsum, Char['i', 'k'], a, Char['i', 'j'], b, Char['j', 'k']; check_inferred=false, testset_name="matrix-matrix product - frule")
73+
test_rrule(einsum, Char['i', 'k'], a, Char['i', 'j'], b, Char['j', 'k']; check_inferred=false, testset_name="matrix-matrix product - rrule")
74+
75+
# matrix-matrix inner product
76+
a = ones(3, 4)
77+
b = ones(4, 3)
78+
test_frule(einsum, Char[], a, Char['i', 'j'], b, Char['j', 'i']; check_inferred=false, testset_name="matrix-matrix inner product - frule")
79+
test_rrule(einsum, Char[], a, Char['i', 'j'], b, Char['j', 'i']; check_inferred=false, testset_name="matrix-matrix inner product - rrule")
80+
end
81+
82+
@testset "complex" begin
83+
# scalar-scalar product
84+
a = fill(1.0 + 1.0im)
85+
b = 2.0 * fill(1.0 + 1.0im)
86+
test_frule(einsum, Char[], a, Char[], b, Char[]; check_inferred=false, testset_name="scalar-scalar product - frule")
87+
test_rrule(einsum, Char[], a, Char[], b, Char[]; check_inferred=false, testset_name="scalar-scalar product - rrule")
88+
89+
# vector-vector inner product
90+
a = fill(1.0 + 1.0im, 2)
91+
b = 2.0 .* fill(1.0 + 1.0im, 2)
92+
test_frule(einsum, Char[], a, Char['i'], b, Char['i']; check_inferred=false, testset_name="vector-vector inner product - frule")
93+
test_rrule(einsum, Char[], a, Char['i'], b, Char['i']; check_inferred=false, testset_name="vector-vector inner product - rrule")
94+
95+
# vector-vector outer product
96+
a = fill(1.0 + 1.0im, 2)
97+
b = 2.0 .* fill(1.0 + 1.0im, 3)
98+
test_frule(einsum, Char['i', 'j'], a, Char['i'], b, Char['j']; check_inferred=false, testset_name="vector-vector outer product - frule")
99+
test_rrule(einsum, Char['i', 'j'], a, Char['i'], b, Char['j']; check_inferred=false, testset_name="vector-vector outer product - rrule")
100+
101+
# matrix-vector product
102+
a = fill(1.0 + 1.0im, 2, 3)
103+
b = 2.0 .* fill(1.0 + 1.0im, 3)
104+
test_frule(einsum, Char['i'], a, Char['i', 'j'], b, Char['j']; check_inferred=false, testset_name="matrix-vector product - frule")
105+
test_rrule(einsum, Char['i'], a, Char['i', 'j'], b, Char['j']; check_inferred=false, testset_name="matrix-vector product - rrule")
106+
107+
# matrix-matrix product
108+
a = fill(1.0 + 1.0im, 4, 2)
109+
b = 2.0 .* fill(1.0 + 1.0im, 2, 3)
110+
test_frule(einsum, Char['i', 'k'], a, Char['i', 'j'], b, Char['j', 'k']; check_inferred=false, testset_name="matrix-matrix product - frule")
111+
test_rrule(einsum, Char['i', 'k'], a, Char['i', 'j'], b, Char['j', 'k']; check_inferred=false, testset_name="matrix-matrix product - rrule")
112+
113+
# matrix-matrix inner product
114+
a = fill(1.0 + 1.0im, 3, 4)
115+
b = fill(1.0 + 1.0im, 4, 3)
116+
test_frule(einsum, Char[], a, Char['i', 'j'], b, Char['j', 'i']; check_inferred=false, testset_name="matrix-matrix inner product - frule")
117+
test_rrule(einsum, Char[], a, Char['i', 'j'], b, Char['j', 'i']; check_inferred=false, testset_name="matrix-matrix inner product - rrule")
118+
end
119+
end
120+
end
121+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Test
77
end
88

99
@testset "Integration" verbose = true begin
10+
include("integration/ChainRules_test.jl")
1011
include("integration/Dagger_test.jl")
1112
end
1213

0 commit comments

Comments
 (0)