Skip to content

Commit

Permalink
Merge pull request #574 from vroulet:fix_tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 602157024
  • Loading branch information
JAXopt authors committed Jan 28, 2024
2 parents 7b4dd31 + 48b09dc commit eb06919
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 17 deletions.
4 changes: 2 additions & 2 deletions tests/bfgs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ def fun(x, *args, **kwargs): # Rosenbrock function.

@parameterized.product(
fun_init_and_opt=[
('rosenbrock', onp.zeros(2, dtype='float32'), 0.),
('himmelblau', onp.ones(2, dtype='float32'), 0.),
('rosenbrock', onp.zeros(2), 0.),
('himmelblau', onp.ones(2), 0.),
('matyas', onp.ones(2) * 6., 0.),
('eggholder', onp.ones(2) * 100., None),
],
Expand Down
4 changes: 2 additions & 2 deletions tests/isotonic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def test_compare_with_sklearn(self, increasing, n=10):
y_min = y_sort[2]
y_max = y_sort[n-5]
output = isotonic_l2_pav(y, y_min=y_min, y_max=y_max, increasing=increasing)
output_sklearn = jnp.array(isotonic.isotonic_regression(y, y_min=y_min,
y_max=y_max, increasing=increasing))
output_sklearn = jnp.array(isotonic.isotonic_regression(y, y_min=y_min.item(),
y_max=y_max.item(), increasing=increasing))
self.assertArraysAllClose(output, output_sklearn)

@parameterized.product(increasing=[True, False])
Expand Down
23 changes: 10 additions & 13 deletions tests/lbfgs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,20 +409,17 @@ def binary_logit_log_likelihood_jax(beta, y, x):
x0 = (onp.asarray(beta_init)), method='BFGS'
).x


#jaxopt
# using jaxopt
solver = LBFGS(fun=binary_logit_log_likelihood_jax, maxiter=100,
linesearch="zoom", maxls=10, tol=1e-12)
jaxopt_res = solver.run(beta_init, y, x).params
if jax.config.jax_enable_x64:
# NOTE(vroulet): simply testing in function values at high precision
scipy_val = binary_logit_log_likelihood(scipy_res,
onp.asarray(y),
onp.asarray(x))
jaxopt_val = binary_logit_log_likelihood(jaxopt_res, y, x)
self.assertArraysAllClose(scipy_val, jaxopt_val)
else:
self.assertArraysAllClose(scipy_res, jaxopt_res)

# comparison
scipy_val = binary_logit_log_likelihood(scipy_res,
onp.asarray(y),
onp.asarray(x))
jaxopt_val = binary_logit_log_likelihood(jaxopt_res, y, x)
self.assertArraysAllClose(scipy_val, jaxopt_val)


@parameterized.product(linesearch=['zoom', 'backtracking', 'hager-zhang'])
Expand Down Expand Up @@ -502,8 +499,8 @@ def fun(x):

@parameterized.product(
fun_init_and_opt=[
('rosenbrock', onp.zeros(2, dtype='float32'), 0.),
('himmelblau', onp.ones(2, dtype='float32'), 0.),
('rosenbrock', onp.zeros(2), 0.),
('himmelblau', onp.ones(2), 0.),
('matyas', onp.ones(2) * 6., 0.),
('eggholder', onp.ones(2) * 100., None),
('zakharov', onp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4]), 0.),
Expand Down

0 comments on commit eb06919

Please sign in to comment.