Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable model and data sharding #96

Draft
wants to merge 37 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
52e96ea
edit installation instructions in readme
gianlucadetommaso May 15, 2023
5e0076d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 15, 2023
4c7fd28
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 15, 2023
6cb6581
bump up version
gianlucadetommaso May 15, 2023
1b39780
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 16, 2023
cb2b49a
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 16, 2023
14e3ca4
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 25, 2023
580067d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 27, 2023
048ef09
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 2, 2023
ad542a4
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 12, 2023
41417c1
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 12, 2023
64be374
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 14, 2023
a2d0f34
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 14, 2023
66bba06
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
911aa82
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
01f959b
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
79f8dca
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
99a3b78
add sequence probit
gianlucadetommaso Jun 19, 2023
1c23a9e
add possibility to run sequential probit on last steps only
gianlucadetommaso Jun 20, 2023
4dea50f
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 21, 2023
915a1ea
Merge branch 'main' into seqprobit
gianlucadetommaso Jun 21, 2023
e966745
refactor sequential probit implementation
gianlucadetommaso Jun 23, 2023
529f9aa
add stop gradient flag
gianlucadetommaso Jun 24, 2023
42d2117
pre-commit
gianlucadetommaso Jun 24, 2023
734f597
add probit options in example script
gianlucadetommaso Jun 25, 2023
404840e
mesh
gianlucadetommaso Jun 25, 2023
4444907
enable model and data sharding
gianlucadetommaso Jun 25, 2023
830fbe8
make further changes after training roberta
gianlucadetommaso Jul 11, 2023
e3e1c4f
further changes
gianlucadetommaso Jul 16, 2023
6d47a47
refactoring laplace
gianlucadetommaso Jul 17, 2023
ed571de
start debugging swag
gianlucadetommaso Jul 18, 2023
1ced008
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
6992692
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
b2540c1
make small change in readme because of publish to pypi error
gianlucadetommaso Jul 18, 2023
2362998
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
ba52081
debug deep ensemble
gianlucadetommaso Jul 18, 2023
d2fc289
fix sghmc and sgld
gianlucadetommaso Jul 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
start debugging swag
gianlucadetommaso committed Jul 18, 2023
commit ed571de29e29e49f010953ff62df226c2afa9446
6 changes: 3 additions & 3 deletions fortuna/calib_model/base.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@
Targets,
Uncertainties,
)
from pathlib import Path as _Path
import pathlib
from jax._src.prng import PRNGKeyArray
from orbax.checkpoint import CheckpointManager
from fortuna.utils.checkpoint import get_checkpoint_manager
@@ -105,7 +105,7 @@ def _calibrate(
checkpoint_restorer = (
get_checkpoint_manager(
str(
_Path(config.checkpointer.restore_checkpoint_dir)
pathlib.Path(config.checkpointer.restore_checkpoint_dir)
/ config.checkpointer.checkpoint_type
),
keep_top_n_checkpoints=config.checkpointer.keep_top_n_checkpoints,
@@ -163,7 +163,7 @@ def init_state_fn(rng):
partition_manager=self.partition_manager,
checkpoint_manager=get_checkpoint_manager(
checkpoint_dir=str(
_Path(config.checkpointer.save_checkpoint_dir)
pathlib.Path(config.checkpointer.save_checkpoint_dir)
/ config.checkpointer.checkpoint_type
),
keep_top_n_checkpoints=config.checkpointer.keep_top_n_checkpoints,
38 changes: 15 additions & 23 deletions fortuna/prob_model/posterior/laplace/laplace_posterior.py
Original file line number Diff line number Diff line change
@@ -7,16 +7,9 @@
Optional,
Tuple,
Union,
Any,
Callable
)
from fortuna.data.loader.base import ShardedPrefetchedLoader
from flax.core import FrozenDict
from flax.training.common_utils import (
shard,
shard_prng_key,
)
import jax
from jax.sharding import PartitionSpec
from jax.experimental.pjit import pjit
from jax import (
@@ -34,7 +27,7 @@
import jax.numpy as jnp
from jax.tree_util import tree_map
import tqdm

import pathlib
from fortuna.data.loader import (
DataLoader,
DeviceDimensionAugmentedLoader,
@@ -64,16 +57,15 @@
Mutable,
Params,
Status,
Array
)
import pathlib
from fortuna.utils.checkpoint import get_checkpoint_manager
from fortuna.utils.freeze import get_trainable_paths
from fortuna.utils.nested_dicts import (
nested_get,
nested_set,
nested_unpair,
)
from pathlib import Path
from fortuna.utils.random import generate_random_normal_like_tree
from fortuna.utils.strings import decode_encoded_tuple_of_lists_of_strings_to_array
from fortuna.partitioner.partition_manager.base import PartitionManager
@@ -248,7 +240,7 @@ def fit(
checkpoint_restorer = (
get_checkpoint_manager(
str(
Path(fit_config.checkpointer.restore_checkpoint_dir)
pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir)
/ fit_config.checkpointer.checkpoint_type
),
keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints,
@@ -307,10 +299,10 @@ def fit(
)

self.state = PosteriorStateRepository(
partition_manager=self.partition_manager,
partition_manager=None,
checkpoint_manager=get_checkpoint_manager(
checkpoint_dir=str(
Path(fit_config.checkpointer.save_checkpoint_dir)
pathlib.Path(fit_config.checkpointer.save_checkpoint_dir)
/ fit_config.checkpointer.checkpoint_type
),
keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints,
@@ -319,7 +311,7 @@ def fit(
and fit_config.checkpointer.dump_state
else None,
)
self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints)
self.state.replace(state, keep=fit_config.checkpointer.keep_top_n_checkpoints)
logging.info("Fit completed.")
if (
val_data_loader is not None
@@ -332,7 +324,7 @@ def fit(
shard=fit_config.processor.devices == -1,
)
state = state.replace(prior_log_var=opt_prior_log_var)
self.state.put(state, keep=fit_config.checkpointer.keep_top_n_checkpoints)
self.state.replace(state, keep=fit_config.checkpointer.keep_top_n_checkpoints)
logging.info(f"Best prior log-variance found: {opt_prior_log_var}")
return status

@@ -343,7 +335,7 @@ def sample(
) -> JointState:
if rng is None:
rng = self.rng.get()
state: LaplaceState = self.state.get()
state = self.state.get()
if kwargs.get("prior_log_var") is not None:
state = state.replace(prior_log_var=kwargs.get("prior_log_var"))

@@ -352,9 +344,9 @@ def sample(
state._encoded_which_params
)
mean, hess_lik_diag = nested_unpair(
state.params.unfreeze(),
which_params,
("mean", "hess_lik_diag"),
d=state.params.unfreeze(),
key_paths=tuple(which_params),
labels=("mean", "hess_lik_diag"),
)
std = self._compute_std(
prior_log_var=state.prior_log_var, hess_lik_diag=hess_lik_diag
@@ -363,7 +355,7 @@ def sample(
noise = generate_random_normal_like_tree(rng, std)
params = nested_set(
d=mean,
key_paths=which_params,
key_paths=tuple(which_params),
objs=tuple(
[
tree_map(
@@ -418,7 +410,7 @@ def _init_map_state(
params=FrozenDict(
nested_unpair(
d=state.params.unfreeze(),
key_paths=which_params,
key_paths=tuple(which_params),
labels=("mean", "hess_lik_diag"),
)[0]
)
@@ -457,7 +449,7 @@ def _batched_log_prob(
keys = random.split(rng, n_posterior_samples)

def _lik_log_batched_prob(params, mutable, calib_params, calib_mutable):
return self.likelihood._batched_log_prob(
return self.joint.likelihood._batched_log_prob(
params,
batch,
mutable=mutable,
@@ -470,9 +462,9 @@ def _lik_log_batched_prob(params, mutable, calib_params, calib_mutable):
_lik_log_batched_prob = pjit(
_lik_log_batched_prob,
in_shardings=(
self.partition_manager.shardings.params,
self.partition_manager.shardings.mutable,
self.partition_manager.shardings.calib_params,
self.partition_manager.shardings.params,
self.partition_manager.shardings.calib_mutable,
),
out_shardings=PartitionSpec(("dp", "fsdp")),
6 changes: 3 additions & 3 deletions fortuna/prob_model/posterior/map/map_posterior.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from pathlib import Path
import pathlib
from typing import Optional

from jax import eval_shape
@@ -91,7 +91,7 @@ def fit(
checkpoint_restorer = (
get_checkpoint_manager(
str(
Path(fit_config.checkpointer.restore_checkpoint_dir)
pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir)
/ fit_config.checkpointer.checkpoint_type
),
keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints,
@@ -145,7 +145,7 @@ def init_state_fn(rng):
partition_manager=self.partition_manager,
checkpoint_manager=get_checkpoint_manager(
checkpoint_dir=str(
Path(fit_config.checkpointer.save_checkpoint_dir)
pathlib.Path(fit_config.checkpointer.save_checkpoint_dir)
/ fit_config.checkpointer.checkpoint_type
),
keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints,
Original file line number Diff line number Diff line change
@@ -14,7 +14,6 @@
from jax._src.prng import PRNGKeyArray
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import numpy as np

from fortuna.data.loader import (
DataLoader,
@@ -39,7 +38,7 @@
JittedADVITrainer,
MultiDeviceADVITrainer,
)
from pathlib import Path
import pathlib
from fortuna.partitioner.partition_manager.base import PartitionManager
from fortuna.prob_model.posterior.posterior_state_repository import (
PosteriorStateRepository,
@@ -105,7 +104,7 @@ def fit(
checkpoint_restorer = (
get_checkpoint_manager(
str(
Path(fit_config.checkpointer.restore_checkpoint_dir)
pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir)
/ fit_config.checkpointer.checkpoint_type
),
keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints,
@@ -214,7 +213,7 @@ def fit(
partition_manager=None,
checkpoint_manager=get_checkpoint_manager(
checkpoint_dir=str(
Path(fit_config.checkpointer.save_checkpoint_dir)
pathlib.Path(fit_config.checkpointer.save_checkpoint_dir)
/ fit_config.checkpointer.checkpoint_type
),
keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints,
Original file line number Diff line number Diff line change
@@ -39,7 +39,6 @@
nested_get,
nested_set,
)
from pathlib import Path
from fortuna.utils.checkpoint import get_checkpoint_manager

logger = logging.getLogger(__name__)
@@ -145,7 +144,7 @@ def fit(
checkpoint_restorer = (
get_checkpoint_manager(
str(
Path(fit_config.checkpointer.restore_checkpoint_dir)
pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir)
/ fit_config.checkpointer.checkpoint_type
),
keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints,
@@ -179,7 +178,7 @@ def fit(
partition_manager=self.partition_manager,
checkpoint_manager=get_checkpoint_manager(
str(
Path(fit_config.checkpointer.restore_checkpoint_dir)
pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir)
/ fit_config.checkpointer.checkpoint_type
),
keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints,
5 changes: 2 additions & 3 deletions fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,6 @@
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_posterior_state_repository import (
SGMCMCPosteriorStateRepository,
)
from pathlib import Path
from fortuna.typing import Status
from fortuna.utils.device import select_trainer_given_devices
from fortuna.utils.freeze import get_trainable_paths
@@ -142,7 +141,7 @@ def fit(
checkpoint_restorer = (
get_checkpoint_manager(
str(
Path(fit_config.checkpointer.restore_checkpoint_dir)
pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir)
/ fit_config.checkpointer.checkpoint_type
),
keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints,
@@ -175,7 +174,7 @@ def fit(
size=self.posterior_approximator.n_samples,
checkpoint_manager=get_checkpoint_manager(
str(
Path(fit_config.checkpointer.restore_checkpoint_dir)
pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir)
/ fit_config.checkpointer.checkpoint_type
),
keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints,
7 changes: 3 additions & 4 deletions fortuna/prob_model/posterior/swag/swag_posterior.py
Original file line number Diff line number Diff line change
@@ -8,12 +8,11 @@
from jax._src.prng import PRNGKeyArray
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp

import pathlib
from fortuna.data.loader import (
DataLoader,
InputsLoader,
)
from pathlib import Path
from fortuna.utils.checkpoint import get_checkpoint_manager
from fortuna.prob_model.fit_config.base import FitConfig
from fortuna.prob_model.joint.base import Joint
@@ -105,7 +104,7 @@ def fit(
checkpoint_restorer = (
get_checkpoint_manager(
str(
Path(fit_config.checkpointer.restore_checkpoint_dir)
pathlib.Path(fit_config.checkpointer.restore_checkpoint_dir)
),
keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints,
)
@@ -200,7 +199,7 @@ def fit(
partition_manager=self.partition_manager,
checkpoint_manager=get_checkpoint_manager(
checkpoint_dir=str(
Path(fit_config.checkpointer.save_checkpoint_dir)
pathlib.Path(fit_config.checkpointer.save_checkpoint_dir)
/ fit_config.checkpointer.checkpoint_type
),
keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints,
3 changes: 2 additions & 1 deletion fortuna/prob_model/posterior/swag/swag_trainer.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
Batch,
Path,
)
import pathlib
from fortuna.partitioner.partition_manager.base import PartitionManager
from orbax.checkpoint import CheckpointManager
from fortuna.utils.strings import encode_tuple_of_lists_of_strings_to_numpy
@@ -114,7 +115,7 @@ def save_checkpoint(
def on_train_end(self, state: SWAGState) -> SWAGState:
self.save_checkpoint(
state,
save_checkpoint_dir=self.save_checkpoint_dir,
save_checkpoint_dir=str(pathlib.Path(self.save_checkpoint_dir) / "last"),
keep=self.keep_top_n_checkpoints,
force_save=True,
)
27 changes: 22 additions & 5 deletions fortuna/training/train_state_repository.py
Original file line number Diff line number Diff line change
@@ -64,6 +64,18 @@ def put(
else:
self._state = state

def remove(
self,
checkpoint_dir: Path = None,
):
if checkpoint_dir or self.checkpoint_manager:
if checkpoint_dir is None:
step = self.checkpoint_manager.latest_step()
if step is not None:
self.checkpoint_manager.delete(step)
else:
rmtree(checkpoint_dir)

def pull(
self,
checkpoint_dir: Path = None,
@@ -73,13 +85,18 @@ def pull(
checkpoint_dir=checkpoint_dir,
optimizer=optimizer,
)
if checkpoint_dir or self.checkpoint_manager:
if checkpoint_dir is None:
self.checkpoint_manager.delete(self.checkpoint_manager.latest_step())
else:
rmtree(checkpoint_dir)
self.remove(checkpoint_dir)
return state

def replace(
self,
state: TrainState,
checkpoint_dir: Optional[Path] = None,
keep: int = 1,
):
self.remove(checkpoint_dir)
self.put(state, checkpoint_dir, keep=keep)

def update(
self,
variables: Dict,
Loading