-
Notifications
You must be signed in to change notification settings - Fork 1
Add a grid search refinement algorithm for steps, and graphics to support the idea (GEN-690) #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
littleredcomputer
commented
Oct 22, 2024
- implements a step-and-improve algorithm, where after each step we search a grid of nearby points, some of which may better explain the sensor readings. or use a Boltzmann resample to look for improvements in a continuous probability cloud
- adds some graphics to show the improvements available at a given step
- upgrade genjax and genstudio requirements
return pose_plot(p, wing_length=0.006, **opts) | ||
|
||
def weighted_small_pose_plot(target, poses, ws): | ||
lse_ws = jnp.log(jnp.sum(jnp.exp(ws))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider using jax.scipy.special.logsumexp
scaled_ws = jnp.exp(ws - lse_ws) | ||
max_scaled_w: FloatArray = jnp.max(scaled_ws) | ||
scaled_ws /= max_scaled_w | ||
# the following hack "boosts" lower scores a bit, to give us more visibility into |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah it's a recurring problem, I often have to hack something similar.
nearby_poses = grid_of_nearby_poses(pose, 15, motion_settings) | ||
ws, retvals = jax.vmap(lambda p: assess_model.assess(choicemap, (p,)))(nearby_poses) | ||
return select_by_weight(k, ws, nearby_poses) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@littleredcomputer the structure is all there but I think there is probably something off with the math of the weights from which you're resampling, perhaps in both of these cases.
@littleredcomputer , I want to encourage you to merge PRs more frequently for work here, rather than waiting until the full thing is working (intermediate points are fine). It's hard to share progress from a branch imo... |
@sritchie it might be on me, I'm not sure. I requested some changes as the math was not ok, even though it was working decently well in practice. And I think trying to do these fixes revealed that the code was hard to work with. Any advise? We could merge this version I guess, but it's important to flag this is not quite doing inference in a Bayesian way yet. |
for more information, see https://pre-commit.ci
- take care of the conseqeunces of this throughout - switch from accumulate to scan with dimap (a little more verbose but consistency is improved)
- refresh jupyter version
for more information, see https://pre-commit.ci
Let's please take another look at this @mhuot. I don't think it's perfect at this point and we could select from a few visualization opportunities to work on and drop some experiments that haven't panned out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job Colin, and the code and explanations are really clean thanks to your simplifications and refactoring!
I noticed (and I think you mentioned it somewhere) that you removed the MCMC rejuvenation using the Boltzman rule to instead focus the narrative on going from importance sampling to SMC with resampling at every step to SMC with grid rejuvenation. It is smooth and I think totally sufficient here but I wonder if there was another reason why you removed it.
@littleredcomputer I merged a later PR that MAY have subsumed this — let me know what I goofed here, and if we toss this or if you could help resolve the diff. Thank you! |