Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vanilla HMC method #75

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docs/source/methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,16 @@ Posterior approximation methods
taken by averaging checkpoints over the stochastic optimization trajectory. The covariance is also estimated
empirically along the trajectory, and it is made of a diagonal component and a low-rank non-diagonal one.

- **Hamiltonian Monte Carlo (HMC)** `[Neal, 2010] <https://arxiv.org/pdf/1206.1901.pdf>`_
HMC approximates the posterior as a steady-state distribution of a Monte Carlo Markov chain with Hamiltonian dynamics.
After the initial "burn-in" phase, each step of the chain generates a sample from the posterior. HMC is typically applied
in the full-batch scenario.

- **Stochastic Gradient Hamiltonian Monte Carlo (SGHMC)** `[Chen et al., 2014] <http://proceedings.mlr.press/v32/cheni14.pdf>`_
SGHMC approximates the posterior as a steady-state distribution of a Monte Carlo Markov chain with Hamiltonian dynamics.
After the initial "burn-in" phase, each step of the chain generates samples from the posterior.
SGHMC implements a variant of HMC algorithm that expects noisy gradient estimate computed on mini-batches of data.

- **Cyclical Stochastic Gradient Langevin Dynamics (Cyclical SGLD)** `[Zhang et al., 2020] <https://openreview.net/pdf?id=rkeS1RVtPS>`_
Cyclical SGLD adapts the cyclical cosine step size schedule, and alternates between *exploration* and *sampling* stages to better
Cyclical SGLD adopts the cyclical cosine step size schedule, and alternates between *exploration* and *sampling* stages to better
explore the multimodal posteriors for deep neural networks.

Parametric calibration methods
Expand Down
27 changes: 26 additions & 1 deletion docs/source/references/prob_model/posterior/sgmcmc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,37 @@ SG-MCMC procedures approximate the posterior as a steady-state distribution of
a Monte Carlo Markov chain, that utilizes noisy estimates of the gradient
computed on minibatches of data.

Hamiltonian Monte Carlo (HMC)
=============================

HMC `[Neal, 2010] <https://arxiv.org/pdf/1206.1901.pdf>`_ is a MCMC sampling
algorithm that simulates a Hamiltonian dynamical system to rapidly explores
the posterior.

.. autoclass:: fortuna.prob_model.posterior.sgmcmc.hmc.hmc_approximator.HMCPosteriorApproximator

.. autoclass:: fortuna.prob_model.posterior.sgmcmc.hmc.hmc_posterior.HMCPosterior
:show-inheritance:
:no-inherited-members:
:exclude-members: state
:members: fit, sample, load_state, save_state

.. autoclass:: fortuna.prob_model.posterior.sgmcmc.hmc.hmc_state.HMCState
:show-inheritance:
:no-inherited-members:
:inherited-members: init, init_from_dict
:members: convert_from_map_state
:exclude-members: params, mutable, calib_params, calib_mutable, replace, apply_gradients, encoded_name, create
:no-undoc-members:
:no-special-members:


Stochastic Gradient Hamiltonian Monte Carlo (SGHMC)
===================================================

SGHMC `[Chen T. et al., 2014] <http://proceedings.mlr.press/v32/cheni14.pdf>`_
is a popular MCMC algorithm that uses stochastic gradient estimates to scale
to large datasets.
HMC to large datasets.

.. autoclass:: fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_approximator.SGHMCPosteriorApproximator

Expand Down
2 changes: 2 additions & 0 deletions fortuna/model/model_manager/name_to_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fortuna.prob_model.posterior.map import MAP_NAME
from fortuna.prob_model.posterior.normalizing_flow.advi import ADVI_NAME
from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld import CYCLICAL_SGLD_NAME
from fortuna.prob_model.posterior.sgmcmc.hmc import HMC_NAME
from fortuna.prob_model.posterior.sgmcmc.sghmc import SGHMC_NAME
from fortuna.prob_model.posterior.sngp import SNGP_NAME
from fortuna.prob_model.posterior.swag import SWAG_NAME
Expand All @@ -25,3 +26,4 @@ class ClassificationModelManagers(enum.Enum):
vars()[SNGP_NAME] = SNGPClassificationModelManager
vars()[SGHMC_NAME] = ClassificationModelManager
vars()[CYCLICAL_SGLD_NAME] = ClassificationModelManager
vars()[HMC_NAME] = ClassificationModelManager
3 changes: 3 additions & 0 deletions fortuna/prob_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_approximator import (
CyclicalSGLDPosteriorApproximator,
)
from fortuna.prob_model.posterior.sgmcmc.hmc.hmc_approximator import (
HMCPosteriorApproximator,
)
from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_approximator import (
SGHMCPosteriorApproximator,
)
Expand Down
2 changes: 2 additions & 0 deletions fortuna/prob_model/posterior/name_to_posterior_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum

from fortuna.output_calib_model.state import OutputCalibState
from fortuna.prob_model.posterior.sgmcmc.hmc.hmc_state import HMCState
from fortuna.prob_model.posterior.laplace.laplace_state import LaplaceState
from fortuna.prob_model.posterior.map.map_state import MAPState
from fortuna.prob_model.posterior.normalizing_flow.advi.advi_state import ADVIState
Expand All @@ -21,3 +22,4 @@ class NameToPosteriorState(enum.Enum):
vars()[SWAGState.__name__] = SWAGState
vars()[SGHMCState.__name__] = SGHMCState
vars()[CyclicalSGLDState.__name__] = CyclicalSGLDState
vars()[HMCState.__name__] = HMCState
3 changes: 3 additions & 0 deletions fortuna/prob_model/posterior/posterior_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_posterior import (
CyclicalSGLDPosterior,
)
from fortuna.prob_model.posterior.sgmcmc.hmc import HMC_NAME
from fortuna.prob_model.posterior.sgmcmc.hmc.hmc_posterior import HMCPosterior
from fortuna.prob_model.posterior.sgmcmc.sghmc import SGHMC_NAME
from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_posterior import SGHMCPosterior
from fortuna.prob_model.posterior.sngp import SNGP_NAME
Expand All @@ -35,3 +37,4 @@ class PosteriorApproximations(enum.Enum):
vars()[SNGP_NAME] = SNGPPosterior
vars()[SGHMC_NAME] = SGHMCPosterior
vars()[CYCLICAL_SGLD_NAME] = CyclicalSGLDPosterior
vars()[HMC_NAME] = HMCPosterior
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def init_fn(params):
sgld_state=sgld.init(params),
)

def update_fn(gradient, state, *_):
def update_fn(gradient, state, params=None):
del params

def sgd_step():
step_size = step_schedule(state.sgld_state.count)
preconditioner_state = preconditioner.update_preconditioner(
Expand Down
1 change: 1 addition & 0 deletions fortuna/prob_model/posterior/sgmcmc/hmc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
HMC_NAME = "hmc"
56 changes: 56 additions & 0 deletions fortuna/prob_model/posterior/sgmcmc/hmc/hmc_approximator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Union

from fortuna.prob_model.posterior.sgmcmc.base import (
SGMCMCPosteriorApproximator,
)
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner import (
Preconditioner,
identity_preconditioner,
)
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule import (
StepSchedule,
constant_schedule,
)
from fortuna.prob_model.posterior.sgmcmc.hmc import HMC_NAME


class HMCPosteriorApproximator(SGMCMCPosteriorApproximator):
def __init__(
self,
n_samples: int = 10,
n_thinning: int = 1,
burnin_length: int = 1000,
integration_steps: int = 50_000,
step_schedule: Union[StepSchedule, float] = 3e-5,
) -> None:
"""
HMC posterior approximator. It is responsible to define how the posterior distribution is approximated.

Parameters
----------
n_samples: int
The desired number of the posterior samples.
n_thinning: int
If `n_thinning` > 1, keep only each `n_thinning` sample during the sampling phase.
burnin_length: int
Length of the initial burn-in phase, in steps.
integration_steps: int
Number of integration steps per trajectory.
step_schedule: Union[StepSchedule, float]
Either a constant `float` step size or a schedule function.

"""
super().__init__(
n_samples=n_samples,
n_thinning=n_thinning,
)
if isinstance(step_schedule, float):
step_schedule = constant_schedule(step_schedule)
elif not callable(step_schedule):
raise ValueError(f"`step_schedule` must be a a callable function.")
self.burnin_length = burnin_length
self.integration_steps = integration_steps
self.step_schedule = step_schedule

def __str__(self) -> str:
return HMC_NAME
67 changes: 67 additions & 0 deletions fortuna/prob_model/posterior/sgmcmc/hmc/hmc_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Optional

from fortuna.training.train_state import TrainState
from fortuna.training.callback import Callback
from fortuna.training.train_state_repository import TrainStateRepository
from fortuna.training.trainer import TrainerABC
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_sampling_callback import (
SGMCMCSamplingCallback,
)


class HMCSamplingCallback(SGMCMCSamplingCallback):
def __init__(
self,
n_epochs: int,
n_training_steps: int,
n_samples: int,
n_thinning: int,
burnin_length: int,
trainer: TrainerABC,
state_repository: TrainStateRepository,
keep_top_n_checkpoints: int,
):
"""
Hamiltonian Monte Carlo (HMC) callback that collects samples after the initial burn-in phase.

Parameters
----------
n_epochs: int
The number of epochs.
n_training_steps: int
The number of steps per epoch.
n_samples: int
The desired number of the posterior samples.
n_thinning: int
Keep only each `n_thinning` sample during the sampling phase.
burnin_length: int
Length of the initial burn-in phase, in steps.
trainer: TrainerABC
An instance of the trainer class.
state_repository: TrainStateRepository
An instance of the state repository.
keep_top_n_checkpoints: int
Number of past checkpoint files to keep.
"""
super().__init__(
trainer=trainer,
state_repository=state_repository,
keep_top_n_checkpoints=keep_top_n_checkpoints,
)

self._do_sample = (
lambda current_step, samples_count: samples_count < n_samples
and current_step > burnin_length
and (current_step - burnin_length) % n_thinning == 0
)

total_samples = sum(
self._do_sample(step, 0)
for step in range(1, n_epochs * n_training_steps + 1)
)
if total_samples < n_samples:
raise ValueError(
f"The number of desired samples `n_samples` is {n_samples}. However, only "
f"{total_samples} samples will be collected. Consider adjusting the burnin "
"length, number of epochs, or the thinning parameter."
)
129 changes: 129 additions & 0 deletions fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import jax
import jax.numpy as jnp

from fortuna.typing import Array
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule import (
StepSchedule,
)
from fortuna.utils.random import generate_random_normal_like_tree
from jax._src.prng import PRNGKeyArray
from optax._src.base import PyTree
from optax import GradientTransformation
from typing import NamedTuple


class OptaxHMCState(NamedTuple):
"""Optax state for the HMC integrator."""

count: Array
rng_key: PRNGKeyArray
momentum: PyTree
params: PyTree
hamiltonian: Array
log_prob: Array


def hmc_integrator(
integration_steps: int,
rng_key: PRNGKeyArray,
step_schedule: StepSchedule,
) -> GradientTransformation:
"""Optax implementation of the HMC integrator.

Parameters
----------
integration_steps: int
Number of leapfrog integration steps in each trajectory.
rng_key: PRNGKeyArray
An initial random number generator.
step_schedule: StepSchedule
A function that takes training step as input and returns the step size.
"""

def init_fn(params):
return OptaxHMCState(
count=jnp.zeros([], jnp.int32),
rng_key=rng_key,
momentum=jax.tree_util.tree_map(jnp.zeros_like, params),
params=params,
hamiltonian=jnp.array(-1e6, jnp.float32),
log_prob=jnp.zeros([], jnp.float32),
)

def update_fn(gradient, state, params):
step_size = step_schedule(state.count)

def leapfrog_step():
updates = jax.tree_map(
lambda m: m * step_size,
state.momentum,
)
momentum = jax.tree_map(
lambda m, g: m + g * step_size,
state.momentum,
gradient,
)
return updates, OptaxHMCState(
count=state.count + 1,
rng_key=state.rng_key,
momentum=momentum,
params=state.params,
hamiltonian=state.hamiltonian,
log_prob=state.log_prob,
)

def mh_correction():
key, new_key, uniform_key = jax.random.split(state.rng_key, 3)

momentum = jax.tree_map(
lambda m, g: m + g * step_size / 2,
state.momentum,
gradient,
)

momentum, _ = jax.flatten_util.ravel_pytree(momentum)
kinetic = 0.5 * jnp.dot(momentum, momentum)
hamiltonian = kinetic + state.log_prob
accept_prob = jnp.minimum(1.0, jnp.exp(hamiltonian - state.hamiltonian))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, you can avoid the minimum and the exponential here. You can define

log_accept_ratio = hamiltonian - state.hamiltonian

See later for the accept/reject part.


def _accept():
empty_updates = jax.tree_util.tree_map(jnp.zeros_like, params)
return empty_updates, params, hamiltonian

def _reject():
revert_updates = jax.tree_util.tree_map(
lambda sp, p: sp - p,
state.params,
params,
)
return revert_updates, state.params, state.hamiltonian

updates, new_params, new_hamiltonian = jax.lax.cond(
jax.random.uniform(uniform_key) < accept_prob,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the comment above, this line should become

jnp.log(jax.random.uniform(uniform_key)) < log_accept_ratio.

This is equivalent to what you have written but with one operation less. Alternatively, notice that -log(U) ~ Exponential(1)) if U~Uniform(0, 1). This means that you can also write

-jax.random.exponential(uniform_key)) < log_accept_ratio.

All of these should be equivalent. Please check that the lines I wrote are correct :-)

_accept,
_reject,
)

new_momentum = generate_random_normal_like_tree(key, gradient)
new_momentum = jax.tree_map(
lambda m, g: m + g * step_size / 2,
new_momentum,
gradient,
)

return updates, OptaxHMCState(
count=state.count + 1,
rng_key=new_key,
momentum=new_momentum,
params=new_params,
hamiltonian=new_hamiltonian,
log_prob=state.log_prob,
)

return jax.lax.cond(
state.count % integration_steps == 0,
mh_correction,
leapfrog_step,
)

return GradientTransformation(init_fn, update_fn)
Loading