diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py index fee88af6bc..695a2d0177 100644 --- a/tests/utils/test_checkpoint.py +++ b/tests/utils/test_checkpoint.py @@ -1139,6 +1139,12 @@ def test_best_checkpoint_path(self) -> None: best_path, ) + # apply sanitation + self.assertEqual( + get_best_checkpoint_path(temp_dir, "val/loss", "min"), + best_path, + ) + # handle negative values best_path_2 = os.path.join(temp_dir, "epoch_0_step_0_val_loss=-0.01") os.mkdir(best_path_2) @@ -1373,6 +1379,15 @@ def test_get_checkpoint_dirpaths(self) -> None: {path1, path2, path3}, ) + # with metric name sanitation + self.assertEqual( + { + str(x) + for x in get_checkpoint_dirpaths(temp_dir, metric_name="val/loss") + }, + {path1, path2, path3}, + ) + with tempfile.TemporaryDirectory() as temp_dir: self.assertEqual( get_checkpoint_dirpaths(temp_dir), diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index 53e8f95ac6..7f05f08bde 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -28,12 +28,27 @@ @dataclass class MetricData: """ - Representation of a metric instance. Should provide both a metric name and it's value. + Representation of a metric instance. Should provide both a metric name and its value. + + Note: The metric name is sanitized by replacing '/' with '_' to prevent potential issues + when using the name as a path or identifier. """ name: str value: float + def __init__(self, name: str, value: float) -> None: + self.name = MetricData.sanitize_metric_name(name) + self.value = value + + @classmethod + def sanitize_metric_name(cls, name: str) -> str: + """ + Sanitizes a metric name by replacing '/' with '_'. + This is done to prevent potential issues when using the name as a path or identifier. + """ + return name.replace("/", "_") + @dataclass class BestCheckpointConfig: @@ -481,9 +496,14 @@ def generate_checkpoint_path( self._best_checkpoint_config ), "Attempted to get a checkpoint with metric but best checkpoint config is not set" - assert self._best_checkpoint_config.monitored_metric == metric_data.name, ( + assert ( + MetricData.sanitize_metric_name( + self._best_checkpoint_config.monitored_metric + ) + == metric_data.name + ), ( f"Attempted to get a checkpoint with metric '{metric_data.name}', " - f"but best checkpoint config is for '{none_throws(self._best_checkpoint_config).monitored_metric}'" + f"but best checkpoint config is for '{MetricData.sanitize_metric_name(none_throws(self._best_checkpoint_config).monitored_metric)}'" ) checkpoint_path = CheckpointPath( @@ -815,7 +835,8 @@ def _retrieve_checkpoint_dirpaths( # If a metric was provided, keep only the checkpoints tracking it if metric_name and not ( - ckpt.metric_data and ckpt.metric_data.name == metric_name + ckpt.metric_data + and ckpt.metric_data.name == MetricData.sanitize_metric_name(metric_name) ): continue