Skip to content

Commit 553c74f

Browse files
committed
Move to Black
1 parent 8c2bbdd commit 553c74f

38 files changed

+720
-649
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
[![CI](https://github.com/wesselb/lab/workflows/CI/badge.svg?branch=master)](https://github.com/wesselb/lab/actions?query=workflow%3ACI)
44
[![Coverage Status](https://coveralls.io/repos/github/wesselb/lab/badge.svg?branch=master&service=github)](https://coveralls.io/github/wesselb/lab?branch=master)
55
[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://wesselb.github.io/lab)
6+
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
67

78
A generic interface for linear algebra backends: code it once, run it on any
89
backend

benchmark.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def f2(x):
4444
z = f1(x)
4545
us_lab = (time() - s) / its * 1e6
4646

47-
print('Overhead: {:.1f} us / {:.1f} %'
48-
''.format(us_lab - us_native,
49-
100 * (us_lab / us_native - 1)))
47+
print(
48+
"Overhead: {:.1f} us / {:.1f} %"
49+
"".format(us_lab - us_native, 100 * (us_lab / us_native - 1))
50+
)

lab/autograd/custom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from ..util import as_tuple
55

6-
__all__ = ['autograd_register']
6+
__all__ = ["autograd_register"]
77
_dispatch = Dispatcher()
88

99

lab/autograd/linear_algebra.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55

66
from . import dispatch, B, Numeric
77
from .custom import autograd_register
8-
from ..custom import (
9-
toeplitz_solve, s_toeplitz_solve,
10-
expm, s_expm,
11-
logm, s_logm
12-
)
8+
from ..custom import toeplitz_solve, s_toeplitz_solve, expm, s_expm, logm, s_logm
139
from ..linear_algebra import _default_perm
1410
from ..util import batch_computation
1511

@@ -41,8 +37,7 @@ def transpose(a, perm=None):
4137
@dispatch(Numeric)
4238
def trace(a, axis1=0, axis2=1):
4339
if axis1 == axis2:
44-
raise ValueError('Keyword arguments axis1 and axis2 cannot be the '
45-
'same.')
40+
raise ValueError("Keyword arguments axis1 and axis2 cannot be the same.")
4641

4742
# AutoGrad does not support the `axis1` and `axis2` arguments...
4843

@@ -52,8 +47,9 @@ def trace(a, axis1=0, axis2=1):
5247

5348
# Bring the trace axes forward.
5449
if (axis1, axis2) != (0, 1):
55-
perm = [axis1, axis2] + \
56-
[i for i in range(B.rank(a)) if i != axis1 and i != axis2]
50+
perm = [axis1, axis2] + [
51+
i for i in range(B.rank(a)) if i != axis1 and i != axis2
52+
]
5753
a = anp.transpose(a, axes=perm)
5854

5955
return anp.trace(a)
@@ -110,10 +106,9 @@ def cholesky_solve(a, b):
110106
@dispatch(Numeric, Numeric)
111107
def triangular_solve(a, b, lower_a=True):
112108
def _triangular_solve(a_, b_):
113-
return asla.solve_triangular(a_, b_,
114-
trans='N',
115-
lower=lower_a,
116-
check_finite=False)
109+
return asla.solve_triangular(
110+
a_, b_, trans="N", lower=lower_a, check_finite=False
111+
)
117112

118113
return batch_computation(_triangular_solve, (a, b), (2, 2))
119114

lab/autograd/shaping.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def diag(a):
3030
@dispatch(Numeric)
3131
def vec_to_tril(a, offset=0):
3232
if B.rank(a) != 1:
33-
raise ValueError('Input must be rank 1.')
33+
raise ValueError("Input must be rank 1.")
3434
side, upper, perm = _vec_to_tril_shape_upper_perm(a, offset=offset)
3535
a = anp.concatenate((a, anp.zeros(upper, dtype=a.dtype)))
3636
return anp.reshape(a[perm], (side, side))
@@ -39,16 +39,16 @@ def vec_to_tril(a, offset=0):
3939
@dispatch(Numeric)
4040
def tril_to_vec(a, offset=0):
4141
if B.rank(a) != 2:
42-
raise ValueError('Input must be rank 2.')
42+
raise ValueError("Input must be rank 2.")
4343
n, m = B.shape(a)
4444
if n != m:
45-
raise ValueError('Input must be square.')
45+
raise ValueError("Input must be square.")
4646
return a[anp.tril_indices(n, k=offset)]
4747

4848

4949
@dispatch([Numeric])
5050
def stack(*elements, **kw_args):
51-
return anp.stack(elements, axis=kw_args.get('axis', 0))
51+
return anp.stack(elements, axis=kw_args.get("axis", 0))
5252

5353

5454
@dispatch(Numeric)
@@ -64,7 +64,7 @@ def reshape(a, *shape):
6464

6565
@dispatch([Numeric])
6666
def concat(*elements, **kw_args):
67-
return anp.concatenate(elements, axis=kw_args.get('axis', 0))
67+
return anp.concatenate(elements, axis=kw_args.get("axis", 0))
6868

6969

7070
@dispatch(Numeric, [Int])
@@ -75,7 +75,7 @@ def tile(a, *repeats):
7575
@dispatch(Numeric, object)
7676
def take(a, indices_or_mask, axis=0):
7777
if B.rank(indices_or_mask) != 1:
78-
raise ValueError('Indices or mask must be rank 1.')
78+
raise ValueError("Indices or mask must be rank 1.")
7979

8080
# Put axis `axis` first.
8181
if axis > 0:

lab/bvn_cdf/bvn_cdf.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ import numpy as np
77
cimport cython
88
from cython.parallel import prange
99

10-
cdef extern from 'math.h' nogil:
10+
cdef extern from "math.h" nogil:
1111
double log(double x)
1212
double exp(double x)
1313
double sqrt(double x)
1414

15-
cdef extern from './tvpack.h' nogil:
15+
cdef extern from "./tvpack.h" nogil:
1616
double phid_(double* x)
1717
double bvnd_(double* x, double* y, double* rho)
1818

lab/custom.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66
# noinspection PyUnresolvedReferences
77
from .bvn_cdf import bvn_cdf as bvn_cdf_, s_bvn_cdf
88

9-
__all__ = ['toeplitz_solve', 's_toeplitz_solve',
10-
'bvn_cdf', 's_bvn_cdf',
11-
'expm', 's_expm',
12-
'logm', 's_logm']
9+
__all__ = [
10+
"toeplitz_solve",
11+
"s_toeplitz_solve",
12+
"bvn_cdf",
13+
"s_bvn_cdf",
14+
"expm",
15+
"s_expm",
16+
"logm",
17+
"s_logm",
18+
]
1319

1420
log = logging.getLogger(__name__)
1521

@@ -103,5 +109,6 @@ def logm(a):
103109

104110

105111
def s_logm(a): # pragma: no cover
106-
raise NotImplementedError('The derivative for the matrix logarithm is '
107-
'current not implemented.')
112+
raise NotImplementedError(
113+
"The derivative for the matrix logarithm is current not implemented."
114+
)

lab/generic.py

Lines changed: 63 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -11,62 +11,64 @@
1111
AGNumeric,
1212
TFNumeric,
1313
TorchNumeric,
14-
JaxNumeric
14+
JaxNumeric,
1515
)
1616
from .util import abstract
1717

18-
__all__ = ['nan',
19-
'pi',
20-
'log_2_pi',
21-
'isnan',
22-
'zeros',
23-
'ones',
24-
'zero',
25-
'one',
26-
'eye',
27-
'linspace',
28-
'range',
29-
'cast',
30-
'identity',
31-
'negative',
32-
'abs',
33-
'sign',
34-
'sqrt',
35-
'exp',
36-
'log',
37-
'sin',
38-
'cos',
39-
'tan',
40-
'tanh',
41-
'erf',
42-
'sigmoid',
43-
'softplus',
44-
'relu',
45-
'add',
46-
'subtract',
47-
'multiply',
48-
'divide',
49-
'power',
50-
'minimum',
51-
'maximum',
52-
'leaky_relu',
53-
'min',
54-
'max',
55-
'sum',
56-
'mean',
57-
'std',
58-
'logsumexp',
59-
'all',
60-
'any',
61-
'lt',
62-
'le',
63-
'gt',
64-
'ge',
65-
'bvn_cdf',
66-
'scan',
67-
'sort',
68-
'argsort',
69-
'to_numpy']
18+
__all__ = [
19+
"nan",
20+
"pi",
21+
"log_2_pi",
22+
"isnan",
23+
"zeros",
24+
"ones",
25+
"zero",
26+
"one",
27+
"eye",
28+
"linspace",
29+
"range",
30+
"cast",
31+
"identity",
32+
"negative",
33+
"abs",
34+
"sign",
35+
"sqrt",
36+
"exp",
37+
"log",
38+
"sin",
39+
"cos",
40+
"tan",
41+
"tanh",
42+
"erf",
43+
"sigmoid",
44+
"softplus",
45+
"relu",
46+
"add",
47+
"subtract",
48+
"multiply",
49+
"divide",
50+
"power",
51+
"minimum",
52+
"maximum",
53+
"leaky_relu",
54+
"min",
55+
"max",
56+
"sum",
57+
"mean",
58+
"std",
59+
"logsumexp",
60+
"all",
61+
"any",
62+
"lt",
63+
"le",
64+
"gt",
65+
"ge",
66+
"bvn_cdf",
67+
"scan",
68+
"sort",
69+
"argsort",
70+
"to_numpy",
71+
]
7072

7173
_dispatch = Dispatcher()
7274

@@ -105,8 +107,7 @@ def zeros(dtype, *shape): # pragma: no cover
105107
"""
106108

107109

108-
@dispatch.multi((Int,), # Single integer is not a reference.
109-
([Int],))
110+
@dispatch.multi((Int,), ([Int],)) # Single integer is not a reference.
110111
def zeros(*shape):
111112
return zeros(B.default_dtype, *shape)
112113

@@ -133,8 +134,7 @@ def ones(dtype, *shape): # pragma: no cover
133134
"""
134135

135136

136-
@dispatch.multi((Int,), # Single integer is not a reference.
137-
([Int],))
137+
@dispatch.multi((Int,), ([Int],)) # Single integer is not a reference.
138138
def ones(*shape):
139139
return ones(B.default_dtype, *shape)
140140

@@ -297,6 +297,7 @@ def cast(dtype, a): # pragma: no cover
297297

298298
# Unary functions:
299299

300+
300301
@dispatch(Numeric)
301302
@abstract()
302303
def identity(a): # pragma: no cover
@@ -812,7 +813,7 @@ def bvn_cdf(a, b, c):
812813

813814
@_dispatch(object)
814815
def _as_tuple(x):
815-
return x,
816+
return (x,)
816817

817818

818819
@_dispatch(tuple)
@@ -849,8 +850,10 @@ def scan(f, xs, *init_state):
849850

850851
# Check that the state shape remained constant.
851852
if new_state_shape != state_shape:
852-
raise RuntimeError('Shape of state changed from {} to {}.'
853-
''.format(state_shape, new_state_shape))
853+
raise RuntimeError(
854+
"Shape of state changed from {} to {}."
855+
"".format(state_shape, new_state_shape)
856+
)
854857

855858
# Record the state, stacked over the various elements.
856859
states.append(B.stack(*state, axis=0))

lab/jax/custom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from . import B
99
from ..util import as_tuple
1010

11-
__all__ = ['jax_register']
11+
__all__ = ["jax_register"]
1212
_dispatch = Dispatcher()
1313

1414

lab/jax/linear_algebra.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55

66
from . import dispatch, B, Numeric
77
from .custom import jax_register
8-
from ..custom import (
9-
toeplitz_solve, s_toeplitz_solve,
10-
expm, s_expm,
11-
logm, s_logm
12-
)
8+
from ..custom import toeplitz_solve, s_toeplitz_solve, expm, s_expm, logm, s_logm
139
from ..linear_algebra import _default_perm
1410
from ..util import batch_computation
1511

@@ -94,10 +90,9 @@ def cholesky_solve(a, b):
9490
@dispatch(Numeric, Numeric)
9591
def triangular_solve(a, b, lower_a=True):
9692
def _triangular_solve(a_, b_):
97-
return jsla.solve_triangular(a_, b_,
98-
trans='N',
99-
lower=lower_a,
100-
check_finite=False)
93+
return jsla.solve_triangular(
94+
a_, b_, trans="N", lower=lower_a, check_finite=False
95+
)
10196

10297
return batch_computation(_triangular_solve, (a, b), (2, 2))
10398

0 commit comments

Comments
 (0)