|
| 1 | +""" |
| 2 | +.. _tensor_parallel_initialize_dist: |
| 3 | +Tensor Parallel Initialize Distributed Environment |
| 4 | +================================================== |
| 5 | +
|
| 6 | +This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. |
| 7 | +""" |
| 8 | + |
| 9 | +import logging |
| 10 | +import os |
| 11 | +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union |
| 12 | + |
| 13 | +import numpy as np |
| 14 | +import tensorrt as trt |
| 15 | +import torch |
| 16 | +import torch.distributed as dist |
| 17 | +from torch.distributed._tensor.device_mesh import init_device_mesh |
| 18 | + |
| 19 | + |
| 20 | +def initialize_distributed_env(rank=0, world_size=1, port=29500): |
| 21 | + local_rank = int( |
| 22 | + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) |
| 23 | + ) |
| 24 | + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) |
| 25 | + |
| 26 | + # Set up environment variable to run with mpirun |
| 27 | + os.environ["RANK"] = str(local_rank) |
| 28 | + os.environ["WORLD_SIZE"] = str(world_size) |
| 29 | + os.environ["MASTER_ADDR"] = "127.0.0.1" |
| 30 | + os.environ["MASTER_PORT"] = str(port) |
| 31 | + |
| 32 | + # Necessary to assign a device to each rank. |
| 33 | + torch.cuda.set_device(local_rank) |
| 34 | + |
| 35 | + # We use nccl backend |
| 36 | + dist.init_process_group("nccl") |
| 37 | + |
| 38 | + # set a manual seed for reproducibility |
| 39 | + torch.manual_seed(1111) |
| 40 | + |
| 41 | + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) |
| 42 | + rank = device_mesh.get_rank() |
| 43 | + assert rank == local_rank |
| 44 | + device_id = ( |
| 45 | + rank % torch.cuda.device_count() |
| 46 | + ) # Ensure each rank gets a unique device |
| 47 | + torch.cuda.set_device(device_id) |
| 48 | + |
| 49 | + return device_mesh, world_size, rank |
| 50 | + |
| 51 | + |
| 52 | +def cleanup_distributed_env(): |
| 53 | + """Clean up distributed process group to prevent resource leaks.""" |
| 54 | + if dist.is_initialized(): |
| 55 | + dist.destroy_process_group() |
0 commit comments