Skip to content

Commit

Permalink
Merge pull request #155 from stanfordnlp/zen/fsdp
Browse files Browse the repository at this point in the history
[P0] Fixing trainer saving due to FSDP integration (#154)
  • Loading branch information
frankaging authored Feb 4, 2025
2 parents 5a36985 + 5210c3f commit 852c5a5
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions pyreft/reft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,33 @@ def make_dataloader(

class ReftTrainer(Trainer):
def save_model(self, output_dir, _internal_call=False):
if dist.get_rank() == 0:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
self.model.save_intervention(
save_directory=f"{output_dir}/intervenable_model",
include_model=True
)
# Handle CPU training and non-distributed cases
try:
is_main_process = not dist.is_initialized() or dist.get_rank() == 0
except (RuntimeError, AttributeError) as e: # Catches case when torch.distributed is not available or other dist errors
logger.error(f"Error checking distributed training status: {str(e)}")
is_main_process = True

if is_main_process:
target_dir = f"{output_dir}/intervenable_model"
# Log warning if target directory exists and has content
if os.path.exists(target_dir) and os.listdir(target_dir):
logger.warning(
f"Directory {target_dir} already exists and contains files. "
"Skipping save to prevent overwriting existing model."
)
return

try:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
self.model.save_intervention(
save_directory=target_dir,
include_model=True
)
except Exception as e:
logger.error(f"Error saving model to {target_dir}: {str(e)}")
raise # Re-raise the exception after logging

def _load_best_model(self):
logger.warning(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
Expand Down

0 comments on commit 852c5a5

Please sign in to comment.