Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update fit to new fax api #17

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions nbs/03_fit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"# export\n",
"import jax\n",
"from fax.implicit import twophase\n",
"import jax.experimental.optimizers as optimizers\n",
"from jax.experimental import optix\n",
"\n",
"from neos.transforms import to_bounded_vec, to_inf_vec, to_bounded, to_inf\n",
"from neos.models import *"
Expand All @@ -61,7 +61,7 @@
"):\n",
" '''\n",
" Wraps a series of functions that perform maximum likelihood fitting in the \n",
" `two_phase_solver` method found in the `fax` python module. This allows for\n",
" `two_phase_solve` method found in the `fax` python module. This allows for\n",
" the calculation of gradients of the best-fit parameters with respect to upstream\n",
" parameters that control the underlying model, i.e. the event yields (which are \n",
" then parameterized by weights or similar).\n",
Expand All @@ -74,7 +74,7 @@
" respectively. Differentiable :)\n",
" '''\n",
"\n",
" adam_init, adam_update, adam_get_params = optimizers.adam(1e-6)\n",
" gradient_descent = optix.scale(-1e-2)\n",
"\n",
" def make_model(hyper_pars):\n",
" constrained_mu, nn_pars = hyper_pars[0], hyper_pars[1]\n",
Expand Down Expand Up @@ -108,51 +108,58 @@
" )\n",
" return -expected_logpdf(pars)[0]\n",
"\n",
" return constrained_mu, global_fit_objective, constrained_fit_objective,bounds\n",
" return constrained_mu, global_fit_objective, constrained_fit_objective, bounds\n",
"\n",
" def global_bestfit_minimized(hyper_param):\n",
" _, nll, _ ,_ = make_model(hyper_param)\n",
"\n",
" def bestfit_via_grad_descent(i, param): # gradient descent\n",
" def bestfit_via_grad_descent(param): # gradient descent\n",
" g = jax.grad(nll)(param)\n",
" # param = param - g * learning_rate\n",
" param = adam_get_params(adam_update(i,g,adam_init(param)))\n",
" return param\n",
" updates, _ = gradient_descent.update(g, gradient_descent.init(param))\n",
" return optix.apply_updates(param, updates)\n",
"\n",
" return bestfit_via_grad_descent\n",
" \n",
"\n",
" def constrained_bestfit_minimized(hyper_param):\n",
" mu, nll, cnll,bounds = make_model(hyper_param)\n",
" mu, nll, cnll, bounds = make_model(hyper_param)\n",
"\n",
" def bestfit_via_grad_descent(i, param): # gradient descent\n",
" def bestfit_via_grad_descent(param): # gradient descent\n",
" _, np = param[0], param[1:]\n",
" g = jax.grad(cnll)(np)\n",
" np = adam_get_params(adam_update(i,g,adam_init(np)))\n",
" updates, _ = gradient_descent.update(g, gradient_descent.init(np))\n",
" np = optix.apply_updates(np, updates)\n",
" param = jax.numpy.concatenate([jax.numpy.asarray([mu]), np])\n",
" return param\n",
" \n",
"\n",
" return bestfit_via_grad_descent\n",
"\n",
" global_solve = twophase.two_phase_solver(\n",
" param_func=global_bestfit_minimized,\n",
" default_rtol=default_rtol,\n",
" default_atol=default_atol,\n",
" default_max_iter=default_max_iter\n",
" \n",
" convergence_test = twophase.default_convergence_test(\n",
" rtol=default_rtol,\n",
" atol=default_atol,\n",
" )\n",
" constrained_solver = twophase.two_phase_solver(\n",
" param_func=constrained_bestfit_minimized,\n",
" default_rtol=default_rtol,\n",
" default_atol=default_atol,\n",
" default_max_iter=default_max_iter,\n",
" global_solver = twophase.default_solver(\n",
" convergence_test=convergence_test,\n",
" max_iter=default_max_iter,\n",
" )\n",
" constrained_solver = global_solver\n",
"\n",
" def g_fitter(init, hyper_pars):\n",
" solve = global_solve(init, hyper_pars)\n",
" return solve.value\n",
" return twophase.two_phase_solve(\n",
" global_bestfit_minimized,\n",
" init,\n",
" hyper_pars,\n",
" solvers=(global_solver,),\n",
" )\n",
"\n",
" def c_fitter(init, hyper_pars):\n",
" solve = constrained_solver(init, hyper_pars)\n",
" return solve.value\n",
" return twophase.two_phase_solve(\n",
" constrained_bestfit_minimized,\n",
" init,\n",
" hyper_pars,\n",
" solvers=(constrained_solver,),\n",
" )\n",
"\n",
" return g_fitter, c_fitter"
]
Expand Down
57 changes: 32 additions & 25 deletions neos/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Cell
import jax
from fax.implicit import twophase
import jax.experimental.optimizers as optimizers
from jax.experimental import optix

from .transforms import to_bounded_vec, to_inf_vec, to_bounded, to_inf
from .models import *
Expand All @@ -21,7 +21,7 @@ def get_solvers(
):
'''
Wraps a series of functions that perform maximum likelihood fitting in the
`two_phase_solver` method found in the `fax` python module. This allows for
`two_phase_solve` method found in the `fax` python module. This allows for
the calculation of gradients of the best-fit parameters with respect to upstream
parameters that control the underlying model, i.e. the event yields (which are
then parameterized by weights or similar).
Expand All @@ -34,7 +34,7 @@ def get_solvers(
respectively. Differentiable :)
'''

adam_init, adam_update, adam_get_params = optimizers.adam(1e-6)
gradient_descent = optix.scale(-1e-2)

def make_model(hyper_pars):
constrained_mu, nn_pars = hyper_pars[0], hyper_pars[1]
Expand Down Expand Up @@ -68,50 +68,57 @@ def constrained_fit_objective(nuis_par): # NLL
)
return -expected_logpdf(pars)[0]

return constrained_mu, global_fit_objective, constrained_fit_objective,bounds
return constrained_mu, global_fit_objective, constrained_fit_objective, bounds

def global_bestfit_minimized(hyper_param):
_, nll, _ ,_ = make_model(hyper_param)

def bestfit_via_grad_descent(i, param): # gradient descent
def bestfit_via_grad_descent(param): # gradient descent
g = jax.grad(nll)(param)
# param = param - g * learning_rate
param = adam_get_params(adam_update(i,g,adam_init(param)))
return param
updates, _ = gradient_descent.update(g, gradient_descent.init(param))
return optix.apply_updates(param, updates)

return bestfit_via_grad_descent


def constrained_bestfit_minimized(hyper_param):
mu, nll, cnll,bounds = make_model(hyper_param)
mu, nll, cnll, bounds = make_model(hyper_param)

def bestfit_via_grad_descent(i, param): # gradient descent
def bestfit_via_grad_descent(param): # gradient descent
_, np = param[0], param[1:]
g = jax.grad(cnll)(np)
np = adam_get_params(adam_update(i,g,adam_init(np)))
updates, _ = gradient_descent.update(g, gradient_descent.init(np))
np = optix.apply_updates(np, updates)
param = jax.numpy.concatenate([jax.numpy.asarray([mu]), np])
return param


return bestfit_via_grad_descent

global_solve = twophase.two_phase_solver(
param_func=global_bestfit_minimized,
default_rtol=default_rtol,
default_atol=default_atol,
default_max_iter=default_max_iter
convergence_test = twophase.default_convergence_test(
rtol=default_rtol,
atol=default_atol,
)
constrained_solver = twophase.two_phase_solver(
param_func=constrained_bestfit_minimized,
default_rtol=default_rtol,
default_atol=default_atol,
default_max_iter=default_max_iter,
global_solver = twophase.default_solver(
convergence_test=convergence_test,
max_iter=default_max_iter,
)
constrained_solver = global_solver

def g_fitter(init, hyper_pars):
solve = global_solve(init, hyper_pars)
return solve.value
return twophase.two_phase_solve(
global_bestfit_minimized,
init,
hyper_pars,
solvers=(global_solver,),
)

def c_fitter(init, hyper_pars):
solve = constrained_solver(init, hyper_pars)
return solve.value
return twophase.two_phase_solve(
constrained_bestfit_minimized,
init,
hyper_pars,
solvers=(constrained_solver,),
)

return g_fitter, c_fitter