@@ -106,7 +106,7 @@ def signal_handler():
106106 elif config .multimodal_encode_worker :
107107 await init_multimodal_encode_worker (runtime , config )
108108 logger .debug ("init_multimodal_encode_worker completed" )
109- elif config .multimodal_worker :
109+ elif config .multimodal_worker or config . multimodal_encode_prefill_worker :
110110 await init_multimodal_worker (runtime , config )
111111 logger .debug ("init_multimodal_worker completed" )
112112 elif config .is_prefill_worker :
@@ -605,8 +605,15 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
605605
606606
607607async def init_multimodal_worker (runtime : DistributedRuntime , config : Config ):
608- """Initialize multimodal worker component for aggregated or disaggregated mode"""
608+ """
609+ Initialize multimodal worker component.
610+
611+ Supports two modes:
612+ 1. --multimodal-worker: Receives embeddings from separate encoder
613+ 2. --multimodal-encode-prefill-worker: Handles inline encoding (e.g., Llama 4)
609614
615+ Both can operate in aggregated (P+D) or disaggregated (P→D) mode.
616+ """
610617 component = runtime .namespace (config .namespace ).component (config .component )
611618 await component .create_service ()
612619
@@ -615,16 +622,12 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
615622
616623 engine_client , vllm_config , default_sampling_params = setup_vllm_engine (config )
617624
618- # TODO: Support Disaggregated mode separately
619- client = (
620- await runtime .namespace (config .namespace )
621- .component ("backend" )
622- .endpoint ("generate" )
623- .client ()
624- )
625+ # For aggregated mode, no downstream client is needed
626+ # TODO: Implement disaggregated mode with proper decode worker client
627+ downstream_client = None
625628
626629 handler = MultimodalPDWorkerHandler (
627- runtime , component , engine_client , config , client
630+ runtime , component , engine_client , config , downstream_client
628631 )
629632
630633 await handler .async_init (runtime )
@@ -637,14 +640,15 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
637640 handler .kv_publisher = kv_publisher
638641
639642 metrics_labels = [("model" , config .model )]
640-
641643 try :
642644 await asyncio .gather (
643645 generate_endpoint .serve_endpoint (
644- handler .generate , metrics_labels = metrics_labels
646+ handler .generate ,
647+ metrics_labels = metrics_labels ,
645648 ),
646649 clear_endpoint .serve_endpoint (
647- handler .clear_kv_blocks , metrics_labels = metrics_labels
650+ handler .clear_kv_blocks ,
651+ metrics_labels = metrics_labels ,
648652 ),
649653 )
650654 except Exception as e :
0 commit comments