Skip to content

Commit 5beefc0

Browse files
committed
Distributed utils package, separating out env for single GPU and multiGPU
1 parent d09103a commit 5beefc0

File tree

8 files changed

+113
-45
lines changed

8 files changed

+113
-45
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
.. _tensor_parallel_initialize_dist:
3+
Tensor Parallel Initialize Distributed Environment
4+
==================================================
5+
6+
This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference.
7+
"""
8+
9+
import logging
10+
import os
11+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
12+
13+
import numpy as np
14+
import tensorrt as trt
15+
import torch
16+
import torch.distributed as dist
17+
from torch.distributed._tensor.device_mesh import init_device_mesh
18+
19+
20+
def initialize_distributed_env(rank=0, world_size=1, port=29500):
21+
local_rank = int(
22+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
23+
)
24+
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
25+
26+
# Set up environment variable to run with mpirun
27+
os.environ["RANK"] = str(local_rank)
28+
os.environ["WORLD_SIZE"] = str(world_size)
29+
os.environ["MASTER_ADDR"] = "127.0.0.1"
30+
os.environ["MASTER_PORT"] = str(port)
31+
32+
# Necessary to assign a device to each rank.
33+
torch.cuda.set_device(local_rank)
34+
35+
# We use nccl backend
36+
dist.init_process_group("nccl")
37+
38+
# set a manual seed for reproducibility
39+
torch.manual_seed(1111)
40+
41+
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
42+
rank = device_mesh.get_rank()
43+
assert rank == local_rank
44+
device_id = (
45+
rank % torch.cuda.device_count()
46+
) # Ensure each rank gets a unique device
47+
torch.cuda.set_device(device_id)
48+
49+
return device_mesh, world_size, rank
50+
51+
52+
def cleanup_distributed_env():
53+
"""Clean up distributed process group to prevent resource leaks."""
54+
if dist.is_initialized():
55+
dist.destroy_process_group()

examples/distributed_inference/tensor_parallel_rotary_embedding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import torch
1717
import torch_tensorrt
1818
from rotary_embedding import RotaryAttention, parallel_rotary_block
19+
from tensor_parallel_initialize_dist import (
20+
cleanup_distributed_env,
21+
initialize_distributed_env,
22+
)
1923
from torch.distributed import dist
2024
from torch_tensorrt.dynamo.distributed.utils import (
21-
cleanup_distributed_env,
2225
get_tensor_parallel_device_mesh,
23-
initialize_distributed_env,
2426
initialize_logger,
2527
)
2628

examples/distributed_inference/tensor_parallel_simple_example.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@
3737
parallelize_module,
3838
)
3939
from torch_tensorrt.dynamo.distributed.utils import (
40-
cleanup_distributed_env,
4140
get_tensor_parallel_device_mesh,
42-
initialize_distributed_env,
4341
initialize_logger,
4442
)
4543

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

py/torch_tensorrt/dynamo/distributed/utils.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from typing import Optional
1010

1111
import torch
12-
import torch.distributed as dist
1312
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
1413
from torch_tensorrt._version import __tensorrt_llm_version__
1514

@@ -18,30 +17,6 @@
1817
logger = logging.getLogger(__name__)
1918

2019

21-
def initialize_distributed_env(
22-
rank: int = 0, world_size: int = 1, port: int = 29500
23-
) -> None:
24-
local_rank = int(
25-
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
26-
)
27-
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
28-
29-
# Set up environment variable to run with mpirun
30-
os.environ["RANK"] = str(local_rank)
31-
os.environ["WORLD_SIZE"] = str(world_size)
32-
os.environ["MASTER_ADDR"] = "127.0.0.1"
33-
os.environ["MASTER_PORT"] = str(port)
34-
35-
# Necessary to assign a device to each rank.
36-
torch.cuda.set_device(local_rank)
37-
38-
# We use nccl backend
39-
dist.init_process_group("nccl")
40-
41-
# set a manual seed for reproducibility
42-
torch.manual_seed(1111)
43-
44-
4520
def check_tensor_parallel_device_number(world_size: int) -> None:
4621
if world_size % 2 != 0:
4722
raise ValueError(
@@ -76,12 +51,6 @@ def initialize_logger(rank: int, logger_file_name: str) -> logging.Logger:
7651
return logger
7752

7853

79-
def cleanup_distributed_env() -> None:
80-
"""Clean up distributed process group to prevent resource leaks."""
81-
if dist.is_initialized():
82-
dist.destroy_process_group()
83-
84-
8554
def is_platform_supported_for_trtllm() -> bool:
8655
"""
8756
Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend.
@@ -127,7 +96,6 @@ def is_platform_supported_for_trtllm() -> bool:
12796
logger.warning(f"Failed to detect CUDA version: {e}")
12897
return False
12998

130-
13199
return True
132100

133101

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ def run(self):
454454
"torch_tensorrt.dynamo.conversion.impl.unary",
455455
"torch_tensorrt.dynamo.conversion.plugins",
456456
"torch_tensorrt.dynamo.debug",
457+
"torch_tensorrt.dynamo.distributed",
457458
"torch_tensorrt.dynamo.lowering",
458459
"torch_tensorrt.dynamo.lowering.passes",
459460
"torch_tensorrt.dynamo.partitioning",
Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
import random
34

45
import numpy as np
56
import tensorrt as trt
@@ -8,8 +9,35 @@
89
from torch.distributed._tensor.device_mesh import init_device_mesh
910

1011

11-
def set_environment_variables_pytest():
12+
def set_environment_variables_pytest_single_process():
13+
port = 29500 + random.randint(1, 1000)
1214
os.environ["WORLD_SIZE"] = str(1)
1315
os.environ["RANK"] = str(0)
1416
os.environ["MASTER_ADDR"] = "127.0.0.1"
15-
os.environ["MASTER_PORT"] = str(29500)
17+
os.environ["MASTER_PORT"] = str(port)
18+
19+
20+
def set_environment_variables_pytest_multi_process(
21+
rank: int = 0, world_size: int = 1
22+
) -> None:
23+
port = 29500 + random.randint(1, 1000)
24+
# these variables are set by mpirun -n 2
25+
local_rank = int(
26+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
27+
)
28+
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
29+
30+
# Set up environment variable to run with mpirun
31+
os.environ["RANK"] = str(local_rank)
32+
os.environ["WORLD_SIZE"] = str(world_size)
33+
os.environ["MASTER_ADDR"] = "127.0.0.1"
34+
os.environ["MASTER_PORT"] = str(port)
35+
36+
# Necessary to assign a device to each rank.
37+
torch.cuda.set_device(local_rank)
38+
39+
# We use nccl backend
40+
dist.init_process_group("nccl")
41+
42+
# set a manual seed for reproducibility
43+
torch.manual_seed(1111)

tests/py/dynamo/distributed/test_nccl_ops.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,27 @@
55
import torch.distributed as dist
66
import torch.nn as nn
77
from conversion.harness import DispatchTestCase
8-
from distributed_utils import set_environment_variables_pytest
8+
9+
# The distributed env initialization has to be before torchTRT import since it uses barrier
10+
from distributed_utils import (
11+
set_environment_variables_pytest_multi_process,
12+
set_environment_variables_pytest_single_process,
13+
)
914
from parameterized import parameterized
1015
from torch.testing._internal.common_utils import run_tests
11-
from torch_tensorrt.dynamo.utils import is_platform_supported_for_trtllm
16+
17+
if "OMPI_COMM_WORLD_SIZE" in os.environ:
18+
set_environment_variables_pytest_multi_process()
19+
else:
20+
set_environment_variables_pytest_single_process()
21+
22+
if not dist.is_initialized():
23+
dist.init_process_group(
24+
backend="nccl",
25+
init_method="env://",
26+
)
27+
28+
from torch_tensorrt.dynamo.distributed.utils import is_platform_supported_for_trtllm
1229

1330

1431
class DistributedGatherModel(nn.Module):
@@ -48,11 +65,9 @@ class TestNcclOpsConverter(DispatchTestCase):
4865
)
4966
@classmethod
5067
def setUpClass(cls):
51-
set_environment_variables_pytest()
52-
cls.world_size = 1
53-
if not dist.is_initialized():
54-
dist.init_process_group(backend="nccl")
55-
cls.group = dist.new_group(ranks=[0])
68+
cls.world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1))
69+
cls.rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0))
70+
cls.group = dist.new_group(ranks=list(range(cls.world_size)))
5671
cls.group_name = cls.group.group_name
5772

5873
@classmethod

0 commit comments

Comments
 (0)