-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add vanilla HMC method #75
base: main
Are you sure you want to change the base?
Changes from all commits
9554a25
10a9527
a96283a
cf2c710
90974f4
c1973e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
HMC_NAME = "hmc" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from typing import Union | ||
|
||
from fortuna.prob_model.posterior.sgmcmc.base import ( | ||
SGMCMCPosteriorApproximator, | ||
) | ||
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner import ( | ||
Preconditioner, | ||
identity_preconditioner, | ||
) | ||
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule import ( | ||
StepSchedule, | ||
constant_schedule, | ||
) | ||
from fortuna.prob_model.posterior.sgmcmc.hmc import HMC_NAME | ||
|
||
|
||
class HMCPosteriorApproximator(SGMCMCPosteriorApproximator): | ||
def __init__( | ||
self, | ||
n_samples: int = 10, | ||
n_thinning: int = 1, | ||
burnin_length: int = 1000, | ||
integration_steps: int = 50_000, | ||
step_schedule: Union[StepSchedule, float] = 3e-5, | ||
) -> None: | ||
""" | ||
HMC posterior approximator. It is responsible to define how the posterior distribution is approximated. | ||
|
||
Parameters | ||
---------- | ||
n_samples: int | ||
The desired number of the posterior samples. | ||
n_thinning: int | ||
If `n_thinning` > 1, keep only each `n_thinning` sample during the sampling phase. | ||
burnin_length: int | ||
Length of the initial burn-in phase, in steps. | ||
integration_steps: int | ||
Number of integration steps per trajectory. | ||
step_schedule: Union[StepSchedule, float] | ||
Either a constant `float` step size or a schedule function. | ||
|
||
""" | ||
super().__init__( | ||
n_samples=n_samples, | ||
n_thinning=n_thinning, | ||
) | ||
if isinstance(step_schedule, float): | ||
step_schedule = constant_schedule(step_schedule) | ||
elif not callable(step_schedule): | ||
raise ValueError(f"`step_schedule` must be a a callable function.") | ||
self.burnin_length = burnin_length | ||
self.integration_steps = integration_steps | ||
self.step_schedule = step_schedule | ||
|
||
def __str__(self) -> str: | ||
return HMC_NAME |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from typing import Optional | ||
|
||
from fortuna.training.train_state import TrainState | ||
from fortuna.training.callback import Callback | ||
from fortuna.training.train_state_repository import TrainStateRepository | ||
from fortuna.training.trainer import TrainerABC | ||
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_sampling_callback import ( | ||
SGMCMCSamplingCallback, | ||
) | ||
|
||
|
||
class HMCSamplingCallback(SGMCMCSamplingCallback): | ||
def __init__( | ||
self, | ||
n_epochs: int, | ||
n_training_steps: int, | ||
n_samples: int, | ||
n_thinning: int, | ||
burnin_length: int, | ||
trainer: TrainerABC, | ||
state_repository: TrainStateRepository, | ||
keep_top_n_checkpoints: int, | ||
): | ||
""" | ||
Hamiltonian Monte Carlo (HMC) callback that collects samples after the initial burn-in phase. | ||
|
||
Parameters | ||
---------- | ||
n_epochs: int | ||
The number of epochs. | ||
n_training_steps: int | ||
The number of steps per epoch. | ||
n_samples: int | ||
The desired number of the posterior samples. | ||
n_thinning: int | ||
Keep only each `n_thinning` sample during the sampling phase. | ||
burnin_length: int | ||
Length of the initial burn-in phase, in steps. | ||
trainer: TrainerABC | ||
An instance of the trainer class. | ||
state_repository: TrainStateRepository | ||
An instance of the state repository. | ||
keep_top_n_checkpoints: int | ||
Number of past checkpoint files to keep. | ||
""" | ||
super().__init__( | ||
trainer=trainer, | ||
state_repository=state_repository, | ||
keep_top_n_checkpoints=keep_top_n_checkpoints, | ||
) | ||
|
||
self._do_sample = ( | ||
lambda current_step, samples_count: samples_count < n_samples | ||
and current_step > burnin_length | ||
and (current_step - burnin_length) % n_thinning == 0 | ||
) | ||
|
||
total_samples = sum( | ||
self._do_sample(step, 0) | ||
for step in range(1, n_epochs * n_training_steps + 1) | ||
) | ||
if total_samples < n_samples: | ||
raise ValueError( | ||
f"The number of desired samples `n_samples` is {n_samples}. However, only " | ||
f"{total_samples} samples will be collected. Consider adjusting the burnin " | ||
"length, number of epochs, or the thinning parameter." | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from fortuna.typing import Array | ||
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule import ( | ||
StepSchedule, | ||
) | ||
from fortuna.utils.random import generate_random_normal_like_tree | ||
from jax._src.prng import PRNGKeyArray | ||
from optax._src.base import PyTree | ||
from optax import GradientTransformation | ||
from typing import NamedTuple | ||
|
||
|
||
class OptaxHMCState(NamedTuple): | ||
"""Optax state for the HMC integrator.""" | ||
|
||
count: Array | ||
rng_key: PRNGKeyArray | ||
momentum: PyTree | ||
params: PyTree | ||
hamiltonian: Array | ||
log_prob: Array | ||
|
||
|
||
def hmc_integrator( | ||
integration_steps: int, | ||
rng_key: PRNGKeyArray, | ||
step_schedule: StepSchedule, | ||
) -> GradientTransformation: | ||
"""Optax implementation of the HMC integrator. | ||
|
||
Parameters | ||
---------- | ||
integration_steps: int | ||
Number of leapfrog integration steps in each trajectory. | ||
rng_key: PRNGKeyArray | ||
An initial random number generator. | ||
step_schedule: StepSchedule | ||
A function that takes training step as input and returns the step size. | ||
""" | ||
|
||
def init_fn(params): | ||
return OptaxHMCState( | ||
count=jnp.zeros([], jnp.int32), | ||
rng_key=rng_key, | ||
momentum=jax.tree_util.tree_map(jnp.zeros_like, params), | ||
params=params, | ||
hamiltonian=jnp.array(-1e6, jnp.float32), | ||
log_prob=jnp.zeros([], jnp.float32), | ||
) | ||
|
||
def update_fn(gradient, state, params): | ||
step_size = step_schedule(state.count) | ||
|
||
def leapfrog_step(): | ||
updates = jax.tree_map( | ||
lambda m: m * step_size, | ||
state.momentum, | ||
) | ||
momentum = jax.tree_map( | ||
lambda m, g: m + g * step_size, | ||
state.momentum, | ||
gradient, | ||
) | ||
return updates, OptaxHMCState( | ||
count=state.count + 1, | ||
rng_key=state.rng_key, | ||
momentum=momentum, | ||
params=state.params, | ||
hamiltonian=state.hamiltonian, | ||
log_prob=state.log_prob, | ||
) | ||
|
||
def mh_correction(): | ||
key, new_key, uniform_key = jax.random.split(state.rng_key, 3) | ||
|
||
momentum = jax.tree_map( | ||
lambda m, g: m + g * step_size / 2, | ||
state.momentum, | ||
gradient, | ||
) | ||
|
||
momentum, _ = jax.flatten_util.ravel_pytree(momentum) | ||
kinetic = 0.5 * jnp.dot(momentum, momentum) | ||
hamiltonian = kinetic + state.log_prob | ||
accept_prob = jnp.minimum(1.0, jnp.exp(hamiltonian - state.hamiltonian)) | ||
|
||
def _accept(): | ||
empty_updates = jax.tree_util.tree_map(jnp.zeros_like, params) | ||
return empty_updates, params, hamiltonian | ||
|
||
def _reject(): | ||
revert_updates = jax.tree_util.tree_map( | ||
lambda sp, p: sp - p, | ||
state.params, | ||
params, | ||
) | ||
return revert_updates, state.params, state.hamiltonian | ||
|
||
updates, new_params, new_hamiltonian = jax.lax.cond( | ||
jax.random.uniform(uniform_key) < accept_prob, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following the comment above, this line should become
This is equivalent to what you have written but with one operation less. Alternatively, notice that
All of these should be equivalent. Please check that the lines I wrote are correct :-) |
||
_accept, | ||
_reject, | ||
) | ||
|
||
new_momentum = generate_random_normal_like_tree(key, gradient) | ||
new_momentum = jax.tree_map( | ||
lambda m, g: m + g * step_size / 2, | ||
new_momentum, | ||
gradient, | ||
) | ||
|
||
return updates, OptaxHMCState( | ||
count=state.count + 1, | ||
rng_key=new_key, | ||
momentum=new_momentum, | ||
params=new_params, | ||
hamiltonian=new_hamiltonian, | ||
log_prob=state.log_prob, | ||
) | ||
|
||
return jax.lax.cond( | ||
state.count % integration_steps == 0, | ||
mh_correction, | ||
leapfrog_step, | ||
) | ||
|
||
return GradientTransformation(init_fn, update_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed, you can avoid the minimum and the exponential here. You can define
log_accept_ratio = hamiltonian - state.hamiltonian
See later for the accept/reject part.