1+ import torch
2+ import pytest
3+
14from toma import explicit
25from toma .batchsize_cache import NoBatchsizeCache , GlobalBatchsizeCache , StacktraceMemoryBatchsizeCache
36
@@ -6,7 +9,23 @@ def raise_fake_oom():
69 raise RuntimeError ("CUDA out of memory." )
710
811
9- def test_fake_explicit_toma_none ():
12+ def test_fake_explicit_batch_raise ():
13+ def f (batchsize ):
14+ raise_fake_oom ()
15+
16+ with pytest .raises (RuntimeError ):
17+ explicit .batch (f , 64 )
18+
19+
20+ def test_fake_explicit_range_raise ():
21+ def f (start , end ):
22+ raise_fake_oom ()
23+
24+ with pytest .raises (RuntimeError ):
25+ explicit .range (f , 0 , 64 , 64 )
26+
27+
28+ def test_fake_explicit_batch_none ():
1029 batchsizes = []
1130
1231 def f (batchsize ):
@@ -22,7 +41,7 @@ def f(batchsize):
2241 assert batchsizes == [64 , 32 , 16 , 64 , 32 , 16 ]
2342
2443
25- def test_fake_explicit_toma_global ():
44+ def test_fake_explicit_batch_global ():
2645 batchsizes = []
2746
2847 def f (batchsize ):
@@ -40,7 +59,7 @@ def f(batchsize):
4059 assert batchsizes == [64 , 32 , 16 , 16 , 16 , 16 ]
4160
4261
43- def test_fake_explicit_toma_sm ():
62+ def test_fake_explicit_batch_sm ():
4463 batchsizes = []
4564
4665 def f (batchsize ):
@@ -58,7 +77,7 @@ def f(batchsize):
5877 assert batchsizes == [64 , 32 , 16 , 16 , 16 , 64 , 32 , 16 ]
5978
6079
61- def test_fake_explicit_toma_mix ():
80+ def test_fake_explicit_batch_mix ():
6281 batchsizes = []
6382
6483 def f (batchsize ):
@@ -79,7 +98,7 @@ def f(batchsize):
7998 assert batchsizes == [64 , 32 , 16 , 16 , 16 , 16 , 64 , 32 , 16 , 16 ]
8099
81100
82- def test_fake_explicit_toma_range_none ():
101+ def test_fake_explicit_range_none ():
83102 batchsizes = []
84103
85104 def f (start , end ):
@@ -99,7 +118,7 @@ def f(start, end):
99118 assert batchsizes == [64 , 64 , 32 , 32 , 16 , 16 ] * 2
100119
101120
102- def test_fake_explicit_toma_range_global ():
121+ def test_fake_explicit_range_global ():
103122 batchsizes = []
104123
105124 def f (start , end ):
@@ -121,7 +140,7 @@ def f(start, end):
121140 assert batchsizes == [64 , 64 , 32 , 32 , 16 , 16 ] + [16 ] * 8 * 2
122141
123142
124- def test_fake_explicit_toma_range_sm ():
143+ def test_fake_explicit_range_sm ():
125144 batchsizes = []
126145
127146 def f (start , end ):
@@ -143,7 +162,7 @@ def f(start, end):
143162 assert batchsizes == [64 , 64 , 32 , 32 , 16 , 16 ] + [16 ] * 8 + [64 , 64 , 32 , 32 , 16 , 16 ]
144163
145164
146- def test_fake_explicit_toma_range_sm ():
165+ def test_fake_explicit_range_sm ():
147166 batchsizes = []
148167
149168 def f (start , end ):
@@ -166,3 +185,12 @@ def f(start, end):
166185 explicit .range (f , 0 , 128 , 64 , toma_cache_type = StacktraceMemoryBatchsizeCache )
167186
168187 assert batchsizes == ([64 , 64 , 32 , 32 , 16 , 16 ] + [16 ] * 8 * 2 + [64 , 64 , 32 , 32 , 16 , 16 ] + [16 ] * 8 )
188+
189+
190+ def test_explicit_chunked ():
191+ def func (tensor , start , end ):
192+ tensor [:] = 1.
193+
194+ tensor = torch .zeros ((128 , 4 , 4 ))
195+ explicit .chunked (func , tensor , 32 )
196+ assert torch .allclose (tensor , torch .tensor (1. ))
0 commit comments