Skip to content

Commit 519ca7e

Browse files
test ut fixup
Signed-off-by: CodeNine-CJ <[email protected]>
1 parent 42a6004 commit 519ca7e

File tree

1 file changed

+46
-46
lines changed

1 file changed

+46
-46
lines changed

tests/ut/torchair/test_torchair_model_runner.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -65,75 +65,76 @@ class TestNPUTorchairModelRunner(PytestBase):
6565
# return runner
6666
@pytest.fixture
6767
def setup_npu_torchair_model_runner(self, mocker: MockerFixture):
68-
# 核心配置对象
6968
vllm_config = MagicMock(spec=VllmConfig)
70-
71-
# --- 必需的配置对象 (用于初始化属性) ---
7269
vllm_config.model_config = MagicMock()
70+
cache_config = CacheConfig(block_size=16)
71+
vllm_config.cache_config = cache_config
7372
vllm_config.model_config.hf_config = MagicMock()
7473
vllm_config.model_config.hf_config.index_topk = 2
75-
vllm_config.model_config.max_model_len = 1024 # 模拟 model_config.max_model_len
76-
vllm_config.model_config.use_mla = False # 模拟 model_config.use_mla
77-
vllm_config.model_config.get_hidden_size.return_value = 512 # 模拟 get_hidden_size
78-
vllm_config.model_config.pooler_config = None # 模拟 is_pooling_model
79-
vllm_config.model_config.logits_processors = [] # 模拟 build_logitsprocs
80-
74+
vllm_config.model_config.max_model_len = 1024
75+
vllm_config.model_config.use_mla = False
76+
vllm_config.model_config.get_hidden_size.return_value = 512
77+
vllm_config.model_config.pooler_config = None
78+
vllm_config.model_config.logits_processors = []
8179
cache_config = MagicMock()
8280
cache_config.block_size = 16
83-
cache_config.cache_dtype = "auto" # 模拟 cache_config.cache_dtype
81+
cache_config.cache_dtype = "auto"
8482
vllm_config.cache_config = cache_config
85-
83+
8684
speculative_config = MagicMock()
8785
speculative_config.num_speculative_tokens = 4
8886
vllm_config.speculative_config = speculative_config
89-
87+
9088
vllm_config.compilation_config = MagicMock()
91-
vllm_config.compilation_config.cudagraph_mode = Mock() # 模拟 compilation_config
92-
# 模拟 aclgraph_batch_sizes
93-
vllm_config.compilation_config.cudagraph_capture_sizes = [1, 2, 4]
94-
89+
vllm_config.compilation_config.cudagraph_mode = Mock()
90+
vllm_config.compilation_config.cudagraph_capture_sizes = [1, 2, 4]
91+
9592
vllm_config.lora_config = MagicMock()
9693
vllm_config.parallel_config = MagicMock()
97-
vllm_config.parallel_config.data_parallel_size = 1 # 模拟 dp_size
98-
vllm_config.parallel_config.data_parallel_rank = 0 # 模拟 dp_rank
99-
vllm_config.parallel_config.cp_kv_cache_interleave_size = 1 # 模拟 cp_kv_cache_interleave_size
94+
vllm_config.parallel_config.data_parallel_size = 1
95+
vllm_config.parallel_config.data_parallel_rank = 0
96+
vllm_config.parallel_config.cp_kv_cache_interleave_size = 1
10097

10198
scheduler_config = MagicMock()
102-
scheduler_config.max_num_batched_tokens = 2048 # 模拟 max_num_tokens
103-
scheduler_config.max_num_seqs = 64 # 模拟 max_num_reqs (decode_max_num_seqs 默认为 0)
104-
scheduler_config.chunked_prefill_enabled = True # 模拟 chunked_prefill_enabled
105-
scheduler_config.async_scheduling = False # 模拟 use_async_scheduling
99+
scheduler_config.max_num_batched_tokens = 2048
100+
scheduler_config.max_num_seqs = 64
101+
scheduler_config.chunked_prefill_enabled = True
102+
scheduler_config.async_scheduling = False
106103
vllm_config.scheduler_config = scheduler_config
107104

108-
# --- 修复 'load_config' 报错 ---
109-
vllm_config.load_config = MagicMock()
105+
vllm_config.load_config = MagicMock()
110106

111-
# --- 模拟 kv_transfer_config (用于判断 kv role) ---
112107
vllm_config.kv_transfer_config = None
113108

114-
# --- 模拟其他缺失的函数/常量 ---
115-
mocker.patch("vllm_ascend.worker.model_runner_v1.is_pin_memory_available",
116-
return_value=True) # 模拟 pin_memory
109+
mocker.patch(
110+
"vllm_ascend.worker.model_runner_v1.is_pin_memory_available",
111+
return_value=True)
117112
mocker.patch("vllm_ascend.worker.model_runner_v1.cdiv",
118-
return_value=64) # 模拟 max_num_blocks_per_req
119-
mocker.patch("vllm_ascend.worker.model_runner_v1.prefill_context_parallel_enable",
120-
return_value=False) # 模拟 pcp_size/rank
121-
mocker.patch("vllm_ascend.worker.model_runner_v1.get_dcp_group").return_value.world_size = 1 # 模拟 dcp_size/rank
122-
mocker.patch("vllm_ascend.torchair.torchair_model_runner.get_attn_backend",
123-
autospec=True) # 模拟 Attention 设置
124-
mocker.patch("vllm_ascend.torchair.torchair_model_runner._set_up_drafter") # 模拟 drafter 设置
125-
mocker.patch("vllm_ascend.torchair.torchair_model_runner._may_pad_kv_consumer_num_seq") # 模拟 kv 填充
126-
127-
# NPU特定的配置
113+
return_value=64)
114+
mocker.patch(
115+
"vllm_ascend.worker.model_runner_v1.prefill_context_parallel_enable",
116+
return_value=False)
117+
mocker.patch("vllm_ascend.worker.model_runner_v1.get_dcp_group"
118+
).return_value.world_size = 1
119+
mocker.patch(
120+
"vllm_ascend.torchair.torchair_model_runner.get_attn_backend",
121+
autospec=True)
122+
mocker.patch(
123+
"vllm_ascend.torchair.torchair_model_runner._set_up_drafter")
124+
mocker.patch(
125+
"vllm_ascend.torchair.torchair_model_runner._may_pad_kv_consumer_num_seq"
126+
)
127+
128+
128129
device = torch.device("npu:0")
129130
ascend_config = MagicMock()
130-
# 确保 ascend_scheduler_config.enabled 被设置,否则 chunked_prefill_enabled 会被设置为 True
131-
ascend_config.ascend_scheduler_config.enabled = False
132-
# 其他 NPU/Ascend 特有配置
133-
ascend_config.weight_prefetch_config = Mock()
131+
132+
ascend_config.ascend_scheduler_config.enabled = False
133+
134+
ascend_config.weight_prefetch_config = Mock()
134135
ascend_config.dynamic_eplb = False
135136
ascend_config.expert_map_record_path = None
136-
137+
137138
mocker.patch("vllm_ascend.utils.get_ascend_config",
138139
return_value=ascend_config)
139140
mocker.patch("vllm_ascend.torchair.utils.register_torchair_model")
@@ -146,9 +147,8 @@ def setup_npu_torchair_model_runner(self, mocker: MockerFixture):
146147
mock_attn_backend = Mock()
147148
mock_attn_backend.get_builder_cls.return_value = lambda *args, **kwargs: mock_attn_builder
148149

149-
# 设置类属性(如果需要)
150150
NPUTorchairModelRunner.decode_token_per_req = 1
151-
NPUTorchairModelRunner.max_num_tokens = 10 # 这行在实际测试中可能被覆盖
151+
NPUTorchairModelRunner.max_num_tokens = 10
152152

153153
runner = NPUTorchairModelRunner(vllm_config, device)
154154
runner.vllm_config = vllm_config

0 commit comments

Comments
 (0)