Skip to content

Commit e55bf99

Browse files
committed
More granularity in protecting tests requiring torch
1 parent 677cfe6 commit e55bf99

File tree

4 files changed

+42
-22
lines changed

4 files changed

+42
-22
lines changed

tests/ndarray/test_elementwise_funcs.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
import blosc2
88

9-
torch = pytest.importorskip("torch", reason="torch not available")
10-
119
warnings.simplefilter("always")
1210

1311
# Functions to test (add more as needed)
@@ -312,9 +310,10 @@ def test_unary_funcs(np_func, blosc_func, dtype, shape, chunkshape):
312310
@pytest.mark.parametrize(("np_func", "blosc_func"), UNARY_FUNC_PAIRS)
313311
@pytest.mark.parametrize("dtype", STR_DTYPES)
314312
@pytest.mark.parametrize("shape", [(10,), (20, 20)])
315-
@pytest.mark.parametrize("xp", [torch])
316-
def test_unfuncs_proxy(np_func, blosc_func, dtype, shape, xp):
317-
_test_unary_func_proxy(np_func, blosc_func, dtype, shape, xp)
313+
def test_unary_funcs_torch_proxy(np_func, blosc_func, dtype, shape):
314+
"""Test unary functions with torch tensors as input (via proxy)."""
315+
torch = pytest.importorskip("torch")
316+
_test_unary_func_proxy(np_func, blosc_func, dtype, shape, torch)
318317

319318

320319
@pytest.mark.heavy
@@ -335,9 +334,10 @@ def test_binary_funcs(np_func, blosc_func, dtype, shape, chunkshape):
335334
@pytest.mark.parametrize(("np_func", "blosc_func"), BINARY_FUNC_PAIRS)
336335
@pytest.mark.parametrize("dtype", STR_DTYPES)
337336
@pytest.mark.parametrize(("shape", "chunkshape"), SHAPES_CHUNKS)
338-
@pytest.mark.parametrize("xp", [torch])
339-
def test_binfuncs_proxy(np_func, blosc_func, dtype, shape, chunkshape, xp):
340-
_test_binary_func_proxy(np_func, blosc_func, dtype, shape, chunkshape, xp)
337+
def test_binary_funcs_torch_proxy(np_func, blosc_func, dtype, shape, chunkshape):
338+
"""Test binary functions with torch tensors as input (via proxy)."""
339+
torch = pytest.importorskip("torch")
340+
_test_binary_func_proxy(np_func, blosc_func, dtype, shape, chunkshape, torch)
341341

342342

343343
@pytest.mark.heavy

tests/ndarray/test_lazyexpr.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@
1515
from blosc2.lazyexpr import ne_evaluate
1616
from blosc2.utils import get_chunks_idx, npvecdot
1717

18-
torch = pytest.importorskip("torch", reason="torch not available")
18+
# Conditionally import torch for proxy tests
19+
try:
20+
import torch
21+
22+
PROXY_TEST_XP = [torch, np]
23+
except ImportError:
24+
torch = None
25+
PROXY_TEST_XP = [np]
1926

2027
NITEMS_SMALL = 100
2128
NITEMS = 1000
@@ -1848,7 +1855,7 @@ def test_lazyexpr_2args():
18481855

18491856
@pytest.mark.parametrize(
18501857
"xp",
1851-
[torch, np],
1858+
PROXY_TEST_XP,
18521859
)
18531860
@pytest.mark.parametrize(
18541861
"dtype",

tests/ndarray/test_linalg.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88
from blosc2.lazyexpr import linalg_funcs
99
from blosc2.utils import npvecdot
1010

11-
torch = pytest.importorskip("torch", reason="torch not available")
11+
# Conditionally import torch for proxy tests
12+
try:
13+
import torch
14+
15+
PROXY_TEST_XP = [torch, np]
16+
except ImportError:
17+
torch = None
18+
PROXY_TEST_XP = [np]
1219

1320

1421
@pytest.mark.parametrize(
@@ -827,7 +834,7 @@ def test_diagonal(shape, chunkshape, offset):
827834

828835
@pytest.mark.parametrize(
829836
"xp",
830-
[torch, np],
837+
PROXY_TEST_XP,
831838
)
832839
@pytest.mark.parametrize(
833840
"dtype",

tests/ndarray/test_setitem.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111

1212
import blosc2
1313

14-
torch = pytest.importorskip("torch", reason="torch not available")
15-
1614
argnames = "shape, chunks, blocks, slices, dtype"
1715
argvalues = [
1816
([456], [258], [73], slice(0, 1), np.int32),
@@ -46,14 +44,6 @@ def test_setitem(shape, chunks, blocks, slices, dtype):
4644
nparray[slices] = val
4745
np.testing.assert_almost_equal(a[...], nparray)
4846

49-
# Object called via SimpleProxy
50-
slice_shape = a[slices].shape
51-
dtype_ = {np.float32: torch.float32, np.int32: torch.int32, np.float64: torch.float64}[dtype]
52-
val = torch.ones(slice_shape, dtype=dtype_)
53-
a[slices] = val
54-
nparray[slices] = val
55-
np.testing.assert_almost_equal(a[...], nparray)
56-
5747
# blosc2.NDArray
5848
if np.prod(slice_shape) == 1 or len(slice_shape) != len(blocks):
5949
chunks = None
@@ -65,6 +55,22 @@ def test_setitem(shape, chunks, blocks, slices, dtype):
6555
np.testing.assert_almost_equal(a[...], nparray)
6656

6757

58+
@pytest.mark.parametrize(argnames, argvalues)
59+
def test_setitem_torch_proxy(shape, chunks, blocks, slices, dtype):
60+
torch = pytest.importorskip("torch")
61+
size = int(np.prod(shape))
62+
nparray = np.arange(size, dtype=dtype).reshape(shape)
63+
a = blosc2.frombuffer(bytes(nparray), nparray.shape, dtype=dtype, chunks=chunks, blocks=blocks)
64+
65+
# Object called via SimpleProxy (torch tensor)
66+
slice_shape = a[slices].shape
67+
dtype_ = {np.float32: torch.float32, np.int32: torch.int32, np.float64: torch.float64}[dtype]
68+
val = torch.ones(slice_shape, dtype=dtype_)
69+
a[slices] = val
70+
nparray[slices] = val
71+
np.testing.assert_almost_equal(a[...], nparray)
72+
73+
6874
@pytest.mark.parametrize(
6975
("shape", "slices"),
7076
[

0 commit comments

Comments
 (0)