Skip to content

Commit

Permalink
Move locking logic to snapshotting.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666040622
  • Loading branch information
liangyaning33 authored and Orbax Authors committed Aug 22, 2024
1 parent 5ae048a commit 24cfcd2
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
6 changes: 3 additions & 3 deletions checkpoint/orbax/checkpoint/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from orbax.checkpoint import utils
from orbax.checkpoint.metadata import value as value_metadata
from orbax.checkpoint.path import step as step_lib

from orbax.checkpoint.path import utils as path_utils

PyTree = Any
STANDARD_ARRAY_TYPES = (int, float, np.ndarray, jax.Array)
Expand All @@ -49,7 +49,7 @@ def _lock_checkpoint(
step: int,
step_name_format: step_lib.NameFormat[step_lib.Metadata],
) -> bool:
"""Locks a checkpoint by writing a LOCKED directory."""
"""Creates a snapshot for CNS and a LOCKED file for non-CNS paths."""
logging.info('Locking step: %d before gaining control.', step)
step_dir = step_name_format.find_step(checkpoint_dir, step).path
if not step_dir.exists():
Expand All @@ -73,7 +73,7 @@ def _unlock_checkpoint(
step: int,
step_name_format: step_lib.NameFormat[step_lib.Metadata],
):
"""Removes a LOCKED directory to indicate unlocking."""
"""Deletes a snapshot for CNS and removes a LOCKED directory for non-CNS."""
if multihost.process_index() == 0:
logging.info('Unlocking existing step: %d.', step)
step_dir = step_name_format.find_step(checkpoint_dir, step).path
Expand Down
9 changes: 9 additions & 0 deletions checkpoint/orbax/checkpoint/checkpoint_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""Tests for checkpoint_utils."""

import functools
from unittest import mock
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
Expand Down Expand Up @@ -296,6 +298,13 @@ def test_unlock_existing(self):
)
self.assertTrue(utils.is_locked(self.directory / str(0)))
self.assertFalse(utils.is_locked(self.directory / str(1)))
checkpoint_utils._unlock_checkpoint(
self.directory,
step=0,
step_name_format=step_lib.standard_name_format(
step_prefix=None, step_format_fixed_length=None
),
)

for _ in checkpoint_utils.checkpoints_iterator(self.directory):
break
Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/path/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# outside Orbax (ignoring OSS).

import asyncio
import os
from typing import List, Optional, Tuple

from etils import epath
Expand Down

0 comments on commit 24cfcd2

Please sign in to comment.