Skip to content

Commit fe845c3

Browse files
committed
Add missing test
1 parent d157759 commit fe845c3

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

tests/test_generic.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import lab as B
1111

1212
# noinspection PyUnresolvedReferences
13+
from .util import check_lazy_shapes # noqa
1314
from .util import (
1415
Bool,
1516
BoolTensor,
@@ -22,7 +23,6 @@
2223
approx,
2324
autograd_box,
2425
check_function,
25-
check_lazy_shapes,
2626
)
2727

2828

@@ -92,6 +92,16 @@ def test_device_and_to_active_device(check_lazy_shapes):
9292
assert B.to_active_device(a) is a
9393

9494

95+
def test_device_jax_exception(check_lazy_shapes):
96+
a = Tensor(2, 2).jax()
97+
a.devices = lambda: set()
98+
with pytest.raises(
99+
RuntimeError,
100+
match="(?i)could not determine device of JAX array",
101+
):
102+
B.device(a)
103+
104+
95105
@pytest.mark.parametrize("t", [tf.float32, torch.float32, jnp.float32])
96106
@pytest.mark.parametrize(
97107
"f",

0 commit comments

Comments
 (0)