-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Namespace-aware xarray.ufuncs
#9776
Conversation
elif isinstance(obj, array_type("dask")): | ||
_walk_array_namespaces(obj._meta, namespaces) | ||
else: | ||
namespace = getattr(obj, "__array_namespace__", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we ever prioritize dispatching with np.func
via __array_ufunc__
(if it exists) over the library's __array_namespace__().func
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__array_ufunc__
is a more generic protocol, intended to support arbitrary new ufuncs without requiring an array library to be aware of them.
In practice there are very few examples of ufuncs defined outside of NumPy itself, and we wouldn't need to support them here because we are explicitly listing supported ufuncs in this module. I guess the one example that comes to mind would be the rare cases where NumPy adds a new ufunc and an array wrapping library like Dask hasn't written a wrapper yet.
At this point, I think going "all in" on __array_namespace__
is the right call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess my other comment would be the main reason to consider __array_ufunc__
. Some duck arrays don't implement all ufuncs. So either of these approaches would solve the same problem.
xarray/ufuncs.py
Outdated
) | ||
func = getattr(np, self._name) | ||
|
||
return xr.apply_ufunc(func, *args, dask="parallelized", **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there ever a reason to use dask's ufuncs with dask="allowed"
instead of the appropriate _meta
array's namespace and dask="parallelized"
? With jax
for example, which doesn't have __array_ufunc__
, this ends up converting to numpy
. So it would have to be special cased.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In user code using xr.apply_ufunc
there is - dask='allowed'
can be used to rechunk along a core dimension e.g. by applying a dask reduction ufunc along that dimension. Not sure if that's relevant here though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are all elementwise so no core dimensions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there ever a reason to use dask's ufuncs with dask="allowed" instead of the appropriate _meta array's namespace and dask="parallelized"?
Yes, this feels like a cleaner solution to me.
With jax for example, which doesn't have array_ufunc, this ends up converting to numpy. So it would have to be special cased
Is the concern here Dask wrapping JAX?
Generally, I think it's best for xarray to avoid introspecting into wrapped array types in Xarray, and leave nested wrapping up to other libraries, which can better understand their own implementation details. So Dask wrapping JAX should be fixed in Dask.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So Dask wrapping JAX should be fixed in Dask.
Totally fair. It looks like basically the same effort here would be required in dask then, because dask's ufuncs are all simple wrappers around the numpy version so they aren't aware of the namespace.
xarray/ufuncs.py
Outdated
|
||
|
||
# Auto generate from the public numpy ufuncs | ||
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can hard code these if preferred?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I would suggest hard coding these if possible, ideally as something like:
sin = _unary_ufunc('sin')
add = _binary_ufunc('add')
...
The reason to prefer this is that static typing does not evaluate loops or dynamically defined functions. So otherwise xarray.ufuncs.sin()
will not be recognized as valid by tools like mypy.
Ideally (could be done in a follow-up PR), these functions could be annotated to follow the appropriate type casting rules, so type checkers would know that xarray.ufuncs.sin(dataset)
returns another Dataset.
In some cases, we use a script to generate all these special methods. I don't think that should be necessary here, but it still may be worth a look to understand the type casting rules:
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/util/generate_ops.py
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/core/_typed_ops.py
xarray/ufuncs.py
Outdated
|
||
# Auto generate from the public numpy ufuncs | ||
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} | ||
excluded_ufuncs = {"divmod", "frexp", "isnat", "matmul", "modf", "vecdot"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are the ones that didn't immediately work. There are also other ufunc like things that aren't technically np.ufunc
subclasses that we could add. I saw angle and iscomplex were special cased before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe worth noting that the reason why matmul
and vecdot
doesn't work is that they are "generalized ufuncs" that use core dimensions.
divmod
, frexp
and modf
doesn't work because they return multiple arrays.
I'm not sure why isnat
didn't work for you. Did you test it with datetime dtypes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you test it with datetime dtypes?
No, this was just a real quick initial pass. Will add this with a special test case.
Do you have an opinion about adding any of the ones with multiple return values? Seems low priority to me.
Same question for the odd balls like angle
, iscomplex
, isreal
, etc?
xarray/ufuncs.py
Outdated
if func is None: | ||
warnings.warn( | ||
f"Function {self._name} not found in {xp.__name__}, falling back to numpy", | ||
stacklevel=2, | ||
) | ||
func = getattr(np, self._name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would lean towards skipping this fall-back, unless there are particularly motivating cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Motivation here would be duck arrays that implement __array_ufunc__
and don't implement the full suite of numpy ufuncs. I ran into this with sparse. Not sure the full delta list, but I see they don't have sin/cos for example. In this case, np.cos(x_sparse)
works but xp.cos(x_sparse)
fails, which is a little weird. Not the most elegant solution though, I agree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's intentional than xp.cos(x_sparse)
fails, because cos(0) != 0
, so the result is no longer sparse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I was wrong, there is a sparse.cos
and a bunch of others, although they don't appear in the API docs. It seems sparse's general approach to these is compute elementwise on the valid data, and then modify the fill_value as required for the empty regions.
There are still 40-some functions that fail without this fallback, although generally more niche. With the fallback, all work and output a sparse array (no auto densification):
absolute, arccos, arccosh, arcsin, arcsinh, arctan, arctan2, arctanh, bitwise_count, cbrt, conj, conjugate, copysign, deg2rad, degrees, exp2, expm1x, fabs, float_power, fmax, fmin, fmod, gcd, heaviside, hypot, invert, isreal, lcm, ldexp, left_shift, logaddexp2, maximum, minimum, mod, nextafter, power, rad2deg, radians, reciprocal, right_shift, rint, signbit, spacing, true_divide
xarray/ufuncs.py
Outdated
elif isinstance(obj, array_type("dask")): | ||
_walk_array_namespaces(obj._meta, namespaces) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not sure Dask's __array_namespace__
instead? That feels a little cleaner than special case logic for dask.array.
xarray/ufuncs.py
Outdated
) | ||
func = getattr(np, self._name) | ||
|
||
return xr.apply_ufunc(func, *args, dask="parallelized", **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there ever a reason to use dask's ufuncs with dask="allowed" instead of the appropriate _meta array's namespace and dask="parallelized"?
Yes, this feels like a cleaner solution to me.
With jax for example, which doesn't have array_ufunc, this ends up converting to numpy. So it would have to be special cased
Is the concern here Dask wrapping JAX?
Generally, I think it's best for xarray to avoid introspecting into wrapped array types in Xarray, and leave nested wrapping up to other libraries, which can better understand their own implementation details. So Dask wrapping JAX should be fixed in Dask.
xarray/ufuncs.py
Outdated
|
||
|
||
# Auto generate from the public numpy ufuncs | ||
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I would suggest hard coding these if possible, ideally as something like:
sin = _unary_ufunc('sin')
add = _binary_ufunc('add')
...
The reason to prefer this is that static typing does not evaluate loops or dynamically defined functions. So otherwise xarray.ufuncs.sin()
will not be recognized as valid by tools like mypy.
Ideally (could be done in a follow-up PR), these functions could be annotated to follow the appropriate type casting rules, so type checkers would know that xarray.ufuncs.sin(dataset)
returns another Dataset.
In some cases, we use a script to generate all these special methods. I don't think that should be necessary here, but it still may be worth a look to understand the type casting rules:
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/util/generate_ops.py
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/core/_typed_ops.py
elif isinstance(obj, array_type("dask")): | ||
_walk_array_namespaces(obj._meta, namespaces) | ||
else: | ||
namespace = getattr(obj, "__array_namespace__", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__array_ufunc__
is a more generic protocol, intended to support arbitrary new ufuncs without requiring an array library to be aware of them.
In practice there are very few examples of ufuncs defined outside of NumPy itself, and we wouldn't need to support them here because we are explicitly listing supported ufuncs in this module. I guess the one example that comes to mind would be the rare cases where NumPy adds a new ufunc and an array wrapping library like Dask hasn't written a wrapper yet.
At this point, I think going "all in" on __array_namespace__
is the right call.
xarray/ufuncs.py
Outdated
|
||
# Auto generate from the public numpy ufuncs | ||
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} | ||
excluded_ufuncs = {"divmod", "frexp", "isnat", "matmul", "modf", "vecdot"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe worth noting that the reason why matmul
and vecdot
doesn't work is that they are "generalized ufuncs" that use core dimensions.
divmod
, frexp
and modf
doesn't work because they return multiple arrays.
I'm not sure why isnat
didn't work for you. Did you test it with datetime dtypes?
xarray/ufuncs.py
Outdated
def __init__(self, name): | ||
self._name = name | ||
|
||
def __call__(self, x, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shoyer lmk if this is what you had in mind for a separating unary from binary funcs. I assume this is for typing purposes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Close! I added a more detailed note below.
xarray/ufuncs.py
Outdated
def _create_op(name): | ||
if not hasattr(np, name): | ||
# handle older numpy versions with missing array api standard aliases | ||
if np.lib.NumpyVersion(np.__version__) < "2.0.0": | ||
return _UnavailableUfunc(name) | ||
raise ValueError(f"'{name}' is not a valid numpy function") | ||
|
||
np_func = getattr(np, name) | ||
if hasattr(np_func, "nin") and np_func.nin == 2: | ||
func = _BinaryUfunc(name) | ||
else: | ||
func = _UnaryUfunc(name) | ||
|
||
func.__name__ = name | ||
doc = getattr(np, name).__doc__ | ||
|
||
doc = _remove_unused_reference_labels(_skip_signature(_dedent(doc), name)) | ||
|
||
func.__doc__ = ( | ||
f"xarray specific variant of numpy.{name}. Handles " | ||
"xarray objects by dispatching to the appropriate " | ||
"function for the underlying array type.\n\n" | ||
f"Documentation from numpy:\n\n{doc}" | ||
) | ||
return func |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type checkers can't evaluate code at runtime, so you really want to write something that prescribes the type signature statically, like:
abs = _UnaryUfunc('abs')
or
def _unary_ufunc(name: str) -> _UnaryUfunc:
func = _UnaryUfunc(name)
func.__doc__ = ...
return func
abs = _unary_ufunc('abs')
When you write abs = _create_op('abs')
, type checkers think the type of abs
could be any of _UnavailableUfunc
or _BinaryUfunc
or _UnaryUfunc
. (You can check this with reveal_type(abs)
if you're curious.)
xarray/ufuncs.py
Outdated
def __init__(self, name): | ||
self._name = name | ||
|
||
def __call__(self, x, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Close! I added a more detailed note below.
This looks great! Could you please also add a brief note to |
Yes will do. Shall we add these back to the main api docs as well, same as before minus the deprecation notice? |
Yes, please!
…On Sun, Nov 17, 2024 at 1:22 PM Sam Levang ***@***.***> wrote:
This looks great! Could you please also add a brief note to whats-new.rst?
Yes will do. Shall we add these back to the main api docs as well, same as
before
<https://github.com/pydata/xarray/pull/6491/files#diff-9ceb22fb61ab27613e143e733f0d2ef128c1f72125ae3c6324465955bd093f93>
minus the deprecation notice?
—
Reply to this email directly, view it on GitHub
<#9776 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJJFVVI36XVHPN3ZAPCCSL2BECIVAVCNFSM6AAAAABRVS4J72VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIOBRGU4DEOJUGM>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
* main: (24 commits) Bump minimum versions (#9796) Namespace-aware `xarray.ufuncs` (#9776) Add prettier and pygrep hooks to pre-commit hooks (#9644) `rolling.construct`: Add `sliding_window_kwargs` to pipe arguments down to `sliding_window_view` (#9720) Bump codecov/codecov-action from 4.6.0 to 5.0.2 in the actions group (#9793) Buffer types (#9787) Add download stats badges (#9786) Fix open_mfdataset for list of fsspec files (#9785) add 'User-Agent'-header to pooch.retrieve (#9782) Optimize `ffill`, `bfill` with dask when `limit` is specified (#9771) fix cf decoding of grid_mapping (#9765) Allow wrapping `np.ndarray` subclasses (#9760) Optimize polyfit (#9766) Use `map_overlap` for rolling reductions with Dask (#9770) fix html repr indexes section (#9768) Bump pypa/gh-action-pypi-publish from 1.11.0 to 1.12.2 in the actions group (#9763) unpin array-api-strict, as issues are resolved upstream (#9762) rewrite the `min_deps_check` script (#9754) CI runs ruff instead of pep8speaks (#9759) Specify copyright holders in main license file (#9756) ...
* main: Bump minimum versions (pydata#9796) Namespace-aware `xarray.ufuncs` (pydata#9776) Add prettier and pygrep hooks to pre-commit hooks (pydata#9644) `rolling.construct`: Add `sliding_window_kwargs` to pipe arguments down to `sliding_window_view` (pydata#9720) Bump codecov/codecov-action from 4.6.0 to 5.0.2 in the actions group (pydata#9793) Buffer types (pydata#9787) Add download stats badges (pydata#9786) Fix open_mfdataset for list of fsspec files (pydata#9785) add 'User-Agent'-header to pooch.retrieve (pydata#9782) Optimize `ffill`, `bfill` with dask when `limit` is specified (pydata#9771) fix cf decoding of grid_mapping (pydata#9765) Allow wrapping `np.ndarray` subclasses (pydata#9760) Optimize polyfit (pydata#9766) Use `map_overlap` for rolling reductions with Dask (pydata#9770)
whats-new.rst
api.rst
Re-implement the old
xarray.ufuncs
module to allow generic ufunc handling for array types that don't implement__array_ufunc__
: