@@ -19,9 +19,9 @@ class ChebaiBaseNet(LightningModule):
19
19
Args:
20
20
criterion (torch.nn.Module, optional): The loss criterion for the model. Defaults to None.
21
21
out_dim (int, optional): The output dimension of the model. Defaults to None.
22
- train_metrics (Dict[str, Metric] , optional): The metrics to be used during training. Defaults to None.
23
- val_metrics (Dict[str, Metric] , optional): The metrics to be used during validation. Defaults to None.
24
- test_metrics (Dict[str, Metric] , optional): The metrics to be used during testing. Defaults to None.
22
+ train_metrics (torch.nn.Module , optional): The metrics to be used during training. Defaults to None.
23
+ val_metrics (torch.nn.Module , optional): The metrics to be used during validation. Defaults to None.
24
+ test_metrics (torch.nn.Module , optional): The metrics to be used during testing. Defaults to None.
25
25
pass_loss_kwargs (bool, optional): Whether to pass loss kwargs to the criterion. Defaults to True.
26
26
optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer. Defaults to None.
27
27
**kwargs: Additional keyword arguments.
@@ -36,9 +36,9 @@ def __init__(
36
36
self ,
37
37
criterion : torch .nn .Module = None ,
38
38
out_dim : Optional [int ] = None ,
39
- train_metrics : Optional [Dict [ str , Metric ] ] = None ,
40
- val_metrics : Optional [Dict [ str , Metric ] ] = None ,
41
- test_metrics : Optional [Dict [ str , Metric ] ] = None ,
39
+ train_metrics : Optional [torch . nn . Module ] = None ,
40
+ val_metrics : Optional [torch . nn . Module ] = None ,
41
+ test_metrics : Optional [torch . nn . Module ] = None ,
42
42
pass_loss_kwargs : bool = True ,
43
43
optimizer_kwargs : Optional [Dict [str , Any ]] = None ,
44
44
** kwargs ,
@@ -207,7 +207,7 @@ def _execute(
207
207
self ,
208
208
batch : XYData ,
209
209
batch_idx : int ,
210
- metrics : Dict [ str , Metric ] ,
210
+ metrics : Optional [ torch . nn . Module ] = None ,
211
211
prefix : Optional [str ] = "" ,
212
212
log : Optional [bool ] = True ,
213
213
sync_dist : Optional [bool ] = False ,
@@ -218,7 +218,7 @@ def _execute(
218
218
Args:
219
219
batch (XYData): The input batch of data.
220
220
batch_idx (int): The index of the current batch.
221
- metrics (Dict[str, Metric] ): A dictionary of metrics to track.
221
+ metrics (torch.nn.Module ): A dictionary of metrics to track.
222
222
prefix (str, optional): A prefix to add to the metric names. Defaults to "".
223
223
log (bool, optional): Whether to log the metrics. Defaults to True.
224
224
sync_dist (bool, optional): Whether to synchronize distributed training. Defaults to False.
@@ -275,13 +275,13 @@ def _execute(
275
275
self ._log_metrics (prefix , metrics , len (batch ))
276
276
return d
277
277
278
- def _log_metrics (self , prefix : str , metrics : Dict [ str , Metric ] , batch_size : int ):
278
+ def _log_metrics (self , prefix : str , metrics : torch . nn . Module , batch_size : int ):
279
279
"""
280
280
Logs the metrics for the given prefix.
281
281
282
282
Args:
283
283
prefix (str): The prefix to be added to the metric names.
284
- metrics (Dict[str, Metric] ): A dictionary containing the metrics to be logged.
284
+ metrics (torch.nn.Module ): A dictionary containing the metrics to be logged.
285
285
batch_size (int): The batch size used for logging.
286
286
287
287
Returns:
0 commit comments