-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Plans for autodiff on fitted models? #220
Comments
Yeah. I think this is a very noble goal but indeed challenging. Still, given the nobility of the goal, I think it's definitely worth scoping out where the issues lie. I'm not sure what the particluar issue raised above is. The problem is that MLJ started when Flux was still in relative infancy (no Zygote) and there's a lot of mutation where Zygote just spits the dummy. When I last played with this, I ran into a rather serious obstacle for probabilistic classifiers. The implementation of
which is equivalent to |
Thanks for sharing your thoughts on this, Anthony. We'll be looking at this in the coming weeks/months and I have no doubt we'll run into lots of issues related to mutation. Nonetheless, I think it's worth exploring. I think If it's alright, I'll keep this open for now and we may come back here with updates. |
Motivation and description
Maybe this is a more general topic for
MLJ
, not only related toFlux
. I know that autodiff has been discussed in the past and withMLJFlux
now being developed, I was wondering if this topic has come back into focus.In an ideal world, it would be possible to differentiate through any
SupervisedModel
and get gradients with respect to parameters or inputs. This would, for example, greatly increase the scope of models we can explain through Counterfactual Explanations (see plans outlined here).MLJFlux
seems like a good place to start, since the underlying models are compatible withZygote
. But even here we quickly run into issues: for example, it does not seem possible to differentiate through apredict
call.An example:
Both
f
andg
can be used to return softmax output forx
Autodiff only works for
g
,but not for
f
:A simple workaround for this specific issue is to just use the
Chain
directly to produce the softmax output but this approach does not generalise to otherMLJ
models.I appreciate that this is a very ambitious idea (perhaps previous discussions have that this is simply asking too much), but I would be curious to hear what others think.
Worth mentioning that for the plans mentioned above, I will get some support from a group of CS students soon. So if you have any plans or ongoing work in this space anyway, perhaps there's something we can help with.
Thanks!
Possible Implementation
No response
The text was updated successfully, but these errors were encountered: