1- import pytest
2- from unittest .mock import Mock , MagicMock , patch
3- from vllm_ascend import torchair_mtp_proposer
4- from vllm .config import VllmConfig
1+ from unittest .mock import MagicMock , Mock
52
3+ import pytest
64import torch
5+ from vllm .config import VllmConfig
76
8- def vllm_version_is (version ):
9- return version == "0.11.0"
10-
11- import sys
12- sys .modules [__name__ ].vllm_version_is = vllm_version_is
7+ from tests .ut .base import PytestBase
8+ from vllm_ascend .torchair .torchair_mtp_proposer import TorchairMtpProposer
9+ from vllm_ascend .utils import vllm_version_is
1310
1411
1512class TestTorchairMtpProposer (PytestBase ):
@@ -29,90 +26,91 @@ def setup_torchair_mtp_proposer(self, mocker: pytest.MockerFixture):
2926
3027 mocker .patch ("vllm_ascend.torchair_mtp_proposer.__init__" ,
3128 return_value = None )
32-
29+
3330 if vllm_version_is ("0.11.0" ):
3431 mock_set_default_dtype = mocker .patch (
3532 'vllm.model_executor.model_loader.utils.set_default_torch_dtype'
3633 )
3734 else :
3835 mock_set_default_dtype = mocker .patch (
39- 'vllm.utls.torch_utils.set_default_torch_dtype'
40- )
36+ 'vllm.utls.torch_utils.set_default_torch_dtype' )
4137 mock_set_default_dtype .return_value .__enter__ .return_value = None
4238
4339 mock_model_loader = MagicMock ()
4440 mocker .patch ("vllm.model_executor.model_loader.get_model_loader" ,
4541 return_value = mock_model_loader )
46- mock_layers = {"target_attn_layer_1" : Mock (), "draft_attn_layer_2" : Mock ()}
42+ mock_layers = {
43+ "target_attn_layer_1" : Mock (),
44+ "draft_attn_layer_2" : Mock ()
45+ }
4746 mocker .patch ("vllm.config.get_layers_from_vllm_config" ,
4847 return_value = mock_layers )
4948 mock_set_current = mocker .patch ("vllm.config.set_current_vllm_config" )
5049 mock_set_current .return_value .__enter__ .return_value = None
5150 mock_torchair_deepseek_mtp = MagicMock ()
5251 mock_torchair_deepseek_mtp .to .return_value = mock_torchair_deepseek_mtp
53- mocker .patch ("vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP" ,
54- return_value = mock_torchair_deepseek_mtp )
55- mocker .patch ("vllm.model_executor.model_loader.utils.process_weights_after_loading" )
52+ mocker .patch (
53+ "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP" ,
54+ return_value = mock_torchair_deepseek_mtp )
55+ mocker .patch (
56+ "vllm.model_executor.model_loader.utils.process_weights_after_loading"
57+ )
5658
5759 proposer = TorchairMtpProposer (vllm_config , device , runner )
5860 proposer .vllm_config = vllm_config
5961 proposer .device = device
6062 proposer .runner = runner
6163
6264 return proposer , mock_model_loader , mock_torchair_deepseek_mtp
63-
65+
6466 def test_init (self , setup_torchair_mtp_proposer ):
6567 proposer , _ , _ , = setup_torchair_mtp_proposer
6668
6769 assert isinstance (proposer , setup_torchair_mtp_proposer )
6870 assert proposer .torchair_compiled_model is None
69- assert proposer .torchair_compiled_models = {}
70- Mock .assert_called_once_with (
71- proposer .__class__ .__bases__ [0 ],
72- proposer .vllm_config ,
73- proposer .device ,
74- proposer .runner
75- )
76-
77- def test_load_model (self , setup_torchair_mtp_proposer , mocker : pytest .MockerFixture ):
71+ Mock .assert_called_once_with (proposer .__class__ .__bases__ [0 ],
72+ proposer .vllm_config , proposer .device ,
73+ proposer .runner )
74+
75+ def test_load_model (self , setup_torchair_mtp_proposer ,
76+ mocker : pytest .MockerFixture ):
7877 proposer , mock_model_loader , mock_torchair_deepseek_mtp = setup_torchair_mtp_proposer
7978 dummpy_model = Mock ()
8079
8180 proposer .load_model (dummpy_model )
8281
83- mocker .patch ("vllm.model_executor.model_loader.get_model_loader" ). assert_called_once_with (
84- proposer . vllm_config . load_config
85- )
82+ mocker .patch ("vllm.model_executor.model_loader.get_model_loader"
83+ ). assert_called_once_with (
84+ proposer . vllm_config . load_config )
8685
87- mock_get_layers = mocker .patch ("vllm.config.get_layers_from_vllm_config" )
88- assert mock_get_layers . call_count = 2
86+ mock_get_layers = mocker .patch (
87+ "vllm.config.get_layers_from_vllm_config" )
8988 mock_get_layers .assert_called_with (
9089 proposer .vllm_config ,
91- mocker .patch ("vllm.model_executor.layers.attention_layer_base.AttentionLayerBase" )
92- )
90+ mocker .patch (
91+ "vllm.model_executor.layers.attention_layer_base.AttentionLayerBase"
92+ ))
9393
94- mocker .patch ("vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP" ). assert_called_once_with (
95- vllm_config = proposer . vllm_config
96- )
94+ mocker .patch (
95+ "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP"
96+ ). assert_called_once_with ( vllm_config = proposer . vllm_config )
9797 mock_torchair_deepseek_mtp .to .assert_called_once (
98- proposer .vllm_config .device_config .device
99- )
98+ proposer .vllm_config .device_config .device )
10099
101100 assert len (proposer .attn_layer_name ) == 1
102101 mocker_layers_keys = mock_get_layers .return_value .keys ()
103102 assert proposer .attn_layer_name [0 ] in mocker_layers_keys
104103
105104 mock_model_loader .get_all_weights .assert_called_once_with (
106105 proposer .vllm_config .speculative_config .draft_model_config ,
107- mock_torchair_deepseek_mtp
108- )
106+ mock_torchair_deepseek_mtp )
109107 mock_torchair_deepseek_mtp .load_weights .assert_called_once_with (
110- mock_model_loader .get_all_weights .return_value
111- )
108+ mock_model_loader .get_all_weights .return_value )
112109
113- mock_process_weights = mocker .patch ("vllm.model_executor.model_loader.utils.process_weights_after_loading" )
110+ mock_process_weights = mocker .patch (
111+ "vllm.model_executor.model_loader.utils.process_weights_after_loading"
112+ )
114113 mock_process_weights .assert_called_once_with (
115114 mock_torchair_deepseek_mtp ,
116115 proposer .vllm_config .speculative_config .draft_model_config ,
117- proposer .vllm_config .device_config .device
118- )
116+ proposer .vllm_config .device_config .device )
0 commit comments