5454 },
5555 "dtensor_cfg" : {
5656 "enabled" : True ,
57+ "_v2" : False ,
5758 "cpu_offload" : False ,
5859 "sequence_parallel" : False ,
5960 "activation_checkpointing" : False ,
@@ -118,12 +119,20 @@ def tokenizer():
118119
119120
120121@pytest .fixture (scope = "function" )
121- def policy (cluster , tokenizer ):
122- """Initialize the policy."""
122+ def policy (cluster , tokenizer , request ):
123+ """Initialize the policy with dtensor v1/v2."""
124+ use_v2 = bool (getattr (request , "param" , False ))
125+ config = {
126+ ** simple_policy_config ,
127+ "dtensor_cfg" : {
128+ ** simple_policy_config ["dtensor_cfg" ],
129+ "_v2" : use_v2 ,
130+ },
131+ }
123132 policy = Policy (
124133 cluster = cluster ,
125134 tokenizer = tokenizer ,
126- config = simple_policy_config ,
135+ config = config ,
127136 init_optimizer = True ,
128137 init_reference_model = False ,
129138 )
@@ -285,7 +294,8 @@ def test_save_and_load_model_and_optimizer(mock_experiment):
285294
286295
287296@pytest .mark .parametrize ("num_gpus" , [1 , 2 ], ids = ["1gpu" , "2gpu" ])
288- def test_convert_dcp_to_hf (policy , num_gpus ):
297+ @pytest .mark .parametrize ("policy" , [False , True ], ids = ["v1" , "v2" ], indirect = True )
298+ def test_convert_dcp_to_hf (policy , num_gpus , request ):
289299 ## warm up with a forward pass
290300 ## this is needed before saving a checkpoint because FSDP does some lazy initialization
291301 input_ids = torch .randint (0 , 16000 , (4 , 128 )) # 4 sequences, each of length 128
@@ -301,21 +311,30 @@ def test_convert_dcp_to_hf(policy, num_gpus):
301311 }
302312 )
303313 policy .train (dummy_fwd_dict , SimpleLoss ())
314+ policy_version_is_v2 = request .node .callspec .params ["policy" ]
304315
305316 with TemporaryDirectory () as tmp_dir :
306317 policy .save_checkpoint (
307318 os .path .join (tmp_dir , "test_hf_and_dcp" ),
319+ checkpointing_cfg = {
320+ "enabled" : True ,
321+ "model_save_format" : "torch_save" if policy_version_is_v2 else None ,
322+ },
308323 )
309324
310325 # Dynamically create the expected set of distcp files based on num_gpus
311326 expected_distcp_files = {f"__{ rank } _0.distcp" for rank in range (num_gpus )}
312327 expected_files = expected_distcp_files .union ({".metadata" })
313328
314- ## make sure we save both HF and DCP checkpoints
315- assert (
316- set (os .listdir (os .path .join (tmp_dir , "test_hf_and_dcp" ))) == expected_files
329+ ckpt_path = (
330+ os .path .join (tmp_dir , "test_hf_and_dcp" , "model" )
331+ if policy_version_is_v2
332+ else os .path .join (tmp_dir , "test_hf_and_dcp" )
317333 )
318334
335+ ## make sure we save both HF and DCP checkpoints
336+ assert set (os .listdir (ckpt_path )) == expected_files
337+
319338 offline_converted_model_path = convert_dcp_to_hf (
320339 os .path .join (tmp_dir , "test_hf_and_dcp" ),
321340 os .path .join (tmp_dir , "test_hf_and_dcp-hf-offline" ),
0 commit comments