@@ -359,3 +359,126 @@ def test_random_mm_bucket_config_not_mutated(
359359 assert len (mm_data ) >= 1
360360 for it in mm_data :
361361 assert it .get ("type" ) == "image_url"
362+
363+
364+ @pytest .mark .benchmark
365+ def test_random_mm_video_sampling (hf_tokenizer : PreTrainedTokenizerBase ) -> None :
366+ """Test video sampling functionality in RandomMultiModalDataset."""
367+ ds = RandomMultiModalDataset (random_seed = 42 )
368+
369+ # Test with video bucket configuration
370+ bucket_config = {
371+ (64 , 64 , 1 ): 0.3 , # Images
372+ (64 , 64 , 8 ): 0.7 , # Videos
373+ }
374+
375+ limit_mm_per_prompt = {"image" : 2 , "video" : 2 }
376+
377+ samples = _collect_mm_samples (
378+ ds ,
379+ hf_tokenizer ,
380+ num_requests = 5 ,
381+ base_items_per_request = 1 ,
382+ num_mm_items_range_ratio = 0.0 ,
383+ limit_mm_per_prompt = limit_mm_per_prompt ,
384+ bucket_config = bucket_config ,
385+ )
386+
387+ assert len (samples ) == 5
388+
389+ # Check that we have both images and videos
390+ video_count = 0
391+ image_count = 0
392+
393+ for s in samples :
394+ mm_data = cast (list [dict [str , Any ]], s .multi_modal_data )
395+ assert len (mm_data ) == 1
396+
397+ item = mm_data [0 ]
398+ if item .get ("type" ) == "video_url" :
399+ video_count += 1
400+ # Verify video URL format
401+ url = item .get ("video_url" , {}).get ("url" , "" )
402+ assert url .startswith ("data:video/mp4;base64," )
403+ elif item .get ("type" ) == "image_url" :
404+ image_count += 1
405+ # Verify image URL format
406+ url = item .get ("image_url" , {}).get ("url" , "" )
407+ assert url .startswith ("data:image/jpeg;base64," )
408+
409+ # Should have some videos due to 0.7 probability
410+ assert video_count > 0
411+ assert image_count > 0
412+
413+
414+ @pytest .mark .benchmark
415+ def test_random_mm_video_only_sampling (hf_tokenizer : PreTrainedTokenizerBase ) -> None :
416+ """Test sampling with only video buckets."""
417+ ds = RandomMultiModalDataset (random_seed = 42 )
418+
419+ bucket_config = {
420+ (64 , 64 , 8 ): 1.0 , # Only videos
421+ }
422+
423+ limit_mm_per_prompt = {"image" : 0 , "video" : 1 }
424+
425+ samples = _collect_mm_samples (
426+ ds ,
427+ hf_tokenizer ,
428+ num_requests = 3 ,
429+ base_items_per_request = 1 ,
430+ num_mm_items_range_ratio = 0.0 ,
431+ limit_mm_per_prompt = limit_mm_per_prompt ,
432+ bucket_config = bucket_config ,
433+ )
434+
435+ assert len (samples ) == 3
436+
437+ for s in samples :
438+ mm_data = cast (list [dict [str , Any ]], s .multi_modal_data )
439+ assert len (mm_data ) == 1
440+
441+ item = mm_data [0 ]
442+ assert item .get ("type" ) == "video_url"
443+ url = item .get ("video_url" , {}).get ("url" , "" )
444+ assert url .startswith ("data:video/mp4;base64," )
445+
446+
447+ @pytest .mark .benchmark
448+ def test_random_mm_video_deterministic_sampling (
449+ hf_tokenizer : PreTrainedTokenizerBase ,
450+ ) -> None :
451+ """Test that video sampling is deterministic with same seed."""
452+ seed = 123
453+ ds_a = RandomMultiModalDataset (random_seed = seed )
454+ ds_b = RandomMultiModalDataset (random_seed = seed )
455+
456+ bucket_config = {
457+ (64 , 64 , 8 ): 1.0 , # Only videos
458+ }
459+
460+ limit_mm_per_prompt = {"image" : 0 , "video" : 1 }
461+
462+ a = _collect_mm_samples (
463+ ds_a ,
464+ hf_tokenizer ,
465+ num_requests = 3 ,
466+ base_items_per_request = 1 ,
467+ num_mm_items_range_ratio = 0.0 ,
468+ limit_mm_per_prompt = limit_mm_per_prompt ,
469+ bucket_config = bucket_config ,
470+ )
471+
472+ b = _collect_mm_samples (
473+ ds_b ,
474+ hf_tokenizer ,
475+ num_requests = 3 ,
476+ base_items_per_request = 1 ,
477+ num_mm_items_range_ratio = 0.0 ,
478+ limit_mm_per_prompt = limit_mm_per_prompt ,
479+ bucket_config = bucket_config ,
480+ )
481+
482+ fa = [_mm_fingerprint_sample (s ) for s in a ]
483+ fb = [_mm_fingerprint_sample (s ) for s in b ]
484+ assert fa == fb
0 commit comments