Skip to content

Catching forgotten backend-specific imports #19

@stoprightthere

Description

@stoprightthere

Hi,

We've been having (at geometric_kernels) some issues with people forgetting to import modules with definitions for backend-specific functions, as in
import geometric_kernels.torch, for example. Currently the error messages are a bit cryptic in that case, and leave peope confused.

I think it would be useful to have a "catch-all" dispatcher that reminds the user to import the backend-specific code, if they try to run a dispatched function with numeric inputs. This should still raise NotFoundLookupError if the inputs are completely off.

I came up with my own half-baked attempt to this:

def promised_dispatch(error_msg, precedence=0):
    """
    Decorator for "promised" function. The implementation is not given yet, but it will be (e.g. in a separate module).
    """
    def wrap(f):
        _f = dispatch.abstract(f)
        signature = plum.Signature.from_callable(f, precedence=precedence)
        
        def _fallback_f(*args, **kwargs):
            if signature.match(args):
                raise RuntimeError(error_msg)

        _f.register(_fallback_f, signature, precedence)
        return _f
    return wrap

@promised_dispatch("Did you forget to do `import geometric_kernels.<backend>?")
def ff(x: B.Numeric):
    pass

@dispatch
def ff(x: B.NPNumeric):
    return x+1

ff(np.r_[3])    # returns 4

ff(torch.tensor([3]))  # raises "Did you forget?"

ff('3')  # raises NotFoundLookupError: `ff('3')` could not be resolved.

Do you think it'd be useful to incorporate something like this into lab? I think this will definitely reduce the amount of confusion.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions