Skip to content

Commit 63e7176

Browse files
authored
[Core][Refactor] move parallel_utils into vllm/distributed (vllm-project#3950)
[WIP][Core][Refactor] move vllm/model_executor/parallel_utils into vllm/distributed and vllm/device_communicators (vllm-project#3950)
1 parent 934d366 commit 63e7176

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+111
-141
lines changed

tests/conftest.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111

1212
from vllm import LLM, SamplingParams
1313
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
14-
from vllm.model_executor.parallel_utils.parallel_state import (
15-
destroy_model_parallel)
14+
from vllm.distributed import destroy_model_parallel
1615
from vllm.sequence import MultiModalData
1716
from vllm.transformers_utils.tokenizer import get_tokenizer
1817

tests/distributed/test_comm_ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import ray
99
import torch
1010

11-
from vllm.model_executor.parallel_utils.communication_op import (
12-
broadcast_tensor_dict, tensor_model_parallel_all_gather,
13-
tensor_model_parallel_all_reduce)
11+
from vllm.distributed import (broadcast_tensor_dict,
12+
tensor_model_parallel_all_gather,
13+
tensor_model_parallel_all_reduce)
1414
from vllm.test_utils import (init_test_distributed_environment,
1515
multi_process_tensor_parallel)
1616

tests/distributed/test_custom_all_reduce.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
import torch
77
import torch.distributed as dist
88

9-
from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar
10-
from vllm.model_executor.parallel_utils.communication_op import (
11-
tensor_model_parallel_all_reduce)
9+
from vllm.distributed import tensor_model_parallel_all_reduce
10+
from vllm.distributed.device_communicators import custom_all_reduce
1211
from vllm.test_utils import (init_test_distributed_environment,
1312
multi_process_tensor_parallel)
1413

@@ -26,10 +25,10 @@ def graph_allreduce(world_size, rank, distributed_init_port):
2625
init_test_distributed_environment(1, world_size, rank,
2726
distributed_init_port)
2827

29-
custom_ar.init_custom_ar()
28+
custom_all_reduce.init_custom_all_reduce()
3029
for sz in test_sizes:
3130
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
32-
with custom_ar.capture():
31+
with custom_all_reduce.capture():
3332
# use integers so result matches NCCL exactly
3433
inp1 = torch.randint(1,
3534
16, (sz, ),
@@ -62,8 +61,8 @@ def eager_allreduce(world_size, rank, distributed_init_port):
6261
distributed_init_port)
6362

6463
sz = 1024
65-
custom_ar.init_custom_ar()
66-
fa = custom_ar.get_handle()
64+
custom_all_reduce.init_custom_all_reduce()
65+
fa = custom_all_reduce.get_handle()
6766
inp = torch.ones(sz, dtype=torch.float32, device=device)
6867
out = fa.all_reduce_unreg(inp)
6968
assert torch.allclose(out, inp * world_size)

tests/distributed/test_pynccl.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import pytest
55
import torch
66

7-
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
8-
ncclGetUniqueId)
7+
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
8+
ncclGetUniqueId)
99

1010

1111
def distributed_run(fn, world_size):

tests/lora/conftest.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212

1313
import vllm
1414
from vllm.config import LoRAConfig
15+
from vllm.distributed import destroy_model_parallel, initialize_model_parallel
1516
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1617
MergedColumnParallelLinear,
1718
RowParallelLinear)
1819
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1920
from vllm.model_executor.layers.sampler import Sampler
2021
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
2122
from vllm.model_executor.model_loader import get_model
22-
from vllm.model_executor.parallel_utils.parallel_state import (
23-
destroy_model_parallel, initialize_model_parallel)
2423

2524

2625
def cleanup():

vllm/distributed/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .communication_op import *
2+
from .parallel_state import *
3+
from .utils import *

vllm/model_executor/parallel_utils/communication_op.py renamed to vllm/distributed/communication_op.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
import torch
55
from torch.distributed import ProcessGroup
66

7-
from vllm.model_executor.parallel_utils import pynccl_utils
8-
from vllm.model_executor.parallel_utils.custom_all_reduce import (
9-
custom_all_reduce)
10-
from vllm.model_executor.parallel_utils.parallel_state import (
11-
get_tensor_model_parallel_group, get_tensor_model_parallel_rank,
12-
get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce)
7+
from .parallel_state import (get_tensor_model_parallel_group,
8+
get_tensor_model_parallel_rank,
9+
get_tensor_model_parallel_world_size,
10+
is_pynccl_enabled_for_all_reduce)
1311

1412

1513
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
@@ -24,6 +22,10 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
2422
TLDR: always assume this function modifies its input, but use the return
2523
value as the output.
2624
"""
25+
from vllm.distributed.device_communicators import pynccl_utils
26+
from vllm.distributed.device_communicators.custom_all_reduce import (
27+
custom_all_reduce)
28+
2729
# Bypass the function if we are using only 1 GPU.
2830
if get_tensor_model_parallel_world_size() == 1:
2931
return input_

vllm/model_executor/parallel_utils/custom_all_reduce.py renamed to vllm/distributed/device_communicators/custom_all_reduce.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import torch.distributed as dist
66

77
from vllm.logger import init_logger
8-
from vllm.model_executor.parallel_utils.parallel_state import (
9-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
108

119
try:
1210
import pynvml
@@ -25,6 +23,9 @@
2523

2624

2725
def init_custom_ar() -> None:
26+
from vllm.distributed import (get_tensor_model_parallel_rank,
27+
get_tensor_model_parallel_world_size)
28+
2829
global _CA_HANDLE
2930
if _CA_HANDLE is not None:
3031
return

vllm/model_executor/parallel_utils/pynccl_utils.py renamed to vllm/distributed/device_communicators/pynccl_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
logger = init_logger(__name__)
1010

1111
try:
12-
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
13-
ncclGetVersion)
12+
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
13+
ncclGetVersion)
1414
except Exception as e:
1515
# in non-NVIDIA environments, we can't import the nccl module
1616
# e.g. when running on machines with AMD GPUs

vllm/model_executor/parallel_utils/parallel_state.py renamed to vllm/distributed/parallel_state.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
import torch
1010

11-
from vllm.model_executor.parallel_utils import pynccl_utils
12-
1311
# Tensor model parallel group that the current rank belongs to.
1412
_TENSOR_MODEL_PARALLEL_GROUP = None
1513
# Pipeline model parallel group that the current rank belongs to.
@@ -266,6 +264,7 @@ def destroy_model_parallel():
266264
_PIPELINE_MODEL_PARALLEL_GROUP = None
267265
global _PIPELINE_GLOBAL_RANKS
268266
_PIPELINE_GLOBAL_RANKS = None
267+
from vllm.distributed.device_communicators import pynccl_utils
269268

270269
# Destroy the pynccl states if any.
271270
pynccl_utils.destroy_process_group()
@@ -279,6 +278,7 @@ def destroy_model_parallel():
279278

280279
@contextlib.contextmanager
281280
def with_pynccl_for_all_reduce():
281+
from vllm.distributed.device_communicators import pynccl_utils
282282
"""use pynccl instead of torch.distributed for all reduce"""
283283
tp_size = get_tensor_model_parallel_world_size()
284284
if tp_size == 1:

vllm/lora/layers.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
from transformers import PretrainedConfig
1111

1212
from vllm.config import LoRAConfig
13+
from vllm.distributed import (get_tensor_model_parallel_rank,
14+
get_tensor_model_parallel_world_size,
15+
split_tensor_along_last_dim,
16+
tensor_model_parallel_all_gather,
17+
tensor_model_parallel_all_reduce,
18+
tensor_model_parallel_gather)
1319
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
1420
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1521
MergedColumnParallelLinear,
@@ -18,13 +24,6 @@
1824
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1925
from vllm.model_executor.layers.vocab_parallel_embedding import (
2026
ParallelLMHead, VocabParallelEmbedding)
21-
from vllm.model_executor.parallel_utils.communication_op import (
22-
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce,
23-
tensor_model_parallel_gather)
24-
from vllm.model_executor.parallel_utils.parallel_state import (
25-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
26-
from vllm.model_executor.parallel_utils.utils import (
27-
split_tensor_along_last_dim)
2827

2928
if TYPE_CHECKING:
3029
pass

vllm/model_executor/layers/activation.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
import torch.nn.functional as F
88

99
from vllm._C import ops
10+
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
11+
get_tensor_model_parallel_world_size)
1012
from vllm.model_executor.layers.quantization import QuantizationConfig
11-
from vllm.model_executor.parallel_utils.parallel_state import (
12-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
13-
from vllm.model_executor.parallel_utils.utils import divide
1413
from vllm.model_executor.utils import set_weight_attrs
1514

1615

vllm/model_executor/layers/linear.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
import torch.nn.functional as F
66
from torch.nn.parameter import Parameter
77

8+
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
9+
get_tensor_model_parallel_world_size,
10+
split_tensor_along_last_dim,
11+
tensor_model_parallel_all_gather,
12+
tensor_model_parallel_all_reduce)
813
from vllm.logger import init_logger
9-
from vllm.model_executor.parallel_utils.communication_op import (
10-
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
11-
from vllm.model_executor.parallel_utils.parallel_state import (
12-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
13-
from vllm.model_executor.parallel_utils.utils import (
14-
divide, split_tensor_along_last_dim)
1514
from vllm.model_executor.utils import set_weight_attrs
1615

1716
logger = init_logger(__name__)

vllm/model_executor/layers/logits_processor.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import torch
55
import torch.nn as nn
66

7-
from vllm.model_executor.parallel_utils.communication_op import (
8-
tensor_model_parallel_gather)
7+
from vllm.distributed import tensor_model_parallel_gather
98
from vllm.model_executor.sampling_metadata import SamplingMetadata
109

1110

vllm/model_executor/layers/vocab_parallel_embedding.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
import torch.nn.functional as F
55
from torch.nn.parameter import Parameter
66

7-
from vllm.model_executor.parallel_utils.communication_op import (
8-
tensor_model_parallel_all_reduce)
9-
from vllm.model_executor.parallel_utils.parallel_state import (
10-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
11-
from vllm.model_executor.parallel_utils.utils import divide
7+
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
8+
get_tensor_model_parallel_world_size,
9+
tensor_model_parallel_all_reduce)
1210
from vllm.model_executor.utils import set_weight_attrs
1311

1412
DEFAULT_VOCAB_PADDING_SIZE = 64

vllm/model_executor/models/baichuan.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
from vllm.attention import Attention, AttentionMetadata
2929
from vllm.config import LoRAConfig
30+
from vllm.distributed import (get_tensor_model_parallel_rank,
31+
get_tensor_model_parallel_world_size)
3032
from vllm.model_executor.layers.activation import SiluAndMul
3133
from vllm.model_executor.layers.layernorm import RMSNorm
3234
from vllm.model_executor.layers.linear import (LinearMethodBase,
@@ -38,8 +40,6 @@
3840
from vllm.model_executor.layers.sampler import Sampler
3941
from vllm.model_executor.layers.vocab_parallel_embedding import (
4042
ParallelLMHead, VocabParallelEmbedding)
41-
from vllm.model_executor.parallel_utils.parallel_state import (
42-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
4343
from vllm.model_executor.sampling_metadata import SamplingMetadata
4444
from vllm.model_executor.weight_utils import (default_weight_loader,
4545
hf_model_weights_iterator)

vllm/model_executor/models/bloom.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from transformers import BloomConfig
2525

2626
from vllm.attention import Attention, AttentionMetadata
27+
from vllm.distributed import (get_tensor_model_parallel_rank,
28+
get_tensor_model_parallel_world_size)
2729
from vllm.model_executor.layers.activation import get_act_fn
2830
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
2931
LinearMethodBase,
@@ -33,8 +35,6 @@
3335
from vllm.model_executor.layers.sampler import Sampler
3436
from vllm.model_executor.layers.vocab_parallel_embedding import (
3537
VocabParallelEmbedding)
36-
from vllm.model_executor.parallel_utils.parallel_state import (
37-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
3838
from vllm.model_executor.sampling_metadata import SamplingMetadata
3939
from vllm.model_executor.weight_utils import (default_weight_loader,
4040
hf_model_weights_iterator)

vllm/model_executor/models/chatglm.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from vllm.attention import Attention, AttentionMetadata
1212
from vllm.config import LoRAConfig
13+
from vllm.distributed import get_tensor_model_parallel_world_size
1314
from vllm.model_executor.layers.activation import SiluAndMul
1415
from vllm.model_executor.layers.layernorm import RMSNorm
1516
from vllm.model_executor.layers.linear import (LinearMethodBase,
@@ -21,8 +22,6 @@
2122
from vllm.model_executor.layers.sampler import Sampler
2223
from vllm.model_executor.layers.vocab_parallel_embedding import (
2324
ParallelLMHead, VocabParallelEmbedding)
24-
from vllm.model_executor.parallel_utils.parallel_state import (
25-
get_tensor_model_parallel_world_size)
2625
from vllm.model_executor.sampling_metadata import SamplingMetadata
2726
from vllm.model_executor.weight_utils import (default_weight_loader,
2827
hf_model_weights_iterator)

vllm/model_executor/models/commandr.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from transformers import CohereConfig
3030

3131
from vllm.attention import Attention, AttentionMetadata
32+
from vllm.distributed import (get_tensor_model_parallel_rank,
33+
get_tensor_model_parallel_world_size)
3234
from vllm.model_executor.layers.activation import SiluAndMul
3335
from vllm.model_executor.layers.linear import (LinearMethodBase,
3436
MergedColumnParallelLinear,
@@ -39,8 +41,6 @@
3941
from vllm.model_executor.layers.sampler import Sampler
4042
from vllm.model_executor.layers.vocab_parallel_embedding import (
4143
VocabParallelEmbedding)
42-
from vllm.model_executor.parallel_utils.parallel_state import (
43-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
4444
from vllm.model_executor.sampling_metadata import SamplingMetadata
4545
from vllm.model_executor.utils import set_weight_attrs
4646
from vllm.model_executor.weight_utils import (default_weight_loader,

vllm/model_executor/models/dbrx.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import torch.nn as nn
66

77
from vllm.attention import Attention, AttentionMetadata
8+
from vllm.distributed import (get_tensor_model_parallel_rank,
9+
get_tensor_model_parallel_world_size,
10+
tensor_model_parallel_all_reduce)
811
from vllm.model_executor.layers.fused_moe import fused_moe
912
from vllm.model_executor.layers.linear import (LinearMethodBase,
1013
QKVParallelLinear,
@@ -15,10 +18,6 @@
1518
from vllm.model_executor.layers.sampler import Sampler
1619
from vllm.model_executor.layers.vocab_parallel_embedding import (
1720
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
18-
from vllm.model_executor.parallel_utils.communication_op import (
19-
tensor_model_parallel_all_reduce)
20-
from vllm.model_executor.parallel_utils.parallel_state import (
21-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
2221
from vllm.model_executor.sampling_metadata import SamplingMetadata
2322
from vllm.model_executor.utils import set_weight_attrs
2423
from vllm.model_executor.weight_utils import (default_weight_loader,

vllm/model_executor/models/deepseek.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
from transformers import PretrainedConfig
2929

3030
from vllm.attention import Attention, AttentionMetadata
31+
from vllm.distributed import (get_tensor_model_parallel_rank,
32+
get_tensor_model_parallel_world_size,
33+
tensor_model_parallel_all_reduce)
3134
from vllm.model_executor.layers.activation import SiluAndMul
3235
from vllm.model_executor.layers.fused_moe import fused_moe
3336
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -41,10 +44,6 @@
4144
from vllm.model_executor.layers.sampler import Sampler
4245
from vllm.model_executor.layers.vocab_parallel_embedding import (
4346
ParallelLMHead, VocabParallelEmbedding)
44-
from vllm.model_executor.parallel_utils.communication_op import (
45-
tensor_model_parallel_all_reduce)
46-
from vllm.model_executor.parallel_utils.parallel_state import (
47-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
4847
from vllm.model_executor.sampling_metadata import SamplingMetadata
4948
from vllm.model_executor.weight_utils import (default_weight_loader,
5049
hf_model_weights_iterator)

vllm/model_executor/models/falcon.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
from transformers import FalconConfig as HF_FalconConfig
2828

2929
from vllm.attention import Attention, AttentionMetadata
30+
from vllm.distributed import (get_tensor_model_parallel_rank,
31+
get_tensor_model_parallel_world_size,
32+
tensor_model_parallel_all_reduce)
3033
from vllm.model_executor.layers.activation import get_act_fn
3134
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
3235
LinearMethodBase,
@@ -37,10 +40,6 @@
3740
from vllm.model_executor.layers.sampler import Sampler
3841
from vllm.model_executor.layers.vocab_parallel_embedding import (
3942
VocabParallelEmbedding)
40-
from vllm.model_executor.parallel_utils.communication_op import (
41-
tensor_model_parallel_all_reduce)
42-
from vllm.model_executor.parallel_utils.parallel_state import (
43-
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
4443
from vllm.model_executor.sampling_metadata import SamplingMetadata
4544
from vllm.model_executor.weight_utils import (default_weight_loader,
4645
hf_model_weights_iterator)

0 commit comments

Comments
 (0)