|
5 | 5 | import torch.distributed as dist
|
6 | 6 | import torch.nn as nn
|
7 | 7 | from conversion.harness import DispatchTestCase
|
8 |
| -from distributed_utils import set_environment_variables_pytest |
| 8 | + |
| 9 | +# The distributed env initialization has to be before torchTRT import since it uses barrier |
| 10 | +from distributed_utils import ( |
| 11 | + set_environment_variables_pytest, |
| 12 | + set_environment_variables_pytest_multi_process, |
| 13 | + set_environment_variables_pytest_single_process, |
| 14 | +) |
9 | 15 | from parameterized import parameterized
|
10 | 16 | from torch.testing._internal.common_utils import run_tests
|
11 |
| -from torch_tensorrt.dynamo.utils import is_platform_supported_for_trtllm |
| 17 | + |
| 18 | +if "OMPI_COMM_WORLD_SIZE" in os.environ: |
| 19 | + set_environment_variables_pytest_multi_process() |
| 20 | +else: |
| 21 | + set_environment_variables_pytest_single_process() |
| 22 | + |
| 23 | +if not dist.is_initialized(): |
| 24 | + dist.init_process_group( |
| 25 | + backend="nccl", |
| 26 | + init_method="env://", |
| 27 | + ) |
| 28 | + |
| 29 | +from torch_tensorrt.dynamo.distributed.utils import is_platform_supported_for_trtllm |
12 | 30 |
|
13 | 31 |
|
14 | 32 | class DistributedGatherModel(nn.Module):
|
@@ -48,11 +66,9 @@ class TestNcclOpsConverter(DispatchTestCase):
|
48 | 66 | )
|
49 | 67 | @classmethod
|
50 | 68 | def setUpClass(cls):
|
51 |
| - set_environment_variables_pytest() |
52 |
| - cls.world_size = 1 |
53 |
| - if not dist.is_initialized(): |
54 |
| - dist.init_process_group(backend="nccl") |
55 |
| - cls.group = dist.new_group(ranks=[0]) |
| 69 | + cls.world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) |
| 70 | + cls.rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) |
| 71 | + cls.group = dist.new_group(ranks=list(range(cls.world_size))) |
56 | 72 | cls.group_name = cls.group.group_name
|
57 | 73 |
|
58 | 74 | @classmethod
|
|
0 commit comments