Automation refactor with Jax #217
Replies: 5 comments 16 replies
-
Would moving to Jax mean the end of native windows support? Not that that should be a deal breaker, but Jax currently only works on Windows via WSL, so I'm curious what that would mean. |
Beta Was this translation helpful? Give feedback.
-
The "simple" implementation is extremely satisfying, but I think in most cases I think we'd want a "fast" implementation anyway (I'd be curious what the performance differences are, in any case). And Jax's lack of windows compatibility is a bummer, but I'm not sure if this should be a deal breaker or not. I'm a fan of the refactor of internal vs user-facing parameters. I'm not very familiar with Scipy's basin hopping algorithm. What's the advantage of this optimizer over what we currently have i.e. calling Lastly, I totally agree on the general purpose sampling algorithm, would be a neat nice-to-have! |
Beta Was this translation helpful? Give feedback.
-
I think this is generally a good idea and I think JAX makes sense. I like that it's a simple transition from standard NP and if this allows for easier development of the additional distributions which is now pretty high touch from for the maintainers and brings more autonomy in the PR process I think that would be a big win. @alejandroschuler have you done any tests to see an impact on run time with this change? |
Beta Was this translation helpful? Give feedback.
-
Alright I think I've settled on a more-or-less stable configuration for things. The main difference is that "score implementations" (now "manifolds") are now subclasses of the distribution object, i.e. they extend both it and the parent score. Previously there was a dynamic mix-in that subclassed the score implementation class with the distribution class but this redesign simplifies things. Again you can compare the fast normal implementation with the equivalent simple version. Doing this also allowed me to do all of the automated "method building" in the score parent class (i.e. The graph shows that as long as You'll also notice that the scores and derivatives are now split between The ability to automatically support censored outcomes data is very nice, but does complicate things a bit. However, developers are completely free to ignore this. If they simply name their manifold score method |
Beta Was this translation helpful? Give feedback.
-
I think jax would be a great way to go. I had a random thought earlier which might be a work around to avoid dependencies and also potentially allow sympy and jax to work without having to choose one. I imagine other people have already thought about all of this, I am saying it just in case! When adding these extra features just make jax an optional dependency. If you run say
I think it would be quite a nice solution as it keeps the main bulk of ngboost functionality without requiring a larger dependency list. implementation Poetry also has support for it. |
Beta Was this translation helpful? Give feedback.
-
I've been working on refactoring the distributions backend to use jax. The advantage of this is that adding new distributions to ngboost becomes much, much easier at the cost of a little bit of speed. The speed can always be regained by adding the necessary methods, but if the developer does not know how to implement these their class will still work.
The idea is illustrated by comparing two implementations of the normal distribution: a "simple" implementation, and a "fast" implementation. In the"simple" implementation, there is no score implementation and the distribution class has only 3 methods. ngboost uses jax and the provided
cdf
andsample
methods from the distribution to automatically derive/approximate everything that it needs to use the distribution with the log score. In the "fast" implementation I've added a distribution-specific implementation of the log score and some other methods to the distribution that increase speed and numerical stability.I've also simplified (from the user's perspective) the way that ngboost differentiates between internal parameters and the user-facing parameters (e.g. log-sigma vs. sigma in the normal). The transformations from an interval to the reals and their inverses are now automatic and hidden from the user. The goal is that users can implement methods that work on the user-facing parametrization and ngboost will automatically compose these with the transformation to get methods that work with the internal parametrization (this is not yet implemented for the
d_score
method, which requires chain rule as well as composition). The cost of this is some potential confusion when implementing fast distribution-specific scores, as now one must work with and know what the internal parametrization is. But users who are implementing scores should be savvy enough to understand this without too much difficulty, whereas users who just want to try a new distribution and get it to work are often confused by the notion so hiding it should be on the whole a good thing. I've tried to make the distinction in the parametrization clear in the code with some naming conventions (i.e. methods and variables starting with_
work/with or represent the internal parametrization.I've also provided an automatic
_fit_marginal
method that uses gradient descent and basin hopping. The only issue with it at the moment is that the gradient can be numerically weird if the initial guess is off and I don't have a good distribution-agnostic way of setting what that guess should be.It would be excellent if we could implement a fast-enough general purpose sampling algorithm (e.g. inverse transformation sampling maybe?) that could effectively "write" the
sample
method from thecdf
orpdf
methods. Again, the purpose of all of this is to lower the implementation barrier for new distributions and to make the overall system very flexible so users can add as many or as few methods as they want to get the desired speed/complexity trade-off.@avati @tonyduan @ryan-wolbeck lmk what you think!
Beta Was this translation helpful? Give feedback.
All reactions