22
33import pytest
44import torch
5- from vllm .config import VllmConfig
5+ from pytest_mock import MockerFixture
6+ from vllm .config import CacheConfig , VllmConfig
67
78from tests .ut .base import PytestBase
89from vllm_ascend .torchair .torchair_mtp_proposer import TorchairMtpProposer
@@ -15,25 +16,48 @@ class TestTorchairMtpProposer(PytestBase):
1516 def setup_torchair_mtp_proposer (self , mocker : MockerFixture ):
1617 vllm_config = MagicMock (spec = VllmConfig )
1718 vllm_config .device_config = MagicMock ()
18- vllm_config .device_config .device = torch .device ("npu:0 " )
19+ vllm_config .device_config .device = torch .device ("cpu " )
1920 vllm_config .speculative_config = MagicMock ()
2021 vllm_config .speculative_config .draft_model_config = MagicMock ()
2122 vllm_config .speculative_config .draft_model_config .dtype = torch .float16
23+ # vllm_config.speculative_config.draft_model_config.get_hidden_size = lambda: 4096
24+ vllm_config .speculative_config .method = "deepseek_mtp"
25+ vllm_config .speculative_config .num_speculative_tokens = 5
26+
27+ # vllm_config.model_config = MagicMock(
28+ # dtype=torch.float16,
29+ # max_model_len=2048,
30+ # uses_mrope=False,
31+ # hf_config=MagicMock(index_topk=2)
32+ # )
2233 vllm_config .load_config = MagicMock ()
23-
24- device = torch .device ("npu:0" )
34+ cache_config = CacheConfig (block_size = 16 )
35+ vllm_config .cache_config = cache_config
36+ vllm_config .scheduler_config = MagicMock (max_num_batched_tokens = 1024 ,
37+ max_num_seqs = 64 )
38+ # vllm_config.compilation_config = MagicMock()
39+ # vllm_config.compilation_config.cudagraph_mode = None
40+
41+ device = torch .device ("cpu" )
2542 runner = MagicMock ()
43+ runner .pcp_size = 1
44+ runner .dcp_size = 1
45+ runner .pcp_rank = 0
46+ runner .max_num_tokens = 1024
47+ runner .max_num_reqs = 10
48+ runner ._use_aclgraph .return_value = True
2649
27- mocker .patch ("vllm_ascend.torchair_mtp_proposer.__init__" ,
28- return_value = None )
50+ mocker .patch (
51+ "vllm_ascend.torchair.torchair_mtp_proposer.MtpProposer.__init__" ,
52+ return_value = None )
2953
3054 if vllm_version_is ("0.11.0" ):
3155 mock_set_default_dtype = mocker .patch (
3256 'vllm.model_executor.model_loader.utils.set_default_torch_dtype'
3357 )
3458 else :
3559 mock_set_default_dtype = mocker .patch (
36- 'vllm.utls .torch_utils.set_default_torch_dtype' )
60+ 'vllm.utils .torch_utils.set_default_torch_dtype' )
3761 mock_set_default_dtype .return_value .__enter__ .return_value = None
3862
3963 mock_model_loader = MagicMock ()
@@ -60,17 +84,15 @@ def setup_torchair_mtp_proposer(self, mocker: MockerFixture):
6084 proposer .vllm_config = vllm_config
6185 proposer .device = device
6286 proposer .runner = runner
87+ proposer .speculative_config = vllm_config .speculative_config
88+ proposer .draft_model_config = vllm_config .speculative_config .draft_model_config
89+ proposer .method = vllm_config .speculative_config .method
6390
6491 return proposer , mock_model_loader , mock_torchair_deepseek_mtp
6592
6693 def test_init (self , setup_torchair_mtp_proposer ):
6794 proposer , _ , _ , = setup_torchair_mtp_proposer
68-
69- assert isinstance (proposer , setup_torchair_mtp_proposer )
70- assert proposer .torchair_compiled_model is None
71- Mock .assert_called_once_with (proposer .__class__ .__bases__ [0 ],
72- proposer .vllm_config , proposer .device ,
73- proposer .runner )
95+ assert isinstance (proposer , TorchairMtpProposer )
7496
7597 def test_load_model (self , setup_torchair_mtp_proposer ,
7698 mocker : MockerFixture ):
0 commit comments