-
Notifications
You must be signed in to change notification settings - Fork 1
Add a grid search refinement algorithm for steps, and graphics to support the idea (GEN-690) #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
ee2bb33
afb30f0
2bf8000
02ca83c
de1ddad
82ccd64
07d4abb
a99a0dc
b3f7413
4e88dc2
e81d7b0
7d1ab25
3a8957b
93c78a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ | |
|
||
auth.authenticate_user() | ||
%pip install --quiet keyring keyrings.google-artifactregistry-auth # type: ignore # noqa | ||
%pip install --quiet genjax==0.5.1 genstudio==2024.7.30.1617 --extra-index-url https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple/ # type: ignore # noqa | ||
%pip install --quiet genjax==0.7.0 genstudio==2024.9.7 --extra-index-url https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple/ # type: ignore # noqa | ||
# %% [markdown] | ||
# # ProbComp Localization Tutorial | ||
# | ||
|
@@ -36,6 +36,9 @@ | |
import json | ||
import genstudio.plot as Plot | ||
|
||
|
||
|
||
|
||
import itertools | ||
import jax | ||
import jax.numpy as jnp | ||
|
@@ -385,8 +388,10 @@ def integrate_controls_physical(robot_inputs): | |
# %% [markdown] | ||
# ### Plot such data | ||
# %% | ||
def pose_plot(p, r=0.5, fill: str | Any = "black", **opts): | ||
WING_ANGLE, WING_LENGTH = jnp.pi/12, 0.6 | ||
def pose_plot(p, fill: str | Any = "black", **opts): | ||
r = opts.get('r', 0.5) | ||
wing_opacity = opts.get('opacity', 0.3) | ||
WING_ANGLE, WING_LENGTH = jnp.pi/12, opts.get('wing_length', 0.6) | ||
center = p.p | ||
angle = jnp.arctan2(*(center - p.step_along(-r).p)[::-1]) | ||
|
||
|
@@ -399,19 +404,19 @@ def pose_plot(p, r=0.5, fill: str | Any = "black", **opts): | |
# Draw wings | ||
wings = Plot.line( | ||
[wing_ends[0], center, wing_ends[1]], | ||
strokeWidth=2, | ||
strokeWidth=opts.get('strokeWidth', 2), | ||
stroke=fill, | ||
opacity=0.3 | ||
opacity=wing_opacity | ||
) | ||
|
||
# Draw center dot | ||
dot = Plot.ellipse([center], r=0.14, fill=fill, **opts) | ||
dot = Plot.ellipse([center], fill=fill, **opts) | ||
|
||
return wings + dot | ||
|
||
walls_plot = Plot.new( | ||
Plot.line( | ||
Plot.cache(world["wall_verts"]), | ||
world["wall_verts"], | ||
strokeWidth=2, | ||
stroke="#ccc", | ||
), | ||
|
@@ -762,7 +767,6 @@ def generate_path(key: PRNGKey) -> Pose: | |
sample_paths_v = jax.vmap(generate_path)(jax.random.split(sub_key, N_samples)) | ||
|
||
Plot.Grid([walls_plot + poses_to_plots(path) for path in sample_paths_v]) | ||
|
||
# %% | ||
# Animation showing a single path with confidence circles | ||
|
||
|
@@ -977,7 +981,6 @@ def sensor_model_one(pose, angle): | |
@ "distance" | ||
) | ||
|
||
|
||
sensor_model = sensor_model_one.vmap(in_axes=(None, 0)) | ||
|
||
|
||
|
@@ -1083,21 +1086,25 @@ def get_sensors(trace): | |
tr = default_full_model.simulate(sub_key, ()) | ||
|
||
|
||
def animate_path_and_sensors(path, readings, motion_settings, frame_key=None): | ||
frames = [ | ||
plot_path_with_confidence(path, step, motion_settings['p_noise']) | ||
+ plot_sensors(pose, readings[step]) | ||
for step, pose in enumerate(path) | ||
] | ||
|
||
return Plot.Frames(frames, fps=2, key=frame_key) | ||
|
||
|
||
def animate_full_trace(trace, frame_key=None): | ||
path = get_path(trace) | ||
readings = get_sensors(trace) | ||
# since we use make_full_model to curry motion_settings around the scan combinator, | ||
# that object will not be in the outer trace's argument list; but we can be a little | ||
# crafty and find it at a lower level. | ||
motion_settings = trace.get_subtrace(('initial',)).get_subtrace(('pose',)).get_args()[1] | ||
return animate_path_and_sensors(path, readings, motion_settings, frame_key=frame_key) | ||
|
||
frames = [ | ||
plot_path_with_confidence(path, step, motion_settings['p_noise']) | ||
+ plot_sensors(pose, readings[step]) | ||
for step, pose in enumerate(path) | ||
] | ||
|
||
return Plot.Frames(frames, fps=2, key=frame_key) | ||
|
||
animate_full_trace(tr) | ||
# %% [markdown] | ||
|
@@ -1291,7 +1298,7 @@ def constraint_from_path(path): | |
["High deviation", | ||
trace_path_integrated_observations_high_deviation, | ||
motion_settings_high_deviation, | ||
w_high]]]) | Plot.Slider("frame", T, fps=2) | ||
w_high]]]) | Plot.Slider("frame", 0, T, fps=2) | ||
|
||
# %% [markdown] | ||
# ...more closely resembles the density of these data back-fitted onto any other typical (random) paths of the model... | ||
|
@@ -1425,7 +1432,7 @@ def animate_path_as_line(path, **options): | |
x_coords = path.p[:, 0] | ||
y_coords = path.p[:, 1] | ||
return Plot.line({"x": x_coords, "y": y_coords}, | ||
{"curve": "cardinal-open", | ||
{"curve": "linear", | ||
**options}) | ||
# | ||
(world_plot | ||
|
@@ -1517,3 +1524,217 @@ def multi_drift(key, trace: genjax.Trace, scale, K: int, N: int): | |
# We can see some improvement in the density of the paths selected. It's possible to imagine improving the search by repeating this drift process on all of the samples retured by the original importance sample. But we must face one important fact: we have used acceleration to improve what amounts to a brute-force search. The next inference step should take advantage of the information we have about the control steps, iteratively improving the path from the starting point, combining the control step and sensor data information to refine the selection of each step as it is made. | ||
|
||
# %% | ||
# Let's approach the problem step by step instead of trying to infer the whole path. | ||
# For each given pose, we will use the sensor data to propose a refinement. | ||
|
||
@genjax.gen | ||
def perturb_pose(pose: Pose, motion_settings): | ||
d_p = jnp.array(( | ||
genjax.normal(0.0, motion_settings['p_noise']) @ 'd_x', | ||
genjax.normal(0.0, motion_settings['p_noise']) @ 'd_y' | ||
)) | ||
d_hd = genjax.normal(0.0, motion_settings['hd_noise']) @ 'd_hd' | ||
return Pose(pose.p + d_p, pose.hd + d_hd) | ||
|
||
@genjax.gen | ||
def perturb_model(pose: Pose, motion_settings): | ||
p1 = perturb_pose(pose, motion_settings) @ 'pose' | ||
_ = sensor_model(p1, sensor_angles) @ 'sensor' | ||
return p1 | ||
|
||
# %% [markdown] | ||
# To get started we'll work with the initial point, and then improve it. Once that's done, | ||
# we can chain together such improved moves to hopefully get a better inference of the | ||
# actual path. | ||
|
||
# %% | ||
key, sub_key = jax.random.split(key) | ||
p0 = start_pose_prior.simulate(sub_key, (robot_inputs['start'], motion_settings_low_deviation)) | ||
key, sub_key = jax.random.split(key) | ||
tr_p0 = jax.vmap(perturb_model.simulate, in_axes=(0, None))(jax.random.split(sub_key, 100), (p0.get_retval(), motion_settings_low_deviation)) | ||
# %% [markdown] | ||
# Create a choicemap that will enforce the given sensor observation | ||
|
||
def observation_to_choicemap(observation): | ||
return C['sensor', jnp.arange(len(observation)), 'distance'].set(observation) | ||
# %% [markdown] | ||
# The first thing we'll try is a Boltzmann update: generate a cloud of nearby points | ||
# using the generative function we wrote, and weightedly select a replacement from that. | ||
# First, let's generate the cloud and visualize it. | ||
# %% | ||
def boltzmann_sample(key: PRNGKey, N: int, pose: Pose, motion_settings, observations): | ||
return jax.vmap(perturb_model.importance, in_axes=(0, None, None))( | ||
jax.random.split(key, N), | ||
observation_to_choicemap(observations), | ||
(pose, motion_settings) | ||
) | ||
|
||
def small_pose_plot(p: Pose, **opts): | ||
"""This variant of pose_plot will is better when we're zoomed in on the vicinity of one pose. | ||
TODO: consider scaling r and wing_length based on the size of the plot domain.""" | ||
opts = {'r': 0.001} | opts | ||
return pose_plot(p, wing_length=0.006, **opts) | ||
|
||
def weighted_small_pose_plot(target, poses, ws): | ||
lse_ws = jnp.log(jnp.sum(jnp.exp(ws))) | ||
scaled_ws = jnp.exp(ws - lse_ws) | ||
max_scaled_w: FloatArray = jnp.max(scaled_ws) | ||
scaled_ws /= max_scaled_w | ||
# the following hack "boosts" lower scores a bit, to give us more visibility into | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah it's a recurring problem, I often have to hack something similar. |
||
# the density of the nearby cloud. Aesthetically, I found too many points were | ||
# invisible without some adjustment, since the score distribution is concentrated | ||
# closely around 1.0 | ||
scaled_ws = scaled_ws ** 0.3 | ||
return (Plot.new([small_pose_plot(p, fill=w) for p, w in zip(poses, scaled_ws)] | ||
+ small_pose_plot(target, r = 0.003, fill='red') | ||
+ small_pose_plot(robot_inputs['start'], r=0.003,fill='green')) | ||
+ { | ||
"color": {"type":"linear", "scheme":"Purples"}, | ||
"height": 400, | ||
"width": 400, | ||
"aspectRatio": 1 | ||
}) | ||
key, sub_key = jax.random.split(key) | ||
bs = boltzmann_sample(sub_key, 1000, p0.get_retval(), motion_settings_low_deviation, observations_low_deviation[0]) | ||
weighted_small_pose_plot(p0.get_retval(), bs[0].get_retval(), bs[1]) | ||
# %% [markdown] | ||
# Develop a function which will produce a grid of evenly spaced nearby poses given | ||
# an initial pose. $n$ is the number of steps to take in each cardinal direction | ||
# (up/down, left/right and changes in heading). For example, if you say $n = 2$, there | ||
# will be a $5\times 5$ grid of positions with the original pose in the center, and 5 layers | ||
# of this type, each with different heading deltas (including zero), for a total of | ||
# $125 = 5^3$ alternate poses. | ||
# %% | ||
def grid_of_nearby_poses(p, n, motion_settings): | ||
indices = jnp.arange(-n, n+1) | ||
n_indices = len(indices) | ||
grid_ax = indices * 2 * motion_settings['p_noise'] / n | ||
grid = jnp.dstack(jnp.meshgrid(grid_ax, grid_ax)).reshape(n_indices * n_indices, -1) | ||
# That's the position grid. We will now make a 1-d grid for the heading deltas, | ||
# and then form the linear cartesian product. | ||
headings = indices * 2 * motion_settings['hd_noise'] / n | ||
return Pose(jnp.repeat(p.p + grid, n_indices, axis=0), jnp.tile(p.hd + headings, n_indices * n_indices)) | ||
# %% | ||
cube_step_size = 8 | ||
pose_grid = grid_of_nearby_poses(p0.get_retval(), cube_step_size, motion_settings_low_deviation) | ||
@genjax.gen | ||
def assess_model(p): | ||
sensor_model(p, sensor_angles) @ 'sensor' | ||
return p | ||
# %% | ||
key, sub_key = jax.random.split(key) | ||
model_assess = jax.jit(assess_model.assess) | ||
assess_scores, assess_retvals = jax.vmap(lambda k, p: model_assess(k, (p,)), in_axes=(None, 0))(observation_to_choicemap(observations_low_deviation[0]), pose_grid) | ||
#assess_scores, assess_retvals = jax.vmap(lambda p: sensor_model.assess(cm, (p, sensor_angles)))(pose_grid) | ||
# %% | ||
# Our grid of nearby poses is actually a cube when we take into consideration the | ||
# heading deltas. In order to get a 2d density, we decide to flatten the cube by | ||
# taking the "best" of the headings by score at each point. | ||
def flatten_pose_cube(n, poses, scores): | ||
d = 2 * n + 1 | ||
pose_groups = poses.p.reshape((d, d*d, 2)) | ||
heading_groups = poses.hd.reshape((d, d*d)) | ||
score_groups = scores.reshape((d, d*d)) | ||
# find the best score in each group | ||
best = jnp.argmax(score_groups, axis=1) | ||
# We want to select the best column from every row, so we need to | ||
# explicitly enumerate the rows we want (using : would not have the | ||
# same effect) | ||
return (Pose(pose_groups[jnp.arange(len(pose_groups)), best], | ||
heading_groups[jnp.arange(len(heading_groups)), best]), | ||
score_groups[jnp.arange(len(score_groups)), best]) | ||
|
||
# %% | ||
#sensor_model.assess(cm, (pose_grid[0], sensor_angles)) | ||
#sensor_model.simulate(sub_key, (pose_grid[0], sensor_angles)) | ||
#sensor_model.importance(sub_key, observations_to_choicemap(observations_low_deviation, 0), (pose_grid[0], sensor_angles)) | ||
|
||
# Since the above calls work... | ||
# I think this ought to work, but doesn't! TODO: find a minimal repro and file an issue | ||
#sensor_model.assess(cm, (pose_grid[0], sensor_angles)) | ||
# %% [markdown] | ||
# Prepare a plot showing the density of nearby improvements available using the grid | ||
# search and importance sampling techniques. | ||
# %% | ||
assess_pose_plane, assess_score_plan = flatten_pose_cube(cube_step_size, assess_retvals, assess_scores) | ||
(weighted_small_pose_plot(p0.get_retval(), assess_retvals, assess_scores) & | ||
weighted_small_pose_plot(p0.get_retval(), bs[0].get_retval(), bs[1])) | ||
# %% [markdown] | ||
# Now let's try doing the whole path. We want to produce something that is ultimately | ||
# scan-compatible, so it should have the form state -> update -> new_state. The state | ||
# is obviously the pose; the update will include the sensor readings at the current | ||
# position and the control input for the next step. | ||
|
||
def select_by_weight(key: PRNGKey, weights: FloatArray, things): | ||
chosen = jax.random.categorical(key, weights) | ||
return jax.tree.map(lambda v: v[chosen], things) | ||
|
||
def improved_path(key: PRNGKey, motion_settings: dict, observations: FloatArray, mode: str): | ||
|
||
def boltzmann_improver(k: PRNGKey, pose, observation): | ||
k1, k2 = jax.random.split(k, 2) | ||
trs, ws = boltzmann_sample(k1, 1000, pose, motion_settings, observation) | ||
return select_by_weight(k2, ws, trs.get_retval()) | ||
|
||
def grid_search_improver(k: PRNGKey, pose, observation): | ||
choicemap = observation_to_choicemap(observation) | ||
nearby_poses = grid_of_nearby_poses(pose, 15, motion_settings) | ||
ws, retvals = jax.vmap(lambda p: assess_model.assess(choicemap, (p,)))(nearby_poses) | ||
return select_by_weight(k, ws, nearby_poses) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @littleredcomputer the structure is all there but I think there is probably something off with the math of the weights from which you're resampling, perhaps in both of these cases. |
||
def improve_pose_and_step(state, update): | ||
pose = state | ||
observation, control, key = update | ||
k1, k2 = jax.random.split(key) | ||
# improve the step where we are | ||
improver = {"grid": grid_search_improver, "boltzmann": boltzmann_improver}[mode] | ||
p1 = improver(k1, pose, observation) | ||
# run the step model to advance one step | ||
p2 = step_model.simulate(k2, (p1, control, motion_settings)) | ||
return (p2.get_retval(), p1) | ||
|
||
# We have one fewer control than step, since no step got us to the initial position. | ||
# Our scan step starts at the initial step and applies a control input each time. | ||
# To make things balance, we need to add a zero step to the end of the control input | ||
# array, so that when we arrive at the final step, no more control input is given. | ||
controls = robot_inputs['controls'] + Control(jnp.array([0]), jnp.array([0])) | ||
n_steps = len(controls) | ||
sub_keys = jax.random.split(key, n_steps + 1) | ||
p0 = start_pose_prior.simulate(sub_keys[0], (robot_inputs['start'], motion_settings)).get_retval() | ||
return jax.lax.scan(improve_pose_and_step, p0, ( | ||
observations, | ||
controls, | ||
sub_keys[1:] | ||
)) | ||
# %% | ||
# Select an importance sample via weight in both the low and high deviation settings. | ||
key, k1, k2 = jax.random.split(key, 3) | ||
low_importance = select_by_weight(k1, low_weights, low_deviation_paths) | ||
high_importance = select_by_weight(k2, high_weights, high_deviation_paths) | ||
# %% | ||
key, sub_key = jax.random.split(key) | ||
endpoint_low, improved_low = improved_path(sub_key, motion_settings_low_deviation, observations_low_deviation, "grid") | ||
# %% | ||
|
||
def path_comparison_plot(*plots): | ||
types = ["improved", "integrated", "importance", "true"] | ||
plot = world_plot | ||
plot += [animate_path_as_line(p, strokeWidth=2, stroke=Plot.constantly(t)) for p, t in zip(plots, types)] | ||
plot += [poses_to_plots(p, fill=Plot.constantly(t)) for p, t in zip(plots, types)] | ||
return plot + Plot.color_map({"integrated": "green", "improved": "blue", "true": "black", "importance": "red"}) | ||
|
||
# %% | ||
path_comparison_plot(improved_low, path_integrated, low_importance, path_low_deviation) | ||
# %% | ||
key, sub_key = jax.random.split(key) | ||
endpoint_high, improved_high = improved_path(sub_key, motion_settings_high_deviation, observations_high_deviation, "grid") | ||
path_comparison_plot(improved_high, path_integrated, high_importance, path_high_deviation) | ||
# %% [markdown] | ||
# To see how the grid search improves poses, we play back the grid-search path | ||
# next to an importance sample path. You can see the grid search has a better fit | ||
# of sensor data to wall position at a variety of time steps. | ||
# %% | ||
Plot.Row( | ||
animate_path_and_sensors(improved_high, observations_high_deviation, motion_settings_high_deviation, frame_key="frame"), | ||
animate_path_and_sensors(high_importance, observations_high_deviation, motion_settings_high_deviation, frame_key="frame") | ||
) | Plot.Slider("frame", 0, T, fps=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider using
jax.scipy.special.logsumexp