Skip to content

Commit b461f90

Browse files
fehiepsineerajprad
authored andcommitted
Add thumbnail images to examples (#459)
1 parent 033393a commit b461f90

29 files changed

+365
-321
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ numpyro.egg-info
22
__pycache__/
33
.ipynb_checkpoints/
44
build
5+
tutorials/source/examples/
56

67
# built / compiled
78
*.pyc

docs/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ help:
1717
# Catch-all target: route all unknown targets to Sphinx using the new
1818
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
1919
%: Makefile
20-
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
20+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

examples/README.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Code Examples
2+
=============
3+
4+
Examples on specifying models and doing inference in NumPyro.
5+
6+
`View source files on github`__
7+
8+
.. _github: https://github.com/pyro-ppl/numpyro/tree/master/examples
9+
10+
__ github_

examples/baseball.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
1-
import argparse
2-
3-
import numpy as onp
4-
5-
import jax.numpy as np
6-
import jax.random as random
7-
from jax.scipy.special import logsumexp
8-
9-
import numpyro
10-
import numpyro.distributions as dist
11-
from numpyro.examples.datasets import BASEBALL, load_dataset
12-
from numpyro.infer import MCMC, NUTS, Predictive, log_likelihood
13-
14-
151
"""
2+
Baseball
3+
========
4+
165
Original example from Pyro:
176
https://github.com/pyro-ppl/pyro/blob/dev/examples/baseball.py
187
@@ -29,28 +18,44 @@
2918
from our models.
3019
3120
Three models are evaluated:
32-
- Complete pooling model: The success probability of scoring a hit is shared
33-
amongst all players.
34-
- No pooling model: Each individual player's success probability is distinct and
35-
there is no data sharing amongst players.
36-
- Partial pooling model: A hierarchical model with partial data sharing.
3721
22+
- Complete pooling model: The success probability of scoring a hit is shared
23+
amongst all players.
24+
- No pooling model: Each individual player's success probability is distinct and
25+
there is no data sharing amongst players.
26+
- Partial pooling model: A hierarchical model with partial data sharing.
3827
3928
We recommend Radford Neal's tutorial on HMC ([3]) to users who would like to get a
4029
more comprehensive understanding of HMC and its variants, and to [4] for details on
4130
the No U-Turn Sampler, which provides an efficient and automated way (i.e. limited
4231
hyper-parameters) of running HMC on different problems.
4332
44-
[1] Carpenter B. (2016), ["Hierarchical Partial Pooling for Repeated Binary Trials"]
45-
(http://mc-stan.org/users/documentation/case-studies/pool-binary-trials.html).
46-
[2] Efron B., Morris C. (1975), "Data analysis using Stein's estimator and its
47-
generalizations", J. Amer. Statist. Assoc., 70, 311-319.
48-
[3] Neal, R. (2012), "MCMC using Hamiltonian Dynamics",
49-
(https://arxiv.org/pdf/1206.1901.pdf)
50-
[4] Hoffman, M. D. and Gelman, A. (2014), "The No-U-turn sampler: Adaptively setting
51-
path lengths in Hamiltonian Monte Carlo", (https://arxiv.org/abs/1111.4246)
33+
**References:**
34+
35+
1. Carpenter B. (2016), `"Hierarchical Partial Pooling for Repeated Binary Trials"
36+
<http://mc-stan.org/users/documentation/case-studies/pool-binary-trials.html/>`_.
37+
2. Efron B., Morris C. (1975), "Data analysis using Stein's estimator and its
38+
generalizations", J. Amer. Statist. Assoc., 70, 311-319.
39+
3. Neal, R. (2012), "MCMC using Hamiltonian Dynamics",
40+
(https://arxiv.org/pdf/1206.1901.pdf)
41+
4. Hoffman, M. D. and Gelman, A. (2014), "The No-U-turn sampler: Adaptively setting
42+
path lengths in Hamiltonian Monte Carlo", (https://arxiv.org/abs/1111.4246)
5243
"""
5344

45+
import argparse
46+
import os
47+
48+
import numpy as onp
49+
50+
import jax.numpy as np
51+
import jax.random as random
52+
from jax.scipy.special import logsumexp
53+
54+
import numpyro
55+
import numpyro.distributions as dist
56+
from numpyro.examples.datasets import BASEBALL, load_dataset
57+
from numpyro.infer import MCMC, NUTS, Predictive, log_likelihood
58+
5459

5560
def fully_pooled(at_bats, hits=None):
5661
r"""
@@ -125,7 +130,8 @@ def partially_pooled_with_logit(at_bats, hits=None):
125130

126131
def run_inference(model, at_bats, hits, rng_key, args):
127132
kernel = NUTS(model)
128-
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains)
133+
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
134+
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
129135
mcmc.run(rng_key, at_bats, hits)
130136
return mcmc.get_samples()
131137

examples/bnn.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""
2+
Bayesian Neural Network
3+
=======================
4+
25
We demonstrate how to use NUTS to do inference on a simple (small)
36
Bayesian neural network with two hidden layers.
47
"""
58

69
import argparse
10+
import os
711
import time
812

913
import matplotlib
@@ -58,7 +62,8 @@ def model(X, Y, D_H):
5862
def run_inference(model, args, rng_key, X, Y, D_H):
5963
start = time.time()
6064
kernel = NUTS(model)
61-
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains)
65+
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
66+
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
6267
mcmc.run(rng_key, X, Y, D_H)
6368
mcmc.print_summary()
6469
print('\nMCMC elapsed time:', time.time() - start)
@@ -124,7 +129,7 @@ def main(args):
124129
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")
125130

126131
plt.savefig('bnn_plot.pdf')
127-
plt.close()
132+
plt.tight_layout()
128133

129134

130135
if __name__ == "__main__":

examples/funnel.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,9 @@
1-
import argparse
2-
3-
import matplotlib.pyplot as plt
4-
import seaborn as sns
5-
6-
from jax import random
7-
import jax.numpy as np
8-
9-
import numpyro
10-
import numpyro.distributions as dist
11-
from numpyro.distributions.transforms import AffineTransform
12-
from numpyro.infer import MCMC, NUTS
13-
14-
sns.set(context='talk')
15-
16-
171
"""
2+
Neal's Funnel
3+
=============
4+
185
This example, which is adapted from [1], illustrates how to leverage non-centered
19-
parameterization using the class `~numpyro.distributions.TransformedDistribution`.
6+
parameterization using the class :class:`numpyro.distributions.TransformedDistribution`.
207
We will examine the difference between two types of parameterizations on the
218
10-dimensional Neal's funnel distribution. As we will see, HMC gets trouble at
229
the neck of the funnel if centered parameterization is used. On the contrary,
@@ -29,11 +16,26 @@
2916
inference algorithms know to do reparameterization automatically is to declare
3017
the random variable as a transformed distribution.
3118
32-
[1] *Stan User's Guide*, https://mc-stan.org/docs/2_19/stan-users-guide/reparameterization-section.html
33-
[2] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019), "Automatic
34-
Reparameterisation of Probabilistic Programs", (https://arxiv.org/abs/1906.03028)
19+
**References:**
20+
21+
1. *Stan User's Guide*, https://mc-stan.org/docs/2_19/stan-users-guide/reparameterization-section.html
22+
2. Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019), "Automatic
23+
Reparameterisation of Probabilistic Programs", (https://arxiv.org/abs/1906.03028)
3524
"""
3625

26+
import argparse
27+
import os
28+
29+
import matplotlib.pyplot as plt
30+
31+
from jax import random
32+
import jax.numpy as np
33+
34+
import numpyro
35+
import numpyro.distributions as dist
36+
from numpyro.distributions.transforms import AffineTransform
37+
from numpyro.infer import MCMC, NUTS
38+
3739

3840
def model(dim=10):
3941
y = numpyro.sample('y', dist.Normal(0, 3))
@@ -48,41 +50,45 @@ def reparam_model(dim=10):
4850

4951
def run_inference(model, args, rng_key):
5052
kernel = NUTS(model)
51-
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains)
53+
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
54+
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
5255
mcmc.run(rng_key)
56+
mcmc.print_summary()
5357
return mcmc.get_samples()
5458

5559

5660
def main(args):
5761
rng_key = random.PRNGKey(0)
5862

5963
# do inference with centered parameterization
64+
print("============================= Centered Parameterization ==============================")
6065
samples = run_inference(model, args, rng_key)
6166

6267
# do inference with non-centered parameterization
68+
print("\n=========================== Non-centered Parameterization ============================")
6369
reparam_samples = run_inference(reparam_model, args, rng_key)
6470

6571
# make plots
66-
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(8, 16))
72+
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(6.4, 6.4))
6773

68-
sns.scatterplot(samples['x'][:, 0], samples['y'], color='g', alpha=0.3, ax=ax1)
69-
ax1.set(xlim=(-20, 20), ylim=(-9, 9), xlabel='x[0]', ylabel='y',
74+
ax1.plot(samples['x'][:, 0], samples['y'], "go", alpha=0.3)
75+
ax1.set(xlim=(-20, 20), ylim=(-9, 9), ylabel='y',
7076
title='Funnel samples with centered parameterization')
7177

72-
sns.scatterplot(reparam_samples['x'][:, 0], reparam_samples['y'], color='g', alpha=0.3, ax=ax2)
78+
ax2.plot(reparam_samples['x'][:, 0], reparam_samples['y'], "go", alpha=0.3)
7379
ax2.set(xlim=(-20, 20), ylim=(-9, 9), xlabel='x[0]', ylabel='y',
7480
title='Funnel samples with non-centered parameterization')
7581

7682
plt.savefig('funnel_plot.pdf')
77-
plt.close()
83+
plt.tight_layout()
7884

7985

8086
if __name__ == "__main__":
8187
assert numpyro.__version__.startswith('0.2.1')
8288
parser = argparse.ArgumentParser(description="Non-centered reparameterization example")
8389
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
8490
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
85-
parser.add_argument("--num-chains", nargs='?', default=4, type=int)
91+
parser.add_argument("--num-chains", nargs='?', default=1, type=int)
8692
parser.add_argument("--device", default='cpu', type=str, help='use "cpu" or "gpu".')
8793
args = parser.parse_args()
8894

examples/gp.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1+
"""
2+
Gaussian Process
3+
================
4+
5+
In this example we show how to use NUTS to sample from the posterior
6+
over the hyperparameters of a gaussian process.
7+
"""
8+
19
import argparse
10+
import os
211
import time
312

413
import matplotlib
@@ -16,11 +25,6 @@
1625

1726
matplotlib.use('Agg') # noqa: E402
1827

19-
"""
20-
In this example we show how to use NUTS to sample from the posterior
21-
over the hyperparameters of a gaussian process.
22-
"""
23-
2428

2529
# squared exponential kernel with diagonal noise term
2630
def kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
@@ -49,7 +53,8 @@ def model(X, Y):
4953
def run_inference(model, args, rng_key, X, Y):
5054
start = time.time()
5155
kernel = NUTS(model)
52-
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains)
56+
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
57+
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
5358
mcmc.run(rng_key, X, Y)
5459
mcmc.print_summary()
5560
print('\nMCMC elapsed time:', time.time() - start)
@@ -117,7 +122,7 @@ def main(args):
117122
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")
118123

119124
plt.savefig("gp_plot.pdf")
120-
plt.close()
125+
plt.tight_layout()
121126

122127

123128
if __name__ == "__main__":

examples/hmm.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
1-
import argparse
2-
import time
3-
4-
import numpy as onp
5-
6-
from jax import lax, random
7-
import jax.numpy as np
8-
from jax.scipy.special import logsumexp
9-
10-
import numpyro
11-
import numpyro.distributions as dist
12-
from numpyro.infer import MCMC, NUTS
13-
14-
151
"""
2+
Hidden Markov Model
3+
===================
4+
165
In this example, we will follow [1] to construct a semi-supervised Hidden Markov
176
Model for a generative model with observations are words and latent variables
187
are categories. Instead of automatically marginalizing all discrete latent
@@ -26,12 +15,30 @@
2615
JAX's `lax.scan` primitive. The primitive will greatly improve compiling for the
2716
model.
2817
29-
[1] https://mc-stan.org/docs/2_19/stan-users-guide/hmms-section.html
30-
[2] http://pyro.ai/examples/hmm.html
31-
[3] https://en.wikipedia.org/wiki/Forward_algorithm
32-
[4] https://discourse.pymc.io/t/how-to-marginalized-markov-chain-with-categorical/2230
18+
**References:**
19+
20+
1. https://mc-stan.org/docs/2_19/stan-users-guide/hmms-section.html
21+
2. http://pyro.ai/examples/hmm.html
22+
3. https://en.wikipedia.org/wiki/Forward_algorithm
23+
4. https://discourse.pymc.io/t/how-to-marginalized-markov-chain-with-categorical/2230
3324
"""
3425

26+
import argparse
27+
import os
28+
import time
29+
30+
import matplotlib.pyplot as plt
31+
import numpy as onp
32+
from scipy.stats import gaussian_kde
33+
34+
from jax import lax, random
35+
import jax.numpy as np
36+
from jax.scipy.special import logsumexp
37+
38+
import numpyro
39+
import numpyro.distributions as dist
40+
from numpyro.infer import MCMC, NUTS
41+
3542

3643
def simulate_data(rng_key, num_categories, num_words, num_supervised_data, num_unsupervised_data):
3744
rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
@@ -148,13 +155,29 @@ def main(args):
148155
rng_key = random.PRNGKey(2)
149156
start = time.time()
150157
kernel = NUTS(semi_supervised_hmm)
151-
mcmc = MCMC(kernel, args.num_warmup, args.num_samples)
158+
mcmc = MCMC(kernel, args.num_warmup, args.num_samples,
159+
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
152160
mcmc.run(rng_key, transition_prior, emission_prior, supervised_categories,
153161
supervised_words, unsupervised_words)
154162
samples = mcmc.get_samples()
155163
print_results(samples, transition_prob, emission_prob)
156164
print('\nMCMC elapsed time:', time.time() - start)
157165

166+
# make plots
167+
fig, ax = plt.subplots(1, 1)
168+
169+
x = onp.linspace(0, 1, 101)
170+
for i in range(transition_prob.shape[0]):
171+
for j in range(transition_prob.shape[1]):
172+
ax.plot(x, gaussian_kde(samples['transition_prob'][:, i, j])(x),
173+
label="transition_prob[{}, {}], true value = {:.2f}"
174+
.format(i, j, transition_prob[i, j]))
175+
ax.set(xlabel="Probability", ylabel="Frequency",
176+
title="Transition probability posterior")
177+
178+
plt.savefig("hmm_plot.pdf")
179+
plt.tight_layout()
180+
158181

159182
if __name__ == '__main__':
160183
assert numpyro.__version__.startswith('0.2.1')

0 commit comments

Comments
 (0)