File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -408,7 +408,9 @@ def load_optim_state_dict(self, state_dict: dict) -> None:
408408 ``optimizer.state_dict()``
409409 """
410410 optim_state_dict = FSDP .optim_state_dict_to_load (
411- state_dict , self .model , self .optim_wrapper .optimizer )
411+ optim_state_dict = state_dict ,
412+ model = self .model ,
413+ optim = self .optim_wrapper .optimizer )
412414 self .optim_wrapper .load_state_dict (optim_state_dict )
413415
414416 def _init_state_dict_cfg (self , state_dict_cfg : Union [str , dict ]) -> None :
@@ -539,7 +541,9 @@ def build_optim_wrapper(
539541 # Force to load the converted optim_state_dict in full mode.
540542 with FSDP .state_dict_type (model , StateDictType .FULL_STATE_DICT ):
541543 optim_state_dict = FSDP .optim_state_dict_to_load (
542- optim_state_dict , model , new_optimizer )
544+ optim_state_dict = optim_state_dict ,
545+ model = model ,
546+ optim = new_optimizer )
543547 new_optimizer .load_state_dict (optim_state_dict )
544548 optim_wrapper .optimizer = new_optimizer
545549 return optim_wrapper
You can’t perform that action at this time.
0 commit comments