@@ -175,6 +175,40 @@ def variables(self):
175175 return vars
176176
177177 def build (self , y_true , y_pred ):
178+ # Handle nested structures using tree utilities similar to CompileLoss
179+ if tree .is_nested (y_true ) and tree .is_nested (y_pred ):
180+ try :
181+ # Align the metrics configuration with the structure of y_pred
182+ if self ._has_nested_structure (self ._user_metrics ) or self ._has_nested_structure (self ._user_weighted_metrics ):
183+ self ._flat_metrics = self ._build_nested_metrics (
184+ self ._user_metrics , y_true , y_pred , "metrics"
185+ )
186+ self ._flat_weighted_metrics = self ._build_nested_metrics (
187+ self ._user_weighted_metrics , y_true , y_pred , "weighted_metrics"
188+ )
189+ else :
190+ self ._build_with_flat_structure (y_true , y_pred )
191+ except (ValueError , TypeError ):
192+ self ._build_with_flat_structure (y_true , y_pred )
193+ else :
194+ self ._build_with_flat_structure (y_true , y_pred )
195+ self .built = True
196+
197+ def _has_nested_structure (self , obj ):
198+ """Helper method to check if object has nested dict/list structure."""
199+ if obj is None :
200+ return False
201+ if isinstance (obj , dict ):
202+ for value in obj .values ():
203+ if isinstance (value , (dict , list )):
204+ return True
205+ elif isinstance (obj , list ):
206+ for item in obj :
207+ if isinstance (item , (dict , list )):
208+ return True
209+ return False
210+
211+ def _build_with_flat_structure (self , y_true , y_pred ):
178212 num_outputs = 1 # default
179213 # Resolve output names. If y_pred is a dict, prefer its keys.
180214 if isinstance (y_pred , dict ):
@@ -219,7 +253,193 @@ def build(self, y_true, y_pred):
219253 y_pred ,
220254 argument_name = "weighted_metrics" ,
221255 )
222- self .built = True
256+
257+ def _build_nested_metrics (self , metrics_config , y_true , y_pred , argument_name ):
258+ """Build metrics for nested structures following y_pred structure."""
259+ if metrics_config is None :
260+ # If metrics_config is None, create None placeholders for each output
261+ return self ._build_flat_placeholders (y_true , y_pred )
262+
263+ if (isinstance (metrics_config , dict ) and
264+ isinstance (y_pred , dict ) and
265+ set (metrics_config .keys ()).issubset (set (y_pred .keys ())) and
266+ not any (tree .is_nested (v ) for v in y_pred .values ())):
267+
268+ return self ._build_metrics_set_for_nested (metrics_config , y_true , y_pred , argument_name )
269+
270+ # Handle metrics configuration with tree structure similar to y_pred
271+ def build_recursive_metrics (metrics_cfg , yt , yp , path = (), is_nested_path = False ):
272+ """Recursively build metrics for nested structures."""
273+ if isinstance (metrics_cfg , dict ) and isinstance (yp , dict ):
274+ # Both metrics and predictions are dicts, process recursively
275+ flat_metrics = []
276+ for key in yp .keys ():
277+ current_path = path + (key ,)
278+ if key in metrics_cfg :
279+ if isinstance (yp [key ], dict ) and isinstance (metrics_cfg [key ], dict ):
280+ flat_metrics .extend (build_recursive_metrics (metrics_cfg [key ], yt [key ], yp [key ], current_path , True ))
281+ elif isinstance (yp [key ], (list , tuple )) and isinstance (metrics_cfg [key ], (list , tuple )):
282+ flat_metrics .extend (build_recursive_metrics (metrics_cfg [key ], yt [key ], yp [key ], current_path , True ))
283+ else :
284+ output_name = "_" .join (map (str , current_path )) if is_nested_path else None
285+ flat_metrics .append (self ._build_single_output_metrics (metrics_cfg [key ], yt [key ], yp [key ], argument_name , output_name = output_name ))
286+ else :
287+ flat_metrics .append (None )
288+ return flat_metrics
289+ elif isinstance (metrics_cfg , (list , tuple )) and isinstance (yp , (list , tuple )):
290+
291+ flat_metrics = []
292+ for i , (m_cfg , y_t_elem , y_p_elem ) in enumerate (zip (metrics_cfg , yt , yp )):
293+ current_path = path + (i ,)
294+ if isinstance (y_p_elem , (dict , list , tuple )) and isinstance (m_cfg , (dict , list , tuple )):
295+ flat_metrics .extend (build_recursive_metrics (m_cfg , y_t_elem , y_p_elem , current_path , True ))
296+ else :
297+ output_name = "_" .join (map (str , current_path )) if is_nested_path else None
298+ flat_metrics .append (self ._build_single_output_metrics (m_cfg , y_t_elem , y_p_elem , argument_name , output_name = output_name ))
299+ return flat_metrics
300+ else :
301+ output_name = "_" .join (map (str , path )) if path and is_nested_path else None
302+ return [self ._build_single_output_metrics (metrics_cfg , yt , yp , argument_name , output_name = output_name )]
303+
304+ # For truly complex nested structures, use recursive approach
305+ return build_recursive_metrics (metrics_config , y_true , y_pred )
306+
307+ def _build_single_output_metrics (self , metric_config , y_true , y_pred , argument_name , output_name = None ):
308+ """Build metrics for a single output."""
309+ if metric_config is None :
310+ return None
311+ elif not isinstance (metric_config , list ):
312+ metric_config = [metric_config ]
313+ if not all (is_function_like (m ) for m in metric_config ):
314+ raise ValueError (
315+ f"All entries in the sublists of the "
316+ f"`{ argument_name } ` structure should be metric objects. "
317+ f"Found the following with unknown types: { metric_config } "
318+ )
319+ return MetricsList (
320+ [
321+ get_metric (m , y_true , y_pred )
322+ for m in metric_config
323+ if m is not None
324+ ],
325+ output_name = output_name
326+ )
327+
328+ def _build_flat_placeholders (self , y_true , y_pred ):
329+ """Create None placeholders for each output when config is None."""
330+ flat_y_pred = tree .flatten (y_pred )
331+ return [None ] * len (flat_y_pred )
332+
333+ def _build_metrics_set_for_nested (self , metrics , y_true , y_pred , argument_name ):
334+ """Alternative method to build metrics when we detect nested structures."""
335+ flat_y_pred = tree .flatten (y_pred )
336+ flat_y_true = tree .flatten (y_true )
337+
338+ if isinstance (y_pred , dict ):
339+ flat_output_names = tree .flatten (y_pred )
340+ output_names = self ._flatten_dict_keys (y_pred )
341+ else :
342+ output_names = [None ] * len (flat_y_pred ) if self .output_names is None else self .output_names
343+
344+ # If metrics is a flat dict that should map to the outputs
345+ if isinstance (metrics , dict ):
346+ flat_metrics = []
347+ if isinstance (y_pred , dict ):
348+ # Map metrics dict to y_pred dict keys
349+ for idx , (name , yt , yp ) in enumerate (zip (y_pred .keys (), flat_y_true , flat_y_pred )):
350+ if name in metrics :
351+ metric_list = metrics [name ]
352+ if not isinstance (metric_list , list ):
353+ metric_list = [metric_list ]
354+ if not all (is_function_like (e ) for e in metric_list ):
355+ raise ValueError (
356+ f"All entries in the sublists of the "
357+ f"`{ argument_name } ` dict should be metric objects. "
358+ f"At key '{ name } ', found the following with unknown types: { metric_list } "
359+ )
360+ flat_metrics .append (
361+ MetricsList (
362+ [
363+ get_metric (m , yt , yp )
364+ for m in metric_list
365+ if m is not None
366+ ],
367+ output_name = name ,
368+ )
369+ )
370+ else :
371+ flat_metrics .append (None )
372+ else :
373+ return self ._build_metrics_set (metrics , len (flat_y_pred ), output_names , flat_y_true , flat_y_pred , argument_name )
374+ elif isinstance (metrics , (list , tuple )):
375+ # Handle list/tuple case for nested outputs
376+ if len (metrics ) != len (flat_y_pred ):
377+ raise ValueError (
378+ f"For a model with multiple outputs, "
379+ f"when providing the `{ argument_name } ` argument as a "
380+ f"list, it should have as many entries as the model has "
381+ f"outputs. Received:\n { argument_name } ={ metrics } \n of "
382+ f"length { len (metrics )} whereas the model has "
383+ f"{ len (flat_y_pred )} outputs."
384+ )
385+ flat_metrics = []
386+ for idx , (mls , yt , yp ) in enumerate (zip (metrics , flat_y_true , flat_y_pred )):
387+ if not isinstance (mls , list ):
388+ mls = [mls ]
389+ name = output_names [idx ] if output_names and idx < len (output_names ) else None
390+ if not all (is_function_like (e ) for e in mls ):
391+ raise ValueError (
392+ f"All entries in the sublists of the "
393+ f"`{ argument_name } ` list should be metric objects. "
394+ f"Found the following sublist with unknown types: { mls } "
395+ )
396+ flat_metrics .append (
397+ MetricsList (
398+ [
399+ get_metric (m , yt , yp )
400+ for m in mls
401+ if m is not None
402+ ],
403+ output_name = name ,
404+ )
405+ )
406+ else :
407+ # Handle single metric applied to all outputs
408+ flat_metrics = []
409+ for idx , (yt , yp ) in enumerate (zip (flat_y_true , flat_y_pred )):
410+ name = output_names [idx ] if output_names and idx < len (output_names ) else None
411+ if metrics is None :
412+ flat_metrics .append (None )
413+ else :
414+ if not is_function_like (metrics ):
415+ raise ValueError (
416+ f"Expected all entries in the `{ argument_name } ` list "
417+ f"to be metric objects. Received instead:\n "
418+ f"{ argument_name } ={ metrics } "
419+ )
420+ flat_metrics .append (
421+ MetricsList (
422+ [get_metric (metrics , yt , yp )],
423+ output_name = name ,
424+ )
425+ )
426+
427+ return flat_metrics
428+
429+ def _flatten_dict_keys (self , d ):
430+ """Flatten dict to get key names in order."""
431+ if isinstance (d , dict ):
432+ return list (d .keys ())
433+ elif isinstance (d , (list , tuple )):
434+ result = []
435+ for item in d :
436+ if isinstance (item , dict ):
437+ result .extend (list (item .keys ()))
438+ else :
439+ result .append (None )
440+ return result
441+ else :
442+ return [None ]
223443
224444 def _build_metrics_set (
225445 self , metrics , num_outputs , output_names , y_true , y_pred , argument_name
0 commit comments