diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index a0cc52fb6..a9f6e094b 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -20,7 +20,12 @@ from slime.utils.eval_config import EvalDatasetConfig from slime.utils.http_utils import get, post from slime.utils.misc import SingletonMeta, load_function -from slime.utils.processing_utils import encode_image_for_rollout_engine, load_processor, load_tokenizer +from slime.utils.processing_utils import ( + build_processor_kwargs, + encode_image_for_rollout_engine, + load_processor, + load_tokenizer, +) from slime.utils.types import Sample from .rm_hub import async_rm, batched_async_rm @@ -112,7 +117,8 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A ), f"Sample status is {sample.status}" if state.processor: - processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs) + processor_kwargs = build_processor_kwargs(sample.multimodal_inputs) + processor_output = state.processor(text=sample.prompt, **processor_kwargs) prompt_ids = processor_output["input_ids"][0] sample.multimodal_train_inputs = { k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] diff --git a/slime/utils/processing_utils.py b/slime/utils/processing_utils.py index e48d46571..18aa27997 100644 --- a/slime/utils/processing_utils.py +++ b/slime/utils/processing_utils.py @@ -16,6 +16,28 @@ def load_tokenizer(name_or_path: str, **kwargs): return AutoTokenizer.from_pretrained(name_or_path, **kwargs) +def build_processor_kwargs(multimodal_inputs: dict | None = None) -> dict: + + forced = { + # force return_tensors to None for input_ids + "return_tensors": None, + } + modality_forced = {"return_tensors": "pt"} + + result = dict(multimodal_inputs) if multimodal_inputs else {} + + result.update(forced) + + # set return_tensors="pt" for modality-specific outputs + for key in ("audio_kwargs", "images_kwargs", "videos_kwargs"): + if key in result: + result[key] = {**result[key], **modality_forced} + else: + result[key] = modality_forced.copy() + + return result + + def load_processor(name_or_path: str, **kwargs): try: proc = AutoProcessor.from_pretrained(name_or_path, **kwargs) @@ -31,15 +53,15 @@ def load_processor(name_or_path: str, **kwargs): def process_vision_info(prompt, processor): - # temporary solution, will write image utils for slime later - from qwen_vl_utils import process_vision_info + # TODO: temporary solution, will write image utils for slime later + from qwen_vl_utils import process_vision_info as qwen_process_vision_info if hasattr(processor.image_processor, "patch_size"): image_patch_size = processor.image_processor.patch_size else: logger.info(f"Using default patch size: {DEFAULT_PATCH_SIZE}") image_patch_size = DEFAULT_PATCH_SIZE - images, videos = process_vision_info(prompt, image_patch_size=image_patch_size) + images, videos = qwen_process_vision_info(prompt, image_patch_size=image_patch_size) multimodal_inputs = {"images": images, "videos": videos} return multimodal_inputs @@ -50,4 +72,5 @@ def encode_image_for_rollout_engine(image) -> str: if image.mode != "RGB": image = image.convert("RGB") image.save(buffer, format="PNG") - return base64.b64encode(buffer.getvalue()).decode("utf-8") + image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") + return f"data:image/png;base64,{image_base64}"