Skip to content

Commit

Permalink
use read/write_events everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Dec 28, 2022
1 parent 6ac7b5e commit ad575c8
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 58 deletions.
14 changes: 7 additions & 7 deletions pyopencl/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _extract_extra_args_types_values(extra_args):
if isinstance(val, cl.array.Array):
extra_args_types.append(VectorArg(val.dtype, name, with_offset=False))
extra_args_values.append(val)
extra_wait_for.extend(val.events)
extra_wait_for.extend(val.write_events)
elif isinstance(val, np.generic):
extra_args_types.append(ScalarArg(val.dtype, name))
extra_args_values.append(val)
Expand Down Expand Up @@ -1161,7 +1161,7 @@ def __call__(self, queue, n_objects, *args, **kwargs):
data_args.append(arg_val.base_data)
if arg_descr.with_offset:
data_args.append(arg_val.offset)
wait_for.extend(arg_val.events)
wait_for.extend(arg_val.write_events)
else:
data_args.append(arg_val)

Expand All @@ -1180,7 +1180,7 @@ def __call__(self, queue, n_objects, *args, **kwargs):
counts = cl.array.empty(queue,
(n_objects + 1), index_dtype, allocator=allocator)
counts[-1] = 0
wait_for = wait_for + counts.events
wait_for = wait_for + counts.write_events

# The scan will turn the "counts" array into the "starts" array
# in-place.
Expand Down Expand Up @@ -1233,7 +1233,7 @@ def __call__(self, queue, n_objects, *args, **kwargs):
info_record.nonempty_indices,
info_record.compressed_indices,
info_record.num_nonempty_lists,
wait_for=[count_event] + info_record.compressed_indices.events)
wait_for=[count_event] + info_record.compressed_indices.write_events)

info_record.starts = compressed_counts

Expand Down Expand Up @@ -1262,13 +1262,13 @@ def __call__(self, queue, n_objects, *args, **kwargs):
evt = scan_kernel(
starts_ary,
size=info_record.num_nonempty_lists,
wait_for=starts_ary.events)
wait_for=starts_ary.write_events)
else:
evt = scan_kernel(starts_ary, wait_for=[count_event],
size=n_objects)

starts_ary.setitem(0, 0, queue=queue, wait_for=[evt])
scan_events.extend(starts_ary.events)
scan_events.extend(starts_ary.write_events)

# retrieve count
info_record.count = int(starts_ary[-1].get())
Expand Down Expand Up @@ -1430,7 +1430,7 @@ def __call__(self, queue, keys, values, nkeys,

starts = (cl.array.empty(queue, (nkeys+1), starts_dtype, allocator=allocator)
.fill(len(values_sorted_by_key), wait_for=[evt]))
evt, = starts.events
evt, = starts.write_events

evt = knl_info.start_finder(starts, keys_sorted_by_key,
range=slice(len(keys_sorted_by_key)),
Expand Down
2 changes: 1 addition & 1 deletion pyopencl/bitonic_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __call__(self, arr, idx=None, queue=None, wait_for=None, axis=0):

if wait_for is None:
wait_for = []
wait_for = wait_for + arr.events
wait_for = wait_for + arr.write_events

last_evt = cl.enqueue_marker(queue, wait_for=wait_for)

Expand Down
50 changes: 21 additions & 29 deletions pyopencl/clmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@
THE SOFTWARE.
"""

import pyopencl.array as cl_array
import pyopencl.elementwise as elementwise
from pyopencl.array import _get_common_dtype
from pyopencl.array import elwise_kernel_runner, _get_common_dtype
import numpy as np


def _make_unary_array_func(name):
@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def knl_runner(result, arg):
if arg.dtype.kind == "c":
from pyopencl.elementwise import complex_dtype_to_name
Expand All @@ -42,8 +41,7 @@ def knl_runner(result, arg):

def f(array, queue=None):
result = array._new_like_me(queue=queue)
event1 = knl_runner(result, array, queue=queue)
result.add_event(event1)
knl_runner(result, array, queue=queue)
return result

return f
Expand All @@ -59,13 +57,13 @@ def f(array, queue=None):
asinpi = _make_unary_array_func("asinpi")


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _atan2(result, arg1, arg2):
return elementwise.get_float_binary_func_kernel(
result.context, "atan2", arg1.dtype, arg2.dtype, result.dtype)


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _atan2pi(result, arg1, arg2):
return elementwise.get_float_binary_func_kernel(
result.context, "atan2pi", arg1.dtype, arg2.dtype, result.dtype)
Expand All @@ -80,7 +78,7 @@ def atan2(y, x, queue=None):
"""
queue = queue or y.queue
result = y._new_like_me(_get_common_dtype(y, x, queue))
result.add_event(_atan2(result, y, x, queue=queue))
_atan2(result, y, x, queue=queue)
return result


Expand All @@ -94,7 +92,7 @@ def atan2pi(y, x, queue=None):
"""
queue = queue or y.queue
result = y._new_like_me(_get_common_dtype(y, x, queue))
result.add_event(_atan2pi(result, y, x, queue=queue))
_atan2pi(result, y, x, queue=queue)
return result


Expand All @@ -121,7 +119,7 @@ def atan2pi(y, x, queue=None):
# TODO: fmin


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _fmod(result, arg, mod):
return elementwise.get_fmod_kernel(result.context, result.dtype,
arg.dtype, mod.dtype)
Expand All @@ -132,13 +130,13 @@ def fmod(arg, mod, queue=None):
for each element in ``arg`` and ``mod``."""
queue = (queue or arg.queue) or mod.queue
result = arg._new_like_me(_get_common_dtype(arg, mod, queue))
result.add_event(_fmod(result, arg, mod, queue=queue))
_fmod(result, arg, mod, queue=queue)
return result

# TODO: fract


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _frexp(sig, expt, arg):
return elementwise.get_frexp_kernel(sig.context, sig.dtype,
expt.dtype, arg.dtype)
Expand All @@ -150,9 +148,7 @@ def frexp(arg, queue=None):
"""
sig = arg._new_like_me(queue=queue)
expt = arg._new_like_me(queue=queue, dtype=np.int32)
event1 = _frexp(sig, expt, arg, queue=queue)
sig.add_event(event1)
expt.add_event(event1)
_frexp(sig, expt, arg, queue=queue, noutputs=2)
return sig, expt

# TODO: hypot
Expand All @@ -161,7 +157,7 @@ def frexp(arg, queue=None):
ilogb = _make_unary_array_func("ilogb")


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _ldexp(result, sig, exp):
return elementwise.get_ldexp_kernel(result.context, result.dtype,
sig.dtype, exp.dtype)
Expand All @@ -173,7 +169,7 @@ def ldexp(significand, exponent, queue=None):
``result = significand * 2**exponent``.
"""
result = significand._new_like_me(queue=queue)
result.add_event(_ldexp(result, significand, exponent))
_ldexp(result, significand, exponent)
return result


Expand All @@ -191,7 +187,7 @@ def ldexp(significand, exponent, queue=None):
# TODO: minmag


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _modf(intpart, fracpart, arg):
return elementwise.get_modf_kernel(intpart.context, intpart.dtype,
fracpart.dtype, arg.dtype)
Expand All @@ -203,9 +199,7 @@ def modf(arg, queue=None):
"""
intpart = arg._new_like_me(queue=queue)
fracpart = arg._new_like_me(queue=queue)
event1 = _modf(intpart, fracpart, arg, queue=queue)
fracpart.add_event(event1)
intpart.add_event(event1)
_modf(intpart, fracpart, arg, queue=queue, noutputs=2)
return fracpart, intpart


Expand Down Expand Up @@ -238,19 +232,19 @@ def modf(arg, queue=None):
# TODO: table 6.10, integer functions
# TODO: table 6.12, clamp et al

@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _bessel_jn(result, n, x):
return elementwise.get_bessel_kernel(result.context, "j", result.dtype,
np.dtype(type(n)), x.dtype)


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _bessel_yn(result, n, x):
return elementwise.get_bessel_kernel(result.context, "y", result.dtype,
np.dtype(type(n)), x.dtype)


@cl_array.elwise_kernel_runner
@elwise_kernel_runner
def _hankel_01(h0, h1, x):
if h0.dtype != h1.dtype:
raise TypeError("types of h0 and h1 must match")
Expand All @@ -260,20 +254,18 @@ def _hankel_01(h0, h1, x):

def bessel_jn(n, x, queue=None):
result = x._new_like_me(queue=queue)
result.add_event(_bessel_jn(result, n, x, queue=queue))
_bessel_jn(result, n, x, queue=queue)
return result


def bessel_yn(n, x, queue=None):
result = x._new_like_me(queue=queue)
result.add_event(_bessel_yn(result, n, x, queue=queue))
_bessel_yn(result, n, x, queue=queue)
return result


def hankel_01(x, queue=None):
h0 = x._new_like_me(queue=queue)
h1 = x._new_like_me(queue=queue)
event1 = _hankel_01(h0, h1, x, queue=queue)
h0.add_event(event1)
h1.add_event(event1)
_hankel_01(h0, h1, x, queue=queue, noutputs=2)
return h0, h1
22 changes: 6 additions & 16 deletions pyopencl/clrandom.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,11 @@ def fill_uniform(self, ary, a=0, b=1, queue=None):
evt = knl(queue,
(self.num_work_items,), None,
self.state.data, ary.data, ary.size*size_multiplier,
b-a, a, wait_for=ary.events)
ary.add_event(evt)
self.state.add_event(evt)
b-a, a, wait_for=ary.write_events + ary.read_events)

ary.add_write_event(evt)
self.state.add_write_event(evt)

return ary

def uniform(self, *args, **kwargs):
Expand Down Expand Up @@ -369,23 +371,11 @@ def fill_normal(self, ary, mu=0, sigma=1, queue=None):
evt = knl(queue,
(self.num_work_items,), self.wg_size,
self.state.data, ary.data, ary.size*size_multiplier, sigma, mu,
<<<<<<< Updated upstream
wait_for=ary.events)
ary.add_event(evt)
self.state.add_event(evt)
||||||| Stash base
wait_for=ary.write_events)

ary.add_write_event(evt)
self.state.add_write_event(evt)

=======
wait_for=ary.write_events + ary.read_events)

ary.add_write_event(evt)
self.state.add_write_event(evt)

>>>>>>> Stashed changes
return evt

def normal(self, *args, **kwargs):
Expand Down Expand Up @@ -681,7 +671,7 @@ def _fill(self, distribution, ary, scale, shift, queue=None):
gsize, lsize = _splay(queue.device, ary.size)

evt = knl(queue, gsize, lsize, *args)
ary.add_event(evt)
ary.add_write_event(evt)

self.counter[0] += n * counter_multiplier
c1_incr, self.counter[0] = divmod(self.counter[0], self.counter_max)
Expand Down
2 changes: 1 addition & 1 deletion pyopencl/invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def add_buf_arg(arg_idx, typechar, expr_str):
cl_arg_idx += 1

if in_enqueue:
wait_for_parts .append(f"{arg_var}.events")
wait_for_parts.append(f"{arg_var}.write_events")

continue

Expand Down
4 changes: 2 additions & 2 deletions pyopencl/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
invocation_args.append(arg.base_data)
if arg_tp.with_offset:
invocation_args.append(arg.offset)
wait_for.extend(arg.events)
wait_for.extend(arg.write_events)
else:
invocation_args.append(arg)

Expand Down Expand Up @@ -517,7 +517,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
wait_for=wait_for)
wait_for = [last_evt]

result.add_event(last_evt)
result.add_write_event(last_evt)

if group_count == 1:
if return_event:
Expand Down
4 changes: 2 additions & 2 deletions pyopencl/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,7 +1523,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
data_args.append(arg_val.base_data)
if arg_descr.with_offset:
data_args.append(arg_val.offset)
wait_for.extend(arg_val.events)
wait_for.extend(arg_val.write_events)
else:
data_args.append(arg_val)

Expand Down Expand Up @@ -1750,7 +1750,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
data_args.append(arg_val.base_data)
if arg_descr.with_offset:
data_args.append(arg_val.offset)
wait_for.extend(arg_val.events)
wait_for.extend(arg_val.write_events)
else:
data_args.append(arg_val)

Expand Down

0 comments on commit ad575c8

Please sign in to comment.