diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 95a17410e..50a1008bc 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -688,7 +688,7 @@ def clean_last_n_checkpoints(output_dir: str, keep_last_n_checkpoints: int) -> N folders = [f for f in os.listdir(output_dir) if is_checkpoint_folder(output_dir, f)] # find the checkpoint with the largest step checkpoints = sorted(folders, key=lambda x: int(x.split("_")[-1])) - if len(checkpoints) > keep_last_n_checkpoints: + if keep_last_n_checkpoints != -1 and len(checkpoints) > keep_last_n_checkpoints: for checkpoint in checkpoints[: len(checkpoints) - keep_last_n_checkpoints]: logger.info(f"Removing checkpoint {checkpoint}") shutil.rmtree(os.path.join(output_dir, checkpoint))