Skip to content

Commit ab75b53

Browse files
worker ut
Signed-off-by: CodeNine-CJ <[email protected]>
1 parent 160495a commit ab75b53

File tree

3 files changed

+76
-71
lines changed

3 files changed

+76
-71
lines changed
Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
import os
2-
import unittest
3-
import pytest
4-
from unittest.mock import Mock, MagicMock, patch
1+
from unittest.mock import MagicMock, Mock, patch
52

3+
import pytest
64
import torch
7-
from vllm_ascend.torchair_model_runner import NPUTorchairModelRunner
5+
from pytest_mock import MockerFixture
86
from vllm.config import VllmConfig
97

8+
from tests.ut.base import PytestBase
9+
from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner
10+
1011

1112
class TestNPUTorchairModelRunner(PytestBase):
1213

@@ -20,7 +21,6 @@ def setup_npu_torchair_model_runner(self, mocker: MockerFixture):
2021
device = torch.device("npu:0")
2122

2223
ascend_config = MagicMock()
23-
ascend_config = enable_shared_expert_dp = False
2424
ascend_config.max_num_batched_tokens = 2048
2525
ascend_config.max_model_len = 1024
2626
ascend_config.torchair_graph_config = MagicMock()
@@ -29,30 +29,37 @@ def setup_npu_torchair_model_runner(self, mocker: MockerFixture):
2929
ascend_config.torchair_graph_config.graph_batch_sizes = [1, 2, 4]
3030
ascend_config.torchair_graph_config.graph_batch_sizes_init = True
3131

32-
mocker.patch("vllm_ascend.worker.model_runner_v1.NPUModelRunner.__init__",
33-
return_value=None)
32+
mocker.patch(
33+
"vllm_ascend.worker.model_runner_v1.NPUModelRunner.__init__",
34+
return_value=None)
3435

35-
mocker.patch("vllm_ascend.get_ascend_config", return_value=ascend_config)
36+
mocker.patch("vllm_ascend.get_ascend_config",
37+
return_value=ascend_config)
3638
mocker.patch("vllm_ascend.torchair.utils.register_torchair_model")
3739
mocker.patch("vllm_ascend.torchair.utils.torchair_ops_patch")
38-
mocker.patch("vllm_ascend.torchair.utils.torchair_quant_method_register")
39-
mocker.patch("vllm_ascend.envs.VLLM_ASCEND_TRACE_RECOMPILES", return_value=False)
40+
mocker.patch(
41+
"vllm_ascend.torchair.utils.torchair_quant_method_register")
42+
mocker.patch("vllm_ascend.envs.VLLM_ASCEND_TRACE_RECOMPILES",
43+
return_value=False)
4044

4145
mock_attn_builder = Mock()
4246
mock_attn_backend = Mock()
4347
mock_attn_backend.get_builder_cls.return_value = lambda *args, **kwargs: mock_attn_builder
44-
with patch.object(NPUTorchairModelRunner, 'attn_backend', mock_attn_backend):
45-
with patch.object(NPUTorchairModelRunner, 'speculative_config', MagicMock()):
48+
with patch.object(NPUTorchairModelRunner, 'attn_backend',
49+
mock_attn_backend):
50+
with patch.object(NPUTorchairModelRunner, 'speculative_config',
51+
MagicMock()):
4652
NPUTorchairModelRunner.decode_token_per_req = 1
4753
NPUTorchairModelRunner.max_num_tokens = 10
4854

4955
runner = NPUTorchairModelRunner(vllm_config, device)
5056
runner.vllm_config = vllm_config
5157
runner.device = device
5258
runner.attn_backend = mock_attn_backend
53-
59+
5460
return runner
5561

56-
def test_init(self, mocker: MockerFixture, setup_npu_torchair_model_runner):
62+
def test_init(self, mocker: MockerFixture,
63+
setup_npu_torchair_model_runner):
5764
runner = setup_npu_torchair_model_runner
58-
assert isinstance(runner, NPUTorchairModelRunner)
65+
assert isinstance(runner, NPUTorchairModelRunner)
Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1+
from unittest.mock import MagicMock, Mock
2+
13
import pytest
2-
from unittest.mock import Mock, MagicMock, patch
3-
from vllm_ascend import torchair_mtp_proposer
4+
import torch
45
from vllm.config import VllmConfig
6+
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
7+
from tests.ut.base import PytestBase
58

6-
import torch
79

810
def vllm_version_is(version):
911
return version == "0.11.0"
1012

13+
1114
import sys
15+
1216
sys.modules[__name__].vllm_version_is = vllm_version_is
1317

1418

@@ -29,90 +33,91 @@ def setup_torchair_mtp_proposer(self, mocker: pytest.MockerFixture):
2933

3034
mocker.patch("vllm_ascend.torchair_mtp_proposer.__init__",
3135
return_value=None)
32-
36+
3337
if vllm_version_is("0.11.0"):
3438
mock_set_default_dtype = mocker.patch(
3539
'vllm.model_executor.model_loader.utils.set_default_torch_dtype'
3640
)
3741
else:
3842
mock_set_default_dtype = mocker.patch(
39-
'vllm.utls.torch_utils.set_default_torch_dtype'
40-
)
43+
'vllm.utls.torch_utils.set_default_torch_dtype')
4144
mock_set_default_dtype.return_value.__enter__.return_value = None
4245

4346
mock_model_loader = MagicMock()
4447
mocker.patch("vllm.model_executor.model_loader.get_model_loader",
4548
return_value=mock_model_loader)
46-
mock_layers = {"target_attn_layer_1": Mock(), "draft_attn_layer_2": Mock()}
49+
mock_layers = {
50+
"target_attn_layer_1": Mock(),
51+
"draft_attn_layer_2": Mock()
52+
}
4753
mocker.patch("vllm.config.get_layers_from_vllm_config",
4854
return_value=mock_layers)
4955
mock_set_current = mocker.patch("vllm.config.set_current_vllm_config")
5056
mock_set_current.return_value.__enter__.return_value = None
5157
mock_torchair_deepseek_mtp = MagicMock()
5258
mock_torchair_deepseek_mtp.to.return_value = mock_torchair_deepseek_mtp
53-
mocker.patch("vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP",
54-
return_value=mock_torchair_deepseek_mtp)
55-
mocker.patch("vllm.model_executor.model_loader.utils.process_weights_after_loading")
59+
mocker.patch(
60+
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP",
61+
return_value=mock_torchair_deepseek_mtp)
62+
mocker.patch(
63+
"vllm.model_executor.model_loader.utils.process_weights_after_loading"
64+
)
5665

5766
proposer = TorchairMtpProposer(vllm_config, device, runner)
5867
proposer.vllm_config = vllm_config
5968
proposer.device = device
6069
proposer.runner = runner
6170

6271
return proposer, mock_model_loader, mock_torchair_deepseek_mtp
63-
72+
6473
def test_init(self, setup_torchair_mtp_proposer):
6574
proposer, _, _, = setup_torchair_mtp_proposer
6675

6776
assert isinstance(proposer, setup_torchair_mtp_proposer)
6877
assert proposer.torchair_compiled_model is None
69-
assert proposer.torchair_compiled_models = {}
70-
Mock.assert_called_once_with(
71-
proposer.__class__.__bases__[0],
72-
proposer.vllm_config,
73-
proposer.device,
74-
proposer.runner
75-
)
76-
77-
def test_load_model(self, setup_torchair_mtp_proposer, mocker: pytest.MockerFixture):
78+
Mock.assert_called_once_with(proposer.__class__.__bases__[0],
79+
proposer.vllm_config, proposer.device,
80+
proposer.runner)
81+
82+
def test_load_model(self, setup_torchair_mtp_proposer,
83+
mocker: pytest.MockerFixture):
7884
proposer, mock_model_loader, mock_torchair_deepseek_mtp = setup_torchair_mtp_proposer
7985
dummpy_model = Mock()
8086

8187
proposer.load_model(dummpy_model)
8288

83-
mocker.patch("vllm.model_executor.model_loader.get_model_loader").assert_called_once_with(
84-
proposer.vllm_config.load_config
85-
)
89+
mocker.patch("vllm.model_executor.model_loader.get_model_loader"
90+
).assert_called_once_with(
91+
proposer.vllm_config.load_config)
8692

87-
mock_get_layers = mocker.patch("vllm.config.get_layers_from_vllm_config")
88-
assert mock_get_layers.call_count = 2
93+
mock_get_layers = mocker.patch(
94+
"vllm.config.get_layers_from_vllm_config")
8995
mock_get_layers.assert_called_with(
9096
proposer.vllm_config,
91-
mocker.patch("vllm.model_executor.layers.attention_layer_base.AttentionLayerBase")
92-
)
97+
mocker.patch(
98+
"vllm.model_executor.layers.attention_layer_base.AttentionLayerBase"
99+
))
93100

94-
mocker.patch("vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP").assert_called_once_with(
95-
vllm_config=proposer.vllm_config
96-
)
101+
mocker.patch(
102+
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP"
103+
).assert_called_once_with(vllm_config=proposer.vllm_config)
97104
mock_torchair_deepseek_mtp.to.assert_called_once(
98-
proposer.vllm_config.device_config.device
99-
)
105+
proposer.vllm_config.device_config.device)
100106

101107
assert len(proposer.attn_layer_name) == 1
102108
mocker_layers_keys = mock_get_layers.return_value.keys()
103109
assert proposer.attn_layer_name[0] in mocker_layers_keys
104110

105111
mock_model_loader.get_all_weights.assert_called_once_with(
106112
proposer.vllm_config.speculative_config.draft_model_config,
107-
mock_torchair_deepseek_mtp
108-
)
113+
mock_torchair_deepseek_mtp)
109114
mock_torchair_deepseek_mtp.load_weights.assert_called_once_with(
110-
mock_model_loader.get_all_weights.return_value
111-
)
115+
mock_model_loader.get_all_weights.return_value)
112116

113-
mock_process_weights = mocker.patch("vllm.model_executor.model_loader.utils.process_weights_after_loading")
117+
mock_process_weights = mocker.patch(
118+
"vllm.model_executor.model_loader.utils.process_weights_after_loading"
119+
)
114120
mock_process_weights.assert_called_once_with(
115121
mock_torchair_deepseek_mtp,
116122
proposer.vllm_config.speculative_config.draft_model_config,
117-
proposer.vllm_config.device_config.device
118-
)
123+
proposer.vllm_config.device_config.device)

tests/ut/torchair/test_torchair_worker.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
import unittest
31
from unittest.mock import MagicMock, patch
42

53
import torch
@@ -69,24 +67,24 @@ def test_init_device(self, mock_platform, mock_init_dist_env):
6967

7068
mock_platform.empty_cache.assert_called_once()
7169
mock_platform.seed_everything.assert_called_once_with(42)
72-
mock_platform.mem_get_info.assert_called_once(
73-
)
74-
mock_init_dist_env.assert_called_once(
75-
)
70+
mock_platform.mem_get_info.assert_called_once()
71+
mock_init_dist_env.assert_called_once()
7672

7773
self.assertEqual(str(result), "npu:1")
7874
self.assertEqual(worker.init_npu_memory, 1000)
79-
75+
8076
@patch(
8177
"vllm_ascend.worker.worker_v1.NPUWorker._init_worker_distributed_environment"
8278
)
8379
@patch("vllm_ascend.worker.worker_v1.NPUPlatform")
84-
def test_init_device_torchair_worker(self, mock_platform, mock_init_dist_env):
80+
def test_init_device_torchair_worker(self, mock_platform,
81+
mock_init_dist_env):
8582
from vllm_ascend.torchair.torchair_worker import NPUTorchairWorker
8683

8784
mock_platform.mem_get_info.return_value = (1000, 2000)
8885

89-
with patch.object(NPUTorchairWorker, "__init__", lambda x, **kwargs: None):
86+
with patch.object(NPUTorchairWorker, "__init__",
87+
lambda x, **kwargs: None):
9088
worker = NPUTorchairWorker
9189
worker.local_rank = 1
9290
worker.model_config = MagicMock()
@@ -100,13 +98,8 @@ def test_init_device_torchair_worker(self, mock_platform, mock_init_dist_env):
10098

10199
mock_platform.empty_cache.assert_called_once()
102100
mock_platform.seed_everything.assert_called_once_with(42)
103-
mock_platform.mem_get_info.assert_called_once(
104-
)
105-
mock_init_dist_env.assert_called_once(
106-
)
101+
mock_platform.mem_get_info.assert_called_once()
102+
mock_init_dist_env.assert_called_once()
107103

108104
self.assertEqual(str(result), "npu:1")
109105
self.assertEqual(worker.init_npu_memory, 1000)
110-
111-
112-

0 commit comments

Comments
 (0)