Skip to content

Commit e67c295

Browse files
authored
[Bugfix] fix automatic prefix args and add log info (vllm-project#3608)
1 parent 925f333 commit e67c295

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

vllm/core/block_manager.py

+5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
1010
from vllm.utils import Device
1111
from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor
12+
from vllm.logger import init_logger
13+
14+
logger = init_logger(__name__)
1215

1316

1417
class BlockAllocatorBase(ABC):
@@ -241,11 +244,13 @@ def __init__(
241244
self.watermark_blocks = int(watermark * num_gpu_blocks)
242245

243246
if self.enable_caching:
247+
logger.info("enable automatic prefix caching")
244248
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
245249
num_gpu_blocks)
246250
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
247251
num_cpu_blocks)
248252
else:
253+
logger.info("disable automatic prefix caching")
249254
self.gpu_allocator = UncachedBlockAllocator(
250255
Device.GPU, block_size, num_gpu_blocks)
251256
self.cpu_allocator = UncachedBlockAllocator(

vllm/engine/arg_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,8 @@ def create_engine_configs(
337337
cache_config = CacheConfig(self.block_size,
338338
self.gpu_memory_utilization,
339339
self.swap_space, self.kv_cache_dtype,
340-
model_config.get_sliding_window())
340+
model_config.get_sliding_window(),
341+
self.enable_prefix_caching)
341342
parallel_config = ParallelConfig(
342343
self.pipeline_parallel_size, self.tensor_parallel_size,
343344
self.worker_use_ray, self.max_parallel_loading_workers,

0 commit comments

Comments
 (0)