Skip to content

Add a grid search refinement algorithm for steps, and graphics to support the idea (GEN-690) #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 239 additions & 18 deletions genjax-localization-tutorial/probcomp-localization-tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

auth.authenticate_user()
%pip install --quiet keyring keyrings.google-artifactregistry-auth # type: ignore # noqa
%pip install --quiet genjax==0.5.1 genstudio==2024.7.30.1617 --extra-index-url https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple/ # type: ignore # noqa
%pip install --quiet genjax==0.7.0 genstudio==2024.9.7 --extra-index-url https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple/ # type: ignore # noqa
# %% [markdown]
# # ProbComp Localization Tutorial
#
Expand All @@ -36,6 +36,9 @@
import json
import genstudio.plot as Plot




import itertools
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -385,8 +388,10 @@ def integrate_controls_physical(robot_inputs):
# %% [markdown]
# ### Plot such data
# %%
def pose_plot(p, r=0.5, fill: str | Any = "black", **opts):
WING_ANGLE, WING_LENGTH = jnp.pi/12, 0.6
def pose_plot(p, fill: str | Any = "black", **opts):
r = opts.get('r', 0.5)
wing_opacity = opts.get('opacity', 0.3)
WING_ANGLE, WING_LENGTH = jnp.pi/12, opts.get('wing_length', 0.6)
center = p.p
angle = jnp.arctan2(*(center - p.step_along(-r).p)[::-1])

Expand All @@ -399,19 +404,19 @@ def pose_plot(p, r=0.5, fill: str | Any = "black", **opts):
# Draw wings
wings = Plot.line(
[wing_ends[0], center, wing_ends[1]],
strokeWidth=2,
strokeWidth=opts.get('strokeWidth', 2),
stroke=fill,
opacity=0.3
opacity=wing_opacity
)

# Draw center dot
dot = Plot.ellipse([center], r=0.14, fill=fill, **opts)
dot = Plot.ellipse([center], fill=fill, **opts)

return wings + dot

walls_plot = Plot.new(
Plot.line(
Plot.cache(world["wall_verts"]),
world["wall_verts"],
strokeWidth=2,
stroke="#ccc",
),
Expand Down Expand Up @@ -762,7 +767,6 @@ def generate_path(key: PRNGKey) -> Pose:
sample_paths_v = jax.vmap(generate_path)(jax.random.split(sub_key, N_samples))

Plot.Grid([walls_plot + poses_to_plots(path) for path in sample_paths_v])

# %%
# Animation showing a single path with confidence circles

Expand Down Expand Up @@ -977,7 +981,6 @@ def sensor_model_one(pose, angle):
@ "distance"
)


sensor_model = sensor_model_one.vmap(in_axes=(None, 0))


Expand Down Expand Up @@ -1083,21 +1086,25 @@ def get_sensors(trace):
tr = default_full_model.simulate(sub_key, ())


def animate_path_and_sensors(path, readings, motion_settings, frame_key=None):
frames = [
plot_path_with_confidence(path, step, motion_settings['p_noise'])
+ plot_sensors(pose, readings[step])
for step, pose in enumerate(path)
]

return Plot.Frames(frames, fps=2, key=frame_key)


def animate_full_trace(trace, frame_key=None):
path = get_path(trace)
readings = get_sensors(trace)
# since we use make_full_model to curry motion_settings around the scan combinator,
# that object will not be in the outer trace's argument list; but we can be a little
# crafty and find it at a lower level.
motion_settings = trace.get_subtrace(('initial',)).get_subtrace(('pose',)).get_args()[1]
return animate_path_and_sensors(path, readings, motion_settings, frame_key=frame_key)

frames = [
plot_path_with_confidence(path, step, motion_settings['p_noise'])
+ plot_sensors(pose, readings[step])
for step, pose in enumerate(path)
]

return Plot.Frames(frames, fps=2, key=frame_key)

animate_full_trace(tr)
# %% [markdown]
Expand Down Expand Up @@ -1291,7 +1298,7 @@ def constraint_from_path(path):
["High deviation",
trace_path_integrated_observations_high_deviation,
motion_settings_high_deviation,
w_high]]]) | Plot.Slider("frame", T, fps=2)
w_high]]]) | Plot.Slider("frame", 0, T, fps=2)

# %% [markdown]
# ...more closely resembles the density of these data back-fitted onto any other typical (random) paths of the model...
Expand Down Expand Up @@ -1425,7 +1432,7 @@ def animate_path_as_line(path, **options):
x_coords = path.p[:, 0]
y_coords = path.p[:, 1]
return Plot.line({"x": x_coords, "y": y_coords},
{"curve": "cardinal-open",
{"curve": "linear",
**options})
#
(world_plot
Expand Down Expand Up @@ -1517,3 +1524,217 @@ def multi_drift(key, trace: genjax.Trace, scale, K: int, N: int):
# We can see some improvement in the density of the paths selected. It's possible to imagine improving the search by repeating this drift process on all of the samples retured by the original importance sample. But we must face one important fact: we have used acceleration to improve what amounts to a brute-force search. The next inference step should take advantage of the information we have about the control steps, iteratively improving the path from the starting point, combining the control step and sensor data information to refine the selection of each step as it is made.

# %%
# Let's approach the problem step by step instead of trying to infer the whole path.
# For each given pose, we will use the sensor data to propose a refinement.

@genjax.gen
def perturb_pose(pose: Pose, motion_settings):
d_p = jnp.array((
genjax.normal(0.0, motion_settings['p_noise']) @ 'd_x',
genjax.normal(0.0, motion_settings['p_noise']) @ 'd_y'
))
d_hd = genjax.normal(0.0, motion_settings['hd_noise']) @ 'd_hd'
return Pose(pose.p + d_p, pose.hd + d_hd)

@genjax.gen
def perturb_model(pose: Pose, motion_settings):
p1 = perturb_pose(pose, motion_settings) @ 'pose'
_ = sensor_model(p1, sensor_angles) @ 'sensor'
return p1

# %% [markdown]
# To get started we'll work with the initial point, and then improve it. Once that's done,
# we can chain together such improved moves to hopefully get a better inference of the
# actual path.

# %%
key, sub_key = jax.random.split(key)
p0 = start_pose_prior.simulate(sub_key, (robot_inputs['start'], motion_settings_low_deviation))
key, sub_key = jax.random.split(key)
tr_p0 = jax.vmap(perturb_model.simulate, in_axes=(0, None))(jax.random.split(sub_key, 100), (p0.get_retval(), motion_settings_low_deviation))
# %% [markdown]
# Create a choicemap that will enforce the given sensor observation

def observation_to_choicemap(observation):
return C['sensor', jnp.arange(len(observation)), 'distance'].set(observation)
# %% [markdown]
# The first thing we'll try is a Boltzmann update: generate a cloud of nearby points
# using the generative function we wrote, and weightedly select a replacement from that.
# First, let's generate the cloud and visualize it.
# %%
def boltzmann_sample(key: PRNGKey, N: int, pose: Pose, motion_settings, observations):
return jax.vmap(perturb_model.importance, in_axes=(0, None, None))(
jax.random.split(key, N),
observation_to_choicemap(observations),
(pose, motion_settings)
)

def small_pose_plot(p: Pose, **opts):
"""This variant of pose_plot will is better when we're zoomed in on the vicinity of one pose.
TODO: consider scaling r and wing_length based on the size of the plot domain."""
opts = {'r': 0.001} | opts
return pose_plot(p, wing_length=0.006, **opts)

def weighted_small_pose_plot(target, poses, ws):
lse_ws = jnp.log(jnp.sum(jnp.exp(ws)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider using jax.scipy.special.logsumexp

scaled_ws = jnp.exp(ws - lse_ws)
max_scaled_w: FloatArray = jnp.max(scaled_ws)
scaled_ws /= max_scaled_w
# the following hack "boosts" lower scores a bit, to give us more visibility into
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's a recurring problem, I often have to hack something similar.

# the density of the nearby cloud. Aesthetically, I found too many points were
# invisible without some adjustment, since the score distribution is concentrated
# closely around 1.0
scaled_ws = scaled_ws ** 0.3
return (Plot.new([small_pose_plot(p, fill=w) for p, w in zip(poses, scaled_ws)]
+ small_pose_plot(target, r = 0.003, fill='red')
+ small_pose_plot(robot_inputs['start'], r=0.003,fill='green'))
+ {
"color": {"type":"linear", "scheme":"Purples"},
"height": 400,
"width": 400,
"aspectRatio": 1
})
key, sub_key = jax.random.split(key)
bs = boltzmann_sample(sub_key, 1000, p0.get_retval(), motion_settings_low_deviation, observations_low_deviation[0])
weighted_small_pose_plot(p0.get_retval(), bs[0].get_retval(), bs[1])
# %% [markdown]
# Develop a function which will produce a grid of evenly spaced nearby poses given
# an initial pose. $n$ is the number of steps to take in each cardinal direction
# (up/down, left/right and changes in heading). For example, if you say $n = 2$, there
# will be a $5\times 5$ grid of positions with the original pose in the center, and 5 layers
# of this type, each with different heading deltas (including zero), for a total of
# $125 = 5^3$ alternate poses.
# %%
def grid_of_nearby_poses(p, n, motion_settings):
indices = jnp.arange(-n, n+1)
n_indices = len(indices)
grid_ax = indices * 2 * motion_settings['p_noise'] / n
grid = jnp.dstack(jnp.meshgrid(grid_ax, grid_ax)).reshape(n_indices * n_indices, -1)
# That's the position grid. We will now make a 1-d grid for the heading deltas,
# and then form the linear cartesian product.
headings = indices * 2 * motion_settings['hd_noise'] / n
return Pose(jnp.repeat(p.p + grid, n_indices, axis=0), jnp.tile(p.hd + headings, n_indices * n_indices))
# %%
cube_step_size = 8
pose_grid = grid_of_nearby_poses(p0.get_retval(), cube_step_size, motion_settings_low_deviation)
@genjax.gen
def assess_model(p):
sensor_model(p, sensor_angles) @ 'sensor'
return p
# %%
key, sub_key = jax.random.split(key)
model_assess = jax.jit(assess_model.assess)
assess_scores, assess_retvals = jax.vmap(lambda k, p: model_assess(k, (p,)), in_axes=(None, 0))(observation_to_choicemap(observations_low_deviation[0]), pose_grid)
#assess_scores, assess_retvals = jax.vmap(lambda p: sensor_model.assess(cm, (p, sensor_angles)))(pose_grid)
# %%
# Our grid of nearby poses is actually a cube when we take into consideration the
# heading deltas. In order to get a 2d density, we decide to flatten the cube by
# taking the "best" of the headings by score at each point.
def flatten_pose_cube(n, poses, scores):
d = 2 * n + 1
pose_groups = poses.p.reshape((d, d*d, 2))
heading_groups = poses.hd.reshape((d, d*d))
score_groups = scores.reshape((d, d*d))
# find the best score in each group
best = jnp.argmax(score_groups, axis=1)
# We want to select the best column from every row, so we need to
# explicitly enumerate the rows we want (using : would not have the
# same effect)
return (Pose(pose_groups[jnp.arange(len(pose_groups)), best],
heading_groups[jnp.arange(len(heading_groups)), best]),
score_groups[jnp.arange(len(score_groups)), best])

# %%
#sensor_model.assess(cm, (pose_grid[0], sensor_angles))
#sensor_model.simulate(sub_key, (pose_grid[0], sensor_angles))
#sensor_model.importance(sub_key, observations_to_choicemap(observations_low_deviation, 0), (pose_grid[0], sensor_angles))

# Since the above calls work...
# I think this ought to work, but doesn't! TODO: find a minimal repro and file an issue
#sensor_model.assess(cm, (pose_grid[0], sensor_angles))
# %% [markdown]
# Prepare a plot showing the density of nearby improvements available using the grid
# search and importance sampling techniques.
# %%
assess_pose_plane, assess_score_plan = flatten_pose_cube(cube_step_size, assess_retvals, assess_scores)
(weighted_small_pose_plot(p0.get_retval(), assess_retvals, assess_scores) &
weighted_small_pose_plot(p0.get_retval(), bs[0].get_retval(), bs[1]))
# %% [markdown]
# Now let's try doing the whole path. We want to produce something that is ultimately
# scan-compatible, so it should have the form state -> update -> new_state. The state
# is obviously the pose; the update will include the sensor readings at the current
# position and the control input for the next step.

def select_by_weight(key: PRNGKey, weights: FloatArray, things):
chosen = jax.random.categorical(key, weights)
return jax.tree.map(lambda v: v[chosen], things)

def improved_path(key: PRNGKey, motion_settings: dict, observations: FloatArray, mode: str):

def boltzmann_improver(k: PRNGKey, pose, observation):
k1, k2 = jax.random.split(k, 2)
trs, ws = boltzmann_sample(k1, 1000, pose, motion_settings, observation)
return select_by_weight(k2, ws, trs.get_retval())

def grid_search_improver(k: PRNGKey, pose, observation):
choicemap = observation_to_choicemap(observation)
nearby_poses = grid_of_nearby_poses(pose, 15, motion_settings)
ws, retvals = jax.vmap(lambda p: assess_model.assess(choicemap, (p,)))(nearby_poses)
return select_by_weight(k, ws, nearby_poses)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@littleredcomputer the structure is all there but I think there is probably something off with the math of the weights from which you're resampling, perhaps in both of these cases.

def improve_pose_and_step(state, update):
pose = state
observation, control, key = update
k1, k2 = jax.random.split(key)
# improve the step where we are
improver = {"grid": grid_search_improver, "boltzmann": boltzmann_improver}[mode]
p1 = improver(k1, pose, observation)
# run the step model to advance one step
p2 = step_model.simulate(k2, (p1, control, motion_settings))
return (p2.get_retval(), p1)

# We have one fewer control than step, since no step got us to the initial position.
# Our scan step starts at the initial step and applies a control input each time.
# To make things balance, we need to add a zero step to the end of the control input
# array, so that when we arrive at the final step, no more control input is given.
controls = robot_inputs['controls'] + Control(jnp.array([0]), jnp.array([0]))
n_steps = len(controls)
sub_keys = jax.random.split(key, n_steps + 1)
p0 = start_pose_prior.simulate(sub_keys[0], (robot_inputs['start'], motion_settings)).get_retval()
return jax.lax.scan(improve_pose_and_step, p0, (
observations,
controls,
sub_keys[1:]
))
# %%
# Select an importance sample via weight in both the low and high deviation settings.
key, k1, k2 = jax.random.split(key, 3)
low_importance = select_by_weight(k1, low_weights, low_deviation_paths)
high_importance = select_by_weight(k2, high_weights, high_deviation_paths)
# %%
key, sub_key = jax.random.split(key)
endpoint_low, improved_low = improved_path(sub_key, motion_settings_low_deviation, observations_low_deviation, "grid")
# %%

def path_comparison_plot(*plots):
types = ["improved", "integrated", "importance", "true"]
plot = world_plot
plot += [animate_path_as_line(p, strokeWidth=2, stroke=Plot.constantly(t)) for p, t in zip(plots, types)]
plot += [poses_to_plots(p, fill=Plot.constantly(t)) for p, t in zip(plots, types)]
return plot + Plot.color_map({"integrated": "green", "improved": "blue", "true": "black", "importance": "red"})

# %%
path_comparison_plot(improved_low, path_integrated, low_importance, path_low_deviation)
# %%
key, sub_key = jax.random.split(key)
endpoint_high, improved_high = improved_path(sub_key, motion_settings_high_deviation, observations_high_deviation, "grid")
path_comparison_plot(improved_high, path_integrated, high_importance, path_high_deviation)
# %% [markdown]
# To see how the grid search improves poses, we play back the grid-search path
# next to an importance sample path. You can see the grid search has a better fit
# of sensor data to wall position at a variety of time steps.
# %%
Plot.Row(
animate_path_and_sensors(improved_high, observations_high_deviation, motion_settings_high_deviation, frame_key="frame"),
animate_path_and_sensors(high_importance, observations_high_deviation, motion_settings_high_deviation, frame_key="frame")
) | Plot.Slider("frame", 0, T, fps=2)
Loading
Loading