From 017c2e3bb8a252f47a4f30d5d80db788d5fcdf9c Mon Sep 17 00:00:00 2001 From: Dimios45 Date: Mon, 20 Oct 2025 19:22:07 +0530 Subject: [PATCH] Add metrics implementation to compile_utils --- keras/src/trainers/compile_utils.py | 222 ++++++++++++++++++++++- keras/src/trainers/compile_utils_test.py | 95 ++++++++++ 2 files changed, 316 insertions(+), 1 deletion(-) diff --git a/keras/src/trainers/compile_utils.py b/keras/src/trainers/compile_utils.py index d911aa805ca0..c4b855b38f1d 100644 --- a/keras/src/trainers/compile_utils.py +++ b/keras/src/trainers/compile_utils.py @@ -175,6 +175,40 @@ def variables(self): return vars def build(self, y_true, y_pred): + # Handle nested structures using tree utilities similar to CompileLoss + if tree.is_nested(y_true) and tree.is_nested(y_pred): + try: + # Align the metrics configuration with the structure of y_pred + if self._has_nested_structure(self._user_metrics) or self._has_nested_structure(self._user_weighted_metrics): + self._flat_metrics = self._build_nested_metrics( + self._user_metrics, y_true, y_pred, "metrics" + ) + self._flat_weighted_metrics = self._build_nested_metrics( + self._user_weighted_metrics, y_true, y_pred, "weighted_metrics" + ) + else: + self._build_with_flat_structure(y_true, y_pred) + except (ValueError, TypeError): + self._build_with_flat_structure(y_true, y_pred) + else: + self._build_with_flat_structure(y_true, y_pred) + self.built = True + + def _has_nested_structure(self, obj): + """Helper method to check if object has nested dict/list structure.""" + if obj is None: + return False + if isinstance(obj, dict): + for value in obj.values(): + if isinstance(value, (dict, list)): + return True + elif isinstance(obj, list): + for item in obj: + if isinstance(item, (dict, list)): + return True + return False + + def _build_with_flat_structure(self, y_true, y_pred): num_outputs = 1 # default # Resolve output names. If y_pred is a dict, prefer its keys. if isinstance(y_pred, dict): @@ -219,7 +253,193 @@ def build(self, y_true, y_pred): y_pred, argument_name="weighted_metrics", ) - self.built = True + + def _build_nested_metrics(self, metrics_config, y_true, y_pred, argument_name): + """Build metrics for nested structures following y_pred structure.""" + if metrics_config is None: + # If metrics_config is None, create None placeholders for each output + return self._build_flat_placeholders(y_true, y_pred) + + if (isinstance(metrics_config, dict) and + isinstance(y_pred, dict) and + set(metrics_config.keys()).issubset(set(y_pred.keys())) and + not any(tree.is_nested(v) for v in y_pred.values())): + + return self._build_metrics_set_for_nested(metrics_config, y_true, y_pred, argument_name) + + # Handle metrics configuration with tree structure similar to y_pred + def build_recursive_metrics(metrics_cfg, yt, yp, path=(), is_nested_path=False): + """Recursively build metrics for nested structures.""" + if isinstance(metrics_cfg, dict) and isinstance(yp, dict): + # Both metrics and predictions are dicts, process recursively + flat_metrics = [] + for key in yp.keys(): + current_path = path + (key,) + if key in metrics_cfg: + if isinstance(yp[key], dict) and isinstance(metrics_cfg[key], dict): + flat_metrics.extend(build_recursive_metrics(metrics_cfg[key], yt[key], yp[key], current_path, True)) + elif isinstance(yp[key], (list, tuple)) and isinstance(metrics_cfg[key], (list, tuple)): + flat_metrics.extend(build_recursive_metrics(metrics_cfg[key], yt[key], yp[key], current_path, True)) + else: + output_name = "_".join(map(str, current_path)) if is_nested_path else None + flat_metrics.append(self._build_single_output_metrics(metrics_cfg[key], yt[key], yp[key], argument_name, output_name=output_name)) + else: + flat_metrics.append(None) + return flat_metrics + elif isinstance(metrics_cfg, (list, tuple)) and isinstance(yp, (list, tuple)): + + flat_metrics = [] + for i, (m_cfg, y_t_elem, y_p_elem) in enumerate(zip(metrics_cfg, yt, yp)): + current_path = path + (i,) + if isinstance(y_p_elem, (dict, list, tuple)) and isinstance(m_cfg, (dict, list, tuple)): + flat_metrics.extend(build_recursive_metrics(m_cfg, y_t_elem, y_p_elem, current_path, True)) + else: + output_name = "_".join(map(str, current_path)) if is_nested_path else None + flat_metrics.append(self._build_single_output_metrics(m_cfg, y_t_elem, y_p_elem, argument_name, output_name=output_name)) + return flat_metrics + else: + output_name = "_".join(map(str, path)) if path and is_nested_path else None + return [self._build_single_output_metrics(metrics_cfg, yt, yp, argument_name, output_name=output_name)] + + # For truly complex nested structures, use recursive approach + return build_recursive_metrics(metrics_config, y_true, y_pred) + + def _build_single_output_metrics(self, metric_config, y_true, y_pred, argument_name, output_name=None): + """Build metrics for a single output.""" + if metric_config is None: + return None + elif not isinstance(metric_config, list): + metric_config = [metric_config] + if not all(is_function_like(m) for m in metric_config): + raise ValueError( + f"All entries in the sublists of the " + f"`{argument_name}` structure should be metric objects. " + f"Found the following with unknown types: {metric_config}" + ) + return MetricsList( + [ + get_metric(m, y_true, y_pred) + for m in metric_config + if m is not None + ], + output_name=output_name + ) + + def _build_flat_placeholders(self, y_true, y_pred): + """Create None placeholders for each output when config is None.""" + flat_y_pred = tree.flatten(y_pred) + return [None] * len(flat_y_pred) + + def _build_metrics_set_for_nested(self, metrics, y_true, y_pred, argument_name): + """Alternative method to build metrics when we detect nested structures.""" + flat_y_pred = tree.flatten(y_pred) + flat_y_true = tree.flatten(y_true) + + if isinstance(y_pred, dict): + flat_output_names = tree.flatten(y_pred) + output_names = self._flatten_dict_keys(y_pred) + else: + output_names = [None] * len(flat_y_pred) if self.output_names is None else self.output_names + + # If metrics is a flat dict that should map to the outputs + if isinstance(metrics, dict): + flat_metrics = [] + if isinstance(y_pred, dict): + # Map metrics dict to y_pred dict keys + for idx, (name, yt, yp) in enumerate(zip(y_pred.keys(), flat_y_true, flat_y_pred)): + if name in metrics: + metric_list = metrics[name] + if not isinstance(metric_list, list): + metric_list = [metric_list] + if not all(is_function_like(e) for e in metric_list): + raise ValueError( + f"All entries in the sublists of the " + f"`{argument_name}` dict should be metric objects. " + f"At key '{name}', found the following with unknown types: {metric_list}" + ) + flat_metrics.append( + MetricsList( + [ + get_metric(m, yt, yp) + for m in metric_list + if m is not None + ], + output_name=name, + ) + ) + else: + flat_metrics.append(None) + else: + return self._build_metrics_set(metrics, len(flat_y_pred), output_names, flat_y_true, flat_y_pred, argument_name) + elif isinstance(metrics, (list, tuple)): + # Handle list/tuple case for nested outputs + if len(metrics) != len(flat_y_pred): + raise ValueError( + f"For a model with multiple outputs, " + f"when providing the `{argument_name}` argument as a " + f"list, it should have as many entries as the model has " + f"outputs. Received:\n{argument_name}={metrics}\nof " + f"length {len(metrics)} whereas the model has " + f"{len(flat_y_pred)} outputs." + ) + flat_metrics = [] + for idx, (mls, yt, yp) in enumerate(zip(metrics, flat_y_true, flat_y_pred)): + if not isinstance(mls, list): + mls = [mls] + name = output_names[idx] if output_names and idx < len(output_names) else None + if not all(is_function_like(e) for e in mls): + raise ValueError( + f"All entries in the sublists of the " + f"`{argument_name}` list should be metric objects. " + f"Found the following sublist with unknown types: {mls}" + ) + flat_metrics.append( + MetricsList( + [ + get_metric(m, yt, yp) + for m in mls + if m is not None + ], + output_name=name, + ) + ) + else: + # Handle single metric applied to all outputs + flat_metrics = [] + for idx, (yt, yp) in enumerate(zip(flat_y_true, flat_y_pred)): + name = output_names[idx] if output_names and idx < len(output_names) else None + if metrics is None: + flat_metrics.append(None) + else: + if not is_function_like(metrics): + raise ValueError( + f"Expected all entries in the `{argument_name}` list " + f"to be metric objects. Received instead:\n" + f"{argument_name}={metrics}" + ) + flat_metrics.append( + MetricsList( + [get_metric(metrics, yt, yp)], + output_name=name, + ) + ) + + return flat_metrics + + def _flatten_dict_keys(self, d): + """Flatten dict to get key names in order.""" + if isinstance(d, dict): + return list(d.keys()) + elif isinstance(d, (list, tuple)): + result = [] + for item in d: + if isinstance(item, dict): + result.extend(list(item.keys())) + else: + result.append(None) + return result + else: + return [None] def _build_metrics_set( self, metrics, num_outputs, output_names, y_true, y_pred, argument_name diff --git a/keras/src/trainers/compile_utils_test.py b/keras/src/trainers/compile_utils_test.py index d27c5292b63d..630365071f06 100644 --- a/keras/src/trainers/compile_utils_test.py +++ b/keras/src/trainers/compile_utils_test.py @@ -538,6 +538,101 @@ def test_struct_loss_namedtuple(self): value = compile_loss(y_true, y_pred) self.assertAllClose(value, 1.07666, atol=1e-5) + def test_nested_dict_metrics(self): + import numpy as np + from keras.src import Input + from keras.src import Model + from keras.src import layers + from keras.src import metrics as metrics_module + + # Create test data matching the nested structure + y_true = { + 'a': np.random.rand(10, 10), + 'b': { + 'c': np.random.rand(10, 10), + 'd': np.random.rand(10, 10) + } + } + y_pred = { + 'a': np.random.rand(10, 10), + 'b': { + 'c': np.random.rand(10, 10), + 'd': np.random.rand(10, 10) + } + } + + # Test compiling with nested dict metrics + compile_metrics = CompileMetrics( + metrics={ + 'a': [metrics_module.MeanSquaredError()], + 'b': { + 'c': [metrics_module.MeanSquaredError(), metrics_module.MeanAbsoluteError()], + 'd': [metrics_module.MeanSquaredError()] + } + }, + weighted_metrics=None, + ) + + # Build the metrics + compile_metrics.build(y_true, y_pred) + + # Update state and get results + compile_metrics.update_state(y_true, y_pred, sample_weight=None) + result = compile_metrics.result() + + # Check that expected metrics are present + # The actual names might be different based on how MetricsList handles output names + expected_metric_names = [] + for key in result.keys(): + if 'mean_squared_error' in key or 'mean_absolute_error' in key: + expected_metric_names.append(key) + + # At least some metrics should be computed + self.assertGreater(len(expected_metric_names), 0) + + # Test with symbolic tensors as well + y_true_symb = tree.map_structure(lambda _: backend.KerasTensor((10, 10)), y_true) + y_pred_symb = tree.map_structure(lambda _: backend.KerasTensor((10, 10)), y_pred) + compile_metrics_symbolic = CompileMetrics( + metrics={ + 'a': [metrics_module.MeanSquaredError()], + 'b': { + 'c': [metrics_module.MeanSquaredError(), metrics_module.MeanAbsoluteError()], + 'd': [metrics_module.MeanSquaredError()] + } + }, + weighted_metrics=None, + ) + compile_metrics_symbolic.build(y_true_symb, y_pred_symb) + self.assertTrue(compile_metrics_symbolic.built) + + +def test_nested_dict_metrics(): + import numpy as np + from keras.src import layers + from keras.src import Input + from keras.src import Model + + X = np.random.rand(100, 32) + Y1 = np.random.rand(100, 10) + Y2 = np.random.rand(100, 10) + Y3 = np.random.rand(100, 10) + + def create_model(): + x = Input(shape=(32,)) + y1 = layers.Dense(10)(x) + y2 = layers.Dense(10)(x) + y3 = layers.Dense(10)(x) + return Model(inputs=x, outputs={'a': y1, 'b': {'c': y2, 'd': y3}}) + + model = create_model() + model.compile( + optimizer='adam', + loss={'a': 'mse', 'b': {'c': 'mse', 'd': 'mse'}}, + metrics={'a': ['mae'], 'b': {'c': 'mse', 'd': 'mae'}}, + ) + model.train_on_batch(X, {'a': Y1, 'b': {'c': Y2, 'd': Y3}}) + def test_struct_loss_invalid_path(self): y_true = { "a": {