Skip to content

Commit

Permalink
test fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Jan 17, 2025
1 parent 7c153ad commit 45e660d
Showing 1 changed file with 23 additions and 16 deletions.
39 changes: 23 additions & 16 deletions tests/pyop2/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,13 @@ def cache(self):
int_comm.Set_attr(comm_cache_keyval, _cache_collection)
return _cache_collection[default_cache_name]

def code_cache_len_equals(self, expected):
# We need to do this check because different things also get
# put into self.cache
return sum(
1 for key in self.cache if key[1] == "compile_global_kernel"
) == expected

@pytest.fixture
def a(cls, diterset):
return op2.Dat(diterset, list(range(nelems)), numpy.uint32, "a")
Expand All @@ -328,14 +335,14 @@ def test_same_args(self, iterset, iter2ind1, x, a):
a(op2.WRITE),
x(op2.READ, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(op2.Kernel(kernel_cpy, "cpy"),
iterset,
a(op2.WRITE),
x(op2.READ, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_diff_kernel(self, iterset, iter2ind1, x, a):
self.cache.clear()
Expand All @@ -348,7 +355,7 @@ def test_diff_kernel(self, iterset, iter2ind1, x, a):
a(op2.WRITE),
x(op2.READ, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

kernel_cpy = "static void cpy(unsigned int* DST, unsigned int* SRC) { *DST = *SRC; }"

Expand All @@ -357,7 +364,7 @@ def test_diff_kernel(self, iterset, iter2ind1, x, a):
a(op2.WRITE),
x(op2.READ, iter2ind1))

assert len(self.cache) == 2
assert self.code_cache_len_equals(2)

def test_invert_arg_similar_shape(self, iterset, iter2ind1, x, y):
self.cache.clear()
Expand All @@ -377,14 +384,14 @@ def test_invert_arg_similar_shape(self, iterset, iter2ind1, x, y):
x(op2.RW, iter2ind1),
y(op2.RW, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(op2.Kernel(kernel_swap, "swap"),
iterset,
y(op2.RW, iter2ind1),
x(op2.RW, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_dloop_ignore_scalar(self, iterset, a, b):
self.cache.clear()
Expand All @@ -404,14 +411,14 @@ def test_dloop_ignore_scalar(self, iterset, a, b):
a(op2.RW),
b(op2.RW))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(op2.Kernel(kernel_swap, "swap"),
iterset,
b(op2.RW),
a(op2.RW))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_vector_map(self, iterset, x2, iter2ind2):
self.cache.clear()
Expand All @@ -431,13 +438,13 @@ def test_vector_map(self, iterset, x2, iter2ind2):
iterset,
x2(op2.RW, iter2ind2))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(op2.Kernel(kernel_swap, "swap"),
iterset,
x2(op2.RW, iter2ind2))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_same_iteration_space_works(self, iterset, x2, iter2ind2):
self.cache.clear()
Expand All @@ -447,12 +454,12 @@ def test_same_iteration_space_works(self, iterset, x2, iter2ind2):
op2.par_loop(k, iterset,
x2(op2.INC, iter2ind2))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(k, iterset,
x2(op2.INC, iter2ind2))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_change_dat_dtype_matters(self, iterset, diterset):
d = op2.Dat(diterset, list(range(nelems)), numpy.uint32)
Expand All @@ -463,12 +470,12 @@ def test_change_dat_dtype_matters(self, iterset, diterset):

op2.par_loop(k, iterset, d(op2.WRITE))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

d = op2.Dat(diterset, list(range(nelems)), numpy.int32)
op2.par_loop(k, iterset, d(op2.WRITE))

assert len(self.cache) == 2
assert self.code_cache_len_equals(2)

def test_change_global_dtype_matters(self, iterset, diterset):
g = op2.Global(1, 0, dtype=numpy.uint32, comm=COMM_WORLD)
Expand All @@ -479,12 +486,12 @@ def test_change_global_dtype_matters(self, iterset, diterset):

op2.par_loop(k, iterset, g(op2.INC))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

g = op2.Global(1, 0, dtype=numpy.float64, comm=COMM_WORLD)
op2.par_loop(k, iterset, g(op2.INC))

assert len(self.cache) == 2
assert self.code_cache_len_equals(2)


class TestSparsityCache:
Expand Down

0 comments on commit 45e660d

Please sign in to comment.