Skip to content

Commit 160495a

Browse files
worker ut test
Signed-off-by: CodeNine-CJ <[email protected]>
1 parent 9cc4222 commit 160495a

File tree

3 files changed

+288
-0
lines changed

3 files changed

+288
-0
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os
2+
import unittest
3+
import pytest
4+
from unittest.mock import Mock, MagicMock, patch
5+
6+
import torch
7+
from vllm_ascend.torchair_model_runner import NPUTorchairModelRunner
8+
from vllm.config import VllmConfig
9+
10+
11+
class TestNPUTorchairModelRunner(PytestBase):
12+
13+
@pytest.fixture
14+
def setup_npu_torchair_model_runner(self, mocker: MockerFixture):
15+
vllm_config = MagicMock(spec=VllmConfig)
16+
vllm_config.model_config = MagicMock()
17+
vllm_config.model_config.hf_config = MagicMock()
18+
vllm_config.model_config.hf_config.index_topk = 2
19+
20+
device = torch.device("npu:0")
21+
22+
ascend_config = MagicMock()
23+
ascend_config = enable_shared_expert_dp = False
24+
ascend_config.max_num_batched_tokens = 2048
25+
ascend_config.max_model_len = 1024
26+
ascend_config.torchair_graph_config = MagicMock()
27+
ascend_config.torchair_graph_config.use_cached_graph = True
28+
ascend_config.torchair_graph_config.use_cached_kv_cache_bytes = False
29+
ascend_config.torchair_graph_config.graph_batch_sizes = [1, 2, 4]
30+
ascend_config.torchair_graph_config.graph_batch_sizes_init = True
31+
32+
mocker.patch("vllm_ascend.worker.model_runner_v1.NPUModelRunner.__init__",
33+
return_value=None)
34+
35+
mocker.patch("vllm_ascend.get_ascend_config", return_value=ascend_config)
36+
mocker.patch("vllm_ascend.torchair.utils.register_torchair_model")
37+
mocker.patch("vllm_ascend.torchair.utils.torchair_ops_patch")
38+
mocker.patch("vllm_ascend.torchair.utils.torchair_quant_method_register")
39+
mocker.patch("vllm_ascend.envs.VLLM_ASCEND_TRACE_RECOMPILES", return_value=False)
40+
41+
mock_attn_builder = Mock()
42+
mock_attn_backend = Mock()
43+
mock_attn_backend.get_builder_cls.return_value = lambda *args, **kwargs: mock_attn_builder
44+
with patch.object(NPUTorchairModelRunner, 'attn_backend', mock_attn_backend):
45+
with patch.object(NPUTorchairModelRunner, 'speculative_config', MagicMock()):
46+
NPUTorchairModelRunner.decode_token_per_req = 1
47+
NPUTorchairModelRunner.max_num_tokens = 10
48+
49+
runner = NPUTorchairModelRunner(vllm_config, device)
50+
runner.vllm_config = vllm_config
51+
runner.device = device
52+
runner.attn_backend = mock_attn_backend
53+
54+
return runner
55+
56+
def test_init(self, mocker: MockerFixture, setup_npu_torchair_model_runner):
57+
runner = setup_npu_torchair_model_runner
58+
assert isinstance(runner, NPUTorchairModelRunner)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import pytest
2+
from unittest.mock import Mock, MagicMock, patch
3+
from vllm_ascend import torchair_mtp_proposer
4+
from vllm.config import VllmConfig
5+
6+
import torch
7+
8+
def vllm_version_is(version):
9+
return version == "0.11.0"
10+
11+
import sys
12+
sys.modules[__name__].vllm_version_is = vllm_version_is
13+
14+
15+
class TestTorchairMtpProposer(PytestBase):
16+
17+
@pytest.fixture
18+
def setup_torchair_mtp_proposer(self, mocker: pytest.MockerFixture):
19+
vllm_config = MagicMock(spec=VllmConfig)
20+
vllm_config.device_config = MagicMock()
21+
vllm_config.device_config.device = torch.device("npu:0")
22+
vllm_config.speculative_config = MagicMock()
23+
vllm_config.speculative_config.draft_model_config = MagicMock()
24+
vllm_config.speculative_config.draft_model_config.dtype = torch.float16
25+
vllm_config.load_config = MagicMock()
26+
27+
device = torch.device("npu:0")
28+
runner = MagicMock()
29+
30+
mocker.patch("vllm_ascend.torchair_mtp_proposer.__init__",
31+
return_value=None)
32+
33+
if vllm_version_is("0.11.0"):
34+
mock_set_default_dtype = mocker.patch(
35+
'vllm.model_executor.model_loader.utils.set_default_torch_dtype'
36+
)
37+
else:
38+
mock_set_default_dtype = mocker.patch(
39+
'vllm.utls.torch_utils.set_default_torch_dtype'
40+
)
41+
mock_set_default_dtype.return_value.__enter__.return_value = None
42+
43+
mock_model_loader = MagicMock()
44+
mocker.patch("vllm.model_executor.model_loader.get_model_loader",
45+
return_value=mock_model_loader)
46+
mock_layers = {"target_attn_layer_1": Mock(), "draft_attn_layer_2": Mock()}
47+
mocker.patch("vllm.config.get_layers_from_vllm_config",
48+
return_value=mock_layers)
49+
mock_set_current = mocker.patch("vllm.config.set_current_vllm_config")
50+
mock_set_current.return_value.__enter__.return_value = None
51+
mock_torchair_deepseek_mtp = MagicMock()
52+
mock_torchair_deepseek_mtp.to.return_value = mock_torchair_deepseek_mtp
53+
mocker.patch("vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP",
54+
return_value=mock_torchair_deepseek_mtp)
55+
mocker.patch("vllm.model_executor.model_loader.utils.process_weights_after_loading")
56+
57+
proposer = TorchairMtpProposer(vllm_config, device, runner)
58+
proposer.vllm_config = vllm_config
59+
proposer.device = device
60+
proposer.runner = runner
61+
62+
return proposer, mock_model_loader, mock_torchair_deepseek_mtp
63+
64+
def test_init(self, setup_torchair_mtp_proposer):
65+
proposer, _, _, = setup_torchair_mtp_proposer
66+
67+
assert isinstance(proposer, setup_torchair_mtp_proposer)
68+
assert proposer.torchair_compiled_model is None
69+
assert proposer.torchair_compiled_models = {}
70+
Mock.assert_called_once_with(
71+
proposer.__class__.__bases__[0],
72+
proposer.vllm_config,
73+
proposer.device,
74+
proposer.runner
75+
)
76+
77+
def test_load_model(self, setup_torchair_mtp_proposer, mocker: pytest.MockerFixture):
78+
proposer, mock_model_loader, mock_torchair_deepseek_mtp = setup_torchair_mtp_proposer
79+
dummpy_model = Mock()
80+
81+
proposer.load_model(dummpy_model)
82+
83+
mocker.patch("vllm.model_executor.model_loader.get_model_loader").assert_called_once_with(
84+
proposer.vllm_config.load_config
85+
)
86+
87+
mock_get_layers = mocker.patch("vllm.config.get_layers_from_vllm_config")
88+
assert mock_get_layers.call_count = 2
89+
mock_get_layers.assert_called_with(
90+
proposer.vllm_config,
91+
mocker.patch("vllm.model_executor.layers.attention_layer_base.AttentionLayerBase")
92+
)
93+
94+
mocker.patch("vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP").assert_called_once_with(
95+
vllm_config=proposer.vllm_config
96+
)
97+
mock_torchair_deepseek_mtp.to.assert_called_once(
98+
proposer.vllm_config.device_config.device
99+
)
100+
101+
assert len(proposer.attn_layer_name) == 1
102+
mocker_layers_keys = mock_get_layers.return_value.keys()
103+
assert proposer.attn_layer_name[0] in mocker_layers_keys
104+
105+
mock_model_loader.get_all_weights.assert_called_once_with(
106+
proposer.vllm_config.speculative_config.draft_model_config,
107+
mock_torchair_deepseek_mtp
108+
)
109+
mock_torchair_deepseek_mtp.load_weights.assert_called_once_with(
110+
mock_model_loader.get_all_weights.return_value
111+
)
112+
113+
mock_process_weights = mocker.patch("vllm.model_executor.model_loader.utils.process_weights_after_loading")
114+
mock_process_weights.assert_called_once_with(
115+
mock_torchair_deepseek_mtp,
116+
proposer.vllm_config.speculative_config.draft_model_config,
117+
proposer.vllm_config.device_config.device
118+
)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import os
2+
import unittest
3+
from unittest.mock import MagicMock, patch
4+
5+
import torch
6+
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
7+
8+
from tests.ut.base import TestBase
9+
from vllm_ascend.utils import vllm_version_is
10+
11+
init_cache_hf_modules_path = "vllm.utils.init_cached_hf_modules" if vllm_version_is(
12+
"0.11.0") else "vllm.utils.import_utils.init_cached_hf_modules"
13+
14+
15+
class TestNPUTorchairWorker(TestBase):
16+
17+
def setUp(self):
18+
self.cache_config_mock = MagicMock(spec=CacheConfig)
19+
self.cache_config_mock.cache_type = "auto"
20+
21+
self.model_config_mock = MagicMock(spec=ModelConfig)
22+
self.model_config_mock.dtype = torch.float16
23+
self.model_config_mock.trust_remote_code = False
24+
25+
self.hf_config_mock = MagicMock()
26+
self.hf_config_mock.model_type = "test_model"
27+
if hasattr(self.hf_config_mock, 'index_topk'):
28+
delattr(self.hf_config_mock, 'index_topk')
29+
30+
self.model_config_mock.hf_config = self.hf_config_mock
31+
32+
self.parallel_config_mock = MagicMock(spec=ParallelConfig)
33+
34+
self.vllm_config_mock = MagicMock(spec=VllmConfig)
35+
self.vllm_config_mock.cache_config = self.cache_config_mock
36+
self.vllm_config_mock.model_config = self.model_config_mock
37+
self.vllm_config_mock.parallel_config = self.parallel_config_mock
38+
self.vllm_config_mock.additional_config = None
39+
self.vllm_config_mock.load_config = None
40+
self.vllm_config_mock.scheduler_config = None
41+
self.vllm_config_mock.device_config = None
42+
self.vllm_config_mock.compilation_config = None
43+
44+
self.local_rank = 0
45+
self.rank = 0
46+
self.distributed_init_method = "tcp://localhost:12345"
47+
self.is_driver_worker = False
48+
49+
@patch(
50+
"vllm_ascend.worker.worker_v1.NPUWorker._init_worker_distributed_environment"
51+
)
52+
@patch("vllm_ascend.worker.worker_v1.NPUPlatform")
53+
def test_init_device(self, mock_platform, mock_init_dist_env):
54+
from vllm_ascend.worker.worker_v1 import NPUWorker
55+
56+
mock_platform.mem_get_info.return_value = (1000, 2000)
57+
58+
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
59+
worker = NPUWorker
60+
worker.local_rank = 1
61+
worker.model_config = MagicMock()
62+
worker.model_config.seed = 42
63+
64+
result = worker._init_device()
65+
66+
mock_platform.set_device.assert_called_once()
67+
call_args = mock_platform.set_device.call_args[0][0]
68+
self.assertEqual(str(call_args), "npu:1")
69+
70+
mock_platform.empty_cache.assert_called_once()
71+
mock_platform.seed_everything.assert_called_once_with(42)
72+
mock_platform.mem_get_info.assert_called_once(
73+
)
74+
mock_init_dist_env.assert_called_once(
75+
)
76+
77+
self.assertEqual(str(result), "npu:1")
78+
self.assertEqual(worker.init_npu_memory, 1000)
79+
80+
@patch(
81+
"vllm_ascend.worker.worker_v1.NPUWorker._init_worker_distributed_environment"
82+
)
83+
@patch("vllm_ascend.worker.worker_v1.NPUPlatform")
84+
def test_init_device_torchair_worker(self, mock_platform, mock_init_dist_env):
85+
from vllm_ascend.torchair.torchair_worker import NPUTorchairWorker
86+
87+
mock_platform.mem_get_info.return_value = (1000, 2000)
88+
89+
with patch.object(NPUTorchairWorker, "__init__", lambda x, **kwargs: None):
90+
worker = NPUTorchairWorker
91+
worker.local_rank = 1
92+
worker.model_config = MagicMock()
93+
worker.model_config.seed = 42
94+
95+
result = worker._init_device()
96+
97+
mock_platform.set_device.assert_called_once()
98+
call_args = mock_platform.set_device.call_args[0][0]
99+
self.assertEqual(str(call_args), "npu:1")
100+
101+
mock_platform.empty_cache.assert_called_once()
102+
mock_platform.seed_everything.assert_called_once_with(42)
103+
mock_platform.mem_get_info.assert_called_once(
104+
)
105+
mock_init_dist_env.assert_called_once(
106+
)
107+
108+
self.assertEqual(str(result), "npu:1")
109+
self.assertEqual(worker.init_npu_memory, 1000)
110+
111+
112+

0 commit comments

Comments
 (0)