diff --git a/examples/eval/README.md b/examples/eval/README.md index f33bf94aa..3c692fffe 100644 --- a/examples/eval/README.md +++ b/examples/eval/README.md @@ -43,25 +43,27 @@ docker run \ ``` ## 4) Inside the Skills container + +Set openai api key: +```bash +export OPENAI_API_KEY=none +``` + Clone repos and install the Skills package: ```bash -git clone -b slime_skills https://github.com/guapisolo/slime.git /opt/slime +git clone -b main https://github.com/THUDM/slime /opt/slime git clone -b slime https://github.com/guapisolo/Skills.git /opt/Skills cd /opt/Skills pip install -e . -``` -Download/prepare datasets: -```bash +# Download/prepare datasets: cd /opt/Skills/nemo_skills/dataset python3 aime25/prepare.py python3 hle/prepare.py python3 arena-hard/prepare.py -``` -Start the skills server: -```bash +# Start the skills server: cd /opt/slime python examples/eval/nemo_skills/skills_server.py \ --host 0.0.0.0 \ diff --git a/examples/eval/scripts/multi_tasks.yaml b/examples/eval/scripts/multi_tasks.yaml index be41e7fe3..154b84f8c 100644 --- a/examples/eval/scripts/multi_tasks.yaml +++ b/examples/eval/scripts/multi_tasks.yaml @@ -14,6 +14,16 @@ eval: path: /root/ifbench/IFBench_eval.jsonl rm_type: ifbench n_samples_per_eval_prompt: 1 + - name: tau2-airline + path: /root/tau2-bench/data/tau2/airline_test_tasks.jsonl + custom_generate_function_path: examples.tau2-bench.generate_with_tau2.generate + input_key: task_id + label_key: null + apply_chat_template: False + top_k: 1 + max_response_len: 1024 + max_context_len: 40000 + n_samples_per_eval_prompt: 1 delegate: # these tasks go through delegate eval function (examples.eval.eval_delegate_rollout.generate_rollout) - name: skills # this url should align with env docker network alias @@ -24,9 +34,9 @@ eval: datasets: - name: aime25 max_response_len: 8192 - n_samples_per_eval_prompt: 8 - - name: arena-hard n_samples_per_eval_prompt: 2 + - name: arena-hard + n_samples_per_eval_prompt: 1 + max_response_len: 24576 - name: hle - max_response_len: 32768 - + max_response_len: 24576 diff --git a/examples/eval/scripts/run-qwen3-4B.sh b/examples/eval/scripts/run-qwen3-4B.sh index 7cf8a7ffa..6dd177021 100644 --- a/examples/eval/scripts/run-qwen3-4B.sh +++ b/examples/eval/scripts/run-qwen3-4B.sh @@ -122,8 +122,8 @@ MISC_ARGS=( ) export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -export CUDA_VISIBLE_DEVICES=6,7 -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 2 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +# export CUDA_VISIBLE_DEVICES=4,5,6,7 +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 RUNTIME_ENV_JSON="{ \"env_vars\": { @@ -136,7 +136,7 @@ ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 train.py \ --actor-num-nodes 1 \ - --actor-num-gpus-per-node 2 \ + --actor-num-gpus-per-node 8 \ --colocate \ ${MODEL_ARGS[@]} \ ${CKPT_ARGS[@]} \ diff --git a/examples/tau2-bench/README.md b/examples/tau2-bench/README.md new file mode 100644 index 000000000..12bf634a0 --- /dev/null +++ b/examples/tau2-bench/README.md @@ -0,0 +1,72 @@ +# Tau2 bench with slime + +This example mirrors `examples/tau-bench`, but plugs the newer tau2 gym environment into slime rollouts. + +## Setup + +Use the `zhuzilin/slime:latest` image and initialize the environment required for Tau2-Bench: +```bash +cd /root/ +git clone https://github.com/slimerl/slime.git +cd slime +pip install -e . +# for tau2 bench +cd /root/ +git clone https://github.com/sierra-research/tau2-bench.git +cd tau2-bench +pip install -e . +``` + +Use the following script to generate mock data for slime training. + +```bash +cd /root/slime +python examples/tau2-bench/tau2_mock.py \ + --output-dir /root/tau2-bench/data/tau2 +``` +Initialize the Qwen2.5-3B-Instruct model needed for tool use: + +```bash +# hf checkpoint +huggingface-cli download Qwen/Qwen3-4B-Instruct-2507 --local-dir /root/Qwen3-4B-Instruct-2507 + +# mcore checkpoint +cd /root/slime +source scripts/models/qwen3-4B-Instruct-2507.sh +PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/Qwen3-4B-Instruct-2507 \ + --save /root/Qwen3-4B-Instruct-2507_torch_dist +``` + +## Running the Script + +The custom rollout entrypoint is `examples.tau2-bench.generate_with_tau2.generate`. A sample launcher is provided in `examples/tau2-bench/run_tau2_qwen3_4B.sh`; the important CLI flags are: + +```bash +--prompt-data /root/tau2-bench/data/tau2/airline_train_tasks.jsonl +--input-key task_id +--custom-generate-function-path examples.tau2-bench.generate_with_tau2.generate +``` + +You need to configure your litellm API in `generate_with_tau2.py` for user simulation: + +```python +TAU2_CONFIGS = { + "domain": "airline", # tau2 domain: airline | retail | telecom | mock + "task_split": "train", # task split within the domain + "max_steps": 100, # safety cap on interaction steps + "user_llm": "gpt-4.1-mini", # LiteLLM model name for user simulator + "solo_mode": False, # set True to disable user simulator +} +# Replace with your actual API key for user sim +GEMINI_API_KEY = "YOUR_GEMINI_KEY" +os.environ["GEMINI_API_KEY"] = GEMINI_API_KEY +``` + +And run: + +```bash +cd /root/slime +bash examples/tau2-bench/run_tau2_qwen3_4B.sh +``` diff --git a/examples/tau2-bench/configs/strip_think.yaml b/examples/tau2-bench/configs/strip_think.yaml new file mode 100644 index 000000000..3ce5fe6fd --- /dev/null +++ b/examples/tau2-bench/configs/strip_think.yaml @@ -0,0 +1 @@ +rollout_strip_think: true diff --git a/examples/tau2-bench/generate_with_tau2.py b/examples/tau2-bench/generate_with_tau2.py new file mode 100644 index 000000000..5c069a639 --- /dev/null +++ b/examples/tau2-bench/generate_with_tau2.py @@ -0,0 +1,59 @@ +""" +Tau2-Bench integration for slime Training. + +Configure the domain/task split below, point slime at this file via +--custom-generate-function-path examples.tau2-bench.generate_with_tau2.generate +""" + +import logging +import os +from typing import Any + +from slime.utils.types import Sample + +from .trainable_agent import Tau2TrainableAgent, res_to_sample + +logger = logging.getLogger(__name__) + +# Base configuration (edit here as needed). +TAU2_CONFIGS: dict[str, Any] = { + "domain": "airline", # tau2 domain: airline | retail | telecom | mock + "task_split": "train", # task split within the domain + "max_steps": 100, # safety cap on interaction steps + # Explicit gemini provider prefix to avoid Vertex ADC path. + # "user_llm": "gemini/gemini-2.5-flash-lite", + "user_llm": "gpt-4.1", + "user_llm_args": {}, # will inject api_key below + "solo_mode": False, # set True to disable user simulator +} + +# Replace with your actual API key for user simulator (LiteLLM) +API_KEY = "NONE" +if API_KEY == "NONE": + API_KEY = os.getenv("OPENAI_API_KEY") +# Also pass through args to force gemini path +TAU2_CONFIGS["user_llm_args"] = {"api_key": API_KEY} + + +async def generate(args: dict[str, Any], sample: Sample, sampling_params: dict) -> Sample: + assert not args.partial_rollout, "Partial rollout is not supported for tau2." + + agent = Tau2TrainableAgent( + args=args, + sampling_params=sampling_params, + domain=TAU2_CONFIGS["domain"], + task_split=TAU2_CONFIGS["task_split"], + max_steps=TAU2_CONFIGS["max_steps"], + user_llm=TAU2_CONFIGS["user_llm"], + user_llm_args=TAU2_CONFIGS.get("user_llm_args") or {}, + solo_mode=TAU2_CONFIGS["solo_mode"], + ) + + task_id, task_index = agent._resolve_task_id(sample.prompt) # noqa: SLF001 - simple helper + logger.info("Starting tau2 rollout for task_id=%s (index=%s)", task_id, task_index) + + interaction_result = await agent.run_episode(task_id) + result_sample = res_to_sample(interaction_result, task_index) + + logger.info("Finished tau2 rollout for task_id=%s", task_id) + return result_sample diff --git a/examples/tau2-bench/run_tau2_qwen3_4B.sh b/examples/tau2-bench/run_tau2_qwen3_4B.sh new file mode 100644 index 000000000..26157fd5b --- /dev/null +++ b/examples/tau2-bench/run_tau2_qwen3_4B.sh @@ -0,0 +1,152 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/../../scripts/models/qwen3-4B-Instruct-2507.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen3-4B-Instruct-2507/ + --ref-load /root/Qwen3-4B-Instruct-2507_torch_dist/ + --load /root/Qwen3-4B-Instruct-2507_slime/ + --save /root/Qwen3-4B-Instruct-2507_slime/ + --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data /root/tau2-bench/data/tau2/airline_train_tasks.jsonl + --input-key task_id + --rollout-shuffle + --num-rollout 500 + --rollout-batch-size 16 + --n-samples-per-prompt 4 + --rollout-max-response-len 1024 + --rollout-temperature 0.8 + --global-batch-size 64 + --balance-data +) + +EVAL_ARGS=( + --eval-interval 5 + --eval-prompt-data airline-test /root/tau2-bench/data/tau2/airline_test_tasks.jsonl + --n-samples-per-eval-prompt 1 + --eval-max-response-len 1024 + --eval-top-k 1 + --eval-input-key task_id +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project slime-tau2 + --wandb-group qwen3-4B + --wandb-key ${WANDB_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 + # If gemini API reports concurrency limit error, set this parameter to reduce the concurrency + # --sglang-server-concurrency 32 +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash +) + +CUSTOM_ARGS=( + --custom-generate-function-path examples.tau2-bench.generate_with_tau2.generate +) +# launch the master node of ray in container +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +# If you want more or less GPUs, change this parameter +NUM_GPUS=2 +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 --temp-dir /root/shared/ray_temp + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}:/root/tau2-bench\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node ${NUM_GPUS} \ + --rollout-num-gpus ${NUM_GPUS} \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${DISTRIBUTED_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${CUSTOM_ARGS[@]} diff --git a/examples/tau2-bench/sglang_tool_parser.py b/examples/tau2-bench/sglang_tool_parser.py new file mode 100644 index 000000000..ab2dc463a --- /dev/null +++ b/examples/tau2-bench/sglang_tool_parser.py @@ -0,0 +1,31 @@ +from typing import Any, Dict, List + +from sglang.srt.function_call.function_call_parser import FunctionCallParser +from sglang.srt.managers.io_struct import Function, Tool + + +def parse_tools(response: str, tools: List[Dict[str, Any]], parser: str = "qwen25"): + """ + This function mimics the function call parser API from + https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py#L952 + But running locally + """ + tools_list = [ + Tool( + function=Function( + name=tool["function"]["name"], + description=tool["function"]["description"], + parameters=tool["function"]["parameters"], + ), + type=tool["type"], + ) + for tool in tools + ] + parser = FunctionCallParser(tools=tools_list, tool_call_parser=parser) + + normal_text, calls = parser.parse_non_stream(response) + + return { + "normal_text": normal_text, + "calls": [call.model_dump() for call in calls], # Convert pydantic objects to dictionaries + } diff --git a/examples/tau2-bench/tau2_mock.py b/examples/tau2-bench/tau2_mock.py new file mode 100644 index 000000000..0293b758f --- /dev/null +++ b/examples/tau2-bench/tau2_mock.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +""" +Export tau2 tasks into JSONL files for slime training (tau1_mock style). + +Each file looks like domain_split_tasks.jsonl with lines: +{"index": 0, "task_id": "0", "task_set": "airline", "task_split": "train", "metadata": {...}} +""" + +import argparse +import json +from pathlib import Path +from typing import Iterable, Optional + + +def iter_tasks(task_set: str, task_split: Optional[str]): + from tau2.registry import registry + + task_loader = registry.get_tasks_loader(task_set) + try: + return task_loader(task_split) + except TypeError: + return task_loader() + + +def write_tasks(tasks: Iterable, task_set: str, task_split: Optional[str], output_dir: Path) -> None: + suffix = f"_{task_split}" if task_split else "" + output_path = output_dir / f"{task_set}{suffix}_tasks.jsonl" + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w") as f: + for idx, task in enumerate(tasks): + row = { + "index": idx, + "task_id": str(task.id), + "task_set": task_set, + "task_split": task_split, + "metadata": task.model_dump(), + } + f.write(json.dumps(row) + "\n") + print(f"Saved {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="tau2 mock data generator (tau1_mock style).") + parser.add_argument("--output-dir", required=True, type=Path, help="Directory to write JSONL files") + args = parser.parse_args() + + from tau2.registry import registry + + task_sets = registry.get_task_sets() + for task_set in task_sets: + split_loader = registry.get_task_splits_loader(task_set) + if split_loader is not None: + splits = list(split_loader().keys()) + else: + splits = [None] + for split in splits: + try: + tasks = iter_tasks(task_set, split) + except ValueError as e: + print(f"[skip] {task_set} split={split}: {e}") + continue + write_tasks(tasks, task_set, split, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/examples/tau2-bench/trainable_agent.py b/examples/tau2-bench/trainable_agent.py new file mode 100644 index 000000000..9727502f9 --- /dev/null +++ b/examples/tau2-bench/trainable_agent.py @@ -0,0 +1,422 @@ +import json +import logging +import re +from dataclasses import dataclass +from enum import Enum +from typing import Any + +from tau2.agent.llm_agent import AGENT_INSTRUCTION, SYSTEM_PROMPT +from tau2.data_model.message import AssistantMessage, Message, ToolCall, ToolMessage, UserMessage +from tau2.gym.gym_agent import AgentGymEnv +from tau2.registry import registry + +from slime.rollout.sglang_rollout import GenerateState +from slime.utils.http_utils import post +from slime.utils.types import Sample + +from .sglang_tool_parser import parse_tools + +logger = logging.getLogger(__name__) + + +class Status(Enum): + COMPLETED = "completed" + TRUNCATED = "truncated" + ABORTED = "aborted" + + +@dataclass +class InteractionResult: + prompt: str + reward: float + messages: list[dict[str, Any]] + info: dict[str, Any] + response: str = "" + loss_mask: list[int] | None = None + tokens: list[int] | None = None + status: Status = Status.COMPLETED + response_length: int = 0 + + +def _tool_to_openai_schema(tool: Any) -> dict[str, Any]: + """Convert tau2 Tool object to OpenAI schema expected by chat template.""" + return tool.openai_schema + + +def _tool_call_to_openai(call: ToolCall) -> dict[str, Any]: + """Convert tau2 ToolCall to OpenAI-compatible tool call payload.""" + return { + "id": call.id or call.name, + "type": "function", + "function": { + "name": call.name, + "arguments": json.dumps(call.arguments), + }, + } + + +def _tau_message_to_chat(msg: Message) -> dict[str, Any] | None: + """Convert tau2 message objects to the chat format expected by transformers templates.""" + if isinstance(msg, UserMessage): + if msg.tool_calls: + tool_calls = [_tool_call_to_openai(call) for call in msg.tool_calls] + return {"role": "user", "content": None, "tool_calls": tool_calls} + return {"role": "user", "content": msg.content} + if isinstance(msg, AssistantMessage): + if msg.tool_calls: + tool_calls = [_tool_call_to_openai(call) for call in msg.tool_calls] + return {"role": "assistant", "content": None, "tool_calls": tool_calls} + return {"role": "assistant", "content": msg.content} + if isinstance(msg, ToolMessage): + # tool_call_id keeps the chain aligned; name is optional for most templates. + return {"role": "tool", "content": msg.content or "", "tool_call_id": msg.id} + logger.debug("Skipping unsupported message type %s", type(msg)) + return None + + +def res_to_sample(res: InteractionResult, task_index: Any) -> Sample: + status_mapping = { + Status.COMPLETED: Sample.Status.COMPLETED, + Status.TRUNCATED: Sample.Status.TRUNCATED, + Status.ABORTED: Sample.Status.ABORTED, + } + sample = Sample( + index=task_index, + prompt=res.prompt, + tokens=res.tokens or [], + response=res.response, + reward=res.reward, + loss_mask=res.loss_mask, + status=status_mapping.get(res.status, Sample.Status.ABORTED), + metadata=res.info, + ) + sample.response_length = res.response_length + return sample + + +class Tau2TrainableAgent: + """ + Minimal wrapper that lets slime drive a tau2 AgentGymEnv using an sglang-served model. + """ + + def __init__( + self, + args, + sampling_params: dict[str, Any], + domain: str, + task_split: str, + max_steps: int = 100, + user_llm: str | None = None, + user_llm_args: dict[str, Any] | None = None, + solo_mode: bool = False, + all_messages_as_observation: bool = True, + ): + self.args = args + self.sampling_params = dict(sampling_params or {}) + self._max_response_len = self.sampling_params.get("max_new_tokens") or args.rollout_max_response_len + self._max_context_len = self.sampling_params.get("max_context_len") or args.rollout_max_context_len + self._strip_think = getattr(args, "rollout_strip_think", False) + self.domain = domain + self.task_split = task_split + self.max_steps = max_steps + self.user_llm = user_llm + self.user_llm_args = user_llm_args or {} + self.solo_mode = solo_mode + self.all_messages_as_observation = all_messages_as_observation + + self._task_splits = self._load_task_splits() + + def _load_task_splits(self) -> dict[str, list[str]] | None: + loader = registry.get_task_splits_loader(self.domain) + if loader is None: + return None + return loader() + + def _resolve_task_id(self, prompt_value: str) -> tuple[str, int]: + """ + Convert the incoming prompt payload into a concrete task id. + Accepts raw task ids or integer indices into the configured split. + """ + raw = str(prompt_value).strip() + try: + payload = json.loads(raw) + except Exception: + payload = raw + + if isinstance(payload, dict): + if "task_id" in payload: + return str(payload["task_id"]), int(payload.get("index", -1)) + if "index" in payload: + idx = int(payload["index"]) + return self._task_id_from_index(idx), idx + + try: + idx = int(payload) + return self._task_id_from_index(idx), idx + except Exception: + return str(payload), -1 + + def _task_id_from_index(self, idx: int) -> str: + if self._task_splits is None: + raise ValueError("Task splits not available; provide task_id directly or set up splits.") + if self.task_split not in self._task_splits: + raise ValueError(f"task_split={self.task_split} not found. Available: {list(self._task_splits)}") + split_ids = self._task_splits[self.task_split] + if idx < 0 or idx >= len(split_ids): + raise IndexError(f"Index {idx} out of range for split '{self.task_split}' with {len(split_ids)} tasks.") + return str(split_ids[idx]) + + def _build_system_message(self, policy: str) -> dict[str, str]: + system_prompt = SYSTEM_PROMPT.format(domain_policy=policy, agent_instruction=AGENT_INSTRUCTION) + return {"role": "system", "content": system_prompt} + + def _get_token_delta(self, tokenizer, messages: list[dict[str, Any]]) -> tuple[list[int], list[int]]: + """ + Compute token delta and loss mask for the newest message. + """ + curr = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False) + token_ids: list[int] = [] + loss_mask: list[int] = [] + + if messages[-1]["role"] == "assistant": + prev = tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True, tokenize=False) + new_tokens = tokenizer.encode(curr[len(prev) :], add_special_tokens=False) + token_ids += new_tokens + loss_mask += [1] * len(new_tokens) + else: + prev = tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=False, tokenize=False) + new_tokens = tokenizer.encode(curr[len(prev) :], add_special_tokens=False) + token_ids += new_tokens + loss_mask += [0] * len(new_tokens) + return token_ids, loss_mask + + async def _call_llm(self, url: str, payload: dict[str, Any]) -> dict[str, Any]: + return await post(url, payload) + + def _parse_response(self, response: str, tools_info: list[dict[str, Any]]) -> dict[str, Any]: + parsed = parse_tools(response, tools_info, parser="qwen25") + return parsed + + def _build_action_string(self, calls: list[dict[str, Any]], text_response: str) -> str: + if not calls: + return text_response + tool_call = calls[0] + try: + params = json.loads(tool_call["parameters"]) + except Exception: + params = {} + action = {"id": tool_call.get("id") or tool_call.get("name"), "name": tool_call["name"], "arguments": params} + return json.dumps(action) + + @staticmethod + def _strip_think_segments(text: str) -> str: + if "" not in text: + return text + return re.sub(r".*?", "", text, flags=re.DOTALL) + + def _append_new_messages( + self, + tokenizer, + chat_messages: list[dict[str, Any]], + env_messages: list[Message], + seen_count: int, + response_token_ids: list[int], + loss_masks: list[int], + ) -> int: + for msg in env_messages[seen_count:]: + chat_msg = _tau_message_to_chat(msg) + if chat_msg is None: + continue + chat_messages.append(chat_msg) + token_ids, loss_mask = self._get_token_delta(tokenizer, chat_messages) + response_token_ids.extend(token_ids) + loss_masks.extend(loss_mask) + return len(env_messages) + + def _build_final_result( + self, + res: InteractionResult, + total_reward: float, + info: dict[str, Any], + chat_messages: list[dict[str, Any]], + loss_masks: list[int], + prompt_token_ids: list[int], + response_token_ids: list[int], + ) -> InteractionResult: + res.reward = total_reward + res.info.update(info) + res.messages = chat_messages + res.loss_mask = loss_masks + res.tokens = prompt_token_ids + response_token_ids + res.response = "".join([m.get("content", "") or "" for m in chat_messages if m["role"] == "assistant"]) + res.response_length = len(loss_masks) + return res + + async def run_episode(self, task_id: str) -> InteractionResult: + state = GenerateState(self.args) + tokenizer = state.tokenizer + url = f"http://{self.args.sglang_router_ip}:{self.args.sglang_router_port}/generate" + + env = AgentGymEnv( + domain=self.domain, + task_id=task_id, + max_steps=self.max_steps, + solo_mode=self.solo_mode, + user_llm=self.user_llm, + user_llm_args=self.user_llm_args, + all_messages_as_observation=self.all_messages_as_observation, + ) + + _, env_info = env.reset() + tools_info = [_tool_to_openai_schema(t) for t in env_info["tools"]] + base_info: dict[str, Any] = { + "task_id": getattr(env_info.get("task"), "id", task_id), + "domain": self.domain, + "task_split": self.task_split, + "policy": env_info.get("policy"), + } + + chat_messages: list[dict[str, Any]] = [self._build_system_message(env_info["policy"])] + initial_env_messages = env._agent.observation if getattr(env, "_agent", None) else [] + seen_env_messages = len(initial_env_messages) + for m in initial_env_messages: + converted = _tau_message_to_chat(m) + if converted: + chat_messages.append(converted) + + prompt_text = tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True, tools=tools_info + ) + prompt_token_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"] + + loss_masks: list[int] = [] + response_token_ids: list[int] = [] + total_reward = 0.0 + + res = InteractionResult( + prompt=prompt_text, + reward=0.0, + messages=[], + info=base_info.copy(), + status=Status.COMPLETED, + ) + + terminated = False + for _ in range(self.max_steps): + text_input = tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True, tools=tools_info + ) + + used_tokens = len(prompt_token_ids) + len(response_token_ids) + if self._max_context_len is not None: + remaining_context = self._max_context_len - used_tokens + max_new_tokens = min(self._max_response_len, remaining_context) + else: + max_new_tokens = self._max_response_len + + if max_new_tokens <= 0: + res.status = Status.TRUNCATED + return self._build_final_result( + res, total_reward, res.info, chat_messages, loss_masks, prompt_token_ids, response_token_ids + ) + + self.sampling_params["max_new_tokens"] = max_new_tokens + + payload = {"text": text_input, "sampling_params": self.sampling_params} + output = await self._call_llm(url, payload) + + if output["meta_info"]["finish_reason"]["type"] == "abort": + res.status = Status.ABORTED + return self._build_final_result( + res, total_reward, res.info, chat_messages, loss_masks, prompt_token_ids, response_token_ids + ) + + response = output["text"] + if self._strip_think: + response = self._strip_think_segments(response) + if response.endswith("<|im_end|>"): + response = response[:-10] + + try: + parsed = self._parse_response(response, tools_info) + calls = parsed["calls"] + normal_text = parsed["normal_text"].strip() + except Exception as e: + logger.warning("Failed to parse response: %s", e) + res.status = Status.ABORTED + return self._build_final_result( + res, total_reward, res.info, chat_messages, loss_masks, prompt_token_ids, response_token_ids + ) + + if not calls and not (normal_text or response): + logger.warning("Empty model response; aborting rollout.") + res.status = Status.ABORTED + return self._build_final_result( + res, total_reward, res.info, chat_messages, loss_masks, prompt_token_ids, response_token_ids + ) + + if calls: + # Enforce protocol: tool call message should not contain user-facing text. + assistant_message = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": call.get("id") or f"call_{idx}", + "type": "function", + "function": { + "name": call["name"], + "arguments": call.get("parameters", "{}"), + }, + } + for idx, call in enumerate(calls) + ], + } + else: + assistant_message = {"role": "assistant", "content": normal_text or response} + + chat_messages.append(assistant_message) + token_ids, loss_mask = self._get_token_delta(tokenizer, chat_messages) + response_token_ids.extend(token_ids) + loss_masks.extend(loss_mask) + + action_string = self._build_action_string(calls, normal_text or response) + try: + _, reward, terminated, _, step_info = env.step(action_string) + except Exception as e: + logger.warning("Environment step failed: %s", e) + res.status = Status.ABORTED + return self._build_final_result( + res, total_reward, res.info, chat_messages, loss_masks, prompt_token_ids, response_token_ids + ) + + total_reward = reward + # Update env/task metadata; keep it JSON-serializable where possible. + reward_info = step_info.get("reward_info") + if isinstance(reward_info, str): + try: + reward_info = json.loads(reward_info) + except Exception: + pass + res.info.update({"reward_info": reward_info}) + + if getattr(env, "_agent", None): + seen_env_messages = self._append_new_messages( + tokenizer, + chat_messages, + env._agent.observation, + seen_env_messages, + response_token_ids, + loss_masks, + ) + + if terminated: + res.status = Status.COMPLETED + break + + if not terminated: + res.status = Status.TRUNCATED + + return self._build_final_result( + res, total_reward, res.info, chat_messages, loss_masks, prompt_token_ids, response_token_ids + ) diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index 6f6530d57..5934ab4ee 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -59,6 +59,7 @@ def __init__(self, args: Namespace) -> None: top_p=args.rollout_top_p, top_k=args.rollout_top_k, max_new_tokens=args.rollout_max_response_len, + max_context_len=args.rollout_max_context_len, stop=args.rollout_stop, stop_token_ids=args.rollout_stop_token_ids, skip_special_tokens=args.rollout_skip_special_tokens, @@ -234,8 +235,12 @@ async def generate_and_rm( sample.status = Sample.Status.ABORTED return sample - if args.custom_generate_function_path is not None: - custom_generate_func = load_function(args.custom_generate_function_path) + # sample param level > args level + override_generate_path = ( + sampling_params.get("custom_generate_function_path", None) or args.custom_generate_function_path + ) + if override_generate_path is not None: + custom_generate_func = load_function(override_generate_path) sample = await custom_generate_func(args, sample, sampling_params) else: sample = await generate(args, sample, sampling_params) @@ -477,7 +482,7 @@ async def eval_rollout_single_dataset( global EVAL_PROMPT_DATASET - cache_key = dataset_cfg.cache_key + (args.hf_checkpoint, args.apply_chat_template) + cache_key = dataset_cfg.cache_key + (args.hf_checkpoint, dataset_cfg.apply_chat_template) if cache_key not in EVAL_PROMPT_DATASET: tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True) EVAL_PROMPT_DATASET[cache_key] = Dataset( @@ -489,7 +494,7 @@ async def eval_rollout_single_dataset( multimodal_keys=args.multimodal_keys, metadata_key=dataset_cfg.metadata_key, tool_key=dataset_cfg.tool_key, - apply_chat_template=args.apply_chat_template, + apply_chat_template=dataset_cfg.apply_chat_template, apply_chat_template_kwargs=args.apply_chat_template_kwargs, ) dataset = EVAL_PROMPT_DATASET[cache_key] @@ -499,11 +504,13 @@ async def eval_rollout_single_dataset( top_p=dataset_cfg.top_p, top_k=dataset_cfg.top_k, max_new_tokens=dataset_cfg.max_response_len, + max_context_len=dataset_cfg.max_context_len, stop=args.rollout_stop, stop_token_ids=args.rollout_stop_token_ids, skip_special_tokens=args.rollout_skip_special_tokens, no_stop_trim=True, spaces_between_special_tokens=False, + custom_generate_function_path=dataset_cfg.custom_generate_function_path, ) tasks = [] diff --git a/slime/utils/eval_config.py b/slime/utils/eval_config.py index 4a7c1e912..836740c09 100644 --- a/slime/utils/eval_config.py +++ b/slime/utils/eval_config.py @@ -28,6 +28,11 @@ "default_keys": ("top_k",), "arg_attrs": ("eval_top_k", "rollout_top_k"), }, + "max_context_len": { + "dataset_keys": ("max_context_len",), + "default_keys": ("max_context_len",), + "arg_attrs": ("eval_max_context_len", "rollout_max_context_len"), + }, "max_response_len": { "dataset_keys": ("max_response_len",), "default_keys": ("max_response_len",), @@ -36,6 +41,11 @@ } DATASET_SAMPLE_SPECS: dict[str, dict[str, tuple[str, ...]]] = { + "apply_chat_template": { + "dataset_keys": ("apply_chat_template",), + "default_keys": ("apply_chat_template",), + "arg_attrs": ("apply_chat_template",), + }, "input_key": { "dataset_keys": ("input_key",), "default_keys": ("input_key",), @@ -56,6 +66,11 @@ "default_keys": ("metadata_key",), "arg_attrs": ("metadata_key",), }, + "custom_generate_function_path": { + "dataset_keys": ("custom_generate_function_path",), + "default_keys": ("custom_generate_function_path",), + "arg_attrs": ("custom_generate_function_path",), + }, } @@ -98,8 +113,10 @@ class EvalDatasetConfig: name: str path: str rm_type: str | None = None + custom_generate_function_path: str | None = None # Dataset-specific overrides + apply_chat_template: bool | None = None input_key: str | None = None label_key: str | None = None tool_key: str | None = None @@ -110,6 +127,7 @@ class EvalDatasetConfig: temperature: float | None = None top_p: float | None = None top_k: int | None = None + max_context_len: int | None = None max_response_len: int | None = None metadata_overrides: dict[str, Any] = field(default_factory=dict) @@ -123,6 +141,7 @@ def cache_key(self) -> tuple[Any, ...]: return ( self.name, self.path, + self.apply_chat_template, self.input_key, self.label_key, self.tool_key,