Skip to content

Add a grid search refinement algorithm for steps, and graphics to support the idea (GEN-690) #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

littleredcomputer
Copy link
Collaborator

  • implements a step-and-improve algorithm, where after each step we search a grid of nearby points, some of which may better explain the sensor readings. or use a Boltzmann resample to look for improvements in a continuous probability cloud
  • adds some graphics to show the improvements available at a given step
  • upgrade genjax and genstudio requirements

@littleredcomputer littleredcomputer changed the title Add a grid search refinement algorithm for steps, and graphics to support the idea Add a grid search refinement algorithm for steps, and graphics to support the idea (GEN-690) Oct 22, 2024
Copy link

linear bot commented Oct 22, 2024

return pose_plot(p, wing_length=0.006, **opts)

def weighted_small_pose_plot(target, poses, ws):
lse_ws = jnp.log(jnp.sum(jnp.exp(ws)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider using jax.scipy.special.logsumexp

scaled_ws = jnp.exp(ws - lse_ws)
max_scaled_w: FloatArray = jnp.max(scaled_ws)
scaled_ws /= max_scaled_w
# the following hack "boosts" lower scores a bit, to give us more visibility into
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's a recurring problem, I often have to hack something similar.

nearby_poses = grid_of_nearby_poses(pose, 15, motion_settings)
ws, retvals = jax.vmap(lambda p: assess_model.assess(choicemap, (p,)))(nearby_poses)
return select_by_weight(k, ws, nearby_poses)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@littleredcomputer the structure is all there but I think there is probably something off with the math of the weights from which you're resampling, perhaps in both of these cases.

@sritchie
Copy link
Collaborator

sritchie commented Nov 1, 2024

@littleredcomputer , I want to encourage you to merge PRs more frequently for work here, rather than waiting until the full thing is working (intermediate points are fine). It's hard to share progress from a branch imo...

@MathieuHuot
Copy link
Collaborator

@littleredcomputer , I want to encourage you to merge PRs more frequently for work here, rather than waiting until the full thing is working (intermediate points are fine). It's hard to share progress from a branch imo...

@sritchie it might be on me, I'm not sure. I requested some changes as the math was not ok, even though it was working decently well in practice. And I think trying to do these fixes revealed that the code was hard to work with. Any advise? We could merge this version I guess, but it's important to flag this is not quite doing inference in a Bayesian way yet.

@littleredcomputer
Copy link
Collaborator Author

Let's please take another look at this @mhuot. I don't think it's perfect at this point and we could select from a few visualization opportunities to work on and drop some experiments that haven't panned out

Copy link
Collaborator

@MathieuHuot MathieuHuot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job Colin, and the code and explanations are really clean thanks to your simplifications and refactoring!
I noticed (and I think you mentioned it somewhere) that you removed the MCMC rejuvenation using the Boltzman rule to instead focus the narrative on going from importance sampling to SMC with resampling at every step to SMC with grid rejuvenation. It is smooth and I think totally sufficient here but I wonder if there was another reason why you removed it.

@sritchie
Copy link
Collaborator

@littleredcomputer I merged a later PR that MAY have subsumed this — let me know what I goofed here, and if we toss this or if you could help resolve the diff. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants