-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Description
Hi,
We've been having (at geometric_kernels) some issues with people forgetting to import modules with definitions for backend-specific functions, as in
import geometric_kernels.torch, for example. Currently the error messages are a bit cryptic in that case, and leave peope confused.
I think it would be useful to have a "catch-all" dispatcher that reminds the user to import the backend-specific code, if they try to run a dispatched function with numeric inputs. This should still raise NotFoundLookupError if the inputs are completely off.
I came up with my own half-baked attempt to this:
def promised_dispatch(error_msg, precedence=0):
"""
Decorator for "promised" function. The implementation is not given yet, but it will be (e.g. in a separate module).
"""
def wrap(f):
_f = dispatch.abstract(f)
signature = plum.Signature.from_callable(f, precedence=precedence)
def _fallback_f(*args, **kwargs):
if signature.match(args):
raise RuntimeError(error_msg)
_f.register(_fallback_f, signature, precedence)
return _f
return wrap
@promised_dispatch("Did you forget to do `import geometric_kernels.<backend>?")
def ff(x: B.Numeric):
pass
@dispatch
def ff(x: B.NPNumeric):
return x+1
ff(np.r_[3]) # returns 4
ff(torch.tensor([3])) # raises "Did you forget?"
ff('3') # raises NotFoundLookupError: `ff('3')` could not be resolved.Do you think it'd be useful to incorporate something like this into lab? I think this will definitely reduce the amount of confusion.
Metadata
Metadata
Assignees
Labels
No labels