-
Notifications
You must be signed in to change notification settings - Fork 222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for loss functions with auxiliary data to linesearch #1177
base: main
Are you sure you want to change the base?
Conversation
This change adds support for loss functions that return auxiliary data alongside their primary value, like (loss_value, extra_data). This pattern is commonly used with jax.value_and_grad(fn, has_aux=True). The approach: 1. Added value_fn_has_aux flag to zoom_linesearch and scale_by_zoom_linesearch 2. Modified value handling to properly unpack auxiliary data when needed using a new _unpack_value helper that extracts just the loss value 3. Updated value storage in state to keep the full value+aux tuple when needed 4. Added has_aux parameter to value_and_grad_from_state to properly handle auxiliary data when reusing cached values This allows the linesearch algorithms to work with loss functions that return auxiliary data while maintaining the optimization over just the primary loss value.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Hey, thanks for continuing to work on this! Can you sign the CLA please? |
@ro0mquy thanks for this PR, I'm keen to see this feature added! |
I'm still figuring out the CLA thing with my company. But it's just a matter of time. I also accidentally pushed another fix (the "fix slope calculation in zoom_linesearch") to the same branch. I can revert that again if needed. |
I signed the CLA and updated my email. Is there a way to retrigger the check? |
Hello @ro0mquy, By curiosity
But your fix seems pretty good !
Thanks for looking into this! |
Hey thank you for looking at this.
I don't have strong opinions about the exact api design. Feel free to make suggestions. |
Summary
This change adds support for loss functions that return auxiliary data alongside their primary value, like
(loss_value, extra_data)
. This pattern is commonly used withjax.value_and_grad(fn, has_aux=True)
.The approach:
value_fn_has_aux
flag tozoom_linesearch
andscale_by_zoom_linesearch
_unpack_value
helper that extracts just the loss valuehas_aux
parameter tovalue_and_grad_from_state
to properly handle auxiliary data when reusing cached valuesThis allows the linesearch algorithms to work with loss functions that return auxiliary data while maintaining the optimization over just the primary loss value.
Input needed: How to initialize
opt_state
?The linesearch algorithm stores
value
andgrad
in the optimizer state to enable reuse of function evaluations. When using auxiliary data, JAX compilation needs to know the structure of this data upfront.Currently, I'm initializing it like this:
This feels a bit hacky since it requires an extra function evaluation just to get the structure. Is there a better way to handle this initialization?
The challenge is that the auxiliary data structure is determined by the loss function and could be arbitrary (e.g., dictionaries, nested structures, etc.).
ToDos
opt_state