From 9554a25f2fbdad59b18831d38d3e04f9e36722e3 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Fri, 26 May 2023 12:51:45 +0200 Subject: [PATCH 1/6] fix SGMCMC state repo --- .../posterior/sgmcmc/sgmcmc_posterior_state_repository.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior_state_repository.py b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior_state_repository.py index be343127..b8b1fbbd 100644 --- a/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior_state_repository.py +++ b/fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior_state_repository.py @@ -119,7 +119,7 @@ def _update_state( return state if isinstance(state, list): - return [_update_state(_state, modify=modify) for _state in state] + return [self._update_state(_state, modify=modify) for _state in state] if modify == "add": state = state.replace( From 10a9527c03d21b29b9a52fdfc1b0afe651b74e9c Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Tue, 23 May 2023 12:14:09 +0200 Subject: [PATCH 2/6] add full-batch HMC implementation --- .../model_manager/name_to_model_manager.py | 2 + fortuna/prob_model/__init__.py | 3 + .../posterior/posterior_approximations.py | 3 + .../cyclical_sgld/cyclical_sgld_integrator.py | 4 +- .../posterior/sgmcmc/hmc/__init__.py | 1 + .../posterior/sgmcmc/hmc/hmc_approximator.py | 56 +++++ .../posterior/sgmcmc/hmc/hmc_callback.py | 67 ++++++ .../posterior/sgmcmc/hmc/hmc_integrator.py | 127 ++++++++++ .../posterior/sgmcmc/hmc/hmc_posterior.py | 225 ++++++++++++++++++ .../posterior/sgmcmc/hmc/hmc_state.py | 68 ++++++ .../posterior/sgmcmc/hmc/hmc_trainer.py | 63 +++++ .../sgmcmc/sghmc/sghmc_integrator.py | 3 +- tests/fortuna/prob_model/test_train.py | 6 + 13 files changed, 626 insertions(+), 2 deletions(-) create mode 100644 fortuna/prob_model/posterior/sgmcmc/hmc/__init__.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/hmc/hmc_approximator.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/hmc/hmc_callback.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/hmc/hmc_posterior.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/hmc/hmc_state.py create mode 100644 fortuna/prob_model/posterior/sgmcmc/hmc/hmc_trainer.py diff --git a/fortuna/model/model_manager/name_to_model_manager.py b/fortuna/model/model_manager/name_to_model_manager.py index efc305bd..077c7054 100644 --- a/fortuna/model/model_manager/name_to_model_manager.py +++ b/fortuna/model/model_manager/name_to_model_manager.py @@ -9,6 +9,7 @@ from fortuna.prob_model.posterior.map import MAP_NAME from fortuna.prob_model.posterior.normalizing_flow.advi import ADVI_NAME from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld import CYCLICAL_SGLD_NAME +from fortuna.prob_model.posterior.sgmcmc.hmc import HMC_NAME from fortuna.prob_model.posterior.sgmcmc.sghmc import SGHMC_NAME from fortuna.prob_model.posterior.sngp import SNGP_NAME from fortuna.prob_model.posterior.swag import SWAG_NAME @@ -25,3 +26,4 @@ class ClassificationModelManagers(enum.Enum): vars()[SNGP_NAME] = SNGPClassificationModelManager vars()[SGHMC_NAME] = ClassificationModelManager vars()[CYCLICAL_SGLD_NAME] = ClassificationModelManager + vars()[HMC_NAME] = ClassificationModelManager diff --git a/fortuna/prob_model/__init__.py b/fortuna/prob_model/__init__.py index fa19fc5c..b684ef73 100644 --- a/fortuna/prob_model/__init__.py +++ b/fortuna/prob_model/__init__.py @@ -22,6 +22,9 @@ from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_approximator import ( CyclicalSGLDPosteriorApproximator, ) +from fortuna.prob_model.posterior.sgmcmc.hmc.hmc_approximator import ( + HMCPosteriorApproximator, +) from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_approximator import ( SGHMCPosteriorApproximator, ) diff --git a/fortuna/prob_model/posterior/posterior_approximations.py b/fortuna/prob_model/posterior/posterior_approximations.py index f14a36a1..bd0f2a5c 100644 --- a/fortuna/prob_model/posterior/posterior_approximations.py +++ b/fortuna/prob_model/posterior/posterior_approximations.py @@ -16,6 +16,8 @@ from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_posterior import ( CyclicalSGLDPosterior, ) +from fortuna.prob_model.posterior.sgmcmc.hmc import HMC_NAME +from fortuna.prob_model.posterior.sgmcmc.hmc.hmc_posterior import HMCPosterior from fortuna.prob_model.posterior.sgmcmc.sghmc import SGHMC_NAME from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_posterior import SGHMCPosterior from fortuna.prob_model.posterior.sngp import SNGP_NAME @@ -35,3 +37,4 @@ class PosteriorApproximations(enum.Enum): vars()[SNGP_NAME] = SNGPPosterior vars()[SGHMC_NAME] = SGHMCPosterior vars()[CYCLICAL_SGLD_NAME] = CyclicalSGLDPosterior + vars()[HMC_NAME] = HMCPosterior diff --git a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_integrator.py b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_integrator.py index a6e525c6..9e49bcac 100644 --- a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_integrator.py +++ b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_integrator.py @@ -62,7 +62,9 @@ def init_fn(params): sgld_state=sgld.init(params), ) - def update_fn(gradient, state, *_): + def update_fn(gradient, state, params=None): + del params + def sgd_step(): step_size = step_schedule(state.sgld_state.count) preconditioner_state = preconditioner.update_preconditioner( diff --git a/fortuna/prob_model/posterior/sgmcmc/hmc/__init__.py b/fortuna/prob_model/posterior/sgmcmc/hmc/__init__.py new file mode 100644 index 00000000..df02f157 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/hmc/__init__.py @@ -0,0 +1 @@ +HMC_NAME = "hmc" diff --git a/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_approximator.py b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_approximator.py new file mode 100644 index 00000000..040c58fa --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_approximator.py @@ -0,0 +1,56 @@ +from typing import Union + +from fortuna.prob_model.posterior.sgmcmc.base import ( + SGMCMCPosteriorApproximator, +) +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner import ( + Preconditioner, + identity_preconditioner, +) +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule import ( + StepSchedule, + constant_schedule, +) +from fortuna.prob_model.posterior.sgmcmc.hmc import HMC_NAME + + +class HMCPosteriorApproximator(SGMCMCPosteriorApproximator): + def __init__( + self, + n_samples: int = 10, + n_thinning: int = 1, + burnin_length: int = 1000, + integration_steps: int = 50_000, + step_schedule: Union[StepSchedule, float] = 3e-5, + ) -> None: + """ + HMC posterior approximator. It is responsible to define how the posterior distribution is approximated. + + Parameters + ---------- + n_samples: int + The desired number of the posterior samples. + n_thinning: int + If `n_thinning` > 1, keep only each `n_thinning` sample during the sampling phase. + burnin_length: int + Length of the initial burn-in phase, in steps. + integration_steps: int + Number of integration steps per trajectory. + step_schedule: Union[StepSchedule, float] + Either a constant `float` step size or a schedule function. + + """ + super().__init__( + n_samples=n_samples, + n_thinning=n_thinning, + ) + if isinstance(step_schedule, float): + step_schedule = constant_schedule(step_schedule) + elif not callable(step_schedule): + raise ValueError(f"`step_schedule` must be a a callable function.") + self.burnin_length = burnin_length + self.integration_steps = integration_steps + self.step_schedule = step_schedule + + def __str__(self) -> str: + return HMC_NAME diff --git a/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_callback.py b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_callback.py new file mode 100644 index 00000000..3deb1504 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_callback.py @@ -0,0 +1,67 @@ +from typing import Optional + +from fortuna.training.train_state import TrainState +from fortuna.training.callback import Callback +from fortuna.training.train_state_repository import TrainStateRepository +from fortuna.training.trainer import TrainerABC +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_sampling_callback import ( + SGMCMCSamplingCallback, +) + + +class HMCSamplingCallback(SGMCMCSamplingCallback): + def __init__( + self, + n_epochs: int, + n_training_steps: int, + n_samples: int, + n_thinning: int, + burnin_length: int, + trainer: TrainerABC, + state_repository: TrainStateRepository, + keep_top_n_checkpoints: int, + ): + """ + Hamiltonian Monte Carlo (HMC) callback that collects samples after the initial burn-in phase. + + Parameters + ---------- + n_epochs: int + The number of epochs. + n_training_steps: int + The number of steps per epoch. + n_samples: int + The desired number of the posterior samples. + n_thinning: int + Keep only each `n_thinning` sample during the sampling phase. + burnin_length: int + Length of the initial burn-in phase, in steps. + trainer: TrainerABC + An instance of the trainer class. + state_repository: TrainStateRepository + An instance of the state repository. + keep_top_n_checkpoints: int + Number of past checkpoint files to keep. + """ + super().__init__( + trainer=trainer, + state_repository=state_repository, + keep_top_n_checkpoints=keep_top_n_checkpoints, + ) + + self._do_sample = ( + lambda current_step, samples_count: samples_count < n_samples + and current_step > burnin_length + and (current_step - burnin_length) % n_thinning == 0 + ) + + total_samples = sum( + self._do_sample(step, 0) + for step in range(1, n_epochs * n_training_steps + 1) + ) + if total_samples < n_samples: + raise ValueError( + f"The number of desired samples `n_samples` is {n_samples}. However, only " + f"{total_samples} samples will be collected. Consider adjusting the burnin " + "length, number of epochs, or the thinning parameter." + ) diff --git a/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py new file mode 100644 index 00000000..8755d5c7 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py @@ -0,0 +1,127 @@ +import jax +import jax.numpy as jnp + +from fortuna.typing import Array +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule import ( + StepSchedule, +) +from fortuna.utils.random import generate_random_normal_like_tree +from jax._src.prng import PRNGKeyArray +from optax._src.base import PyTree +from optax import GradientTransformation +from typing import NamedTuple + + +class OptaxHMCState(NamedTuple): + """Optax state for the HMC integrator.""" + count: Array + rng_key: PRNGKeyArray + momentum: PyTree + params: PyTree + hamiltonian: Array + log_prob: Array + + +def hmc_integrator( + integration_steps: int, + rng_key: PRNGKeyArray, + step_schedule: StepSchedule, +) -> GradientTransformation: + """Optax implementation of the HMC integrator. + + Parameters + ---------- + integration_steps: int + Number of leapfrog integration steps in each trajectory. + rng_key: PRNGKeyArray + An initial random number generator. + step_schedule: StepSchedule + A function that takes training step as input and returns the step size. + """ + def init_fn(params): + return OptaxHMCState( + count=jnp.zeros([], jnp.int32), + rng_key=rng_key, + momentum=jax.tree_util.tree_map(jnp.zeros_like, params), + params=params, + hamiltonian=jnp.array(-1e6, jnp.float32), + log_prob=jnp.zeros([], jnp.float32), + ) + + def update_fn(gradient, state, params): + step_size = step_schedule(state.count) + + def leapfrog_step(): + updates = jax.tree_map( + lambda m: m * step_size, + state.momentum, + ) + momentum = jax.tree_map( + lambda m, g: m + g * step_size, + state.momentum, + gradient, + ) + return updates, OptaxHMCState( + count=state.count + 1, + rng_key=state.rng_key, + momentum=momentum, + params=state.params, + hamiltonian=state.hamiltonian, + log_prob=state.log_prob, + ) + + def mh_correction(): + key, new_key, uniform_key = jax.random.split(state.rng_key, 3) + + momentum = jax.tree_map( + lambda m, g: m + g * step_size / 2, + state.momentum, + gradient, + ) + + momentum, _ = jax.flatten_util.ravel_pytree(momentum) + kinetic = 0.5 * jnp.dot(momentum, momentum) + hamiltonian = kinetic + state.log_prob + accept_prob = jnp.minimum(1., jnp.exp(hamiltonian - state.hamiltonian)) + + def _accept(): + empty_updates = jax.tree_util.tree_map(jnp.zeros_like, params) + return empty_updates, params, hamiltonian + + def _reject(): + revert_updates = jax.tree_util.tree_map( + lambda sp, p: sp - p, + state.params, + params, + ) + return revert_updates, state.params, state.hamiltonian + + updates, new_params, new_hamiltonian = jax.lax.cond( + jax.random.uniform(uniform_key) < accept_prob, + _accept, + _reject, + ) + + new_momentum = generate_random_normal_like_tree(key, gradient) + new_momentum = jax.tree_map( + lambda m, g: m + g * step_size / 2, + new_momentum, + gradient, + ) + + return updates, OptaxHMCState( + count=state.count + 1, + rng_key=new_key, + momentum=new_momentum, + params=new_params, + hamiltonian=new_hamiltonian, + log_prob=state.log_prob, + ) + + return jax.lax.cond( + state.count % integration_steps == 0, + mh_correction, + leapfrog_step, + ) + + return GradientTransformation(init_fn, update_fn) diff --git a/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_posterior.py b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_posterior.py new file mode 100644 index 00000000..da126ba6 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_posterior.py @@ -0,0 +1,225 @@ +import logging +from typing import Optional +import pathlib + +from flax.core import FrozenDict +from fortuna.utils.freeze import get_trainable_paths +from fortuna.utils.nested_dicts import nested_set, nested_get +from fortuna.data.loader import DataLoader +from fortuna.prob_model.fit_config.base import FitConfig +from fortuna.prob_model.joint.base import Joint +from fortuna.prob_model.posterior.sgmcmc.hmc.hmc_trainer import ( + HMCTrainer, + JittedHMCTrainer, + MultiDeviceHMCTrainer, +) +from fortuna.prob_model.posterior.run_preliminary_map import ( + run_preliminary_map, +) +from fortuna.prob_model.posterior.map.map_posterior import MAPPosterior +from fortuna.prob_model.posterior.map.map_state import MAPState +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_posterior_state_repository import ( + SGMCMCPosteriorStateRepository, +) +from fortuna.prob_model.posterior.sgmcmc.sgmcmc_posterior import ( + SGMCMCPosterior, +) +from fortuna.prob_model.posterior.sgmcmc.hmc import HMC_NAME +from fortuna.prob_model.posterior.sgmcmc.hmc.hmc_approximator import ( + HMCPosteriorApproximator, +) +from fortuna.prob_model.posterior.sgmcmc.hmc.hmc_callback import ( + HMCSamplingCallback, +) +from fortuna.prob_model.posterior.sgmcmc.hmc.hmc_integrator import ( + hmc_integrator, +) +from fortuna.prob_model.posterior.sgmcmc.hmc.hmc_state import HMCState +from fortuna.typing import Status +from fortuna.utils.device import select_trainer_given_devices + +logger = logging.getLogger(__name__) + + +class HMCPosterior(SGMCMCPosterior): + def __init__( + self, + joint: Joint, + posterior_approximator: HMCPosteriorApproximator, + ): + """ + Hamiltonian Monte Carlo approximate posterior class. + + Parameters + ---------- + joint: Joint + A Joint distribution object. + posterior_approximator: HMCPosteriorApproximator + A HMC posterior approximator. + """ + super().__init__(joint=joint, posterior_approximator=posterior_approximator) + + def __str__(self): + return HMC_NAME + + def fit( + self, + train_data_loader: DataLoader, + val_data_loader: Optional[DataLoader] = None, + fit_config: FitConfig = FitConfig(), + map_fit_config: Optional[FitConfig] = None, + **kwargs, + ) -> Status: + super()._checks_on_fit_start(fit_config, map_fit_config) + + status = {} + + map_state = None + if map_fit_config is not None and fit_config.optimizer.freeze_fun is None: + logging.warning( + "It appears that you are trying to configure `map_fit_config`. " + "However, a preliminary run with MAP is supported only if " + "`fit_config.optimizer.freeze_fun` is given. " + "Since the latter was not given, `map_fit_config` will be ignored." + ) + elif not super()._is_state_available_somewhere( + fit_config + ) and super()._should_run_preliminary_map(fit_config, map_fit_config): + map_state, status["map"] = run_preliminary_map( + joint=self.joint, + train_data_loader=train_data_loader, + val_data_loader=val_data_loader, + map_fit_config=map_fit_config, + rng=self.rng, + **kwargs, + ) + + if fit_config.optimizer.method is not None: + logging.info(f"`FitOptimizer` method in HMC is ignored.") + + fit_config.optimizer.method = hmc_integrator( + integration_steps=self.posterior_approximator.integration_steps, + rng_key=self.rng.get(), + step_schedule=self.posterior_approximator.step_schedule, + ) + + trainer_cls = select_trainer_given_devices( + devices=fit_config.processor.devices, + base_trainer_cls=HMCTrainer, + jitted_trainer_cls=JittedHMCTrainer, + multi_device_trainer_cls=MultiDeviceHMCTrainer, + disable_jit=fit_config.processor.disable_jit, + ) + + save_checkpoint_dir = ( + pathlib.Path(fit_config.checkpointer.save_checkpoint_dir) / "c" + if fit_config.checkpointer.save_checkpoint_dir + else None + ) + trainer = trainer_cls( + predict_fn=self.joint.likelihood.prob_output_layer.predict, + save_checkpoint_dir=save_checkpoint_dir, + save_every_n_steps=fit_config.checkpointer.save_every_n_steps, + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + disable_training_metrics_computation=fit_config.monitor.disable_training_metrics_computation, + eval_every_n_epochs=fit_config.monitor.eval_every_n_epochs, + early_stopping_monitor=fit_config.monitor.early_stopping_monitor, + early_stopping_min_delta=fit_config.monitor.early_stopping_min_delta, + early_stopping_patience=fit_config.monitor.early_stopping_patience, + ) + + if super()._is_state_available_somewhere(fit_config): + state = self._restore_state_from_somewhere(fit_config=fit_config) + else: + state = self._init_map_state(map_state, train_data_loader, fit_config) + + if fit_config.optimizer.freeze_fun is not None: + which_params = get_trainable_paths( + params=state.params, freeze_fun=fit_config.optimizer.freeze_fun + ) + else: + which_params = None + + state = HMCState.convert_from_map_state( + map_state=state, + optimizer=fit_config.optimizer.method, + which_params=which_params, + ) + + state = super()._freeze_optimizer_in_state(state, fit_config) + + self.state = SGMCMCPosteriorStateRepository( + size=self.posterior_approximator.n_samples, + checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir, + which_params=which_params, + all_params=state.params if which_params else None, + ) + + hmc_sampling_callback = HMCSamplingCallback( + n_epochs=fit_config.optimizer.n_epochs, + n_training_steps=len(train_data_loader), + n_samples=self.posterior_approximator.n_samples, + n_thinning=self.posterior_approximator.n_thinning, + burnin_length=self.posterior_approximator.burnin_length, + trainer=trainer, + state_repository=self.state, + keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints, + ) + + logging.info(f"Run HMC.") + state, status["hmc"] = trainer.train( + rng=self.rng.get(), + state=state, + loss_fun=self.joint._batched_log_joint_prob, + training_dataloader=train_data_loader, + training_dataset_size=train_data_loader.size, + n_epochs=fit_config.optimizer.n_epochs, + metrics=fit_config.monitor.metrics, + validation_dataloader=val_data_loader, + validation_dataset_size=val_data_loader.size + if val_data_loader is not None + else None, + verbose=fit_config.monitor.verbose, + callbacks=[hmc_sampling_callback], + ) + logging.info("Fit completed.") + + return status + + def _init_map_state( + self, + state: Optional[MAPState], + data_loader: DataLoader, + fit_config: FitConfig, + ) -> MAPState: + if state is None or fit_config.optimizer.freeze_fun is None: + state = super()._init_joint_state(data_loader) + + return MAPState.init( + params=state.params, + mutable=state.mutable, + optimizer=fit_config.optimizer.method, + calib_params=state.calib_params, + calib_mutable=state.calib_mutable, + ) + else: + random_state = super()._init_joint_state(data_loader) + trainable_paths = get_trainable_paths( + state.params, fit_config.optimizer.freeze_fun + ) + state = state.replace( + params=FrozenDict( + nested_set( + d=state.params.unfreeze(), + key_paths=trainable_paths, + objs=tuple( + [ + nested_get(d=random_state.params, keys=path) + for path in trainable_paths + ] + ), + ) + ) + ) + + return state diff --git a/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_state.py b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_state.py new file mode 100644 index 00000000..35c72177 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_state.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import ( + Dict, + List, + Tuple, + Optional, +) + +import jax.numpy as jnp + +from fortuna.prob_model.posterior.state import PosteriorState +from fortuna.utils.strings import ( + convert_string_to_jnp_array, + encode_tuple_of_lists_of_strings_to_numpy, +) +from fortuna.prob_model.posterior.map.map_state import MAPState +from fortuna.typing import ( + AnyKey, + Array, + OptaxOptimizer, +) + + +class HMCState(PosteriorState): + """ + Attributes + ---------- + encoded_name: jnp.ndarray + HMC state name encoded as an array. + """ + + encoded_name: jnp.ndarray = convert_string_to_jnp_array("HMCState") + _encoded_which_params: Optional[Dict[str, List[Array]]] = None + + @classmethod + def convert_from_map_state( + cls, + map_state: MAPState, + optimizer: OptaxOptimizer, + which_params: Tuple[List[AnyKey], ...], + ) -> HMCState: + """ + Convert a MAP state into an HMC state. + + Parameters + ---------- + map_state: MAPState + A MAP posterior state. + optimizer: OptaxOptimizer + An Optax optimizer. + which_params: Tuple[List[AnyKey], ...] + Sequences of keys pointing to the stochastic parameters. + + Returns + ------- + HMCState + An HMC state. + """ + _encoded_which_params = encode_tuple_of_lists_of_strings_to_numpy(which_params) + return cls.init( + params=map_state.params, + mutable=map_state.mutable, + optimizer=optimizer, + calib_params=map_state.calib_params, + calib_mutable=map_state.calib_mutable, + _encoded_which_params=_encoded_which_params, + ) diff --git a/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_trainer.py b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_trainer.py new file mode 100644 index 00000000..e53ac8c4 --- /dev/null +++ b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_trainer.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import ( + Any, + Callable, + Dict, + Optional, + Tuple, +) + +from flax.core import FrozenDict +from jax._src.prng import PRNGKeyArray +import jax.numpy as jnp +from optax._src.base import PyTree + +from fortuna.prob_model.posterior.map.map_trainer import MAPTrainer +from fortuna.prob_model.posterior.sgmcmc.hmc import HMC_NAME +from fortuna.prob_model.posterior.sgmcmc.hmc.hmc_state import HMCState +from fortuna.training.trainer import ( + JittedMixin, + MultiDeviceMixin, +) +from fortuna.typing import ( + Array, + Batch, +) + + +class HMCTrainer(MAPTrainer): + def training_step( + self, + state: HMCState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Tuple[jnp.ndarray, Dict[str, Any]]: + state, aux = super().training_step( + state=state, + batch=batch, + loss_fun=loss_fun, + rng=rng, + n_data=n_data, + unravel=unravel, + **kwargs, + ) + state = state.replace( + opt_state=state.opt_state._replace(log_prob=aux["loss"]), + ) + return state, aux + + def __str__(self): + return HMC_NAME + + +class JittedHMCTrainer(JittedMixin, HMCTrainer): + pass + + +class MultiDeviceHMCTrainer(MultiDeviceMixin, HMCTrainer): + pass diff --git a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_integrator.py b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_integrator.py index 769ad617..daa0a02f 100644 --- a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_integrator.py +++ b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_integrator.py @@ -57,7 +57,8 @@ def init_fn(params): preconditioner_state=preconditioner.init(params), ) - def update_fn(gradient, state, *_): + def update_fn(gradient, state, params=None): + del params step_size = step_schedule(state.count) preconditioner_state = preconditioner.update_preconditioner( diff --git a/tests/fortuna/prob_model/test_train.py b/tests/fortuna/prob_model/test_train.py index 05fda2f0..a02ffb55 100755 --- a/tests/fortuna/prob_model/test_train.py +++ b/tests/fortuna/prob_model/test_train.py @@ -31,6 +31,9 @@ from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_posterior import ( CyclicalSGLDPosteriorApproximator, ) +from fortuna.prob_model.posterior.sgmcmc.hmc.hmc_posterior import ( + HMCPosteriorApproximator, +) from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_posterior import ( SGHMCPosteriorApproximator, ) @@ -58,6 +61,9 @@ "cyclical_sgld": CyclicalSGLDPosteriorApproximator( n_samples=3, n_thinning=1, cycle_length=4 ), + "hmc": HMCPosteriorApproximator( + n_samples=3, n_thinning=1, burnin_length=1 + ), } From a96283a1c99b06e98928c2db834e542f62c6b62a Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Mon, 29 May 2023 23:27:37 +0200 Subject: [PATCH 3/6] update documentation --- docs/source/methods.rst | 10 ++++--- .../prob_model/posterior/sgmcmc.rst | 27 ++++++++++++++++++- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/docs/source/methods.rst b/docs/source/methods.rst index 560eaeae..95ad5962 100644 --- a/docs/source/methods.rst +++ b/docs/source/methods.rst @@ -29,12 +29,16 @@ Posterior approximation methods taken by averaging checkpoints over the stochastic optimization trajectory. The covariance is also estimated empirically along the trajectory, and it is made of a diagonal component and a low-rank non-diagonal one. +- **Hamiltonian Monte Carlo (HMC)** `[Neal, 2010] `_ + HMC approximates the posterior as a steady-state distribution of a Monte Carlo Markov chain with Hamiltonian dynamics. + After the initial "burn-in" phase, each step of the chain generates a sample from the posterior. HMC is typically applied + in the full-batch scenario. + - **Stochastic Gradient Hamiltonian Monte Carlo (SGHMC)** `[Chen et al., 2014] `_ - SGHMC approximates the posterior as a steady-state distribution of a Monte Carlo Markov chain with Hamiltonian dynamics. - After the initial "burn-in" phase, each step of the chain generates samples from the posterior. + SGHMC implements a variant of HMC algorithm that expects noisy gradient estimate computed on mini-batches of data. - **Cyclical Stochastic Gradient Langevin Dynamics (Cyclical SGLD)** `[Zhang et al., 2020] `_ - Cyclical SGLD adapts the cyclical cosine step size schedule, and alternates between *exploration* and *sampling* stages to better + Cyclical SGLD adopts the cyclical cosine step size schedule, and alternates between *exploration* and *sampling* stages to better explore the multimodal posteriors for deep neural networks. Parametric calibration methods diff --git a/docs/source/references/prob_model/posterior/sgmcmc.rst b/docs/source/references/prob_model/posterior/sgmcmc.rst index 2ab3ce31..c4391486 100644 --- a/docs/source/references/prob_model/posterior/sgmcmc.rst +++ b/docs/source/references/prob_model/posterior/sgmcmc.rst @@ -4,12 +4,37 @@ SG-MCMC procedures approximate the posterior as a steady-state distribution of a Monte Carlo Markov chain, that utilizes noisy estimates of the gradient computed on minibatches of data. +Hamiltonian Monte Carlo (HMC) +============================= + +HMC `[Neal, 2010] `_ is a MCMC sampling +algorithm that simulates a Hamiltonian dynamical system to rapidly explores +the posterior. + +.. autoclass:: fortuna.prob_model.posterior.sgmcmc.hmc.hmc_approximator.HMCPosteriorApproximator + +.. autoclass:: fortuna.prob_model.posterior.sgmcmc.hmc.hmc_posterior.HMCPosterior + :show-inheritance: + :no-inherited-members: + :exclude-members: state + :members: fit, sample, load_state, save_state + +.. autoclass:: fortuna.prob_model.posterior.sgmcmc.hmc.hmc_state.HMCState + :show-inheritance: + :no-inherited-members: + :inherited-members: init, init_from_dict + :members: convert_from_map_state + :exclude-members: params, mutable, calib_params, calib_mutable, replace, apply_gradients, encoded_name, create + :no-undoc-members: + :no-special-members: + + Stochastic Gradient Hamiltonian Monte Carlo (SGHMC) =================================================== SGHMC `[Chen T. et al., 2014] `_ is a popular MCMC algorithm that uses stochastic gradient estimates to scale -to large datasets. +HMC to large datasets. .. autoclass:: fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_approximator.SGHMCPosteriorApproximator From cf2c710db2cba07d7c8f31f3cd7913c9f088f0f1 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Mon, 29 May 2023 23:45:31 +0200 Subject: [PATCH 4/6] lint the code --- fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py index 8755d5c7..f4c09a0f 100644 --- a/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py +++ b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_integrator.py @@ -14,6 +14,7 @@ class OptaxHMCState(NamedTuple): """Optax state for the HMC integrator.""" + count: Array rng_key: PRNGKeyArray momentum: PyTree @@ -38,6 +39,7 @@ def hmc_integrator( step_schedule: StepSchedule A function that takes training step as input and returns the step size. """ + def init_fn(params): return OptaxHMCState( count=jnp.zeros([], jnp.int32), @@ -82,7 +84,7 @@ def mh_correction(): momentum, _ = jax.flatten_util.ravel_pytree(momentum) kinetic = 0.5 * jnp.dot(momentum, momentum) hamiltonian = kinetic + state.log_prob - accept_prob = jnp.minimum(1., jnp.exp(hamiltonian - state.hamiltonian)) + accept_prob = jnp.minimum(1.0, jnp.exp(hamiltonian - state.hamiltonian)) def _accept(): empty_updates = jax.tree_util.tree_map(jnp.zeros_like, params) From 90974f4f46b94b3ee401e6a8ceacc2bd0ffb3cab Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Wed, 14 Jun 2023 21:48:54 +0200 Subject: [PATCH 5/6] Fix missing import --- fortuna/prob_model/posterior/name_to_posterior_state.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fortuna/prob_model/posterior/name_to_posterior_state.py b/fortuna/prob_model/posterior/name_to_posterior_state.py index 0a98da1a..1f35401d 100644 --- a/fortuna/prob_model/posterior/name_to_posterior_state.py +++ b/fortuna/prob_model/posterior/name_to_posterior_state.py @@ -1,6 +1,7 @@ import enum from fortuna.output_calib_model.state import OutputCalibState +from fortuna.prob_model.posterior.sgmcmc.hmc.hmc_state import HMCState from fortuna.prob_model.posterior.laplace.laplace_state import LaplaceState from fortuna.prob_model.posterior.map.map_state import MAPState from fortuna.prob_model.posterior.normalizing_flow.advi.advi_state import ADVIState @@ -21,3 +22,4 @@ class NameToPosteriorState(enum.Enum): vars()[SWAGState.__name__] = SWAGState vars()[SGHMCState.__name__] = SGHMCState vars()[CyclicalSGLDState.__name__] = CyclicalSGLDState + vars()[HMCState.__name__] = HMCState From c1973e1e8faf4f3235f055953f5da8742ec987f8 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Fri, 28 Jul 2023 21:17:57 +0200 Subject: [PATCH 6/6] Fix optimizer freeze support in HMC --- .../posterior/sgmcmc/hmc/hmc_trainer.py | 14 ++++- fortuna/utils/freeze.py | 63 +++++++++++++++++++ tests/fortuna/prob_model/test_train.py | 2 +- 3 files changed, 75 insertions(+), 4 deletions(-) diff --git a/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_trainer.py b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_trainer.py index e53ac8c4..d90f807e 100644 --- a/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_trainer.py +++ b/fortuna/prob_model/posterior/sgmcmc/hmc/hmc_trainer.py @@ -24,6 +24,11 @@ Array, Batch, ) +from fortuna.utils.freeze import ( + has_multiple_opt_state, + get_trainable_opt_state, + update_trainable_opt_state, +) class HMCTrainer(MAPTrainer): @@ -46,9 +51,12 @@ def training_step( unravel=unravel, **kwargs, ) - state = state.replace( - opt_state=state.opt_state._replace(log_prob=aux["loss"]), - ) + if has_multiple_opt_state(state): + opt_state = get_trainable_opt_state(state)._replace(log_prob=aux["loss"]) + state = update_trainable_opt_state(state, opt_state) + else: + opt_state = state.opt_state._replace(log_prob=aux["loss"]) + state = state.replace(opt_state=opt_state) return state, aux def __str__(self): diff --git a/fortuna/utils/freeze.py b/fortuna/utils/freeze.py index a5c3c8f8..31501a87 100644 --- a/fortuna/utils/freeze.py +++ b/fortuna/utils/freeze.py @@ -18,7 +18,9 @@ from optax import ( multi_transform, set_to_zero, + MultiTransformState, ) +from optax._src.wrappers import MaskedState from fortuna.typing import ( AnyKey, @@ -27,6 +29,8 @@ Params, ) +from fortuna.prob_model.posterior.state import PosteriorState + def all_values_in_labels(values: Iterable, labels: Any) -> None: """ @@ -81,6 +85,65 @@ def freeze_optimizer( return multi_transform(partition_optimizers, partition_params) +def has_multiple_opt_state(state: PosteriorState): + """ + Check if a given posterior state containts multiple optimizer states. + + Parameters + ---------- + state: PosteriorState + An instance of `PosteriorState`. + + Returns + ------- + bool + """ + return isinstance(state.opt_state, MultiTransformState) + + +def get_trainable_opt_state(state: PosteriorState): + """ + Get a trainable optimizer state. + + Parameters + ---------- + state: PosteriorState + An instance of `PosteriorState`. + + Returns + ------- + opt_state: Any + An instance of trainable optimizer state. + """ + return state.opt_state.inner_states["trainable"].inner_state + + +def update_trainable_opt_state(state: PosteriorState, opt_state: Any): + """ + Update a trainable optimizer state. + + Parameters + ---------- + state: PosteriorState + An instance of `PosteriorState`. + opt_state: Any + An instance of trainable optimizer state. + + Returns + ------- + PosteriorState + An updated posterior state. + """ + trainable_state = MaskedState(inner_state=opt_state) + opt_state = MultiTransformState( + inner_states={ + k: (trainable_state if k == "trainable" else v) + for k, v in state.opt_state.inner_states.items() + } + ) + return state.replace(opt_state=opt_state) + + def get_trainable_paths( params: Params, freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]], diff --git a/tests/fortuna/prob_model/test_train.py b/tests/fortuna/prob_model/test_train.py index a02ffb55..8da92759 100755 --- a/tests/fortuna/prob_model/test_train.py +++ b/tests/fortuna/prob_model/test_train.py @@ -386,7 +386,7 @@ def dryrun_task(task, method): ) state = ( prob_model.posterior.state.get() - if method not in ["deep_ensemble", "sghmc", "cyclical_sgld"] + if method not in ["deep_ensemble", "sghmc", "cyclical_sgld", "hmc"] else prob_model.posterior.state.get(-1) ) model_editor_params = state.params["model_editor"]["params"].unfreeze()