1
+ @testset " ChainRules" begin
2
+ using ChainRulesTestUtils
3
+
4
+ @testset " einsum" begin
5
+ @testset " unary" begin
6
+ @testset " real" begin
7
+ x = fill (1.0 )
8
+ test_frule (einsum, Char[], x, Char[])
9
+ test_rrule (einsum, Char[], x, Char[]; check_inferred= false )
10
+
11
+ x = ones (2 )
12
+ test_frule (einsum, Char[' i' ], x, Char[' i' ])
13
+ test_rrule (einsum, Char[' i' ], x, Char[' i' ]; check_inferred= false )
14
+
15
+ x = ones (2 , 3 )
16
+ test_frule (einsum, Char[' i' , ' j' ], x, Char[' i' , ' j' ])
17
+ test_rrule (einsum, Char[' i' , ' j' ], x, Char[' i' , ' j' ]; check_inferred= false )
18
+
19
+ x = ones (2 , 3 )
20
+ test_frule (einsum, Char[' j' ], x, Char[' i' , ' j' ])
21
+ test_rrule (einsum, Char[' j' ], x, Char[' i' , ' j' ]; check_inferred= false )
22
+ end
23
+
24
+ @testset " complex" begin
25
+ x = fill (1.0 + 1.0im )
26
+ test_frule (einsum, Char[], x, Char[])
27
+ test_rrule (einsum, Char[], x, Char[]; check_inferred= false )
28
+
29
+ x = fill (1.0 + 1.0im , 2 )
30
+ test_frule (einsum, Char[' i' ], x, Char[' i' ])
31
+ test_rrule (einsum, Char[' i' ], x, Char[' i' ]; check_inferred= false )
32
+
33
+ x = fill (1.0 + 1.0im , 2 , 3 )
34
+ test_frule (einsum, Char[' i' , ' j' ], x, Char[' i' , ' j' ])
35
+ test_rrule (einsum, Char[' i' , ' j' ], x, Char[' i' , ' j' ]; check_inferred= false )
36
+
37
+ x = fill (1.0 + 1.0im , 2 , 3 )
38
+ test_frule (einsum, Char[' j' ], x, Char[' i' , ' j' ])
39
+ test_rrule (einsum, Char[' j' ], x, Char[' i' , ' j' ]; check_inferred= false )
40
+ end
41
+ end
42
+
43
+ @testset " binary" begin
44
+ @testset " real" begin
45
+ # scalar-scalar product
46
+ a = ones ()
47
+ b = 2.0 * ones ()
48
+ test_frule (einsum, Char[], a, Char[], b, Char[]; check_inferred= false , testset_name= " scalar-scalar product - frule" )
49
+ test_rrule (einsum, Char[], a, Char[], b, Char[]; check_inferred= false , testset_name= " scalar-scalar product - rrule" )
50
+
51
+ # vector-vector inner product
52
+ a = ones (2 )
53
+ b = 2.0 .* ones (2 )
54
+ test_frule (einsum, Char[], a, Char[' i' ], b, Char[' i' ]; check_inferred= false , testset_name= " vector-vector inner product - frule" )
55
+ test_rrule (einsum, Char[], a, Char[' i' ], b, Char[' i' ]; check_inferred= false , testset_name= " vector-vector inner product - rrule" )
56
+
57
+ # vector-vector outer product
58
+ a = ones (2 )
59
+ b = 2.0 .* ones (3 )
60
+ test_frule (einsum, Char[' i' , ' j' ], a, Char[' i' ], b, Char[' j' ]; check_inferred= false , testset_name= " vector-vector outer product - frule" )
61
+ test_rrule (einsum, Char[' i' , ' j' ], a, Char[' i' ], b, Char[' j' ]; check_inferred= false , testset_name= " vector-vector outer product - rrule" )
62
+
63
+ # matrix-vector product
64
+ a = ones (2 , 3 )
65
+ b = 2.0 .* ones (3 )
66
+ test_frule (einsum, Char[' i' ], a, Char[' i' , ' j' ], b, Char[' j' ]; check_inferred= false , testset_name= " matrix-vector product - frule" )
67
+ test_rrule (einsum, Char[' i' ], a, Char[' i' , ' j' ], b, Char[' j' ]; check_inferred= false , testset_name= " matrix-vector product - rrule" )
68
+
69
+ # matrix-matrix product
70
+ a = ones (4 , 2 )
71
+ b = 2.0 .* ones (2 , 3 )
72
+ test_frule (einsum, Char[' i' , ' k' ], a, Char[' i' , ' j' ], b, Char[' j' , ' k' ]; check_inferred= false , testset_name= " matrix-matrix product - frule" )
73
+ test_rrule (einsum, Char[' i' , ' k' ], a, Char[' i' , ' j' ], b, Char[' j' , ' k' ]; check_inferred= false , testset_name= " matrix-matrix product - rrule" )
74
+
75
+ # matrix-matrix inner product
76
+ a = ones (3 , 4 )
77
+ b = ones (4 , 3 )
78
+ test_frule (einsum, Char[], a, Char[' i' , ' j' ], b, Char[' j' , ' i' ]; check_inferred= false , testset_name= " matrix-matrix inner product - frule" )
79
+ test_rrule (einsum, Char[], a, Char[' i' , ' j' ], b, Char[' j' , ' i' ]; check_inferred= false , testset_name= " matrix-matrix inner product - rrule" )
80
+ end
81
+
82
+ @testset " complex" begin
83
+ # scalar-scalar product
84
+ a = fill (1.0 + 1.0im )
85
+ b = 2.0 * fill (1.0 + 1.0im )
86
+ test_frule (einsum, Char[], a, Char[], b, Char[]; check_inferred= false , testset_name= " scalar-scalar product - frule" )
87
+ test_rrule (einsum, Char[], a, Char[], b, Char[]; check_inferred= false , testset_name= " scalar-scalar product - rrule" )
88
+
89
+ # vector-vector inner product
90
+ a = fill (1.0 + 1.0im , 2 )
91
+ b = 2.0 .* fill (1.0 + 1.0im , 2 )
92
+ test_frule (einsum, Char[], a, Char[' i' ], b, Char[' i' ]; check_inferred= false , testset_name= " vector-vector inner product - frule" )
93
+ test_rrule (einsum, Char[], a, Char[' i' ], b, Char[' i' ]; check_inferred= false , testset_name= " vector-vector inner product - rrule" )
94
+
95
+ # vector-vector outer product
96
+ a = fill (1.0 + 1.0im , 2 )
97
+ b = 2.0 .* fill (1.0 + 1.0im , 3 )
98
+ test_frule (einsum, Char[' i' , ' j' ], a, Char[' i' ], b, Char[' j' ]; check_inferred= false , testset_name= " vector-vector outer product - frule" )
99
+ test_rrule (einsum, Char[' i' , ' j' ], a, Char[' i' ], b, Char[' j' ]; check_inferred= false , testset_name= " vector-vector outer product - rrule" )
100
+
101
+ # matrix-vector product
102
+ a = fill (1.0 + 1.0im , 2 , 3 )
103
+ b = 2.0 .* fill (1.0 + 1.0im , 3 )
104
+ test_frule (einsum, Char[' i' ], a, Char[' i' , ' j' ], b, Char[' j' ]; check_inferred= false , testset_name= " matrix-vector product - frule" )
105
+ test_rrule (einsum, Char[' i' ], a, Char[' i' , ' j' ], b, Char[' j' ]; check_inferred= false , testset_name= " matrix-vector product - rrule" )
106
+
107
+ # matrix-matrix product
108
+ a = fill (1.0 + 1.0im , 4 , 2 )
109
+ b = 2.0 .* fill (1.0 + 1.0im , 2 , 3 )
110
+ test_frule (einsum, Char[' i' , ' k' ], a, Char[' i' , ' j' ], b, Char[' j' , ' k' ]; check_inferred= false , testset_name= " matrix-matrix product - frule" )
111
+ test_rrule (einsum, Char[' i' , ' k' ], a, Char[' i' , ' j' ], b, Char[' j' , ' k' ]; check_inferred= false , testset_name= " matrix-matrix product - rrule" )
112
+
113
+ # matrix-matrix inner product
114
+ a = fill (1.0 + 1.0im , 3 , 4 )
115
+ b = fill (1.0 + 1.0im , 4 , 3 )
116
+ test_frule (einsum, Char[], a, Char[' i' , ' j' ], b, Char[' j' , ' i' ]; check_inferred= false , testset_name= " matrix-matrix inner product - frule" )
117
+ test_rrule (einsum, Char[], a, Char[' i' , ' j' ], b, Char[' j' , ' i' ]; check_inferred= false , testset_name= " matrix-matrix inner product - rrule" )
118
+ end
119
+ end
120
+ end
121
+ end
0 commit comments