Skip to content

Commit e367a7c

Browse files
author
sfluegel
committed
fix typehints for model metrics
1 parent 1990183 commit e367a7c

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

chebai/models/base.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ class ChebaiBaseNet(LightningModule):
1919
Args:
2020
criterion (torch.nn.Module, optional): The loss criterion for the model. Defaults to None.
2121
out_dim (int, optional): The output dimension of the model. Defaults to None.
22-
train_metrics (Dict[str, Metric], optional): The metrics to be used during training. Defaults to None.
23-
val_metrics (Dict[str, Metric], optional): The metrics to be used during validation. Defaults to None.
24-
test_metrics (Dict[str, Metric], optional): The metrics to be used during testing. Defaults to None.
22+
train_metrics (torch.nn.Module, optional): The metrics to be used during training. Defaults to None.
23+
val_metrics (torch.nn.Module, optional): The metrics to be used during validation. Defaults to None.
24+
test_metrics (torch.nn.Module, optional): The metrics to be used during testing. Defaults to None.
2525
pass_loss_kwargs (bool, optional): Whether to pass loss kwargs to the criterion. Defaults to True.
2626
optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer. Defaults to None.
2727
**kwargs: Additional keyword arguments.
@@ -36,9 +36,9 @@ def __init__(
3636
self,
3737
criterion: torch.nn.Module = None,
3838
out_dim: Optional[int] = None,
39-
train_metrics: Optional[Dict[str, Metric]] = None,
40-
val_metrics: Optional[Dict[str, Metric]] = None,
41-
test_metrics: Optional[Dict[str, Metric]] = None,
39+
train_metrics: Optional[torch.nn.Module] = None,
40+
val_metrics: Optional[torch.nn.Module] = None,
41+
test_metrics: Optional[torch.nn.Module] = None,
4242
pass_loss_kwargs: bool = True,
4343
optimizer_kwargs: Optional[Dict[str, Any]] = None,
4444
**kwargs,
@@ -207,7 +207,7 @@ def _execute(
207207
self,
208208
batch: XYData,
209209
batch_idx: int,
210-
metrics: Dict[str, Metric],
210+
metrics: Optional[torch.nn.Module] = None,
211211
prefix: Optional[str] = "",
212212
log: Optional[bool] = True,
213213
sync_dist: Optional[bool] = False,
@@ -218,7 +218,7 @@ def _execute(
218218
Args:
219219
batch (XYData): The input batch of data.
220220
batch_idx (int): The index of the current batch.
221-
metrics (Dict[str, Metric]): A dictionary of metrics to track.
221+
metrics (torch.nn.Module): A dictionary of metrics to track.
222222
prefix (str, optional): A prefix to add to the metric names. Defaults to "".
223223
log (bool, optional): Whether to log the metrics. Defaults to True.
224224
sync_dist (bool, optional): Whether to synchronize distributed training. Defaults to False.
@@ -275,13 +275,13 @@ def _execute(
275275
self._log_metrics(prefix, metrics, len(batch))
276276
return d
277277

278-
def _log_metrics(self, prefix: str, metrics: Dict[str, Metric], batch_size: int):
278+
def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int):
279279
"""
280280
Logs the metrics for the given prefix.
281281
282282
Args:
283283
prefix (str): The prefix to be added to the metric names.
284-
metrics (Dict[str, Metric]): A dictionary containing the metrics to be logged.
284+
metrics (torch.nn.Module): A dictionary containing the metrics to be logged.
285285
batch_size (int): The batch size used for logging.
286286
287287
Returns:

0 commit comments

Comments
 (0)