Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for loss functions with auxiliary data to linesearch #1177

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

ro0mquy
Copy link

@ro0mquy ro0mquy commented Jan 16, 2025

Summary

This change adds support for loss functions that return auxiliary data alongside their primary value, like (loss_value, extra_data). This pattern is commonly used with jax.value_and_grad(fn, has_aux=True).

The approach:

  1. Added value_fn_has_aux flag to zoom_linesearch and scale_by_zoom_linesearch
  2. Modified value handling to properly unpack auxiliary data when needed using a new _unpack_value helper that extracts just the loss value
  3. Updated value storage in state to keep the full value+aux tuple when needed
  4. Added has_aux parameter to value_and_grad_from_state to properly handle auxiliary data when reusing cached values

This allows the linesearch algorithms to work with loss functions that return auxiliary data while maintaining the optimization over just the primary loss value.

Input needed: How to initialize opt_state?

The linesearch algorithm stores value and grad in the optimizer state to enable reuse of function evaluations. When using auxiliary data, JAX compilation needs to know the structure of this data upfront.

Currently, I'm initializing it like this:

opt_state = optimizer.init(params)
# Run loss function once to get auxiliary data structure
_, aux = loss(params)
# Set value to infinity (to force recalculation) but keep aux structure
value = (jnp.asarray(jnp.inf), aux)
opt_state = optax.tree_utils.tree_set(opt_state, value=value)

This feels a bit hacky since it requires an extra function evaluation just to get the structure. Is there a better way to handle this initialization?

The challenge is that the auxiliary data structure is determined by the loss function and could be arbitrary (e.g., dictionaries, nested structures, etc.).

ToDos

  • Add support to backtracking linesearch
  • Add documentation and doc strings
  • Add tests
  • Improve handling of initial opt_state

This change adds support for loss functions that return auxiliary data alongside
their primary value, like (loss_value, extra_data). This pattern is commonly
used with jax.value_and_grad(fn, has_aux=True).

The approach:
1. Added value_fn_has_aux flag to zoom_linesearch and scale_by_zoom_linesearch
2. Modified value handling to properly unpack auxiliary data when needed using
   a new _unpack_value helper that extracts just the loss value
3. Updated value storage in state to keep the full value+aux tuple when needed
4. Added has_aux parameter to value_and_grad_from_state to properly handle
   auxiliary data when reusing cached values

This allows the linesearch algorithms to work with loss functions that return
auxiliary data while maintaining the optimization over just the primary loss value.
Copy link

google-cla bot commented Jan 16, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@rdyro
Copy link
Collaborator

rdyro commented Feb 3, 2025

Hey, thanks for continuing to work on this! Can you sign the CLA please?

@theo-brown
Copy link

@ro0mquy thanks for this PR, I'm keen to see this feature added!

@ro0mquy
Copy link
Author

ro0mquy commented Feb 19, 2025

I'm still figuring out the CLA thing with my company. But it's just a matter of time. I also accidentally pushed another fix (the "fix slope calculation in zoom_linesearch") to the same branch. I can revert that again if needed.

@ro0mquy
Copy link
Author

ro0mquy commented Mar 9, 2025

I signed the CLA and updated my email. Is there a way to retrigger the check?

@vroulet
Copy link
Collaborator

vroulet commented Mar 11, 2025

Hello @ro0mquy,
Any commit will retrigger the tests, so you may merge with main for example.

By curiosity

  1. why did you need the aux value returned ? (use cases would be great to understand)

  2. why not just wrapping the function that returned an aux into a function that does not return an aux? The linesearch only needs the value. At the end of the linesearch, the value and grad computed at the accepted point may get recycled by the value_and_grad_from_state function but it does not have to. You can always just recompute the value, grad (and potentially aux) by yourself. True, that would not be optimal (because you would be computing twice the value, grad) but that would work.
    Namely, when wrapping LBFGS into a solver (see the notebook) consider simply defining

def run_opt(init_params, fun_with_aux, opt, max_iter, tol):
  fun_without_aux = lambda *a, **kw: fun_with_aux(*a, **kw)[0]

  def step(carry):
    params, state = carry
    (value, aux), grad = jax.value_and_grad(fun_with_aux, has_aux=True)
    updates, state = opt.update(
        grad, state, params, value=value, grad=grad, value_fn=fun_without_aux
    )
    params = optax.apply_updates(params, updates)
    return params, state, aux

  def continuing_criterion(carry):
    _, state = carry
    iter_num = otu.tree_get(state, 'count')
    grad = otu.tree_get(state, 'grad')
    err = otu.tree_l2_norm(grad)
    return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))

  (_, init_aux), _ =  jax.value_and_grad(fun_with_aux, has_aux=True)(init_params)
  init_carry = (init_params, opt.init(init_params), init_aux)
  final_params, final_state, final_aux = jax.lax.while_loop(
      continuing_criterion, step, init_carry
  )
  return final_params, final_state, final_aux

But your fix seems pretty good !

  1. I would like the "has_aux" argument to rather be in the update function too if possible (not clear that it's possible). The reason is to keep all the logic pertaining to the actual function considered outside the signature of the method.
  2. One needs to be careful at initialization: the shape of the aux may not be known by the init function of the optimizer as you noticed. Your hack is quite ok though as long as it is documented (ideally we would change the base api so that the init function could accept additional arguments but that's a much deeper revamp).

Thanks for looking into this!

@ro0mquy
Copy link
Author

ro0mquy commented Mar 12, 2025

Hey thank you for looking at this.

  1. In my use case the aux data contains the individual loss terms that get added together for the final loss value.
  2. Your comment "that would not be optimal" is exactly why I didn't use the wrapper. I would like to reuse the function value and gradient because my loss function calculation is quite expensive.

I don't have strong opinions about the exact api design. Feel free to make suggestions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants