From 43acbbda92562e671279d153d0b191d6345d29c5 Mon Sep 17 00:00:00 2001 From: CLU Authors Date: Thu, 30 Jan 2025 05:35:39 -0800 Subject: [PATCH] Fix when variable is None for include_stats is True or False PiperOrigin-RevId: 721354960 --- clu/parameter_overview.py | 54 ++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/clu/parameter_overview.py b/clu/parameter_overview.py index 3dc9337..8cdb921 100644 --- a/clu/parameter_overview.py +++ b/clu/parameter_overview.py @@ -82,13 +82,17 @@ def flatten_dict( def _count_parameters(params: _ParamsContainer) -> int: """Returns the count of variables for the module or parameter dictionary.""" params = flatten_dict(params) - return sum(np.prod(v.shape) for v in params.values()) + return sum(np.prod(v.shape) for v in params.values() if v is not None) def _parameters_size(params: _ParamsContainer) -> int: """Returns total size (bytes) for the module or parameter dictionary.""" params = flatten_dict(params) - return sum(np.prod(v.shape) * v.dtype.itemsize for v in params.values()) + return sum( + np.prod(v.shape) * v.dtype.itemsize + for v in params.values() + if v is not None + ) def count_parameters(params: _ParamsContainer) -> int: @@ -127,6 +131,8 @@ def _make_row_with_sharding(name, value) -> _ParamRowWithSharding: def _make_row_with_stats(name, value, mean, std) -> _ParamRowWithStats: row = _make_row(name, value) + mean = mean or 0.0 + std = std or 0.0 return _ParamRowWithStats( **dataclasses.asdict(row), mean=float(jax.device_get(mean)), @@ -156,12 +162,11 @@ def _get_parameter_rows( params: Dictionary with parameters as NumPy arrays. The dictionary can be nested. Alternatively a `tf.Module` can be provided, in which case the `trainable_variables` of the module will be used. - include_stats: If True, add columns with mean and std for each variable. - If the string "sharding", add column a column with the sharding of the - variable. - If the string "global", params are sharded global arrays and this - function assumes it is called on every host, i.e. can use collectives. - The sharding of the variables is also added as a column. + include_stats: If True, add columns with mean and std for each variable. If + the string "sharding", add column a column with the sharding of the + variable. If the string "global", params are sharded global arrays and + this function assumes it is called on every host, i.e. can use + collectives. The sharding of the variables is also added as a column. Returns: A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value @@ -185,12 +190,14 @@ def _get_parameter_rows( case True: mean_and_std = _mean_std(values) return jax.tree_util.tree_map( - _make_row_with_stats, names, values, *mean_and_std) + _make_row_with_stats, names, values, *mean_and_std + ) case "global": mean_and_std = _mean_std_jit(values) return jax.tree_util.tree_map( - _make_row_with_stats_and_sharding, names, values, *mean_and_std) + _make_row_with_stats_and_sharding, names, values, *mean_and_std + ) case "sharding": return jax.tree_util.tree_map(_make_row_with_sharding, names, values) @@ -256,8 +263,7 @@ def __init__(self, name, values): column_names = [field.name for field in dataclasses.fields(rows[0])] columns = [ - Column(name, [value_formatter(getattr(row, name)) - for row in rows]) + Column(name, [value_formatter(getattr(row, name)) for row in rows]) for name in column_names ] @@ -312,12 +318,11 @@ def get_parameter_overview( Args: params: Dictionary with parameters as NumPy arrays. The dictionary can be nested. - include_stats: If True, add columns with mean and std for each variable. - If the string "sharding", add column a column with the sharding of the - variable. - If the string "global", params are sharded global arrays and this - function assumes it is called on every host, i.e. can use collectives. - The sharding of the variables is also added as a column. + include_stats: If True, add columns with mean and std for each variable. If + the string "sharding", add column a column with the sharding of the + variable. If the string "global", params are sharded global arrays and + this function assumes it is called on every host, i.e. can use + collectives. The sharding of the variables is also added as a column. max_lines: If not `None`, the maximum number of variables to include. Returns: @@ -375,9 +380,9 @@ def log_parameter_overview( Args: params: Dictionary with parameters as NumPy arrays. The dictionary can be nested. - include_stats: If True, add columns with mean and std for each variable. - If the string "global", params are sharded global arrays and this - function assumes it is called on every host, i.e. can use collectives. + include_stats: If True, add columns with mean and std for each variable. If + the string "global", params are sharded global arrays and this function + assumes it is called on every host, i.e. can use collectives. max_lines: If not `None`, the maximum number of variables to include. msg: Message to be logged before the overview. jax_logging_process: Which JAX process ID should do the logging. None = all. @@ -385,6 +390,9 @@ def log_parameter_overview( """ _log_parameter_overview( - params, include_stats=include_stats, max_lines=max_lines, msg=msg, - jax_logging_process=jax_logging_process + params, + include_stats=include_stats, + max_lines=max_lines, + msg=msg, + jax_logging_process=jax_logging_process, )