Skip to content

Commit 3f1fa7e

Browse files
committed
Distributed utils package, separating out env for single GPU and multiGPU
1 parent 7134053 commit 3f1fa7e

File tree

5 files changed

+55
-33
lines changed

5 files changed

+55
-33
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from utils import is_platform_supported_for_trtllm, load_tensorrt_llm_for_nccl

py/torch_tensorrt/dynamo/distributed/utils.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,6 @@
1818
logger = logging.getLogger(__name__)
1919

2020

21-
def initialize_distributed_env(
22-
rank: int = 0, world_size: int = 1, port: int = 29500
23-
) -> None:
24-
local_rank = int(
25-
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
26-
)
27-
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
28-
29-
# Set up environment variable to run with mpirun
30-
os.environ["RANK"] = str(local_rank)
31-
os.environ["WORLD_SIZE"] = str(world_size)
32-
os.environ["MASTER_ADDR"] = "127.0.0.1"
33-
os.environ["MASTER_PORT"] = str(port)
34-
35-
# Necessary to assign a device to each rank.
36-
torch.cuda.set_device(local_rank)
37-
38-
# We use nccl backend
39-
dist.init_process_group("nccl")
40-
41-
# set a manual seed for reproducibility
42-
torch.manual_seed(1111)
43-
44-
4521
def check_tensor_parallel_device_number(world_size: int) -> None:
4622
if world_size % 2 != 0:
4723
raise ValueError(

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ def run(self):
454454
"torch_tensorrt.dynamo.conversion.impl.unary",
455455
"torch_tensorrt.dynamo.conversion.plugins",
456456
"torch_tensorrt.dynamo.debug",
457+
"torch_tensorrt.dynamo.distributed",
457458
"torch_tensorrt.dynamo.lowering",
458459
"torch_tensorrt.dynamo.lowering.passes",
459460
"torch_tensorrt.dynamo.partitioning",
Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
import random
34

45
import numpy as np
56
import tensorrt as trt
@@ -8,8 +9,35 @@
89
from torch.distributed._tensor.device_mesh import init_device_mesh
910

1011

11-
def set_environment_variables_pytest():
12+
def set_environment_variables_pytest_single_process():
13+
port = 29500 + random.randint(1, 1000)
1214
os.environ["WORLD_SIZE"] = str(1)
1315
os.environ["RANK"] = str(0)
1416
os.environ["MASTER_ADDR"] = "127.0.0.1"
15-
os.environ["MASTER_PORT"] = str(29500)
17+
os.environ["MASTER_PORT"] = str(port)
18+
19+
20+
def set_environment_variables_pytest_multi_process(
21+
rank: int = 0, world_size: int = 1
22+
) -> None:
23+
port = 29500 + random.randint(1, 1000)
24+
# these variables are set by mpirun -n 2
25+
local_rank = int(
26+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
27+
)
28+
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
29+
30+
# Set up environment variable to run with mpirun
31+
os.environ["RANK"] = str(local_rank)
32+
os.environ["WORLD_SIZE"] = str(world_size)
33+
os.environ["MASTER_ADDR"] = "127.0.0.1"
34+
os.environ["MASTER_PORT"] = str(port)
35+
36+
# Necessary to assign a device to each rank.
37+
torch.cuda.set_device(local_rank)
38+
39+
# We use nccl backend
40+
dist.init_process_group("nccl")
41+
42+
# set a manual seed for reproducibility
43+
torch.manual_seed(1111)

tests/py/dynamo/distributed/test_nccl_ops.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,28 @@
55
import torch.distributed as dist
66
import torch.nn as nn
77
from conversion.harness import DispatchTestCase
8-
from distributed_utils import set_environment_variables_pytest
8+
9+
# The distributed env initialization has to be before torchTRT import since it uses barrier
10+
from distributed_utils import (
11+
set_environment_variables_pytest,
12+
set_environment_variables_pytest_multi_process,
13+
set_environment_variables_pytest_single_process,
14+
)
915
from parameterized import parameterized
1016
from torch.testing._internal.common_utils import run_tests
11-
from torch_tensorrt.dynamo.utils import is_platform_supported_for_trtllm
17+
18+
if "OMPI_COMM_WORLD_SIZE" in os.environ:
19+
set_environment_variables_pytest_multi_process()
20+
else:
21+
set_environment_variables_pytest_single_process()
22+
23+
if not dist.is_initialized():
24+
dist.init_process_group(
25+
backend="nccl",
26+
init_method="env://",
27+
)
28+
29+
from torch_tensorrt.dynamo.distributed.utils import is_platform_supported_for_trtllm
1230

1331

1432
class DistributedGatherModel(nn.Module):
@@ -48,11 +66,9 @@ class TestNcclOpsConverter(DispatchTestCase):
4866
)
4967
@classmethod
5068
def setUpClass(cls):
51-
set_environment_variables_pytest()
52-
cls.world_size = 1
53-
if not dist.is_initialized():
54-
dist.init_process_group(backend="nccl")
55-
cls.group = dist.new_group(ranks=[0])
69+
cls.world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1))
70+
cls.rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0))
71+
cls.group = dist.new_group(ranks=list(range(cls.world_size)))
5672
cls.group_name = cls.group.group_name
5773

5874
@classmethod

0 commit comments

Comments
 (0)