Skip to content

Commit

Permalink
use read/write_events everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Oct 3, 2022
1 parent 6c99917 commit 957efa6
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 49 deletions.
14 changes: 7 additions & 7 deletions pyopencl/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _extract_extra_args_types_values(extra_args):
if isinstance(val, cl.array.Array):
extra_args_types.append(VectorArg(val.dtype, name, with_offset=False))
extra_args_values.append(val)
extra_wait_for.extend(val.events)
extra_wait_for.extend(val.write_events)
elif isinstance(val, np.generic):
extra_args_types.append(ScalarArg(val.dtype, name))
extra_args_values.append(val)
Expand Down Expand Up @@ -1163,7 +1163,7 @@ def __call__(self, queue, n_objects, *args, **kwargs):
data_args.append(arg_val.base_data)
if arg_descr.with_offset:
data_args.append(arg_val.offset)
wait_for.extend(arg_val.events)
wait_for.extend(arg_val.write_events)
else:
data_args.append(arg_val)

Expand All @@ -1182,7 +1182,7 @@ def __call__(self, queue, n_objects, *args, **kwargs):
counts = cl.array.empty(queue,
(n_objects + 1), index_dtype, allocator=allocator)
counts[-1] = 0
wait_for = wait_for + counts.events
wait_for = wait_for + counts.write_events

# The scan will turn the "counts" array into the "starts" array
# in-place.
Expand Down Expand Up @@ -1235,7 +1235,7 @@ def __call__(self, queue, n_objects, *args, **kwargs):
info_record.nonempty_indices,
info_record.compressed_indices,
info_record.num_nonempty_lists,
wait_for=[count_event] + info_record.compressed_indices.events)
wait_for=[count_event] + info_record.compressed_indices.write_events)

info_record.starts = compressed_counts

Expand Down Expand Up @@ -1264,13 +1264,13 @@ def __call__(self, queue, n_objects, *args, **kwargs):
evt = scan_kernel(
starts_ary,
size=info_record.num_nonempty_lists,
wait_for=starts_ary.events)
wait_for=starts_ary.write_events)
else:
evt = scan_kernel(starts_ary, wait_for=[count_event],
size=n_objects)

starts_ary.setitem(0, 0, queue=queue, wait_for=[evt])
scan_events.extend(starts_ary.events)
scan_events.extend(starts_ary.write_events)

# retrieve count
info_record.count = int(starts_ary[-1].get())
Expand Down Expand Up @@ -1432,7 +1432,7 @@ def __call__(self, queue, keys, values, nkeys,

starts = (cl.array.empty(queue, (nkeys+1), starts_dtype, allocator=allocator)
.fill(len(values_sorted_by_key), wait_for=[evt]))
evt, = starts.events
evt, = starts.write_events

evt = knl_info.start_finder(starts, keys_sorted_by_key,
range=slice(len(keys_sorted_by_key)),
Expand Down
2 changes: 1 addition & 1 deletion pyopencl/bitonic_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __call__(self, arr, idx=None, queue=None, wait_for=None, axis=0):

if wait_for is None:
wait_for = []
wait_for = wait_for + arr.events
wait_for = wait_for + arr.write_events

last_evt = cl.enqueue_marker(queue, wait_for=wait_for)

Expand Down
101 changes: 72 additions & 29 deletions pyopencl/clmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@
THE SOFTWARE.
"""

import pyopencl.array as cl_array
import pyopencl.elementwise as elementwise
from pyopencl.array import _get_common_dtype
from pyopencl.array import elwise_kernel_runner, _get_common_dtype
import numpy as np


def _make_unary_array_func(name):
@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def knl_runner(result, arg):
if arg.dtype.kind == "c":
from pyopencl.elementwise import complex_dtype_to_name
Expand All @@ -42,8 +41,11 @@ def knl_runner(result, arg):

def f(array, queue=None):
result = array._new_like_me(queue=queue)
event1 = knl_runner(result, array, queue=queue)
result.add_event(event1)
evt = knl_runner(result, array, queue=queue)

array.add_read_event(evt)
result.add_write_event(evt)

return result

return f
Expand All @@ -59,13 +61,13 @@ def f(array, queue=None):
asinpi = _make_unary_array_func("asinpi")


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _atan2(result, arg1, arg2):
return elementwise.get_float_binary_func_kernel(
result.context, "atan2", arg1.dtype, arg2.dtype, result.dtype)


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _atan2pi(result, arg1, arg2):
return elementwise.get_float_binary_func_kernel(
result.context, "atan2pi", arg1.dtype, arg2.dtype, result.dtype)
Expand All @@ -80,7 +82,13 @@ def atan2(y, x, queue=None):
"""
queue = queue or y.queue
result = y._new_like_me(_get_common_dtype(y, x, queue))
result.add_event(_atan2(result, y, x, queue=queue))
evt = _atan2(result, y, x, queue=queue)

x.add_read_event(evt)
if y is not x:
y.add_read_event(evt)
result.add_write_event(evt)

return result


Expand All @@ -94,7 +102,13 @@ def atan2pi(y, x, queue=None):
"""
queue = queue or y.queue
result = y._new_like_me(_get_common_dtype(y, x, queue))
result.add_event(_atan2pi(result, y, x, queue=queue))
evt = _atan2pi(result, y, x, queue=queue)

x.add_read_event(evt)
if y is not x:
y.add_read_event(evt)
result.add_write_event(evt)

return result


Expand All @@ -121,7 +135,7 @@ def atan2pi(y, x, queue=None):
# TODO: fmin


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _fmod(result, arg, mod):
return elementwise.get_fmod_kernel(result.context, result.dtype,
arg.dtype, mod.dtype)
Expand All @@ -132,13 +146,19 @@ def fmod(arg, mod, queue=None):
for each element in `arg` and `mod`."""
queue = (queue or arg.queue) or mod.queue
result = arg._new_like_me(_get_common_dtype(arg, mod, queue))
result.add_event(_fmod(result, arg, mod, queue=queue))
evt = _fmod(result, arg, mod, queue=queue)

arg.add_read_event(evt)
if mod is not arg:
mod.add_read_event(evt)
result.add_write_event(evt)

return result

# TODO: fract


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _frexp(sig, expt, arg):
return elementwise.get_frexp_kernel(sig.context, sig.dtype,
expt.dtype, arg.dtype)
Expand All @@ -150,9 +170,12 @@ def frexp(arg, queue=None):
"""
sig = arg._new_like_me(queue=queue)
expt = arg._new_like_me(queue=queue, dtype=np.int32)
event1 = _frexp(sig, expt, arg, queue=queue)
sig.add_event(event1)
expt.add_event(event1)
evt = _frexp(sig, expt, arg, queue=queue)

arg.add_read_event(evt)
sig.add_write_event(evt)
expt.add_write_event(evt)

return sig, expt

# TODO: hypot
Expand All @@ -161,7 +184,7 @@ def frexp(arg, queue=None):
ilogb = _make_unary_array_func("ilogb")


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _ldexp(result, sig, exp):
return elementwise.get_ldexp_kernel(result.context, result.dtype,
sig.dtype, exp.dtype)
Expand All @@ -173,7 +196,13 @@ def ldexp(significand, exponent, queue=None):
`result = significand * 2**exponent`.
"""
result = significand._new_like_me(queue=queue)
result.add_event(_ldexp(result, significand, exponent))
evt = _ldexp(result, significand, exponent)

significand.add_read_event(evt)
if exponent is not significand:
exponent.add_read_event(evt)
result.add_write_event(evt)

return result


Expand All @@ -191,7 +220,7 @@ def ldexp(significand, exponent, queue=None):
# TODO: minmag


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _modf(intpart, fracpart, arg):
return elementwise.get_modf_kernel(intpart.context, intpart.dtype,
fracpart.dtype, arg.dtype)
Expand All @@ -203,9 +232,12 @@ def modf(arg, queue=None):
"""
intpart = arg._new_like_me(queue=queue)
fracpart = arg._new_like_me(queue=queue)
event1 = _modf(intpart, fracpart, arg, queue=queue)
fracpart.add_event(event1)
intpart.add_event(event1)
evt = _modf(intpart, fracpart, arg, queue=queue)

arg.add_read_event(evt)
fracpart.add_write_event(evt)
intpart.add_write_event(evt)

return fracpart, intpart


Expand Down Expand Up @@ -238,19 +270,19 @@ def modf(arg, queue=None):
# TODO: table 6.10, integer functions
# TODO: table 6.12, clamp et al

@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _bessel_jn(result, n, x):
return elementwise.get_bessel_kernel(result.context, "j", result.dtype,
np.dtype(type(n)), x.dtype)


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _bessel_yn(result, n, x):
return elementwise.get_bessel_kernel(result.context, "y", result.dtype,
np.dtype(type(n)), x.dtype)


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _hankel_01(h0, h1, x):
if h0.dtype != h1.dtype:
raise TypeError("types of h0 and h1 must match")
Expand All @@ -260,20 +292,31 @@ def _hankel_01(h0, h1, x):

def bessel_jn(n, x, queue=None):
result = x._new_like_me(queue=queue)
result.add_event(_bessel_jn(result, n, x, queue=queue))
evt = _bessel_jn(result, n, x, queue=queue)

x.add_read_event(evt)
result.add_write_event(evt)

return result


def bessel_yn(n, x, queue=None):
result = x._new_like_me(queue=queue)
result.add_event(_bessel_yn(result, n, x, queue=queue))
evt = _bessel_yn(result, n, x, queue=queue)

x.add_read_event(evt)
result.add_write_event(evt)

return result


def hankel_01(x, queue=None):
h0 = x._new_like_me(queue=queue)
h1 = x._new_like_me(queue=queue)
event1 = _hankel_01(h0, h1, x, queue=queue)
h0.add_event(event1)
h1.add_event(event1)
evt = _hankel_01(h0, h1, x, queue=queue)

x.add_read_event(evt)
h0.add_write_event(evt)
h1.add_write_event(evt)

return h0, h1
18 changes: 11 additions & 7 deletions pyopencl/clrandom.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,11 @@ def fill_uniform(self, ary, a=0, b=1, queue=None):
evt = knl(queue,
(self.num_work_items,), None,
self.state.data, ary.data, ary.size*size_multiplier,
b-a, a, wait_for=ary.events)
ary.add_event(evt)
self.state.add_event(evt)
b-a, a, wait_for=ary.write_events + ary.read_events)

ary.add_write_event(evt)
self.state.add_write_event(evt)

return ary

def uniform(self, *args, **kwargs):
Expand Down Expand Up @@ -368,9 +370,11 @@ def fill_normal(self, ary, mu=0, sigma=1, queue=None):
evt = knl(queue,
(self.num_work_items,), self.wg_size,
self.state.data, ary.data, ary.size*size_multiplier, sigma, mu,
wait_for=ary.events)
ary.add_event(evt)
self.state.add_event(evt)
wait_for=ary.write_events)

ary.add_write_event(evt)
self.state.add_write_event(evt)

return evt

def normal(self, *args, **kwargs):
Expand Down Expand Up @@ -666,7 +670,7 @@ def _fill(self, distribution, ary, scale, shift, queue=None):
gsize, lsize = _splay(queue.device, ary.size)

evt = knl(queue, gsize, lsize, *args)
ary.add_event(evt)
ary.add_write_event(evt)

self.counter[0] += n * counter_multiplier
c1_incr, self.counter[0] = divmod(self.counter[0], self.counter_max)
Expand Down
2 changes: 1 addition & 1 deletion pyopencl/invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def add_buf_arg(arg_idx, typechar, expr_str):
cl_arg_idx += 1

if in_enqueue:
wait_for_parts .append(f"{arg_var}.events")
wait_for_parts.append(f"{arg_var}.write_events")

continue

Expand Down
4 changes: 2 additions & 2 deletions pyopencl/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def __call__(self, *args, **kwargs):
invocation_args.append(arg.base_data)
if arg_tp.with_offset:
invocation_args.append(arg.offset)
wait_for.extend(arg.events)
wait_for.extend(arg.write_events)
else:
invocation_args.append(arg)

Expand Down Expand Up @@ -440,7 +440,7 @@ def __call__(self, *args, **kwargs):
wait_for=wait_for)
wait_for = [last_evt]

result.add_event(last_evt)
result.add_write_event(last_evt)

if group_count == 1:
if return_event:
Expand Down
4 changes: 2 additions & 2 deletions pyopencl/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,7 +1474,7 @@ def __call__(self, *args, **kwargs):
data_args.append(arg_val.base_data)
if arg_descr.with_offset:
data_args.append(arg_val.offset)
wait_for.extend(arg_val.events)
wait_for.extend(arg_val.write_events)
else:
data_args.append(arg_val)

Expand Down Expand Up @@ -1691,7 +1691,7 @@ def __call__(self, *args, **kwargs):
data_args.append(arg_val.base_data)
if arg_descr.with_offset:
data_args.append(arg_val.offset)
wait_for.extend(arg_val.events)
wait_for.extend(arg_val.write_events)
else:
data_args.append(arg_val)

Expand Down

0 comments on commit 957efa6

Please sign in to comment.