Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions slime/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from slime.utils.eval_config import EvalDatasetConfig
from slime.utils.http_utils import get, post
from slime.utils.misc import SingletonMeta, load_function
from slime.utils.processing_utils import encode_image_for_rollout_engine, load_processor, load_tokenizer
from slime.utils.processing_utils import (
build_processor_kwargs,
encode_image_for_rollout_engine,
load_processor,
load_tokenizer,
)
from slime.utils.types import Sample

from .rm_hub import async_rm, batched_async_rm
Expand Down Expand Up @@ -112,7 +117,8 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A
), f"Sample status is {sample.status}"

if state.processor:
processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs)
processor_kwargs = build_processor_kwargs(sample.multimodal_inputs)
processor_output = state.processor(text=sample.prompt, **processor_kwargs)
prompt_ids = processor_output["input_ids"][0]
sample.multimodal_train_inputs = {
k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"]
Expand Down
31 changes: 27 additions & 4 deletions slime/utils/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,28 @@ def load_tokenizer(name_or_path: str, **kwargs):
return AutoTokenizer.from_pretrained(name_or_path, **kwargs)


def build_processor_kwargs(multimodal_inputs: dict | None = None) -> dict:

forced = {
# force return_tensors to None for input_ids
"return_tensors": None,
}
modality_forced = {"return_tensors": "pt"}

result = dict(multimodal_inputs) if multimodal_inputs else {}

result.update(forced)

# set return_tensors="pt" for modality-specific outputs
for key in ("audio_kwargs", "images_kwargs", "videos_kwargs"):
if key in result:
result[key] = {**result[key], **modality_forced}
else:
result[key] = modality_forced.copy()

return result


def load_processor(name_or_path: str, **kwargs):
try:
proc = AutoProcessor.from_pretrained(name_or_path, **kwargs)
Expand All @@ -31,15 +53,15 @@ def load_processor(name_or_path: str, **kwargs):


def process_vision_info(prompt, processor):
# temporary solution, will write image utils for slime later
from qwen_vl_utils import process_vision_info
# TODO: temporary solution, will write image utils for slime later
from qwen_vl_utils import process_vision_info as qwen_process_vision_info

if hasattr(processor.image_processor, "patch_size"):
image_patch_size = processor.image_processor.patch_size
else:
logger.info(f"Using default patch size: {DEFAULT_PATCH_SIZE}")
image_patch_size = DEFAULT_PATCH_SIZE
images, videos = process_vision_info(prompt, image_patch_size=image_patch_size)
images, videos = qwen_process_vision_info(prompt, image_patch_size=image_patch_size)
multimodal_inputs = {"images": images, "videos": videos}
return multimodal_inputs

Expand All @@ -50,4 +72,5 @@ def encode_image_for_rollout_engine(image) -> str:
if image.mode != "RGB":
image = image.convert("RGB")
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return f"data:image/png;base64,{image_base64}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we move the f"data:image/png;base64,{image_base64}" template into sglang_rollout.py? It seems like a template that is tightly connect to http payload.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about potential future modalities (audio, video, etc.) that may have different MIME types. Keeping the data formatting in each encode functions make sglang_rollout.py doesn't need to handle different MIME types for each modality.
(Although SGLang actually just matches data: and , without parsing the MIME type, but including it makes the format less confusing.)