Skip to content

Commit 22d910a

Browse files
authored
chore: support for agg llama4 mulimodal (#3984)
Signed-off-by: ayushag <[email protected]>
1 parent f2a3c63 commit 22d910a

File tree

3 files changed

+33
-19
lines changed

3 files changed

+33
-19
lines changed

examples/multimodal/launch/agg_llama.sh renamed to components/backends/vllm/launch/agg_multimodal_llama.sh

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@ MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
1111
python -m dynamo.frontend --http-port=8000 &
1212

1313
# run processor
14-
python3 components/processor.py --model $MODEL_NAME --prompt-template "<|image|>\n<prompt>" &
15-
# LLama 4 doesn't support image embedding input, so the prefill worker will also
16-
# handle image encoding.
17-
# run EP/D workers
18-
python3 components/worker.py --model $MODEL_NAME --worker-type encode_prefill --tensor-parallel-size=8 --max-model-len=208960 &
14+
python -m dynamo.vllm --multimodal-processor --model $MODEL_NAME --mm-prompt-template "<|image|>\n<prompt>" &
15+
# Llama 4 doesn't support image embedding input, so use encode+prefill worker
16+
# that handles image encoding inline
17+
python -m dynamo.vllm --multimodal-encode-prefill-worker --model $MODEL_NAME --tensor-parallel-size=8 --max-model-len=208960 --gpu-memory-utilization 0.80 &
1918

2019
# Wait for all background processes to complete
2120
wait

components/src/dynamo/vllm/args.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class Config:
6969
multimodal_processor: bool = False
7070
multimodal_encode_worker: bool = False
7171
multimodal_worker: bool = False
72+
multimodal_encode_prefill_worker: bool = False
7273
mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:"
7374
# dump config to file
7475
dump_config_to: Optional[str] = None
@@ -169,6 +170,11 @@ def parse_args() -> Config:
169170
action="store_true",
170171
help="Run as multimodal worker component for LLM inference with multimodal data",
171172
)
173+
parser.add_argument(
174+
"--multimodal-encode-prefill-worker",
175+
action="store_true",
176+
help="Run as unified encode+prefill+decode worker for models requiring integrated image encoding (e.g., Llama 4)",
177+
)
172178
parser.add_argument(
173179
"--mm-prompt-template",
174180
type=str,
@@ -212,10 +218,11 @@ def parse_args() -> Config:
212218
int(bool(args.multimodal_processor))
213219
+ int(bool(args.multimodal_encode_worker))
214220
+ int(bool(args.multimodal_worker))
221+
+ int(bool(args.multimodal_encode_prefill_worker))
215222
)
216223
if mm_flags > 1:
217224
raise ValueError(
218-
"Use only one of --multimodal-processor, --multimodal-encode-worker, or --multimodal-worker"
225+
"Use only one of --multimodal-processor, --multimodal-encode-worker, --multimodal-worker, or --multimodal-encode-prefill-worker"
219226
)
220227

221228
# Set component and endpoint based on worker type
@@ -225,6 +232,9 @@ def parse_args() -> Config:
225232
elif args.multimodal_encode_worker:
226233
config.component = "encoder"
227234
config.endpoint = "generate"
235+
elif args.multimodal_encode_prefill_worker:
236+
config.component = "encoder"
237+
config.endpoint = "generate"
228238
elif args.multimodal_worker and args.is_prefill_worker:
229239
config.component = "prefill"
230240
config.endpoint = "generate"
@@ -248,6 +258,7 @@ def parse_args() -> Config:
248258
config.multimodal_processor = args.multimodal_processor
249259
config.multimodal_encode_worker = args.multimodal_encode_worker
250260
config.multimodal_worker = args.multimodal_worker
261+
config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker
251262
config.mm_prompt_template = args.mm_prompt_template
252263

253264
# Validate custom Jinja template file exists if provided

components/src/dynamo/vllm/main.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def signal_handler():
106106
elif config.multimodal_encode_worker:
107107
await init_multimodal_encode_worker(runtime, config)
108108
logger.debug("init_multimodal_encode_worker completed")
109-
elif config.multimodal_worker:
109+
elif config.multimodal_worker or config.multimodal_encode_prefill_worker:
110110
await init_multimodal_worker(runtime, config)
111111
logger.debug("init_multimodal_worker completed")
112112
elif config.is_prefill_worker:
@@ -605,8 +605,15 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
605605

606606

607607
async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
608-
"""Initialize multimodal worker component for aggregated or disaggregated mode"""
608+
"""
609+
Initialize multimodal worker component.
610+
611+
Supports two modes:
612+
1. --multimodal-worker: Receives embeddings from separate encoder
613+
2. --multimodal-encode-prefill-worker: Handles inline encoding (e.g., Llama 4)
609614
615+
Both can operate in aggregated (P+D) or disaggregated (P→D) mode.
616+
"""
610617
component = runtime.namespace(config.namespace).component(config.component)
611618
await component.create_service()
612619

@@ -615,16 +622,12 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
615622

616623
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config)
617624

618-
# TODO: Support Disaggregated mode separately
619-
client = (
620-
await runtime.namespace(config.namespace)
621-
.component("backend")
622-
.endpoint("generate")
623-
.client()
624-
)
625+
# For aggregated mode, no downstream client is needed
626+
# TODO: Implement disaggregated mode with proper decode worker client
627+
downstream_client = None
625628

626629
handler = MultimodalPDWorkerHandler(
627-
runtime, component, engine_client, config, client
630+
runtime, component, engine_client, config, downstream_client
628631
)
629632

630633
await handler.async_init(runtime)
@@ -637,14 +640,15 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
637640
handler.kv_publisher = kv_publisher
638641

639642
metrics_labels = [("model", config.model)]
640-
641643
try:
642644
await asyncio.gather(
643645
generate_endpoint.serve_endpoint(
644-
handler.generate, metrics_labels=metrics_labels
646+
handler.generate,
647+
metrics_labels=metrics_labels,
645648
),
646649
clear_endpoint.serve_endpoint(
647-
handler.clear_kv_blocks, metrics_labels=metrics_labels
650+
handler.clear_kv_blocks,
651+
metrics_labels=metrics_labels,
648652
),
649653
)
650654
except Exception as e:

0 commit comments

Comments
 (0)