Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions pymc_experimental/inference/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def sample_smc_blackjax(
num_mcmc_steps,
kernel,
diagnosis,
total_iterations,
int(total_iterations),
iterations_to_diagnose,
inner_kernel_params,
running_time,
Expand Down Expand Up @@ -198,13 +198,13 @@ def arviz_from_particles(model, particles):
-------
"""
n_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
by_varname = {k.name: v.squeeze()[np.newaxis, :] for k, v in zip(model.value_vars, particles)}
by_varname = {k.name: v for k, v in zip(model.value_vars, particles)}
varnames = [v.name for v in model.value_vars]
with model:
strace = NDArray(name=model.name)
strace.setup(n_particles, 0)
for particle_index in range(0, n_particles):
strace.record(point={k: by_varname[k][0][particle_index] for k in varnames})
strace.record(point={k: by_varname[k][particle_index] for k in varnames})
multitrace = MultiTrace((strace,))
return to_inference_data(multitrace, log_likelihood=False)

Expand Down Expand Up @@ -295,14 +295,7 @@ def blackjax_particles_from_pymc_population(model, pymc_population):

order_of_vars = model.value_vars

def _format(var):
variable = pymc_population[var.name]
if len(variable.shape) == 1:
return variable[:, np.newaxis]
else:
return variable

return [_format(var) for var in order_of_vars]
return [pymc_population[var.name] for var in order_of_vars]


def add_to_inference_data(
Expand Down Expand Up @@ -384,7 +377,7 @@ def get_jaxified_particles_fn(model, graph_outputs):
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[graph_outputs])

def logp_fn_wrap(particles):
return logp_fn(*[p.squeeze() for p in particles])[0]
return logp_fn(*particles)[0]

return logp_fn_wrap

Expand Down
12 changes: 6 additions & 6 deletions pymc_experimental/tests/test_blackjax_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_blackjax_particles_from_pymc_population_univariate():
model = fast_model()
population = {"x": np.array([2, 3, 4])}
blackjax_particles = blackjax_particles_from_pymc_population(model, population)
jax.tree.map(np.testing.assert_allclose, blackjax_particles, [np.array([[2], [3], [4]])])
jax.tree.map(np.testing.assert_allclose, blackjax_particles, [np.array([2, 3, 4])])


def test_blackjax_particles_from_pymc_population_multivariate():
Expand All @@ -147,7 +147,7 @@ def test_blackjax_particles_from_pymc_population_multivariate():
jax.tree.map(
np.testing.assert_allclose,
blackjax_particles,
[np.array([[0.34614613], [1.09163261], [-0.44526825]]), np.array([[1], [2], [3]])],
[np.array([0.34614613, 1.09163261, -0.44526825]), np.array([1, 2, 3])],
)


Expand All @@ -168,7 +168,7 @@ def test_blackjax_particles_from_pymc_population_multivariable():
population = {"x": np.array([[2, 3], [5, 6], [7, 9]]), "z": np.array([11, 12, 13])}
blackjax_particles = blackjax_particles_from_pymc_population(model, population)

jax.tree.map(
jax.tree_map(
np.testing.assert_allclose,
blackjax_particles,
[np.array([[2, 3], [5, 6], [7, 9]]), np.array([[11], [12], [13]])],
Expand All @@ -181,7 +181,7 @@ def test_arviz_from_particles():
with model:
inference_data = arviz_from_particles(model, particles)

assert inference_data.posterior.sizes == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2})
assert inference_data.posterior.dims == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2})
assert inference_data.posterior.data_vars.dtypes == Frozen(
{"x": dtype("float64"), "z": dtype("float64")}
)
Expand All @@ -196,7 +196,7 @@ def test_get_jaxified_logprior():
"""
logprior = get_jaxified_logprior(fast_model())
for point in [-0.5, 0.0, 0.5]:
jax.tree.map(
jax.tree_map(
np.testing.assert_allclose,
jax.vmap(logprior)([np.array([point])]),
np.log(scipy.stats.norm(0, 1).pdf(point)),
Expand All @@ -212,7 +212,7 @@ def test_get_jaxified_loglikelihood():
"""
loglikelihood = get_jaxified_loglikelihood(fast_model())
for point in [-0.5, 0.0, 0.5]:
jax.tree.map(
jax.tree_map(
np.testing.assert_allclose,
jax.vmap(loglikelihood)([np.array([point])]),
np.log(scipy.stats.norm(point, 1).pdf(0)),
Expand Down