From fe7c41267fb2e7bec4534ff6c3280140d6ae65fa Mon Sep 17 00:00:00 2001 From: Victor Bourgin Date: Mon, 14 Apr 2025 19:37:23 -0700 Subject: [PATCH] Sanitize Metric Name in Checkpoints Summary: # Context Metric names may be included in checkpoint names when specifying a `best_checkpoint_config`, but no verification is done on the metric name. This may lead to nested directory structures if checkpoint names contain `/`, e.g.: f721785233 Here we use `top1_accuracy/evaluate` as the `monitored_metric`, which will create checkpoints in a nested directory: {F1977112918} Checkpointers won't be able to appropriately restore the checkpoint with the best monitored metric, as each checkpoint will be stored in a different directory. # Proposed change In this diff, we sanitize the metric name prior to checkpoint saving, replacing `/` with `_`. Now, checkpoints are saved in the same directory: f721793003 {F1977113027} Differential Revision: D73004419 --- tests/utils/test_checkpoint.py | 15 +++++++++++++++ torchtnt/utils/checkpoint.py | 29 +++++++++++++++++++++++++---- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py index fee88af6bc..695a2d0177 100644 --- a/tests/utils/test_checkpoint.py +++ b/tests/utils/test_checkpoint.py @@ -1139,6 +1139,12 @@ def test_best_checkpoint_path(self) -> None: best_path, ) + # apply sanitation + self.assertEqual( + get_best_checkpoint_path(temp_dir, "val/loss", "min"), + best_path, + ) + # handle negative values best_path_2 = os.path.join(temp_dir, "epoch_0_step_0_val_loss=-0.01") os.mkdir(best_path_2) @@ -1373,6 +1379,15 @@ def test_get_checkpoint_dirpaths(self) -> None: {path1, path2, path3}, ) + # with metric name sanitation + self.assertEqual( + { + str(x) + for x in get_checkpoint_dirpaths(temp_dir, metric_name="val/loss") + }, + {path1, path2, path3}, + ) + with tempfile.TemporaryDirectory() as temp_dir: self.assertEqual( get_checkpoint_dirpaths(temp_dir), diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index 53e8f95ac6..7f05f08bde 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -28,12 +28,27 @@ @dataclass class MetricData: """ - Representation of a metric instance. Should provide both a metric name and it's value. + Representation of a metric instance. Should provide both a metric name and its value. + + Note: The metric name is sanitized by replacing '/' with '_' to prevent potential issues + when using the name as a path or identifier. """ name: str value: float + def __init__(self, name: str, value: float) -> None: + self.name = MetricData.sanitize_metric_name(name) + self.value = value + + @classmethod + def sanitize_metric_name(cls, name: str) -> str: + """ + Sanitizes a metric name by replacing '/' with '_'. + This is done to prevent potential issues when using the name as a path or identifier. + """ + return name.replace("/", "_") + @dataclass class BestCheckpointConfig: @@ -481,9 +496,14 @@ def generate_checkpoint_path( self._best_checkpoint_config ), "Attempted to get a checkpoint with metric but best checkpoint config is not set" - assert self._best_checkpoint_config.monitored_metric == metric_data.name, ( + assert ( + MetricData.sanitize_metric_name( + self._best_checkpoint_config.monitored_metric + ) + == metric_data.name + ), ( f"Attempted to get a checkpoint with metric '{metric_data.name}', " - f"but best checkpoint config is for '{none_throws(self._best_checkpoint_config).monitored_metric}'" + f"but best checkpoint config is for '{MetricData.sanitize_metric_name(none_throws(self._best_checkpoint_config).monitored_metric)}'" ) checkpoint_path = CheckpointPath( @@ -815,7 +835,8 @@ def _retrieve_checkpoint_dirpaths( # If a metric was provided, keep only the checkpoints tracking it if metric_name and not ( - ckpt.metric_data and ckpt.metric_data.name == metric_name + ckpt.metric_data + and ckpt.metric_data.name == MetricData.sanitize_metric_name(metric_name) ): continue