Skip to content

Commit

Permalink
support concurrency for multigpu testing
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Nov 14, 2024
1 parent 09e97a0 commit 7595420
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 2 deletions.
2 changes: 1 addition & 1 deletion cicd/multigpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
set -e

# only run one test at a time so as not to OOM the GPU
pytest -n1 /workspace/axolotl/tests/e2e/multigpu/
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
5 changes: 5 additions & 0 deletions tests/e2e/multigpu/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port

from axolotl.utils.dict import DictDefault

Expand Down Expand Up @@ -83,6 +84,8 @@ def test_eval_sample_packing(self, temp_dir):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
Expand Down Expand Up @@ -148,6 +151,8 @@ def test_eval(self, temp_dir):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
Expand Down
19 changes: 19 additions & 0 deletions tests/e2e/multigpu/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import yaml
from accelerate.test_utils import execute_subprocess_async
from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port

from axolotl.utils.dict import DictDefault

Expand Down Expand Up @@ -78,6 +79,8 @@ def test_lora_ddp(self, temp_dir):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
Expand Down Expand Up @@ -137,6 +140,8 @@ def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps=1):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
Expand Down Expand Up @@ -209,6 +214,8 @@ def test_dpo_lora_ddp(self, temp_dir):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
Expand Down Expand Up @@ -277,6 +284,8 @@ def test_dpo_qlora_ddp(self, temp_dir):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
Expand Down Expand Up @@ -341,6 +350,8 @@ def test_fsdp(self, temp_dir, gradient_accumulation_steps=1):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
Expand Down Expand Up @@ -407,6 +418,8 @@ def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type=None):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
Expand Down Expand Up @@ -483,6 +496,8 @@ def test_fsdp_qlora_prequant_packed(self, temp_dir):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
Expand Down Expand Up @@ -536,6 +551,8 @@ def test_ds_zero3_packed(self, temp_dir, gradient_accumulation_steps=1):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
Expand Down Expand Up @@ -592,6 +609,8 @@ def test_ds_zero3_qlora_packed(self, temp_dir):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
Expand Down
5 changes: 4 additions & 1 deletion tests/e2e/multigpu/test_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port

from axolotl.utils.dict import DictDefault

Expand All @@ -28,7 +29,7 @@ def test_qlora_fsdp_dpo(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2-1.5B",
"base_model": "Qwen/Qwen2-0.5B",
"load_in_4bit": True,
"rl": "dpo",
"chat_template": "chatml",
Expand Down Expand Up @@ -91,6 +92,8 @@ def test_qlora_fsdp_dpo(self, temp_dir):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
Expand Down

0 comments on commit 7595420

Please sign in to comment.