Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions clu/parameter_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,17 @@ def flatten_dict(
def _count_parameters(params: _ParamsContainer) -> int:
"""Returns the count of variables for the module or parameter dictionary."""
params = flatten_dict(params)
return sum(np.prod(v.shape) for v in params.values())
return sum(np.prod(v.shape) for v in params.values() if v is not None)


def _parameters_size(params: _ParamsContainer) -> int:
"""Returns total size (bytes) for the module or parameter dictionary."""
params = flatten_dict(params)
return sum(np.prod(v.shape) * v.dtype.itemsize for v in params.values())
return sum(
np.prod(v.shape) * v.dtype.itemsize
for v in params.values()
if v is not None
)


def count_parameters(params: _ParamsContainer) -> int:
Expand Down Expand Up @@ -127,6 +131,8 @@ def _make_row_with_sharding(name, value) -> _ParamRowWithSharding:

def _make_row_with_stats(name, value, mean, std) -> _ParamRowWithStats:
row = _make_row(name, value)
mean = mean or 0.0
std = std or 0.0
return _ParamRowWithStats(
**dataclasses.asdict(row),
mean=float(jax.device_get(mean)),
Expand Down Expand Up @@ -156,12 +162,11 @@ def _get_parameter_rows(
params: Dictionary with parameters as NumPy arrays. The dictionary can be
nested. Alternatively a `tf.Module` can be provided, in which case the
`trainable_variables` of the module will be used.
include_stats: If True, add columns with mean and std for each variable.
If the string "sharding", add column a column with the sharding of the
variable.
If the string "global", params are sharded global arrays and this
function assumes it is called on every host, i.e. can use collectives.
The sharding of the variables is also added as a column.
include_stats: If True, add columns with mean and std for each variable. If
the string "sharding", add column a column with the sharding of the
variable. If the string "global", params are sharded global arrays and
this function assumes it is called on every host, i.e. can use
collectives. The sharding of the variables is also added as a column.

Returns:
A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value
Expand All @@ -185,12 +190,14 @@ def _get_parameter_rows(
case True:
mean_and_std = _mean_std(values)
return jax.tree_util.tree_map(
_make_row_with_stats, names, values, *mean_and_std)
_make_row_with_stats, names, values, *mean_and_std
)

case "global":
mean_and_std = _mean_std_jit(values)
return jax.tree_util.tree_map(
_make_row_with_stats_and_sharding, names, values, *mean_and_std)
_make_row_with_stats_and_sharding, names, values, *mean_and_std
)

case "sharding":
return jax.tree_util.tree_map(_make_row_with_sharding, names, values)
Expand Down Expand Up @@ -256,8 +263,7 @@ def __init__(self, name, values):
column_names = [field.name for field in dataclasses.fields(rows[0])]

columns = [
Column(name, [value_formatter(getattr(row, name))
for row in rows])
Column(name, [value_formatter(getattr(row, name)) for row in rows])
for name in column_names
]

Expand Down Expand Up @@ -312,12 +318,11 @@ def get_parameter_overview(
Args:
params: Dictionary with parameters as NumPy arrays. The dictionary can be
nested.
include_stats: If True, add columns with mean and std for each variable.
If the string "sharding", add column a column with the sharding of the
variable.
If the string "global", params are sharded global arrays and this
function assumes it is called on every host, i.e. can use collectives.
The sharding of the variables is also added as a column.
include_stats: If True, add columns with mean and std for each variable. If
the string "sharding", add column a column with the sharding of the
variable. If the string "global", params are sharded global arrays and
this function assumes it is called on every host, i.e. can use
collectives. The sharding of the variables is also added as a column.
max_lines: If not `None`, the maximum number of variables to include.

Returns:
Expand Down Expand Up @@ -375,16 +380,19 @@ def log_parameter_overview(
Args:
params: Dictionary with parameters as NumPy arrays. The dictionary can be
nested.
include_stats: If True, add columns with mean and std for each variable.
If the string "global", params are sharded global arrays and this
function assumes it is called on every host, i.e. can use collectives.
include_stats: If True, add columns with mean and std for each variable. If
the string "global", params are sharded global arrays and this function
assumes it is called on every host, i.e. can use collectives.
max_lines: If not `None`, the maximum number of variables to include.
msg: Message to be logged before the overview.
jax_logging_process: Which JAX process ID should do the logging. None = all.
Use this to avoid logspam when include_stats="global".
"""

_log_parameter_overview(
params, include_stats=include_stats, max_lines=max_lines, msg=msg,
jax_logging_process=jax_logging_process
params,
include_stats=include_stats,
max_lines=max_lines,
msg=msg,
jax_logging_process=jax_logging_process,
)