Skip to content

Commit 5e72216

Browse files
BloodAxeywang96
andauthored
Feature/video support in random mm dataset (vllm-project#25963)
Signed-off-by: Eugene Khvedchenia <[email protected]> Signed-off-by: Eugene Khvedchenya <[email protected]> Co-authored-by: Roger Wang <[email protected]>
1 parent 1a33aac commit 5e72216

File tree

3 files changed

+601
-25
lines changed

3 files changed

+601
-25
lines changed

tests/benchmarks/test_random_dataset.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,3 +359,126 @@ def test_random_mm_bucket_config_not_mutated(
359359
assert len(mm_data) >= 1
360360
for it in mm_data:
361361
assert it.get("type") == "image_url"
362+
363+
364+
@pytest.mark.benchmark
365+
def test_random_mm_video_sampling(hf_tokenizer: PreTrainedTokenizerBase) -> None:
366+
"""Test video sampling functionality in RandomMultiModalDataset."""
367+
ds = RandomMultiModalDataset(random_seed=42)
368+
369+
# Test with video bucket configuration
370+
bucket_config = {
371+
(64, 64, 1): 0.3, # Images
372+
(64, 64, 8): 0.7, # Videos
373+
}
374+
375+
limit_mm_per_prompt = {"image": 2, "video": 2}
376+
377+
samples = _collect_mm_samples(
378+
ds,
379+
hf_tokenizer,
380+
num_requests=5,
381+
base_items_per_request=1,
382+
num_mm_items_range_ratio=0.0,
383+
limit_mm_per_prompt=limit_mm_per_prompt,
384+
bucket_config=bucket_config,
385+
)
386+
387+
assert len(samples) == 5
388+
389+
# Check that we have both images and videos
390+
video_count = 0
391+
image_count = 0
392+
393+
for s in samples:
394+
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
395+
assert len(mm_data) == 1
396+
397+
item = mm_data[0]
398+
if item.get("type") == "video_url":
399+
video_count += 1
400+
# Verify video URL format
401+
url = item.get("video_url", {}).get("url", "")
402+
assert url.startswith("data:video/mp4;base64,")
403+
elif item.get("type") == "image_url":
404+
image_count += 1
405+
# Verify image URL format
406+
url = item.get("image_url", {}).get("url", "")
407+
assert url.startswith("data:image/jpeg;base64,")
408+
409+
# Should have some videos due to 0.7 probability
410+
assert video_count > 0
411+
assert image_count > 0
412+
413+
414+
@pytest.mark.benchmark
415+
def test_random_mm_video_only_sampling(hf_tokenizer: PreTrainedTokenizerBase) -> None:
416+
"""Test sampling with only video buckets."""
417+
ds = RandomMultiModalDataset(random_seed=42)
418+
419+
bucket_config = {
420+
(64, 64, 8): 1.0, # Only videos
421+
}
422+
423+
limit_mm_per_prompt = {"image": 0, "video": 1}
424+
425+
samples = _collect_mm_samples(
426+
ds,
427+
hf_tokenizer,
428+
num_requests=3,
429+
base_items_per_request=1,
430+
num_mm_items_range_ratio=0.0,
431+
limit_mm_per_prompt=limit_mm_per_prompt,
432+
bucket_config=bucket_config,
433+
)
434+
435+
assert len(samples) == 3
436+
437+
for s in samples:
438+
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
439+
assert len(mm_data) == 1
440+
441+
item = mm_data[0]
442+
assert item.get("type") == "video_url"
443+
url = item.get("video_url", {}).get("url", "")
444+
assert url.startswith("data:video/mp4;base64,")
445+
446+
447+
@pytest.mark.benchmark
448+
def test_random_mm_video_deterministic_sampling(
449+
hf_tokenizer: PreTrainedTokenizerBase,
450+
) -> None:
451+
"""Test that video sampling is deterministic with same seed."""
452+
seed = 123
453+
ds_a = RandomMultiModalDataset(random_seed=seed)
454+
ds_b = RandomMultiModalDataset(random_seed=seed)
455+
456+
bucket_config = {
457+
(64, 64, 8): 1.0, # Only videos
458+
}
459+
460+
limit_mm_per_prompt = {"image": 0, "video": 1}
461+
462+
a = _collect_mm_samples(
463+
ds_a,
464+
hf_tokenizer,
465+
num_requests=3,
466+
base_items_per_request=1,
467+
num_mm_items_range_ratio=0.0,
468+
limit_mm_per_prompt=limit_mm_per_prompt,
469+
bucket_config=bucket_config,
470+
)
471+
472+
b = _collect_mm_samples(
473+
ds_b,
474+
hf_tokenizer,
475+
num_requests=3,
476+
base_items_per_request=1,
477+
num_mm_items_range_ratio=0.0,
478+
limit_mm_per_prompt=limit_mm_per_prompt,
479+
bucket_config=bucket_config,
480+
)
481+
482+
fa = [_mm_fingerprint_sample(s) for s in a]
483+
fb = [_mm_fingerprint_sample(s) for s in b]
484+
assert fa == fb

0 commit comments

Comments
 (0)