Skip to content

Commit 1eb561d

Browse files
committed
Add checkpoint manager for resume capability
1 parent 280d6ce commit 1eb561d

File tree

1 file changed

+84
-0
lines changed

1 file changed

+84
-0
lines changed

checkpoint_manager.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
Checkpoint and resume management for training runs.
3+
"""
4+
5+
import json
6+
from datetime import datetime
7+
from pathlib import Path
8+
from typing import Any, Dict, Optional
9+
10+
11+
class CheckpointManager:
12+
"""Manage training checkpoints and resume state."""
13+
14+
def __init__(self, run_dir: Path):
15+
"""
16+
Initialize checkpoint manager.
17+
18+
Args:
19+
run_dir: Directory to store checkpoints and state.
20+
"""
21+
self.run_dir = run_dir
22+
self.state_file = run_dir / "run_state.json"
23+
24+
def save_run_state(
25+
self,
26+
round_idx: int,
27+
global_step: int,
28+
learning_rate: float,
29+
checkpoint_uri: str,
30+
config: Dict[str, Any],
31+
) -> None:
32+
"""
33+
Save current run state for resumption.
34+
35+
Args:
36+
round_idx: Current training round.
37+
global_step: Current global step count.
38+
learning_rate: Current learning rate.
39+
checkpoint_uri: URI of the latest checkpoint.
40+
config: Full configuration dict.
41+
"""
42+
state = {
43+
"round_idx": round_idx,
44+
"global_step": global_step,
45+
"learning_rate": learning_rate,
46+
"checkpoint_uri": checkpoint_uri,
47+
"config": config,
48+
"timestamp": datetime.now().isoformat(),
49+
}
50+
51+
with open(self.state_file, "w") as f:
52+
json.dump(state, f, indent=2)
53+
54+
def load_run_state(self) -> Optional[Dict[str, Any]]:
55+
"""
56+
Load saved run state.
57+
58+
Returns:
59+
Saved state dict, or None if no state file exists.
60+
"""
61+
if not self.state_file.exists():
62+
return None
63+
64+
with open(self.state_file, "r") as f:
65+
return json.load(f)
66+
67+
def has_saved_state(self) -> bool:
68+
"""Check if a saved state exists."""
69+
return self.state_file.exists()
70+
71+
72+
def find_latest_run() -> Optional[Path]:
73+
"""
74+
Find the most recent run directory.
75+
76+
Returns:
77+
Path to latest run directory, or None if no runs exist.
78+
"""
79+
runs_dir = Path("runs")
80+
if not runs_dir.exists():
81+
return None
82+
83+
run_dirs = sorted([d for d in runs_dir.iterdir() if d.is_dir()], reverse=True)
84+
return run_dirs[0] if run_dirs else None

0 commit comments

Comments
 (0)