1+ from unittest .mock import MagicMock , Mock
2+
13import pytest
2- from unittest .mock import Mock , MagicMock , patch
3- from vllm_ascend import torchair_mtp_proposer
4+ import torch
45from vllm .config import VllmConfig
6+ from vllm_ascend .torchair .torchair_mtp_proposer import TorchairMtpProposer
7+ from tests .ut .base import PytestBase
58
6- import torch
79
810def vllm_version_is (version ):
911 return version == "0.11.0"
1012
13+
1114import sys
15+
1216sys .modules [__name__ ].vllm_version_is = vllm_version_is
1317
1418
@@ -29,90 +33,91 @@ def setup_torchair_mtp_proposer(self, mocker: pytest.MockerFixture):
2933
3034 mocker .patch ("vllm_ascend.torchair_mtp_proposer.__init__" ,
3135 return_value = None )
32-
36+
3337 if vllm_version_is ("0.11.0" ):
3438 mock_set_default_dtype = mocker .patch (
3539 'vllm.model_executor.model_loader.utils.set_default_torch_dtype'
3640 )
3741 else :
3842 mock_set_default_dtype = mocker .patch (
39- 'vllm.utls.torch_utils.set_default_torch_dtype'
40- )
43+ 'vllm.utls.torch_utils.set_default_torch_dtype' )
4144 mock_set_default_dtype .return_value .__enter__ .return_value = None
4245
4346 mock_model_loader = MagicMock ()
4447 mocker .patch ("vllm.model_executor.model_loader.get_model_loader" ,
4548 return_value = mock_model_loader )
46- mock_layers = {"target_attn_layer_1" : Mock (), "draft_attn_layer_2" : Mock ()}
49+ mock_layers = {
50+ "target_attn_layer_1" : Mock (),
51+ "draft_attn_layer_2" : Mock ()
52+ }
4753 mocker .patch ("vllm.config.get_layers_from_vllm_config" ,
4854 return_value = mock_layers )
4955 mock_set_current = mocker .patch ("vllm.config.set_current_vllm_config" )
5056 mock_set_current .return_value .__enter__ .return_value = None
5157 mock_torchair_deepseek_mtp = MagicMock ()
5258 mock_torchair_deepseek_mtp .to .return_value = mock_torchair_deepseek_mtp
53- mocker .patch ("vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP" ,
54- return_value = mock_torchair_deepseek_mtp )
55- mocker .patch ("vllm.model_executor.model_loader.utils.process_weights_after_loading" )
59+ mocker .patch (
60+ "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP" ,
61+ return_value = mock_torchair_deepseek_mtp )
62+ mocker .patch (
63+ "vllm.model_executor.model_loader.utils.process_weights_after_loading"
64+ )
5665
5766 proposer = TorchairMtpProposer (vllm_config , device , runner )
5867 proposer .vllm_config = vllm_config
5968 proposer .device = device
6069 proposer .runner = runner
6170
6271 return proposer , mock_model_loader , mock_torchair_deepseek_mtp
63-
72+
6473 def test_init (self , setup_torchair_mtp_proposer ):
6574 proposer , _ , _ , = setup_torchair_mtp_proposer
6675
6776 assert isinstance (proposer , setup_torchair_mtp_proposer )
6877 assert proposer .torchair_compiled_model is None
69- assert proposer .torchair_compiled_models = {}
70- Mock .assert_called_once_with (
71- proposer .__class__ .__bases__ [0 ],
72- proposer .vllm_config ,
73- proposer .device ,
74- proposer .runner
75- )
76-
77- def test_load_model (self , setup_torchair_mtp_proposer , mocker : pytest .MockerFixture ):
78+ Mock .assert_called_once_with (proposer .__class__ .__bases__ [0 ],
79+ proposer .vllm_config , proposer .device ,
80+ proposer .runner )
81+
82+ def test_load_model (self , setup_torchair_mtp_proposer ,
83+ mocker : pytest .MockerFixture ):
7884 proposer , mock_model_loader , mock_torchair_deepseek_mtp = setup_torchair_mtp_proposer
7985 dummpy_model = Mock ()
8086
8187 proposer .load_model (dummpy_model )
8288
83- mocker .patch ("vllm.model_executor.model_loader.get_model_loader" ). assert_called_once_with (
84- proposer . vllm_config . load_config
85- )
89+ mocker .patch ("vllm.model_executor.model_loader.get_model_loader"
90+ ). assert_called_once_with (
91+ proposer . vllm_config . load_config )
8692
87- mock_get_layers = mocker .patch ("vllm.config.get_layers_from_vllm_config" )
88- assert mock_get_layers . call_count = 2
93+ mock_get_layers = mocker .patch (
94+ "vllm.config.get_layers_from_vllm_config" )
8995 mock_get_layers .assert_called_with (
9096 proposer .vllm_config ,
91- mocker .patch ("vllm.model_executor.layers.attention_layer_base.AttentionLayerBase" )
92- )
97+ mocker .patch (
98+ "vllm.model_executor.layers.attention_layer_base.AttentionLayerBase"
99+ ))
93100
94- mocker .patch ("vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP" ). assert_called_once_with (
95- vllm_config = proposer . vllm_config
96- )
101+ mocker .patch (
102+ "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP"
103+ ). assert_called_once_with ( vllm_config = proposer . vllm_config )
97104 mock_torchair_deepseek_mtp .to .assert_called_once (
98- proposer .vllm_config .device_config .device
99- )
105+ proposer .vllm_config .device_config .device )
100106
101107 assert len (proposer .attn_layer_name ) == 1
102108 mocker_layers_keys = mock_get_layers .return_value .keys ()
103109 assert proposer .attn_layer_name [0 ] in mocker_layers_keys
104110
105111 mock_model_loader .get_all_weights .assert_called_once_with (
106112 proposer .vllm_config .speculative_config .draft_model_config ,
107- mock_torchair_deepseek_mtp
108- )
113+ mock_torchair_deepseek_mtp )
109114 mock_torchair_deepseek_mtp .load_weights .assert_called_once_with (
110- mock_model_loader .get_all_weights .return_value
111- )
115+ mock_model_loader .get_all_weights .return_value )
112116
113- mock_process_weights = mocker .patch ("vllm.model_executor.model_loader.utils.process_weights_after_loading" )
117+ mock_process_weights = mocker .patch (
118+ "vllm.model_executor.model_loader.utils.process_weights_after_loading"
119+ )
114120 mock_process_weights .assert_called_once_with (
115121 mock_torchair_deepseek_mtp ,
116122 proposer .vllm_config .speculative_config .draft_model_config ,
117- proposer .vllm_config .device_config .device
118- )
123+ proposer .vllm_config .device_config .device )
0 commit comments