Branching behaviour depending on device used. #23716
Unanswered
AdrienCorenflos
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
What is the preferred way for a user to branch behaviour based on which kind of device the input is located?
To ground the question, say I am implementing the following trinket function:
for which I know the values
cs
andus
are sorted.As per the documentation of
searchsorted
So I'd like to automatically use
jnp.searchsorted(..., method="scan")
on CPU andjnp.searchsorted(..., method="sort")
otherwise.I am seeing (although not understanding yet how I would use it) that JAX uses
mlir
primitive CPU/GPU lowering in the library's lower levels.Is this the preferred end-user interface too or are there simpler alternatives?
Thanks
Beta Was this translation helpful? Give feedback.
All reactions