Skip to content

Commit 7f84226

Browse files
fixup
Signed-off-by: CodeNine-CJ <[email protected]>
1 parent cb81376 commit 7f84226

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

tests/ut/torchair/test_torchair_model_runner.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import torch
55
from pytest_mock import MockerFixture
6-
from vllm.config import CacheConfig, VllmConfig
6+
from vllm.config import VllmConfig
77

88
from tests.ut.base import PytestBase
99
from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner
@@ -66,16 +66,18 @@ class TestNPUTorchairModelRunner(PytestBase):
6666
@pytest.fixture
6767
def setup_npu_torchair_model_runner(self, mocker: MockerFixture):
6868
vllm_config = MagicMock(spec=VllmConfig)
69+
vllm_config.device_config = MagicMock()
70+
vllm_config.device_config.device = torch.device("cpu")
6971
vllm_config.model_config = MagicMock()
70-
cache_config = CacheConfig(block_size=16)
71-
vllm_config.cache_config = cache_config
7272
vllm_config.model_config.hf_config = MagicMock()
7373
vllm_config.model_config.hf_config.index_topk = 2
7474
vllm_config.model_config.max_model_len = 1024
7575
vllm_config.model_config.use_mla = False
7676
vllm_config.model_config.get_hidden_size.return_value = 512
7777
vllm_config.model_config.pooler_config = None
7878
vllm_config.model_config.logits_processors = []
79+
vllm_config.model_config.dtype = torch.float16
80+
7981
cache_config = MagicMock()
8082
cache_config.block_size = 16
8183
cache_config.cache_dtype = "auto"
@@ -85,25 +87,27 @@ def setup_npu_torchair_model_runner(self, mocker: MockerFixture):
8587
speculative_config.num_speculative_tokens = 4
8688
vllm_config.speculative_config = speculative_config
8789

88-
vllm_config.compilation_config = MagicMock()
89-
vllm_config.compilation_config.cudagraph_mode = Mock()
90-
vllm_config.compilation_config.cudagraph_capture_sizes = [1, 2, 4]
90+
compilation_config = MagicMock()
91+
compilation_config.cudagraph_mode = Mock()
92+
compilation_config.cudagraph_capture_sizes = [1, 2, 4]
93+
vllm_config.compilation_config = compilation_config
9194

92-
vllm_config.lora_config = MagicMock()
93-
vllm_config.parallel_config = MagicMock()
94-
vllm_config.parallel_config.data_parallel_size = 1
95-
vllm_config.parallel_config.data_parallel_rank = 0
96-
vllm_config.parallel_config.cp_kv_cache_interleave_size = 1
95+
parallel_config = MagicMock()
96+
parallel_config.data_parallel_size = 1
97+
parallel_config.data_parallel_rank = 0
98+
parallel_config.cp_kv_cache_interleave_size = 1
99+
vllm_config.parallel_config = parallel_config
97100

98101
scheduler_config = MagicMock()
99102
scheduler_config.max_num_batched_tokens = 2048
100103
scheduler_config.max_num_seqs = 64
101104
scheduler_config.chunked_prefill_enabled = True
102105
scheduler_config.async_scheduling = False
106+
scheduler_config.decode_max_num_seqs = 0
103107
vllm_config.scheduler_config = scheduler_config
104108

105109
vllm_config.load_config = MagicMock()
106-
110+
vllm_config.lora_config = MagicMock()
107111
vllm_config.kv_transfer_config = None
108112

109113
mocker.patch(
@@ -116,16 +120,14 @@ def setup_npu_torchair_model_runner(self, mocker: MockerFixture):
116120
return_value=False)
117121
mocker.patch("vllm_ascend.worker.model_runner_v1.get_dcp_group"
118122
).return_value.world_size = 1
119-
mocker.patch(
120-
"vllm_ascend.torchair.torchair_model_runner.get_attn_backend",
121-
autospec=True)
122-
mocker.patch(
123-
"vllm_ascend.torchair.torchair_model_runner._set_up_drafter")
124-
mocker.patch(
125-
"vllm_ascend.torchair.torchair_model_runner._may_pad_kv_consumer_num_seq"
126-
)
127-
128-
device = torch.device("npu:0")
123+
mocker.patch("vllm.attention.get_attn_backend", autospec=True)
124+
# mocker.patch(
125+
# "vllm_ascend.torchair.torchair_model_runner._set_up_drafter")
126+
# mocker.patch(
127+
# "vllm_ascend.torchair.torchair_model_runner._may_pad_kv_consumer_num_seq"
128+
# )
129+
130+
device = torch.device("npu")
129131
ascend_config = MagicMock()
130132

131133
ascend_config.ascend_scheduler_config.enabled = False

0 commit comments

Comments
 (0)