55
66from . import dispatch , B , Numeric
77from .custom import autograd_register
8- from ..custom import (
9- toeplitz_solve , s_toeplitz_solve ,
10- expm , s_expm ,
11- logm , s_logm
12- )
8+ from ..custom import toeplitz_solve , s_toeplitz_solve , expm , s_expm , logm , s_logm
139from ..linear_algebra import _default_perm
1410from ..util import batch_computation
1511
@@ -41,8 +37,7 @@ def transpose(a, perm=None):
4137@dispatch (Numeric )
4238def trace (a , axis1 = 0 , axis2 = 1 ):
4339 if axis1 == axis2 :
44- raise ValueError ('Keyword arguments axis1 and axis2 cannot be the '
45- 'same.' )
40+ raise ValueError ("Keyword arguments axis1 and axis2 cannot be the same." )
4641
4742 # AutoGrad does not support the `axis1` and `axis2` arguments...
4843
@@ -52,8 +47,9 @@ def trace(a, axis1=0, axis2=1):
5247
5348 # Bring the trace axes forward.
5449 if (axis1 , axis2 ) != (0 , 1 ):
55- perm = [axis1 , axis2 ] + \
56- [i for i in range (B .rank (a )) if i != axis1 and i != axis2 ]
50+ perm = [axis1 , axis2 ] + [
51+ i for i in range (B .rank (a )) if i != axis1 and i != axis2
52+ ]
5753 a = anp .transpose (a , axes = perm )
5854
5955 return anp .trace (a )
@@ -110,10 +106,9 @@ def cholesky_solve(a, b):
110106@dispatch (Numeric , Numeric )
111107def triangular_solve (a , b , lower_a = True ):
112108 def _triangular_solve (a_ , b_ ):
113- return asla .solve_triangular (a_ , b_ ,
114- trans = 'N' ,
115- lower = lower_a ,
116- check_finite = False )
109+ return asla .solve_triangular (
110+ a_ , b_ , trans = "N" , lower = lower_a , check_finite = False
111+ )
117112
118113 return batch_computation (_triangular_solve , (a , b ), (2 , 2 ))
119114
0 commit comments