Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stopping condition 'madsen-nielsen' incorrect #575

Open
Joshuaalbert opened this issue Feb 2, 2024 · 0 comments
Open

Stopping condition 'madsen-nielsen' incorrect #575

Joshuaalbert opened this issue Feb 2, 2024 · 0 comments

Comments

@Joshuaalbert
Copy link

Joshuaalbert commented Feb 2, 2024

The documentation says a different thing than code. Specifically, the - is inconsistent with the + in docstring at this part (tree_l2_norm(params) - self.xtol).

Docstring says:

the convergence is achieved once the
coeff update satisfies ``||dcoeffs||_2 <= xtol * (||coeffs||_2 + xtol) `` or
the gradient satisfies ``||grad(f)||_inf <= gtol``.

Code says:

      tree_mul_term = self.xtol * (tree_l2_norm(params) - self.xtol)
      return jnp.all(jnp.array([
        tree_inf_norm(state.gradient) > self.gtol,
        tree_l2_norm(state.delta) > tree_mul_term
      ]))

Additionally, rather than all(array(...)) you should use jnp.bitwise_and(..., ...) or | & and ~ ops.

My suggestion

Upon reading up about madsen-nielsen stopping condition it seems that there is no single version of it. From my optimisation work I find incorporating both absolute and relative tolerance in parameter changes is quite useful. (Currently it looks like it's only relative)

def leaves_vec(tree_x):
  return jnp.concatenate(tree_leaves(tree_map(jnp.ravel, tree_x)))


atol_cond = jnp.all(jnp.abs(leaves_vec(state.delta)) <= self.atol)
rtol_cond = jnp.all(jnp.abs(leaves_vec(state.delta)) <= self.rtol * jnp.abs(tree_vec(params)))
grad_cond = jnp.max(jnp.abs(leaves_vec(state.gradient))) <= self.gtol
done = atol_cond | rtol_cond | grad_cond
return ~done

# defaults
atol = 0. # effectively turned off unless user wants it on, to be backward compatible with current.
rtol = 1e-3
gtol = 1e-3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant