Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interface to <: Supervised MLJ models #69

Closed
2 of 3 tasks
pat-alt opened this issue Oct 13, 2022 · 2 comments
Closed
2 of 3 tasks

Interface to <: Supervised MLJ models #69

pat-alt opened this issue Oct 13, 2022 · 2 comments
Assignees
Labels
difficult This is expected to be difficult. enhancement New feature or request help wanted Extra attention is needed

Comments

@pat-alt
Copy link
Member

pat-alt commented Oct 13, 2022

This is an issue reserved for the TU Delft Student Software Project '23

MLJ is a popular machine learning framework for Julia. It ships with a large suite of common machine learning models, both supervised and unsupervised. It seems natural to interface this package to MLJ, although currently differentiability is a major challenge: to be able to use any of our counterfactual generators to explain MLJ models, those models need to be differentiable with respect to features. Still, this is worth exploring.

I propose the following steps:

  1. Implement basic interface to MLJ (essentially have an AbstractFittedModel for MLJ.Supervised)
  2. From the MLJ model list, identify which ML models fulfil the differentiability criterium. Note that some models, like decision trees, may be differentiable after probability calibration. See below for a potential starting point. Start by focusing on pure Julia models, before dealing with non-native models (like sklearn).
  3. Ideally, I think we would like a single MLJModel<:AbstractFittedModel class that can handle all (compatible) supervised MLJ models. To this end, we will need a mechanism to differentiate between compatible and incompatible models.
  4. Thoroughly test and document your contributions.

This is a challenging task and it is not critical that you succeed at everything. But we would like to aim for the following minimum achievements:

  • Add the basic interface (point 1)
  • Document your process and findings regarding point 2
  • If a complete interface turns out to be too challenging, work on a proof-of-concept at least for one particular MLJ model, ideally Evotrees (points 3 and 4)

Previous attempts

I have tried this in the past, which might or might not be a good starting point:

  1. At this point all of the counterfactual generators need gradient-access and currently leverage Zygote.jl for auto-diff. Not sure if all MLJ models can just be "auto-diffed" in that sense, but some early experiments with EvoTrees has shown that in principal gradient-based counterfactual generators should be applicable (see here).
  2. That being said, Zygote.jl didn't work in this case and I had to rely on ForwardDiff (see here). The problem with trees is that the counterfactual loss function is not smooth and hence taking gradients just resulted in gradients with all elements equal to zero (at least I think the non-smoothness was the issue here). Would still be preferable to use Zygote if possible.
  3. (Non-)Differentiability of models may be a more general issue.
@pat-alt pat-alt added enhancement New feature or request help wanted Extra attention is needed labels Oct 13, 2022
@pat-alt pat-alt added the difficult This is expected to be difficult. label Nov 29, 2022
@pat-alt pat-alt self-assigned this Nov 29, 2022
This was referenced Mar 20, 2023
@pat-alt
Copy link
Member Author

pat-alt commented Apr 3, 2023

MLJFlux is probably the most obvious place to start for this (see related discussion here)

@pat-alt
Copy link
Member Author

pat-alt commented May 19, 2024

This is in principle now implemented (#450), but by default MLJ models are assumed to be non-differentiable (the MLJBase.predict call and other functions don't play nicely with Zygote)

@pat-alt pat-alt closed this as completed May 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
difficult This is expected to be difficult. enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant