Skip to content

Commit 314c272

Browse files
authored
fix: Fix DCP-to-HF conversion for model-wrapped checkpoints (#1881)
Signed-off-by: ruit <ruit@nvidia.com>
1 parent 2196f40 commit 314c272

3 files changed

Lines changed: 48 additions & 14 deletions

File tree

nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,7 @@ def save_checkpoint(
10811081
optimizer=self.optimizer,
10821082
optimizer_path=optimizer_path,
10831083
scheduler=self.scheduler,
1084-
tokenizer=self.tokenizer if tokenizer_path is None else None,
1084+
tokenizer=self.tokenizer if tokenizer_path else None,
10851085
tokenizer_path=tokenizer_path,
10861086
checkpointing_cfg=checkpointing_cfg,
10871087
lora_enabled=self.lora_enabled,

nemo_rl/utils/native_checkpoint.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,17 +236,32 @@ def convert_dcp_to_hf(
236236
raise FileExistsError(
237237
f"HF checkpoint already exists at {hf_ckpt_path}. Delete it to run or set overwrite=True."
238238
)
239-
240239
os.makedirs(hf_ckpt_path, exist_ok=True)
240+
241+
# The ckpt path of dtensor v2 is like <ckpt_dir>/model, while v1 is like <ckpt_dir>
242+
# Choose the correct subdir based on the presence of the metadata file.
243+
metadata_path = os.path.join(dcp_ckpt_path, ".metadata")
244+
if not os.path.exists(metadata_path):
245+
model_subdir = os.path.join(dcp_ckpt_path, "model")
246+
model_metadata_path = os.path.join(model_subdir, ".metadata")
247+
if os.path.exists(model_metadata_path):
248+
dcp_ckpt_path = model_subdir
249+
print(f"Using dcp_ckpt_path of Dtensor V2: {model_subdir}")
250+
else:
251+
raise FileNotFoundError(
252+
f"No metadata file found in {dcp_ckpt_path}(Dtensor V1 ckpt path) or {model_subdir}(Dtensor V2 ckpt path)."
253+
)
254+
else:
255+
print(f"Using dcp_ckpt_path of Dtensor V1: {dcp_ckpt_path}")
256+
241257
weights_path = os.path.join(hf_ckpt_path, "pytorch_model.bin")
242258
dcp_to_torch_save(dcp_ckpt_path, weights_path)
243259

244-
# Need to reload and save b/c the state dict is scoped inside the model key {"model": actual_state_dict}
260+
# Reload and save because DCP exports wrap weights under {"model": ...} in dtensor v1
261+
# while others save a flat state_dict already in dtensor v2.``
245262
state_dict = torch.load(weights_path)
246-
assert set(state_dict.keys()) == {"model"}, (
247-
f"We expect that the state dict only has the top level model key, but found: {state_dict.keys()}"
248-
)
249-
torch.save(state_dict["model"], weights_path)
263+
if set(state_dict.keys()) == {"model"}:
264+
torch.save(state_dict["model"], weights_path)
250265

251266
config = AutoConfig.from_pretrained(
252267
model_name_or_path, trust_remote_code=True, **hf_overrides

tests/unit/utils/test_native_checkpoint.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
},
5555
"dtensor_cfg": {
5656
"enabled": True,
57+
"_v2": False,
5758
"cpu_offload": False,
5859
"sequence_parallel": False,
5960
"activation_checkpointing": False,
@@ -118,12 +119,20 @@ def tokenizer():
118119

119120

120121
@pytest.fixture(scope="function")
121-
def policy(cluster, tokenizer):
122-
"""Initialize the policy."""
122+
def policy(cluster, tokenizer, request):
123+
"""Initialize the policy with dtensor v1/v2."""
124+
use_v2 = bool(getattr(request, "param", False))
125+
config = {
126+
**simple_policy_config,
127+
"dtensor_cfg": {
128+
**simple_policy_config["dtensor_cfg"],
129+
"_v2": use_v2,
130+
},
131+
}
123132
policy = Policy(
124133
cluster=cluster,
125134
tokenizer=tokenizer,
126-
config=simple_policy_config,
135+
config=config,
127136
init_optimizer=True,
128137
init_reference_model=False,
129138
)
@@ -285,7 +294,8 @@ def test_save_and_load_model_and_optimizer(mock_experiment):
285294

286295

287296
@pytest.mark.parametrize("num_gpus", [1, 2], ids=["1gpu", "2gpu"])
288-
def test_convert_dcp_to_hf(policy, num_gpus):
297+
@pytest.mark.parametrize("policy", [False, True], ids=["v1", "v2"], indirect=True)
298+
def test_convert_dcp_to_hf(policy, num_gpus, request):
289299
## warm up with a forward pass
290300
## this is needed before saving a checkpoint because FSDP does some lazy initialization
291301
input_ids = torch.randint(0, 16000, (4, 128)) # 4 sequences, each of length 128
@@ -301,21 +311,30 @@ def test_convert_dcp_to_hf(policy, num_gpus):
301311
}
302312
)
303313
policy.train(dummy_fwd_dict, SimpleLoss())
314+
policy_version_is_v2 = request.node.callspec.params["policy"]
304315

305316
with TemporaryDirectory() as tmp_dir:
306317
policy.save_checkpoint(
307318
os.path.join(tmp_dir, "test_hf_and_dcp"),
319+
checkpointing_cfg={
320+
"enabled": True,
321+
"model_save_format": "torch_save" if policy_version_is_v2 else None,
322+
},
308323
)
309324

310325
# Dynamically create the expected set of distcp files based on num_gpus
311326
expected_distcp_files = {f"__{rank}_0.distcp" for rank in range(num_gpus)}
312327
expected_files = expected_distcp_files.union({".metadata"})
313328

314-
## make sure we save both HF and DCP checkpoints
315-
assert (
316-
set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp"))) == expected_files
329+
ckpt_path = (
330+
os.path.join(tmp_dir, "test_hf_and_dcp", "model")
331+
if policy_version_is_v2
332+
else os.path.join(tmp_dir, "test_hf_and_dcp")
317333
)
318334

335+
## make sure we save both HF and DCP checkpoints
336+
assert set(os.listdir(ckpt_path)) == expected_files
337+
319338
offline_converted_model_path = convert_dcp_to_hf(
320339
os.path.join(tmp_dir, "test_hf_and_dcp"),
321340
os.path.join(tmp_dir, "test_hf_and_dcp-hf-offline"),

0 commit comments

Comments
 (0)