Skip to content

Commit 160d423

Browse files
indrajit96KrishnanPrash
authored andcommitted
feat: Add security flag to MM flow in vllm (#4556)
Co-authored-by: KrishnanPrash <[email protected]>
1 parent 5f8e2c0 commit 160d423

File tree

8 files changed

+84
-16
lines changed

8 files changed

+84
-16
lines changed

components/src/dynamo/vllm/args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class Config:
6262
multimodal_encode_worker: bool = False
6363
multimodal_worker: bool = False
6464
multimodal_decode_worker: bool = False
65+
enable_multimodal: bool = False
6566
multimodal_encode_prefill_worker: bool = False
6667
mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:"
6768
# dump config to file
@@ -161,6 +162,11 @@ def parse_args() -> Config:
161162
action="store_true",
162163
help="Run as unified encode+prefill+decode worker for models requiring integrated image encoding (e.g., Llama 4)",
163164
)
165+
parser.add_argument(
166+
"--enable-multimodal",
167+
action="store_true",
168+
help="Enable multimodal processing. If not set, none of the multimodal components can be used.",
169+
)
164170
parser.add_argument(
165171
"--mm-prompt-template",
166172
type=str,
@@ -218,6 +224,9 @@ def parse_args() -> Config:
218224
"Use only one of --multimodal-processor, --multimodal-encode-worker, --multimodal-worker, --multimodal-decode-worker, or --multimodal-encode-prefill-worker"
219225
)
220226

227+
if mm_flags == 1 and not args.enable_multimodal:
228+
raise ValueError("Use --enable-multimodal to enable multimodal processing")
229+
221230
# Set component and endpoint based on worker type
222231
if args.multimodal_processor:
223232
config.component = "processor"
@@ -256,6 +265,7 @@ def parse_args() -> Config:
256265
config.multimodal_worker = args.multimodal_worker
257266
config.multimodal_decode_worker = args.multimodal_decode_worker
258267
config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker
268+
config.enable_multimodal = args.enable_multimodal
259269
config.mm_prompt_template = args.mm_prompt_template
260270
config.store_kv = args.store_kv
261271

components/src/dynamo/vllm/handlers.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,25 @@ class BaseWorkerHandler(ABC):
6565
Request handler for the generate and clear_kv_blocks endpoints.
6666
"""
6767

68-
def __init__(self, runtime, component, engine, default_sampling_params):
68+
def __init__(
69+
self,
70+
runtime,
71+
component,
72+
engine,
73+
default_sampling_params,
74+
model_max_len: int | None = None,
75+
enable_multimodal: bool = False,
76+
):
6977
self.runtime = runtime
7078
self.component = component
7179
self.engine_client = engine
7280
self.default_sampling_params = default_sampling_params
7381
self.kv_publishers: list[ZmqKvEventPublisher] | None = None
7482
self.engine_monitor = VllmEngineMonitor(runtime, engine)
7583
self.image_loader = ImageLoader()
84+
self.temp_dirs: list[tempfile.TemporaryDirectory] = []
85+
self.model_max_len = model_max_len
86+
self.enable_multimodal = enable_multimodal
7687

7788
@abstractmethod
7889
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
@@ -128,6 +139,13 @@ async def _extract_multimodal_data(
128139
if "multi_modal_data" not in request or request["multi_modal_data"] is None:
129140
return None
130141

142+
# Security check: reject multimodal data if not explicitly enabled
143+
if not self.enable_multimodal:
144+
raise ValueError(
145+
"Received multimodal data but multimodal processing is not enabled. "
146+
"Use --enable-multimodal flag to enable multimodal processing."
147+
)
148+
131149
mm_map = request["multi_modal_data"]
132150
vllm_mm_data = {}
133151

@@ -212,8 +230,17 @@ def __init__(
212230
component,
213231
engine,
214232
default_sampling_params,
233+
model_max_len: int | None = None,
234+
enable_multimodal: bool = False,
215235
):
216-
super().__init__(runtime, component, engine, default_sampling_params)
236+
super().__init__(
237+
runtime,
238+
component,
239+
engine,
240+
default_sampling_params,
241+
model_max_len,
242+
enable_multimodal,
243+
)
217244

218245
async def generate(self, request, context):
219246
# Use context ID for request tracking and correlation
@@ -259,8 +286,23 @@ async def generate(self, request, context):
259286

260287

261288
class PrefillWorkerHandler(BaseWorkerHandler):
262-
def __init__(self, runtime, component, engine, default_sampling_params):
263-
super().__init__(runtime, component, engine, default_sampling_params)
289+
def __init__(
290+
self,
291+
runtime,
292+
component,
293+
engine,
294+
default_sampling_params,
295+
model_max_len: int | None = None,
296+
enable_multimodal: bool = False,
297+
):
298+
super().__init__(
299+
runtime,
300+
component,
301+
engine,
302+
default_sampling_params,
303+
model_max_len,
304+
enable_multimodal,
305+
)
264306

265307
async def generate(self, request, context):
266308
# Use context ID for request tracking and correlation with decode phase

components/src/dynamo/vllm/multimodal_handlers/worker_handler.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@ def __init__(
3838
)
3939

4040
# Call BaseWorkerHandler.__init__ with proper parameters
41-
super().__init__(runtime, component, engine_client, default_sampling_params)
41+
super().__init__(
42+
runtime,
43+
component,
44+
engine_client,
45+
default_sampling_params,
46+
enable_multimodal=config.enable_multimodal,
47+
)
4248

4349
self.config = config
4450
self.enable_disagg = config.is_prefill_worker
@@ -98,7 +104,13 @@ def __init__(
98104
)
99105

100106
# Call BaseWorkerHandler.__init__ with proper parameters
101-
super().__init__(runtime, component, engine_client, default_sampling_params)
107+
super().__init__(
108+
runtime,
109+
component,
110+
engine_client,
111+
default_sampling_params,
112+
enable_multimodal=config.enable_multimodal,
113+
)
102114

103115
self.config = config
104116
self.decode_worker_client = decode_worker_client

docs/backends/vllm/multimodal.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ Dynamo supports multimodal models with vLLM v1. In general, multimodal models ca
2222
> [!WARNING]
2323
> **LLaVA Model Limitation**: Do not use LLaVA models (e.g., `llava-hf/llava-1.5-7b-hf`) with the standard aggregated serving setup, as they contain keywords that Dynamo cannot yet parse. LLaVA models can still be used with the EPD (Encode-Prefill-Decode) setup described below.
2424
25+
> [!IMPORTANT]
26+
> **Security Requirement**: All multimodal workers require the `--enable-multimodal` flag to be explicitly set at startup. This is a security feature to prevent unintended processing of multimodal data from untrusted sources. Workers will fail at startup if multimodal flags (e.g., `--multimodal-worker`, `--multimodal-processor`) are used without `--enable-multimodal`.
27+
This flag is analogus to `--enable-mm-embeds` in vllm serve but also extends it to all multimodal content (url, embeddings, b64).
28+
2529
# Multimodal EPD Deployment Examples
2630

2731
This section provides example workflows and reference implementations for deploying a multimodal model using Dynamo and vLLM v1 with EPD(Encode-Prefill-Decode) pipeline.

examples/backends/vllm/launch/agg_multimodal.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ fi
5353
# --enforce-eager: Quick deployment (remove for production)
5454
# --connector none: No KV transfer needed for aggregated serving
5555
DYN_SYSTEM_PORT=8081 \
56-
python -m dynamo.vllm --model $MODEL_NAME --enforce-eager --connector none $EXTRA_ARGS
56+
python -m dynamo.vllm --enable-multimodal --model $MODEL_NAME --enforce-eager --connector none $EXTRA_ARGS
5757

5858
# Wait for all background processes to complete
5959
wait

examples/backends/vllm/launch/agg_multimodal_epd.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
7373
fi
7474

7575
# Start processor (Python-based preprocessing, handles prompt templating)
76-
python -m dynamo.vllm --multimodal-processor --model $MODEL_NAME --mm-prompt-template "$PROMPT_TEMPLATE" &
76+
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME --mm-prompt-template "$PROMPT_TEMPLATE" &
7777

7878
# run E/P/D workers
79-
CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-encode-worker --model $MODEL_NAME &
80-
CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --model $MODEL_NAME $EXTRA_ARGS &
79+
CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME &
80+
CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS &
8181

8282
# Wait for all background processes to complete
8383
wait

examples/backends/vllm/launch/agg_multimodal_llama.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
1111
python -m dynamo.frontend --http-port=8000 &
1212

1313
# run processor
14-
python -m dynamo.vllm --multimodal-processor --model $MODEL_NAME --mm-prompt-template "<|image|>\n<prompt>" &
14+
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME --mm-prompt-template "<|image|>\n<prompt>" &
1515
# Llama 4 doesn't support image embedding input, so use encode+prefill worker
1616
# that handles image encoding inline
17-
python -m dynamo.vllm --multimodal-encode-prefill-worker --model $MODEL_NAME --tensor-parallel-size=8 --max-model-len=208960 --gpu-memory-utilization 0.80 &
17+
python -m dynamo.vllm --multimodal-encode-prefill-worker --enable-multimodal --model $MODEL_NAME --tensor-parallel-size=8 --max-model-len=208960 --gpu-memory-utilization 0.80 &
1818

1919
# Wait for all background processes to complete
2020
wait

examples/backends/vllm/launch/disagg_multimodal_epd.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ python -m dynamo.frontend --http-port=8000 &
7676

7777
# Start processor
7878
echo "Starting processor..."
79-
python -m dynamo.vllm --multimodal-processor --model $MODEL_NAME --mm-prompt-template "$PROMPT_TEMPLATE" &
79+
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME --mm-prompt-template "$PROMPT_TEMPLATE" &
8080

8181
# Configure GPU memory optimization for specific models
8282
EXTRA_ARGS=""
@@ -86,17 +86,17 @@ fi
8686

8787
# Start encode worker
8888
echo "Starting encode worker on GPU 1..."
89-
VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-encode-worker --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20080"}' &
89+
VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20080"}' &
9090

9191
# Start prefill worker
9292
echo "Starting prefill worker on GPU 2..."
9393
VLLM_NIXL_SIDE_CHANNEL_PORT=20098 \
94-
CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-worker --is-prefill-worker --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' &
94+
CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-worker --is-prefill-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' &
9595

9696
# Start decode worker
9797
echo "Starting decode worker on GPU 3..."
9898
VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \
99-
CUDA_VISIBLE_DEVICES=3 python -m dynamo.vllm --multimodal-decode-worker --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' &
99+
CUDA_VISIBLE_DEVICES=3 python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' &
100100

101101
echo "=================================================="
102102
echo "All components started. Waiting for initialization..."

0 commit comments

Comments
 (0)