Skip to content

Commit f6bee66

Browse files
worker ut
Signed-off-by: CodeNine-CJ <[email protected]>
1 parent 160495a commit f6bee66

File tree

3 files changed

+75
-77
lines changed

3 files changed

+75
-77
lines changed
Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
import os
2-
import unittest
3-
import pytest
4-
from unittest.mock import Mock, MagicMock, patch
1+
from unittest.mock import MagicMock, Mock, patch
52

3+
import pytest
64
import torch
7-
from vllm_ascend.torchair_model_runner import NPUTorchairModelRunner
5+
from pytest_mock import MockerFixture
86
from vllm.config import VllmConfig
97

8+
from tests.ut.base import PytestBase
9+
from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner
10+
1011

1112
class TestNPUTorchairModelRunner(PytestBase):
1213

@@ -20,7 +21,6 @@ def setup_npu_torchair_model_runner(self, mocker: MockerFixture):
2021
device = torch.device("npu:0")
2122

2223
ascend_config = MagicMock()
23-
ascend_config = enable_shared_expert_dp = False
2424
ascend_config.max_num_batched_tokens = 2048
2525
ascend_config.max_model_len = 1024
2626
ascend_config.torchair_graph_config = MagicMock()
@@ -29,30 +29,37 @@ def setup_npu_torchair_model_runner(self, mocker: MockerFixture):
2929
ascend_config.torchair_graph_config.graph_batch_sizes = [1, 2, 4]
3030
ascend_config.torchair_graph_config.graph_batch_sizes_init = True
3131

32-
mocker.patch("vllm_ascend.worker.model_runner_v1.NPUModelRunner.__init__",
33-
return_value=None)
32+
mocker.patch(
33+
"vllm_ascend.worker.model_runner_v1.NPUModelRunner.__init__",
34+
return_value=None)
3435

35-
mocker.patch("vllm_ascend.get_ascend_config", return_value=ascend_config)
36+
mocker.patch("vllm_ascend.get_ascend_config",
37+
return_value=ascend_config)
3638
mocker.patch("vllm_ascend.torchair.utils.register_torchair_model")
3739
mocker.patch("vllm_ascend.torchair.utils.torchair_ops_patch")
38-
mocker.patch("vllm_ascend.torchair.utils.torchair_quant_method_register")
39-
mocker.patch("vllm_ascend.envs.VLLM_ASCEND_TRACE_RECOMPILES", return_value=False)
40+
mocker.patch(
41+
"vllm_ascend.torchair.utils.torchair_quant_method_register")
42+
mocker.patch("vllm_ascend.envs.VLLM_ASCEND_TRACE_RECOMPILES",
43+
return_value=False)
4044

4145
mock_attn_builder = Mock()
4246
mock_attn_backend = Mock()
4347
mock_attn_backend.get_builder_cls.return_value = lambda *args, **kwargs: mock_attn_builder
44-
with patch.object(NPUTorchairModelRunner, 'attn_backend', mock_attn_backend):
45-
with patch.object(NPUTorchairModelRunner, 'speculative_config', MagicMock()):
48+
with patch.object(NPUTorchairModelRunner, 'attn_backend',
49+
mock_attn_backend):
50+
with patch.object(NPUTorchairModelRunner, 'speculative_config',
51+
MagicMock()):
4652
NPUTorchairModelRunner.decode_token_per_req = 1
4753
NPUTorchairModelRunner.max_num_tokens = 10
4854

4955
runner = NPUTorchairModelRunner(vllm_config, device)
5056
runner.vllm_config = vllm_config
5157
runner.device = device
5258
runner.attn_backend = mock_attn_backend
53-
59+
5460
return runner
5561

56-
def test_init(self, mocker: MockerFixture, setup_npu_torchair_model_runner):
62+
def test_init(self, mocker: MockerFixture,
63+
setup_npu_torchair_model_runner):
5764
runner = setup_npu_torchair_model_runner
58-
assert isinstance(runner, NPUTorchairModelRunner)
65+
assert isinstance(runner, NPUTorchairModelRunner)
Lines changed: 43 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
import pytest
2-
from unittest.mock import Mock, MagicMock, patch
3-
from vllm_ascend import torchair_mtp_proposer
4-
from vllm.config import VllmConfig
1+
from unittest.mock import MagicMock, Mock
52

3+
import pytest
64
import torch
5+
from vllm.config import VllmConfig
76

8-
def vllm_version_is(version):
9-
return version == "0.11.0"
10-
11-
import sys
12-
sys.modules[__name__].vllm_version_is = vllm_version_is
7+
from tests.ut.base import PytestBase
8+
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
9+
from vllm_ascend.utils import vllm_version_is
1310

1411

1512
class TestTorchairMtpProposer(PytestBase):
@@ -29,90 +26,91 @@ def setup_torchair_mtp_proposer(self, mocker: pytest.MockerFixture):
2926

3027
mocker.patch("vllm_ascend.torchair_mtp_proposer.__init__",
3128
return_value=None)
32-
29+
3330
if vllm_version_is("0.11.0"):
3431
mock_set_default_dtype = mocker.patch(
3532
'vllm.model_executor.model_loader.utils.set_default_torch_dtype'
3633
)
3734
else:
3835
mock_set_default_dtype = mocker.patch(
39-
'vllm.utls.torch_utils.set_default_torch_dtype'
40-
)
36+
'vllm.utls.torch_utils.set_default_torch_dtype')
4137
mock_set_default_dtype.return_value.__enter__.return_value = None
4238

4339
mock_model_loader = MagicMock()
4440
mocker.patch("vllm.model_executor.model_loader.get_model_loader",
4541
return_value=mock_model_loader)
46-
mock_layers = {"target_attn_layer_1": Mock(), "draft_attn_layer_2": Mock()}
42+
mock_layers = {
43+
"target_attn_layer_1": Mock(),
44+
"draft_attn_layer_2": Mock()
45+
}
4746
mocker.patch("vllm.config.get_layers_from_vllm_config",
4847
return_value=mock_layers)
4948
mock_set_current = mocker.patch("vllm.config.set_current_vllm_config")
5049
mock_set_current.return_value.__enter__.return_value = None
5150
mock_torchair_deepseek_mtp = MagicMock()
5251
mock_torchair_deepseek_mtp.to.return_value = mock_torchair_deepseek_mtp
53-
mocker.patch("vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP",
54-
return_value=mock_torchair_deepseek_mtp)
55-
mocker.patch("vllm.model_executor.model_loader.utils.process_weights_after_loading")
52+
mocker.patch(
53+
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP",
54+
return_value=mock_torchair_deepseek_mtp)
55+
mocker.patch(
56+
"vllm.model_executor.model_loader.utils.process_weights_after_loading"
57+
)
5658

5759
proposer = TorchairMtpProposer(vllm_config, device, runner)
5860
proposer.vllm_config = vllm_config
5961
proposer.device = device
6062
proposer.runner = runner
6163

6264
return proposer, mock_model_loader, mock_torchair_deepseek_mtp
63-
65+
6466
def test_init(self, setup_torchair_mtp_proposer):
6567
proposer, _, _, = setup_torchair_mtp_proposer
6668

6769
assert isinstance(proposer, setup_torchair_mtp_proposer)
6870
assert proposer.torchair_compiled_model is None
69-
assert proposer.torchair_compiled_models = {}
70-
Mock.assert_called_once_with(
71-
proposer.__class__.__bases__[0],
72-
proposer.vllm_config,
73-
proposer.device,
74-
proposer.runner
75-
)
76-
77-
def test_load_model(self, setup_torchair_mtp_proposer, mocker: pytest.MockerFixture):
71+
Mock.assert_called_once_with(proposer.__class__.__bases__[0],
72+
proposer.vllm_config, proposer.device,
73+
proposer.runner)
74+
75+
def test_load_model(self, setup_torchair_mtp_proposer,
76+
mocker: pytest.MockerFixture):
7877
proposer, mock_model_loader, mock_torchair_deepseek_mtp = setup_torchair_mtp_proposer
7978
dummpy_model = Mock()
8079

8180
proposer.load_model(dummpy_model)
8281

83-
mocker.patch("vllm.model_executor.model_loader.get_model_loader").assert_called_once_with(
84-
proposer.vllm_config.load_config
85-
)
82+
mocker.patch("vllm.model_executor.model_loader.get_model_loader"
83+
).assert_called_once_with(
84+
proposer.vllm_config.load_config)
8685

87-
mock_get_layers = mocker.patch("vllm.config.get_layers_from_vllm_config")
88-
assert mock_get_layers.call_count = 2
86+
mock_get_layers = mocker.patch(
87+
"vllm.config.get_layers_from_vllm_config")
8988
mock_get_layers.assert_called_with(
9089
proposer.vllm_config,
91-
mocker.patch("vllm.model_executor.layers.attention_layer_base.AttentionLayerBase")
92-
)
90+
mocker.patch(
91+
"vllm.model_executor.layers.attention_layer_base.AttentionLayerBase"
92+
))
9393

94-
mocker.patch("vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP").assert_called_once_with(
95-
vllm_config=proposer.vllm_config
96-
)
94+
mocker.patch(
95+
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP"
96+
).assert_called_once_with(vllm_config=proposer.vllm_config)
9797
mock_torchair_deepseek_mtp.to.assert_called_once(
98-
proposer.vllm_config.device_config.device
99-
)
98+
proposer.vllm_config.device_config.device)
10099

101100
assert len(proposer.attn_layer_name) == 1
102101
mocker_layers_keys = mock_get_layers.return_value.keys()
103102
assert proposer.attn_layer_name[0] in mocker_layers_keys
104103

105104
mock_model_loader.get_all_weights.assert_called_once_with(
106105
proposer.vllm_config.speculative_config.draft_model_config,
107-
mock_torchair_deepseek_mtp
108-
)
106+
mock_torchair_deepseek_mtp)
109107
mock_torchair_deepseek_mtp.load_weights.assert_called_once_with(
110-
mock_model_loader.get_all_weights.return_value
111-
)
108+
mock_model_loader.get_all_weights.return_value)
112109

113-
mock_process_weights = mocker.patch("vllm.model_executor.model_loader.utils.process_weights_after_loading")
110+
mock_process_weights = mocker.patch(
111+
"vllm.model_executor.model_loader.utils.process_weights_after_loading"
112+
)
114113
mock_process_weights.assert_called_once_with(
115114
mock_torchair_deepseek_mtp,
116115
proposer.vllm_config.speculative_config.draft_model_config,
117-
proposer.vllm_config.device_config.device
118-
)
116+
proposer.vllm_config.device_config.device)

tests/ut/torchair/test_torchair_worker.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
import unittest
31
from unittest.mock import MagicMock, patch
42

53
import torch
@@ -69,24 +67,24 @@ def test_init_device(self, mock_platform, mock_init_dist_env):
6967

7068
mock_platform.empty_cache.assert_called_once()
7169
mock_platform.seed_everything.assert_called_once_with(42)
72-
mock_platform.mem_get_info.assert_called_once(
73-
)
74-
mock_init_dist_env.assert_called_once(
75-
)
70+
mock_platform.mem_get_info.assert_called_once()
71+
mock_init_dist_env.assert_called_once()
7672

7773
self.assertEqual(str(result), "npu:1")
7874
self.assertEqual(worker.init_npu_memory, 1000)
79-
75+
8076
@patch(
8177
"vllm_ascend.worker.worker_v1.NPUWorker._init_worker_distributed_environment"
8278
)
8379
@patch("vllm_ascend.worker.worker_v1.NPUPlatform")
84-
def test_init_device_torchair_worker(self, mock_platform, mock_init_dist_env):
80+
def test_init_device_torchair_worker(self, mock_platform,
81+
mock_init_dist_env):
8582
from vllm_ascend.torchair.torchair_worker import NPUTorchairWorker
8683

8784
mock_platform.mem_get_info.return_value = (1000, 2000)
8885

89-
with patch.object(NPUTorchairWorker, "__init__", lambda x, **kwargs: None):
86+
with patch.object(NPUTorchairWorker, "__init__",
87+
lambda x, **kwargs: None):
9088
worker = NPUTorchairWorker
9189
worker.local_rank = 1
9290
worker.model_config = MagicMock()
@@ -100,13 +98,8 @@ def test_init_device_torchair_worker(self, mock_platform, mock_init_dist_env):
10098

10199
mock_platform.empty_cache.assert_called_once()
102100
mock_platform.seed_everything.assert_called_once_with(42)
103-
mock_platform.mem_get_info.assert_called_once(
104-
)
105-
mock_init_dist_env.assert_called_once(
106-
)
101+
mock_platform.mem_get_info.assert_called_once()
102+
mock_init_dist_env.assert_called_once()
107103

108104
self.assertEqual(str(result), "npu:1")
109105
self.assertEqual(worker.init_npu_memory, 1000)
110-
111-
112-

0 commit comments

Comments
 (0)