-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Definition of pullback for logpdf
is is overly optimistic
#121
Comments
One could keep an explicit list of supported and tested distributions in DistributionsAD, and only define it for those. However, I think a proper fix would be to change the implementation of In general, I try to not look too carefully at the implementations in DistributionsAD - there's type piracy all over the place, and as the example shows it can lead to all kinds of problems... |
This seems reasonable to me. Is this something that the |
I don't know but I assume they might be fine with it. I mean, it seems reasonable to me even without Zygote 🤷 |
This is a hard one. Without making Distributions.jl compatible with Zygote, we need a catchall adjoint here to be an adjoint for Distributions.jl's catchall method. In your case, a workaround would be to define an adjoint for your method that calls pullback on another function name that has no adjoint. Thinking about the bigger problem linked in that issue (I didn't read the whole issue so not sure if this has been discussed), I think we can essentially formalise the workaround used here by adding an additional dispatch layer that allows you to modify the "method-rrule matching rule". Imo, every method should have its own adjoint. If a more specific method was important enough to have in the forward pass then it makes sense that we may need to special case the reverse pass. But some times we may also not need that where a sufficiently generic reverse pass can be the adjoint of many forward methods. Imo this problem can be mostly solved by giving more control to developers and perhaps changing defaults. So now when I define a new Julia function, I can tell ChainRules please don't match my function using Julia's multiple dispatch criteria but treat it as its own thing. This can be literally implemented under the hood using the workaround proposed here, i.e. defining a "bridge rule" that calls another function with no adjoint methods. We can also have an option at the rule definition site telling ChainRules not to match the rule to any forward method whose signature is more specific. Instead, only apply this adjoint to the "specific signatures" provided. This double-headed approach can let us mix and match between "method-rrule matching rules", sometimes using normal multiple dispatch where it seems to not break things and other times matching the exact signature. But more importantly, the approach proposed here lets the user opt out of the rules defined in ChainRules at the forward method definition site. |
So now when I define a new function, if I suspect ChainRules is hurting my performance/correctness, I can either define a correct and performant rule for my method or opt out of ChainRules for this method. When defining a new type, it's more complicated because we automatically "sign in" with a few methods in the forward-pass. So perhaps we can also provide an opt-out mechanism based on types not just methods. |
Thanks for your thoughts @mohamed82008 -- I think we're on the same page in regards to the problem.
Could you elaborate with some pseudo / example code or something? I'm struggling to understand what you're proposing, but would be keen to understand better. I would generally be much more in favour of an opt-in mechanism. My reasoning for that we should view an inability of AD to automatically derive a rule as the norm, rather than the exception. |
Defining a correct or performant rule is easy, just use ChainRules. Opting out can be done by overloading |
An opt-in mechanism (opt out by default) would be hard to implement though. This because when we check for a rule, we check the concrete types to see if there is a method in |
I think what you are really advocating for here is rule definition for "narrow" abstract types, e.g. |
Hmm yes. This could also be done by writing an I mean, the best way to implement an opt-in mechanism is to just not define More generally, the symptom of what One way of reasoning about this as type-piracy is by considering that when I wrote my code, I also implicitly "wrote" a method of I will grant you, that you could either construe this as a problem with the way that |
The DistributionsAD approach is breaking all the Julia rules and it needs to go. But this package was born out of the need to "fix" differentiating most of the distributions using all the AD packages. This meant different workarounds for different packages. Some of those "workarounds" made it back to ReverseDiff or were changed to using ChainRules, while others remained. In a way, defining rrules on any type we don't own is type piracy. But doing so on abstract types is especially bad for the reason you outline. So in summary, I am in favour of removing the method in question here if removing it doesn't break anything or if you have a better implementation. |
This definition is very optimistic about the things that it thinks that it can handle.
In particular, it hijacks control away from this method in Stheno, and causes AD to do something entirely innappropriate in the sense that if this rule didn't exist, my code would work just fine. It causes similar problems to type piracy -- see this well-known ChainRules issue, which explains the core of the problem.
TLDR: defining rules for abstract types causes problems. Since we need to be able to work with abstract types at the minute, this means that you have to be really careful about the abstract types for which you implement rules.
@mohamed82008 any thoughts on how this implementation could be made less aggressive? It's currently blocking for Stheno-Turing integration, and is related to this issue.
The text was updated successfully, but these errors were encountered: