Skip to content

Commit

Permalink
fix monkeypatch reloading in tests and deepspeed patch
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 9, 2024
1 parent 6577087 commit 70099c5
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/axolotl/monkeypatch/trainer_grad_accum.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def patch_forward_for_ga():
PATCHED_TRAINER_CODE = """
disable_deepspeed_no_sync = (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
)
context = (
functools.partial(self.accelerator.no_sync, model=model)
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def apply_patches(self) -> None:
)

patch_training_loop_for_fsdp()
elif self.cfg.deepspeed:
elif self.cfg.deepspeed and self.cfg.gradient_accumulation_steps > 1:
from axolotl.monkeypatch.trainer_grad_accum import (
patch_training_loop_for_deepspeed_0_16_x,
)
Expand Down
11 changes: 9 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,20 @@ def cleanup_monkeypatches():

# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers",),
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.trainer", ["Trainer"]),
("transformers.trainer",),
("transformers", ["Trainer"]),
("transformers.loss.loss_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]

spec = importlib.util.spec_from_file_location(
module_name, sys.modules[module_name].__file__
)
sys.modules[module_name] = importlib.util.module_from_spec(spec)
spec.loader.exec_module(sys.modules[module_name])

module = importlib.import_module(module_name)
sys.modules[module_name] = module
importlib.reload(sys.modules[module_name])
Expand Down
44 changes: 22 additions & 22 deletions tests/e2e/multigpu/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_lora_ddp(self, temp_dir):
},
],
"num_epochs": 1,
"max_steps": 5,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_lora_ddp(self, temp_dir):

@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
[1, 2],
)
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
Expand All @@ -118,8 +118,8 @@ def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 4,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_dpo_lora_ddp(self, temp_dir):
},
],
"num_epochs": 1,
"max_steps": 5,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
Expand Down Expand Up @@ -265,8 +265,8 @@ def test_dpo_qlora_ddp(self, temp_dir):
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 4,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
Expand Down Expand Up @@ -303,7 +303,7 @@ def test_dpo_qlora_ddp(self, temp_dir):

@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
[1, 2],
)
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
Expand All @@ -322,8 +322,8 @@ def test_fsdp(self, temp_dir, gradient_accumulation_steps):
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 4,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
Expand Down Expand Up @@ -394,7 +394,7 @@ def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
},
],
"num_epochs": 1,
"max_steps": 5,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
Expand Down Expand Up @@ -475,7 +475,7 @@ def test_fsdp_qlora_prequant_packed(self, temp_dir):
},
],
"num_epochs": 1,
"max_steps": 5,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
Expand Down Expand Up @@ -526,14 +526,14 @@ def test_fsdp_qlora_prequant_packed(self, temp_dir):

@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
[1, 2],
)
@pytest.mark.parametrize(
"deepspeed",
[
"deepspeed_configs/zero3_bf16.json",
"deepspeed_configs/zero3_bf16_cpuoffload_all.json",
"deepspeed_configs/zero3_bf16_cpuoffload_params.json",
# "deepspeed_configs/zero3_bf16_cpuoffload_params.json",
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -572,8 +572,8 @@ def test_ds_zero3_packed(
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
Expand Down Expand Up @@ -611,7 +611,7 @@ def test_ds_zero3_packed(

@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
[1, 2],
)
@pytest.mark.parametrize(
"qlora",
Expand Down Expand Up @@ -647,8 +647,8 @@ def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora):
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
Expand Down Expand Up @@ -686,7 +686,7 @@ def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora):

@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
[1, 2],
)
@pytest.mark.parametrize(
"qlora",
Expand Down Expand Up @@ -722,8 +722,8 @@ def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora):
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
Expand Down

0 comments on commit 70099c5

Please sign in to comment.