diff --git a/README.md b/README.md index d02ca1976..980b23efe 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,8 @@ For a comprehensive quick start guide covering environment setup, data preparati We also provide examples for some use cases not covered in the quick start guide; please check [examples](examples/). +Python 3.10+ is required (uses `zip(..., strict=True)` in core utilities). + ## Projects Built upon slime slime has powered several novel research projects and production systems. Here are some notable examples: diff --git a/examples/tau-bench/.gitignore b/examples/tau-bench/.gitignore new file mode 100644 index 000000000..c6bb7d0e8 --- /dev/null +++ b/examples/tau-bench/.gitignore @@ -0,0 +1,13 @@ +/outputs/ + +# Local secrets (template is tracked) +/tau2/.env + +# Python caches +**/__pycache__/ +*.pyc + +# Logs / experiment trackers +wandb/ +weave/ +*.log diff --git a/examples/tau-bench/README.md b/examples/tau-bench/README.md index 8f846eb18..eec1904b7 100644 --- a/examples/tau-bench/README.md +++ b/examples/tau-bench/README.md @@ -1,68 +1,19 @@ -# Tau bench -This example shows slime training in an agentic multi-turn tool use environment. +# Tau-Bench: Multi-Turn Tool-Use Training +This folder provides two benchmark entrypoints with parallel conventions. The canonical documentation lives in `examples/tau-bench/training_cookbook.md`; other docs link into it without duplication. -## Environment Setup -Use the `zhuzilin/slime:latest` image and initialize the environment required for Search-R1: +| Benchmark | Repo | Domains | Dual-control | Primary metric | Folder | +|----------|------|---------|--------------|----------------|--------| +| Tau1 | https://github.com/sierra-research/tau-bench | airline, retail | no | pass@1 | `examples/tau-bench/tau1/` | +| Tau2 | https://github.com/sierra-research/tau2-bench | airline, retail, telecom | yes (telecom user-only tools) | pass@4 + pass@1 | `examples/tau-bench/tau2/` | -```bash -cd /root/ -git clone https://github.com/THUDM/slime.git -cd slime -pip install -e . -# for tau bench -cd /root/ -git clone https://github.com/JD-ETH/tau-bench.git -cd tau-bench -git checkout feature/litellm-retry -pip install -e . -``` +### Quick Links +- Training cookbook: `examples/tau-bench/training_cookbook.md`. +- Tau1 README: `examples/tau-bench/tau1/README.md`. +- Tau2 implementation: `examples/tau-bench/tau2/README.md` -Use the following script to generate mock data for slime training. +Note: Tau1 includes a small offline stub for debug/CI without external API keys. -```bash -cd /root/slime/examples/tau-bench -python tau1_mock.py --local_dir /root/tau-bench/ -``` +### Outputs -Initialize the Qwen2.5-3B-Instruct model needed for tool use: - -```bash -# hf checkpoint -huggingface-cli download Qwen/Qwen3-4B-Instruct-2507 --local-dir /root/Qwen3-4B-Instruct-2507 - -# mcore checkpoint -cd /root/slime -source scripts/models/qwen3-4B-Instruct-2507.sh -PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ - ${MODEL_ARGS[@]} \ - --hf-checkpoint /root/Qwen3-4B-Instruct-2507 \ - --save /root/Qwen3-4B-Instruct-2507_torch_dist -``` - -## Running the Script - -You need to configure your litellm API in `generate_with_tau.py` for user simulation: - -```python -TAU_CONFIGS = { - "env": "retail", # Select between ["retail", "airline"] - "agent": "tool-calling", # Select between ["tool-calling", "act", "react", "few-shot"], only tool-calling implemented for now - "user_model": "gemini-2.0-flash-lite", # Cheap Model for user simulator - "user_model_provider": "gemini", - "task_split": "train", # Select between ["train", "test", "dev"] for retail, ["test"] for airline - "user_strategy": "llm", # Select between ["llm", "react", "verify", "reflection"] - "model_provider": "auto_router", # Unused, required - "model": "qwen3-4b", # Unused, reqired -} -# Replace with your actual API key for user sim -GEMINI_API_KEY = "YOUR KEY" -``` - -And run: - - -```bash -cd /root/slime -bash examples/tau-bench/run_qwen3_4B.sh -``` \ No newline at end of file +All generated artifacts are written under `TAU_BENCH_OUT_DIR` (default: `examples/tau-bench/outputs`) and are gitignored. The cookbook assumes the `slimerl/slime:latest` container baseline. diff --git a/examples/tau-bench/public/performance-chart.jpeg b/examples/tau-bench/public/performance-chart.jpeg new file mode 100644 index 000000000..e330a81b4 Binary files /dev/null and b/examples/tau-bench/public/performance-chart.jpeg differ diff --git a/examples/tau-bench/public/slime-pipeline-tau2.jpeg b/examples/tau-bench/public/slime-pipeline-tau2.jpeg new file mode 100644 index 000000000..2451dbf8d Binary files /dev/null and b/examples/tau-bench/public/slime-pipeline-tau2.jpeg differ diff --git a/examples/tau-bench/sglang_tool_parser.py b/examples/tau-bench/sglang_tool_parser.py deleted file mode 100644 index ea4f380ba..000000000 --- a/examples/tau-bench/sglang_tool_parser.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Any - -from sglang.srt.function_call.function_call_parser import FunctionCallParser -from sglang.srt.managers.io_struct import Function, Tool - - -def parse_tools(response: str, tools: list[dict[str, Any]], parser: str = "qwen25"): - """ - This function mimics the function call parser API from - https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py#L952 - But running locally - """ - tools_list = [ - Tool( - function=Function( - name=tool["function"]["name"], - description=tool["function"]["description"], - parameters=tool["function"]["parameters"], - ), - type=tool["type"], - ) - for tool in tools - ] - parser = FunctionCallParser(tools=tools_list, tool_call_parser=parser) - - normal_text, calls = parser.parse_non_stream(response) - - return { - "normal_text": normal_text, - "calls": [call.model_dump() for call in calls], # Convert pydantic objects to dictionaries - } diff --git a/examples/tau-bench/tau1/README.md b/examples/tau-bench/tau1/README.md new file mode 100644 index 000000000..9bdbddeba --- /dev/null +++ b/examples/tau-bench/tau1/README.md @@ -0,0 +1,74 @@ +# Tau1 Bench (tau-bench) + +This example shows slime training in an agentic multi-turn tool use environment. + +## Environment Setup +Use the `slimerl/slime:latest` image and initialize the environment required for Tau1: + +```bash +cd /root/ +git clone https://github.com/THUDM/slime.git +cd slime +pip install -e . +# for tau bench +cd /root/ +git clone https://github.com/JD-ETH/tau-bench.git +cd tau-bench +git checkout feature/litellm-retry +pip install -e . +``` + +Use the following script to generate mock data for slime training. + +```bash +cd /root/slime/examples/tau-bench/tau1 +python tau1_mock.py --local_dir /root/tau-bench/ +``` + +Initialize the Qwen3-4B-Instruct model needed for tool use: + +```bash +# hf checkpoint +huggingface-cli download Qwen/Qwen3-4B-Instruct-2507 --local-dir /root/Qwen3-4B-Instruct-2507 + +# mcore checkpoint +cd /root/slime +source scripts/models/qwen3-4B-Instruct-2507.sh +PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/Qwen3-4B-Instruct-2507 \ + --save /root/Qwen3-4B-Instruct-2507_torch_dist +``` + +## Running the Script + +You need to configure your litellm API in `generate_with_tau.py` for user simulation: + +```python +TAU_CONFIGS = { + "env": "retail", # Select between ["retail", "airline"] + "agent": "tool-calling", # Select between ["tool-calling", "act", "react", "few-shot"], only tool-calling implemented for now + "user_model": "gemini-2.0-flash-lite", # Cheap Model for user simulator + "user_model_provider": "gemini", + "task_split": "train", # Select between ["train", "test", "dev"] for retail, ["test"] for airline + "user_strategy": "llm", # Select between ["llm", "react", "verify", "reflection"] + "model_provider": "auto_router", # Unused, required + "model": "qwen3-4b", # Unused, required +} +# Replace with your actual API key for user sim +GEMINI_API_KEY = "YOUR KEY" +``` + +And run: + +```bash +cd /root/slime +bash examples/tau-bench/tau1/run_qwen3_4B.sh +``` + +## Known gotchas +- If you use an OpenAI-compatible server (e.g., sglang), set `OPENAI_API_BASE` and run tau-bench with `--model-provider openai` (not `openai_like`). +- For tau-bench CLI runs, use a slashless `--model` name (e.g., `Qwen3-4B-Instruct-2507`) to avoid log path errors. + +## Debugging and CI notes +- For offline or CPU-only debugging, you can set `user_model_provider="stub"` in `generate_with_tau.py` to bypass external API calls while preserving episode logging. diff --git a/examples/tau-bench/tau1/episode_logger.py b/examples/tau-bench/tau1/episode_logger.py new file mode 100644 index 000000000..e1f0062fc --- /dev/null +++ b/examples/tau-bench/tau1/episode_logger.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import hashlib +import json +import os +import time +from dataclasses import dataclass +from typing import Any + + +def _truncate(value: str | None, max_chars: int = 8000) -> str | None: + if value is None: + return None + if len(value) <= max_chars: + return value + return value[:max_chars] + f"\n...[truncated {len(value) - max_chars} chars]" + + +def _sha256_text(value: str) -> str: + return hashlib.sha256(value.encode("utf-8", errors="ignore")).hexdigest() + + +@dataclass +class EpisodeLogger: + log_dir: str + run_meta: dict[str, Any] + + def __post_init__(self) -> None: + os.makedirs(self.log_dir, exist_ok=True) + self._jsonl_path = os.path.join(self.log_dir, "episode.jsonl") + self._summary_path = os.path.join(self.log_dir, "summary.json") + self._run_meta_path = os.path.join(self.log_dir, "run_meta.json") + + with open(self._run_meta_path, "w", encoding="utf-8") as f: + json.dump(self.run_meta, f, ensure_ascii=False, indent=2) + + def log_step(self, record: dict[str, Any]) -> None: + # Control field size: truncate long strings and append hashes. + for key in [ + "assistant_raw", + "assistant", + "user_text", + "observation", + "tool_result", + "env_state", + "normal_text", + "tool_parse_error", + "error", + ]: + if key in record and isinstance(record[key], str): + record[f"{key}_hash"] = _sha256_text(record[key]) + record[key] = _truncate(record[key]) + + record["ts"] = time.time() + with open(self._jsonl_path, "a", encoding="utf-8") as f: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + def finalize(self, summary: dict[str, Any]) -> None: + with open(self._summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, ensure_ascii=False, indent=2) diff --git a/examples/tau-bench/generate_with_tau.py b/examples/tau-bench/tau1/generate_with_tau.py similarity index 59% rename from examples/tau-bench/generate_with_tau.py rename to examples/tau-bench/tau1/generate_with_tau.py index ce80c8d99..24f20e214 100644 --- a/examples/tau-bench/generate_with_tau.py +++ b/examples/tau-bench/tau1/generate_with_tau.py @@ -1,25 +1,20 @@ -""" -Tau-Bench Integration for slime Training - -This module provides the main interface for training agents in tau-bench environments -using the slime framework. It handles agent-environment interactions and converts -results to the format expected by slime's training pipeline. -""" +"""Tau-bench integration for slime training.""" import logging import os +import time +import uuid from typing import Any +from episode_logger import EpisodeLogger from tau_bench.envs import get_env from tau_bench.types import RunConfig from trainable_agents import InteractionResult, Status, agent_factory from slime.utils.types import Sample -# Set up logger for this module logger = logging.getLogger(__name__) -# Tau-bench configuration TAU_CONFIGS = { "env": "retail", # Select between ["retail", "airline"] "agent": "tool-calling", # Select between ["tool-calling", "act", "react", "few-shot"] @@ -29,29 +24,15 @@ "model_provider": "auto_router", # Unused, required "model": "qwen3-4b", # Unused, required "user_model_provider": "gemini", + # "user_model_provider": "stub", } -# Replace with your actual API key for user sim GEMINI_API_KEY = "NONE" os.environ["GEMINI_API_KEY"] = GEMINI_API_KEY tau_config = RunConfig(**TAU_CONFIGS) def res_to_sample(res: InteractionResult, task_index: int) -> Sample: - """ - Convert InteractionResult to Sample format for slime training. - - This function transforms the tau-bench interaction result into the format - expected by slime's training pipeline, handling status mapping and response - length calculation. - - Args: - res: InteractionResult from tau-bench agent - task_index: Index of the task being processed - - Returns: - Sample object for slime training - """ - # Map tau-bench status to slime status + """Convert InteractionResult to a slime Sample.""" status_mapping = { Status.COMPLETED: "completed", Status.TRUNCATED: "truncated", @@ -59,7 +40,6 @@ def res_to_sample(res: InteractionResult, task_index: int) -> Sample: } status = status_mapping.get(res.status) - # Debug logging for response tracking logger.debug( f"res_to_sample: response_length=" f"{res.response_length if hasattr(res, 'response_length') else 'None'}, " @@ -67,7 +47,6 @@ def res_to_sample(res: InteractionResult, task_index: int) -> Sample: f"tokens_len={len(res.tokens) if res.tokens else 'None'}" ) - # Create sample with basic information sample = Sample( index=task_index, prompt=res.prompt, @@ -79,16 +58,12 @@ def res_to_sample(res: InteractionResult, task_index: int) -> Sample: metadata=res.info, ) - # Ensure response_length is set correctly if hasattr(res, "response_length"): sample.response_length = res.response_length else: - # Fallback: calculate from loss_mask if available if res.loss_mask: - # loss_mask only contains response part, so length equals response_length sample.response_length = len(res.loss_mask) elif res.tokens: - # If no loss_mask available, use total tokens as fallback sample.response_length = len(res.tokens) else: sample.response_length = 0 @@ -97,33 +72,35 @@ def res_to_sample(res: InteractionResult, task_index: int) -> Sample: return sample -async def generate(args: dict[str, Any], sample: Sample, sampling_params: dict) -> Sample: - """ - Generate a complete agent-environment interaction trajectory for tau-bench. - - This is the main entry point for slime training. It creates a tau-bench - environment, initializes a trainable agent, and executes a full interaction - trajectory. The result is converted to slime's Sample format for training. - - Args: - args: Rollout arguments from slime training pipeline - sample: Sample containing task index in prompt field - sampling_params: LLM sampling parameters +def _default_run_root() -> str: + return os.environ.get("TAU_RUN_DIR", os.path.join(os.getcwd(), "runs", "tau1")) - Returns: - Sample object containing the complete interaction trajectory - Raises: - AssertionError: If partial rollout is requested (not supported) - """ - # Validate arguments +async def generate(args: dict[str, Any], sample: Sample, sampling_params: dict) -> Sample: + """Run a single tau-bench interaction trajectory.""" assert not args.partial_rollout, "Partial rollout is not supported for tau-bench interactions." - # Extract task index from sample prompt task_index = int(sample.prompt) + run_root = _default_run_root() + run_id = time.strftime("%Y%m%d_%H%M%S") + f"_{os.getpid()}_{uuid.uuid4().hex[:8]}" + episode_dir = os.path.join(run_root, run_id, f"task_{task_index:06d}") + os.makedirs(episode_dir, exist_ok=True) + + run_meta = { + "run_id": run_id, + "task_index": task_index, + "tau_config": TAU_CONFIGS, + "sampling_params": { + k: sampling_params.get(k) + for k in ["temperature", "top_p", "top_k", "max_new_tokens"] + if k in sampling_params + }, + "pid": os.getpid(), + } + ep_logger = EpisodeLogger(log_dir=episode_dir, run_meta=run_meta) + logger.info(f"Starting agent-environment interaction for task {task_index}") - # Initialize tau-bench environment env = get_env( env_name=tau_config.env, user_strategy=tau_config.user_strategy, @@ -133,21 +110,26 @@ async def generate(args: dict[str, Any], sample: Sample, sampling_params: dict) task_index=task_index, ) - # Create trainable agent agent = agent_factory( tools_info=env.tools_info, wiki=env.wiki, config=tau_config, rollout_args=args, sampling_params=sampling_params, + episode_logger=ep_logger, ) - # Execute agent-environment interaction - # Note: The sample.prompt field contains the task index for repeatability interaction_result = await agent.asolve(env, agent.rollout_args, agent.sampling_params, task_index) - # Convert to slime Sample format result_sample = res_to_sample(interaction_result, task_index) + ep_logger.finalize( + { + "status": str(interaction_result.status), + "reward": interaction_result.reward, + "response_length": getattr(interaction_result, "response_length", None), + } + ) + logger.info(f"Finished agent-environment interaction for task {task_index}") return result_sample diff --git a/examples/tau-bench/openai_tool_adapter.py b/examples/tau-bench/tau1/openai_tool_adapter.py similarity index 82% rename from examples/tau-bench/openai_tool_adapter.py rename to examples/tau-bench/tau1/openai_tool_adapter.py index b580f28b2..b7192c339 100644 --- a/examples/tau-bench/openai_tool_adapter.py +++ b/examples/tau-bench/tau1/openai_tool_adapter.py @@ -11,6 +11,14 @@ logger = logging.getLogger(__name__) +def _parse_tools_compat(parse_tools_fn, response: str, tools_info, parser_type: str | None): + """Compatibility wrapper for parse_tools() across versions.""" + try: + return parse_tools_fn(response, tools_info, parser_type) + except TypeError: + return parse_tools_fn(response, tools_info) + + @dataclass class OpenAIToolCall: """OpenAI format tool call structure""" @@ -62,20 +70,30 @@ def parse_response_to_openai_format(self, response: str) -> dict[str, Any]: Exception: Thrown when parsing fails """ try: - # Use existing parser to parse tool calls - parsed = parse_tools(response, self.tools_info, self.parser_type) + parsed = _parse_tools_compat(parse_tools, response, self.tools_info, self.parser_type) - # Extract parsing results - normal_text = parsed["normal_text"] - calls = parsed["calls"] + if isinstance(parsed, dict): + normal_text = parsed.get("normal_text") + calls = parsed.get("calls") + elif isinstance(parsed, (list, tuple)) and len(parsed) >= 2: + normal_text, calls = parsed[0], parsed[1] + else: + raise TypeError(f"Unexpected parse_tools result: {type(parsed)}") + + if not isinstance(normal_text, str): + normal_text = response + if not isinstance(calls, list): + calls = [] # Convert to OpenAI format openai_message = self._convert_to_openai_message(normal_text, calls) - return {"openai_message": openai_message, "parsed_result": parsed, "success": True} + parsed_result = {"normal_text": normal_text, "calls": calls} + + return {"openai_message": openai_message, "parsed_result": parsed_result, "success": True} except Exception as e: - logger.warning(f"Parsing failed with error: {str(e)}") + logger.warning(f"Parsing failed with error: {type(e).__name__}: {str(e)}") return {"openai_message": None, "parsed_result": None, "success": False, "error": str(e)} def _convert_to_openai_message(self, normal_text: str, calls: list[dict[str, Any]]) -> OpenAIAssistantMessage: diff --git a/examples/tau-bench/run_qwen3_4B.sh b/examples/tau-bench/tau1/run_qwen3_4B.sh similarity index 93% rename from examples/tau-bench/run_qwen3_4B.sh rename to examples/tau-bench/tau1/run_qwen3_4B.sh index e981ab3ac..e4036374b 100644 --- a/examples/tau-bench/run_qwen3_4B.sh +++ b/examples/tau-bench/tau1/run_qwen3_4B.sh @@ -24,7 +24,7 @@ fi echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/../../scripts/models/qwen3-4B-Instruct-2507.sh" +source "${SCRIPT_DIR}/../../../scripts/models/qwen3-4B-Instruct-2507.sh" CKPT_ARGS=( --hf-checkpoint /root/Qwen3-4B-Instruct-2507/ @@ -119,10 +119,11 @@ CUSTOM_ARGS=( ) # launch the master node of ray in container export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +RAY_DASHBOARD_HOST="${RAY_DASHBOARD_HOST:-127.0.0.1}" # If you want more or less GPUs, change this parameter NUM_GPUS=2 -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 --temp-dir /root/shared/ray_temp +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host="${RAY_DASHBOARD_HOST}" --dashboard-port=8265 --temp-dir /root/shared/ray_temp RUNTIME_ENV_JSON="{ \"env_vars\": { diff --git a/examples/tau-bench/tau1/sglang_tool_parser.py b/examples/tau-bench/tau1/sglang_tool_parser.py new file mode 100644 index 000000000..00db3845e --- /dev/null +++ b/examples/tau-bench/tau1/sglang_tool_parser.py @@ -0,0 +1,45 @@ +import logging +from typing import Any + +logger = logging.getLogger(__name__) + +_SGLANG_AVAILABLE = False +_FunctionCallParser = None +_Function = None +_Tool = None + +try: + from sglang.srt.function_call.function_call_parser import FunctionCallParser as _FunctionCallParser + from sglang.srt.managers.io_struct import Function as _Function + from sglang.srt.managers.io_struct import Tool as _Tool + + _SGLANG_AVAILABLE = True +except Exception as exc: + logger.warning(f"sglang tool parser unavailable (optional). Falling back to no-tool parsing. Error: {exc}") + + +def parse_tools(response: str, tools: list[dict[str, Any]], parser: str = "qwen25") -> dict[str, Any]: + if not _SGLANG_AVAILABLE: + return {"normal_text": response, "calls": []} + + tools_list = [ + _Tool( + function=_Function( + name=tool["function"]["name"], + description=tool["function"]["description"], + parameters=tool["function"]["parameters"], + ), + type=tool.get("type", "function"), + ) + for tool in tools + ] + + parser_obj = _FunctionCallParser(tools=tools_list, tool_call_parser=parser) + try: + normal_text, calls = parser_obj.parse_non_stream(response) + except Exception as exc: + logger.warning(f"sglang tool parser failed, falling back to no-tool parsing. Error: {exc}") + return {"normal_text": response, "calls": []} + + calls = [call.model_dump() if hasattr(call, "model_dump") else call for call in calls] + return {"normal_text": normal_text, "calls": calls} diff --git a/examples/tau-bench/tau1_mock.py b/examples/tau-bench/tau1/tau1_mock.py similarity index 100% rename from examples/tau-bench/tau1_mock.py rename to examples/tau-bench/tau1/tau1_mock.py diff --git a/examples/tau-bench/trainable_agents.py b/examples/tau-bench/tau1/trainable_agents.py similarity index 72% rename from examples/tau-bench/trainable_agents.py rename to examples/tau-bench/tau1/trainable_agents.py index 5cfc3f43d..a28fb36a5 100644 --- a/examples/tau-bench/trainable_agents.py +++ b/examples/tau-bench/tau1/trainable_agents.py @@ -42,16 +42,40 @@ def call_to_action_sglang(calls: list[Any], text_response: str) -> Action: """ # Default action if no action was found. action = Action(name=RESPOND_ACTION_NAME, kwargs={"content": text_response}) - if calls: - if len(calls) > 1: - logger.debug("Multiple tool calls identified, only taking first.") - tool_call = calls[0] - params = json.loads(tool_call["parameters"]) - if not isinstance(params, dict): - logger.warning(f"{params} does not follow dict structure for action") + if not calls: + return action + if len(calls) > 1: + logger.debug("Multiple tool calls identified, only taking first.") + + tool_call = calls[0] + tool_name = tool_call.get("name") + + raw_params = tool_call.get("parameters") + if raw_params is None: + raw_params = tool_call.get("arguments") + + params: Any = {} + try: + if isinstance(raw_params, dict): + params = raw_params + elif isinstance(raw_params, str): + raw_params_str = raw_params.strip() + params = json.loads(raw_params_str) if raw_params_str else {} else: - action = Action(name=tool_call["name"], kwargs=params) - return action + params = {} + except Exception as exc: + logger.warning(f"Failed to parse tool params: {exc}; raw_params={raw_params!r}") + return action + + if not isinstance(params, dict): + logger.warning(f"Tool params is not a dict: {params!r}") + return action + + if not tool_name: + logger.warning(f"Tool call missing name: {tool_call!r}") + return action + + return Action(name=tool_name, kwargs=params) TOOL_INSTRUCTION = ( @@ -201,6 +225,7 @@ async def asolve( # Initialize environment and state state = GenerateState(rollout_args) url = f"http://{rollout_args.sglang_router_ip}:" f"{rollout_args.sglang_router_port}/generate" + step_id = 0 # Get initial environment state obs, info = self._initialize_environment(env, task_index) @@ -210,12 +235,13 @@ async def asolve( prompt_text, prompt_token_ids = self._prepare_prompt_tokens(state, messages) # Initialize tracking variables - loss_masks = [] - response_token_ids = [] + loss_masks: list[int] = [] + response_token_ids: list[int] = [] total_reward = 0.0 # Initialize result res = InteractionResult(prompt=prompt_text, reward=0, messages=[], info={}) + env_response = None # Multi-turn interaction loop for _ in range(max_num_steps): @@ -229,44 +255,97 @@ async def asolve( # Send request to sglang server output = await self._call_llm(url, payload) + finish_reason = output.get("meta_info", {}).get("finish_reason", {}) + finish_type = finish_reason.get("type") # Check for abort - if output["meta_info"]["finish_reason"]["type"] == "abort": + if finish_type == "abort": + if getattr(self, "episode_logger", None): + self.episode_logger.log_step( + { + "step_id": step_id, + "phase": "llm_output", + "finish_type": finish_type, + "assistant_raw": output.get("text", ""), + "messages_len": len(messages), + } + ) res.status = Status.ABORTED return self._build_final_result( res, total_reward, info, messages, loss_masks, prompt_token_ids, response_token_ids ) - response = output["text"] + raw_response = output.get("text", "") + response = raw_response # Remove end of conversation token if present if response.endswith("<|im_end|>"): response = response[:-10] + if getattr(self, "episode_logger", None): + self.episode_logger.log_step( + { + "step_id": step_id, + "phase": "llm_output", + "finish_type": finish_type, + "assistant_raw": raw_response, + "assistant": response, + "text_input_chars": len(text_input) if isinstance(text_input, str) else None, + "messages_len": len(messages), + } + ) + # Parse tool calls using OpenAI adapter - logger.debug(f"Using OpenAI adapter to parse response: {response[:100]}...") try: openai_result = self._parse_tool(response) - logger.debug(f"OpenAI adapter result: success={openai_result['success']}") - - if not openai_result["success"]: - logger.warning(f"OpenAI adapter failed: {openai_result['error']}") - logger.warning( - f"rollout response: {response} can not be parsed into " f"tool calls {openai_result['error']}" - ) + parse_ok = bool(openai_result.get("success", False)) + + if not parse_ok: + if getattr(self, "episode_logger", None): + self.episode_logger.log_step( + { + "step_id": step_id, + "phase": "tool_parse", + "tool_parse_ok": False, + "tool_parse_error": openai_result.get("error"), + "assistant": response, + } + ) res.status = Status.ABORTED return self._build_final_result( res, total_reward, info, messages, loss_masks, prompt_token_ids, response_token_ids ) - # Extract parsed results - parsed = openai_result["parsed_result"] + parsed = openai_result.get("parsed_result") or {} + if getattr(self, "episode_logger", None): + self.episode_logger.log_step( + { + "step_id": step_id, + "phase": "tool_parse", + "tool_parse_ok": True, + "normal_text": parsed.get("normal_text"), + "tool_calls": parsed.get("calls"), + } + ) + logger.debug( - f"Successfully parsed - normal_text: '{parsed['normal_text']}', " f"calls: {parsed['calls']}" + "Successfully parsed - normal_text=%r calls=%r", + parsed.get("normal_text"), + parsed.get("calls"), ) except Exception as e: logger.warning(f"Exception in OpenAI adapter: {e}") - logger.warning(f"rollout response: {response} can not be parsed into " f"tool calls {e}") + logger.warning("rollout response: can not be parsed into tool calls") + if getattr(self, "episode_logger", None): + self.episode_logger.log_step( + { + "step_id": step_id, + "phase": "tool_parse", + "tool_parse_ok": False, + "tool_parse_error": repr(e), + "assistant_raw": response, + } + ) res.status = Status.ABORTED return self._build_final_result( res, total_reward, info, messages, loss_masks, prompt_token_ids, response_token_ids @@ -280,33 +359,50 @@ async def asolve( # Execute action in environment agent_content, calls = parsed["normal_text"], parsed["calls"] - logger.debug(f"Creating action from - content: '{agent_content}', " f"calls: {calls}") action = call_to_action_sglang(calls, agent_content) - logger.debug(f"Created action: {action}") try: env_response = await self._execute_tool(env, action) + if getattr(self, "episode_logger", None): + info_dict = None + try: + info_dict = env_response.info.model_dump() + except Exception: + info_dict = None + self.episode_logger.log_step( + { + "step_id": step_id, + "phase": "env_step", + "env_step_ok": True, + "action_name": action.name, + "action_kwargs": getattr(action, "kwargs", None), + "reward": env_response.reward, + "done": env_response.done, + "observation": env_response.observation, + "info_keys": list(info_dict.keys()) if isinstance(info_dict, dict) else None, + } + ) except Exception as e: - logger.warning("Environment step failed, this is usually related to " "the User simulation call.") - logger.warning(f"Error: {e}") + if getattr(self, "episode_logger", None): + self.episode_logger.log_step( + { + "step_id": step_id, + "phase": "env_step", + "env_step_ok": False, + "error": repr(e), + "action_name": action.name, + "action_kwargs": getattr(action, "kwargs", None), + } + ) res.status = Status.ABORTED return self._build_final_result( res, total_reward, info, messages, loss_masks, prompt_token_ids, response_token_ids ) - logger.debug(f"Environment response: reward={env_response.reward}, " f"done={env_response.done}") - # Update message history based on action type if action.name != RESPOND_ACTION_NAME: - messages.append( - { - "role": "tool", - "name": action.name, - "content": env_response.observation, - } - ) + messages.append({"role": "tool", "name": action.name, "content": env_response.observation}) else: - # Direct response from user messages.append({"role": "user", "content": env_response.observation}) # Update token tracking @@ -321,10 +417,15 @@ async def asolve( # Check if done if env_response.done: res.status = Status.COMPLETED + step_id += 1 break + step_id += 1 + # Handle truncation - if not env_response.done: + if env_response is None: + res.status = Status.ABORTED + elif not env_response.done: res.status = Status.TRUNCATED return self._build_final_result( @@ -427,6 +528,7 @@ def __init__( temperature: float = 0.0, rollout_args: dict[str, Any] | None = None, sampling_params: dict[str, Any] | None = None, + episode_logger=None, ): # Initialize the parent ToolCallingAgent super().__init__( @@ -451,6 +553,7 @@ def __init__( } # Initialize OpenAI adapter self.openai_adapter = create_openai_adapter(tools_info=self.tools_info, parser_type="qwen25") + self.episode_logger = episode_logger def agent_factory( @@ -459,6 +562,7 @@ def agent_factory( config: RunConfig, rollout_args: dict[str, Any] | None = None, sampling_params: dict[str, Any] | None = None, + episode_logger=None, ) -> Agent: if config.agent_strategy == "tool-calling": return TrainableToolCallingAgent( @@ -469,6 +573,7 @@ def agent_factory( temperature=config.temperature, rollout_args=rollout_args, sampling_params=sampling_params, + episode_logger=episode_logger, ) else: raise NotImplementedError(f"Unsupported agent strategy: {config.agent_strategy}") diff --git a/examples/tau-bench/tau2/.env.template b/examples/tau-bench/tau2/.env.template new file mode 100644 index 000000000..b72fa6ba5 --- /dev/null +++ b/examples/tau-bench/tau2/.env.template @@ -0,0 +1,22 @@ +# Tau2 Pipeline Environment Variables +# Copy to `.env` and fill in your keys. + +# User simulator for evaluation (GPT-4.1-2025-04-14 via OpenAI API). +OPENAI_API_KEY=your_openai_api_key_here + +# User simulator for training rollouts (Gemini, optional if using local model). +GEMINI_API_KEY=your_gemini_api_key_here + +# Optional: local user simulator (no external API keys). +# TAU2_USER_API_BASE=http://127.0.0.1:30001/v1 +# TAU2_USER_MODEL=openai/Qwen/Qwen3-4B-Instruct-2507 + +# tau2-bench data directory (optional but recommended for reproducibility). +# Example (editable install): /root/tau2-bench/data +TAU2_DATA_DIR=/path/to/tau2-bench/data + +# Optional: WandB logging (used by the training scripts). +WANDB_API_KEY=your_wandb_api_key_here + +# Optional: Hugging Face auth for model downloads/uploads. +HF_TOKEN=your_hf_token_here diff --git a/examples/tau-bench/tau2/README.md b/examples/tau-bench/tau2/README.md new file mode 100644 index 000000000..a589c6241 --- /dev/null +++ b/examples/tau-bench/tau2/README.md @@ -0,0 +1,19 @@ +# Tau2 Pipeline (tau2-bench) + +This folder contains the tau2-bench integration used by the canonical cookbook in `examples/tau-bench/training_cookbook.md`. + +Start here: `examples/tau-bench/training_cookbook.md`. + +## What's in this folder + +- `run_sft.sh`, `run_grpo.sh`: convenience scripts for SFT and GRPO (write under `TAU_BENCH_OUT_DIR`) +- `start_user_sim_server.sh`: starts a local user simulator server for GRPO/eval (port 30001) +- `rollout.py`: GRPO rollout entrypoint (`--custom-generate-function-path rollout.generate`) +- `tasks.py`: generates `tau2_{split}_all_tasks.jsonl` task index files under `TAU_BENCH_OUT_DIR` +- `eval.py`: unified evaluation harness (supports Pass@1 with `--num-samples=1 --temperature=0.0` for greedy, and Pass@K with `--num-samples=K` for multi-sampling; Pass@4 is the headline metric; pass@k = any success among k attempts; defaults to GPT-4.1-mini user sim) +- `prompting.py`, `actions.py`, `env.py`, `reward.py`: utilities used by rollouts/eval + +## Pre-generated dataset + +The cookbook uses a pinned pre-generated SFT dataset by default: +`Jarrodbarnes/tau2-sft-seed-v3` on Hugging Face. diff --git a/examples/tau-bench/tau2/actions.py b/examples/tau-bench/tau2/actions.py new file mode 100644 index 000000000..c4ceed3a2 --- /dev/null +++ b/examples/tau-bench/tau2/actions.py @@ -0,0 +1,164 @@ +"""Tau2 action parsing and observation formatting. + +We standardize on Qwen3 native function calling: + + {"name": "...", "arguments": {...}} +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any + + +_TOOL_CALL_START = "" +_TOOL_CALL_END = "" + + +@dataclass(frozen=True, slots=True) +class ParsedAction: + name: str + arguments: dict[str, Any] + raw_action_call: str + + +def _to_py_literal(value: Any) -> str: + if value is None: + return "None" + if value is True: + return "True" + if value is False: + return "False" + if isinstance(value, (int, float)): + return repr(value) + if isinstance(value, str): + return repr(value) + if isinstance(value, list): + return "[" + ", ".join(_to_py_literal(v) for v in value) + "]" + if isinstance(value, dict): + items = [] + for k, v in value.items(): + items.append(f"{_to_py_literal(k)}: {_to_py_literal(v)}") + return "{" + ", ".join(items) + "}" + return repr(value) + + +def _to_functional_call(name: str, arguments: dict[str, Any]) -> str: + if not arguments: + return f"{name}()" + parts = [f"{k}={_to_py_literal(v)}" for k, v in arguments.items()] + return f"{name}({', '.join(parts)})" + + +def _find_tool_call_blocks(text: str) -> list[tuple[str, int, int]]: + blocks: list[tuple[str, int, int]] = [] + cursor = 0 + while True: + start = text.find(_TOOL_CALL_START, cursor) + if start == -1: + break + end = text.find(_TOOL_CALL_END, start + len(_TOOL_CALL_START)) + if end == -1: + raise ValueError("Missing for block") + content = text[start + len(_TOOL_CALL_START) : end].strip() + blocks.append((content, start, end + len(_TOOL_CALL_END))) + cursor = end + len(_TOOL_CALL_END) + return blocks + + +def parse_action(text: str) -> ParsedAction: + blocks = _find_tool_call_blocks(text) + if not blocks: + raise ValueError("Missing ... block") + if len(blocks) > 1: + raise ValueError("Multiple blocks found; expected exactly one") + + content, start, end = blocks[0] + prefix = text[:start].strip() + suffix = text[end:].strip() + if prefix or suffix: + raise ValueError("Unexpected text outside block") + + data = json.loads(content) + name = data.get("name") + arguments = data.get("arguments") or {} + + if not isinstance(name, str) or not name: + raise ValueError("Tool call missing non-empty 'name'") + + if isinstance(arguments, str): + arguments = json.loads(arguments) if arguments.strip() else {} + if not isinstance(arguments, dict): + raise ValueError("Tool call 'arguments' must be an object") + + return ParsedAction(name=name, arguments=arguments, raw_action_call=_to_functional_call(name, arguments)) + + +def _strip_role_prefix(line: str) -> tuple[str | None, str]: + line = line.strip() + if ": " not in line: + return None, line + role, content = line.split(": ", 1) + return role.strip().lower(), content + + +@dataclass(frozen=True, slots=True) +class ParsedObservation: + user: str + tool: str + other: str + + +def split_observation(observation: str) -> ParsedObservation: + user_lines: list[str] = [] + tool_lines: list[str] = [] + other_lines: list[str] = [] + + for raw_line in (observation or "").splitlines(): + role, content = _strip_role_prefix(raw_line) + content = content.strip() + if not content: + continue + if role == "user": + user_lines.append(content) + elif role == "tool": + tool_lines.append(content) + else: + other_lines.append(content if role is None else f"{role}: {content}") + + return ParsedObservation( + user="\n".join(user_lines).strip(), + tool="\n".join(tool_lines).strip(), + other="\n".join(other_lines).strip(), + ) + + +def followup_messages_for_observation( + *, + observation: str, + last_action_call: str, + last_action_was_tool: bool, +) -> list[dict[str, str]]: + parsed = split_observation(observation) + messages: list[dict[str, str]] = [] + + if last_action_was_tool: + tool_payload = parsed.tool or parsed.other or "[no_observation]" + messages.append({"role": "user", "content": f"Tool result for {last_action_call}:\n{tool_payload}"}) + if parsed.user: + messages.append({"role": "user", "content": parsed.user}) + return messages + + user_payload = parsed.user or parsed.other or "[no_observation]" + messages.append({"role": "user", "content": user_payload}) + return messages + + +def env_action_from_parsed_action(action: ParsedAction) -> str: + if action.name == "respond": + content = action.arguments.get("content") + if not isinstance(content, str) or not content.strip(): + raise ValueError("respond requires a non-empty content string") + return content + return action.raw_action_call diff --git a/examples/tau-bench/tau2/env.py b/examples/tau-bench/tau2/env.py new file mode 100644 index 000000000..1f8cd8b41 --- /dev/null +++ b/examples/tau-bench/tau2/env.py @@ -0,0 +1,173 @@ +"""Small utilities around tau2-bench `AgentGymEnv`.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, slots=True) +class Tau2EpisodeConfig: + domain: str + task_split: str + user_llm: str + user_llm_args: dict[str, Any] + max_steps: int = 100 + all_messages_as_observation: bool = False + + +class Tau2TaskIndex: + def __init__(self, domain: str, task_split: str): + from tau2.registry import registry + + self.domain = domain + self.task_split = task_split + self._tasks = registry.get_tasks_loader(domain)(task_split) + if not self._tasks: + raise ValueError(f"No tasks loaded for domain={domain} split={task_split}") + + def __len__(self) -> int: + return len(self._tasks) + + def task_id(self, task_index: int) -> str: + if task_index < 0 or task_index >= len(self._tasks): + raise IndexError(f"task_index={task_index} out of range (0..{len(self._tasks)-1})") + return self._tasks[task_index].id + + +class Tau2AgentGymWrapper: + def __init__(self, cfg: Tau2EpisodeConfig): + self.cfg = cfg + self._task_index = Tau2TaskIndex(cfg.domain, cfg.task_split) + self._env = None + self._task_id: str | None = None + self._tools_openai: list[dict[str, Any]] | None = None + self._policy: str | None = None + + @property + def task_count(self) -> int: + return len(self._task_index) + + def task_id_from_index(self, task_index: int) -> str: + return self._task_index.task_id(task_index) + + @property + def tools_openai_schema(self) -> list[dict[str, Any]]: + if self._tools_openai is None: + raise RuntimeError("Call reset() first to populate tool schemas.") + return self._tools_openai + + @property + def policy(self) -> str: + if self._policy is None: + raise RuntimeError("Call reset() first to populate policy.") + return self._policy + + @property + def task_id(self) -> str: + if self._task_id is None: + raise RuntimeError("Call reset() first to select a task.") + return self._task_id + + def reset(self, task_index: int) -> tuple[str, dict[str, Any]]: + from tau2.gym.gym_agent import AgentGymEnv + + self._task_id = self._task_index.task_id(task_index) + self._env = AgentGymEnv( + domain=self.cfg.domain, + task_id=self._task_id, + max_steps=self.cfg.max_steps, + solo_mode=False, + user_llm=self.cfg.user_llm, + user_llm_args=self.cfg.user_llm_args, + all_messages_as_observation=self.cfg.all_messages_as_observation, + ) + observation, info = self._env.reset() + + tools = info.get("tools", []) + self._tools_openai = [t if isinstance(t, dict) else t.openai_schema for t in tools] + + self._policy = info.get("policy", "") + + return observation, info + + def step(self, action: str) -> tuple[str, float, bool, bool, dict[str, Any]]: + if self._env is None: + raise RuntimeError("Call reset() before step().") + return self._env.step(action) + + +def parse_reward_info(info: dict[str, Any]) -> dict[str, Any]: + return parse_reward_info_value(info.get("reward_info")) + + +def parse_reward_info_value(reward_info: Any) -> dict[str, Any]: + if reward_info is None: + return {} + if isinstance(reward_info, dict): + return reward_info + if isinstance(reward_info, str): + try: + parsed = json.loads(reward_info) + return parsed if isinstance(parsed, dict) else {} + except json.JSONDecodeError: + return {} + return {} + + +@dataclass(frozen=True, slots=True) +class PartialScoreWeights: + action: float = 0.5 + communicate: float = 0.15 + env_assertion: float = 0.35 + db: float = 0.0 + + +def compute_partial_score_from_reward_info( + reward_info: dict[str, Any], + *, + weights: PartialScoreWeights = PartialScoreWeights(), + normalize_over_present: bool = True, +) -> tuple[float, dict[str, float]]: + components: dict[str, float] = {} + + action_checks = reward_info.get("action_checks") or [] + if action_checks: + matched = sum(1 for ac in action_checks if ac.get("action_match")) + components["action"] = matched / len(action_checks) + + communicate_checks = reward_info.get("communicate_checks") or [] + if communicate_checks: + met = sum(1 for cc in communicate_checks if cc.get("met")) + components["communicate"] = met / len(communicate_checks) + + env_assertions = reward_info.get("env_assertions") or [] + if env_assertions: + met = sum(1 for ea in env_assertions if ea.get("met")) + components["env_assertion"] = met / len(env_assertions) + + db_check = reward_info.get("db_check") or {} + if isinstance(db_check, dict) and "db_match" in db_check: + components["db"] = 1.0 if db_check.get("db_match") else 0.0 + + if not components: + return 0.0, {} + + weight_map = { + "action": weights.action, + "communicate": weights.communicate, + "env_assertion": weights.env_assertion, + "db": weights.db, + } + + if normalize_over_present: + present = {k: weight_map[k] for k in components.keys()} + weight_sum = sum(present.values()) + if weight_sum <= 0: + return 0.0, components + score = sum(components[k] * present[k] for k in components.keys()) / weight_sum + return float(score), components + + score = sum(components[k] * weight_map[k] for k in components.keys()) + return float(score), components diff --git a/examples/tau-bench/tau2/eval.py b/examples/tau-bench/tau2/eval.py new file mode 100644 index 000000000..88a3e2617 --- /dev/null +++ b/examples/tau-bench/tau2/eval.py @@ -0,0 +1,373 @@ +#!/usr/bin/env python3 +"""Pass@K evaluation for tau2-bench using an SGLang-served policy.""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import os +import sys +from dataclasses import asdict, dataclass +from typing import Any + +import httpx +from transformers import AutoTokenizer + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +if SCRIPT_DIR not in sys.path: + sys.path.insert(0, SCRIPT_DIR) + +from actions import env_action_from_parsed_action, followup_messages_for_observation, parse_action +from env import compute_partial_score_from_reward_info, parse_reward_info +from prompting import build_tau2_agent_system_prompt + +logger = logging.getLogger(__name__) + +DEFAULT_DOMAINS = ("airline", "retail", "telecom") +PASS_AT_K_NOTE = "pass@k = any success among k attempts (not pass^k leaderboard estimate)" + + +def _parse_csv(value: str) -> list[str]: + return [x.strip() for x in value.split(",") if x.strip()] + + +def _get_user_llm_args(*, temperature: float) -> dict[str, Any]: + args: dict[str, Any] = {"temperature": temperature} + api_base = os.environ.get("TAU2_USER_API_BASE", "").strip() + if api_base: + args["api_base"] = api_base + args["api_key"] = "dummy-key-for-local-server" + args["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}} + return args + + +@dataclass(frozen=True, slots=True) +class AttemptResult: + success: bool + reward: float + partial_score: float + partial_components: dict[str, float] + steps: int + status: str + error: str | None = None + reward_info: dict[str, Any] | None = None + + +@dataclass(frozen=True, slots=True) +class PassKResult: + domain: str + task_split: str + task_index: int + task_id: str + num_samples: int + best_success: bool + best_reward: float + best_partial_score: float + best_sample_idx: int + attempts: list[dict[str, Any]] + pass_at_1: float + pass_at_k: float + + +class SGLangClient: + def __init__(self, url: str) -> None: + self.url = url + self._client = httpx.AsyncClient(timeout=httpx.Timeout(300.0)) + + async def close(self) -> None: + await self._client.aclose() + + async def generate(self, *, text: str, sampling_params: dict[str, Any]) -> dict[str, Any]: + resp = await self._client.post(self.url, json={"text": text, "sampling_params": sampling_params}) + resp.raise_for_status() + return resp.json() + + +def _load_tasks(domain: str, task_split: str) -> list[str]: + from tau2.registry import registry + + return [t.id for t in registry.get_tasks_loader(domain)(task_split)] + + +async def _run_one_attempt( + *, + client: SGLangClient, + tokenizer, + domain: str, + task_id: str, + sampling_params: dict[str, Any], + max_steps: int, + user_llm: str, + user_llm_args: dict[str, Any], +) -> AttemptResult: + from tau2.gym.gym_agent import AgentGymEnv + + env = AgentGymEnv( + domain=domain, + task_id=task_id, + max_steps=max_steps, + solo_mode=False, + user_llm=user_llm, + user_llm_args=user_llm_args, + all_messages_as_observation=False, + ) + + observation, info = env.reset() + tools = info.get("tools", []) + tools_openai = [t if isinstance(t, dict) else t.openai_schema for t in tools] + policy = info.get("policy", "") + + system_prompt = build_tau2_agent_system_prompt(domain=domain, policy=policy, tools_openai=tools_openai) + messages: list[dict[str, str]] = [{"role": "system", "content": system_prompt}] + messages.extend( + followup_messages_for_observation( + observation=observation, + last_action_call="(reset)", + last_action_was_tool=False, + ) + ) + + reward = 0.0 + reward_info: dict[str, Any] = {} + + for step in range(max_steps): + prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + out = await client.generate(text=prompt_text, sampling_params=sampling_params) + + if out.get("meta_info", {}).get("finish_reason", {}).get("type") == "abort": + return AttemptResult( + success=False, + reward=0.0, + partial_score=0.0, + partial_components={}, + steps=step, + status="aborted", + error="sglang_abort", + ) + + assistant_text = (out.get("text") or "").strip() + if not assistant_text: + return AttemptResult( + success=False, + reward=0.0, + partial_score=0.0, + partial_components={}, + steps=step, + status="empty_generation", + error="empty_generation", + ) + + try: + parsed = parse_action(assistant_text) + except Exception as exc: + messages.append({"role": "assistant", "content": assistant_text}) + messages.append( + { + "role": "user", + "content": "FORMAT ERROR. Re-output EXACTLY in the required format: " + '{"name": "...", "arguments": {...}}. One action only.', + } + ) + repair_params = {**sampling_params, "temperature": 0.0} + out = await client.generate( + text=tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True), + sampling_params=repair_params, + ) + assistant_text = (out.get("text") or "").strip() + try: + parsed = parse_action(assistant_text) + except Exception: + partial_score, partial_components = compute_partial_score_from_reward_info(reward_info) + return AttemptResult( + success=False, + reward=float(reward), + partial_score=partial_score, + partial_components=partial_components, + steps=step + 1, + status="parse_error", + error=str(exc), + reward_info=reward_info, + ) + + messages.append({"role": "assistant", "content": assistant_text}) + + env_action = env_action_from_parsed_action(parsed) + observation, reward, terminated, _truncated, info = env.step(env_action) + + if terminated: + reward_info = parse_reward_info(info) + partial_score, partial_components = compute_partial_score_from_reward_info(reward_info) + return AttemptResult( + success=float(reward) >= 1.0, + reward=float(reward), + partial_score=partial_score, + partial_components=partial_components, + steps=step + 1, + status="completed", + reward_info=reward_info, + ) + + messages.extend( + followup_messages_for_observation( + observation=observation, + last_action_call=parsed.raw_action_call, + last_action_was_tool=(parsed.name != "respond"), + ) + ) + + partial_score, partial_components = compute_partial_score_from_reward_info(reward_info) + return AttemptResult( + success=False, + reward=float(reward), + partial_score=partial_score, + partial_components=partial_components, + steps=max_steps, + status="truncated", + reward_info=reward_info, + ) + + +def _summarize(results: list[PassKResult], *, k: int) -> dict[str, Any]: + total = len(results) + pass1 = sum(r.pass_at_1 for r in results) / total if total else 0.0 + passk = sum(r.pass_at_k for r in results) / total if total else 0.0 + return {"total": total, "pass_at_1": pass1, f"pass_at_{k}": passk} + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Evaluate tau2-bench with Pass@K sampling") + parser.add_argument("--hf-checkpoint", required=True) + parser.add_argument("--sglang-url", required=True) + parser.add_argument("--output", required=True) + parser.add_argument("--domains", default=",".join(DEFAULT_DOMAINS)) + parser.add_argument("--task-split", default="test", choices=("train", "test", "base")) + parser.add_argument("--max-tasks-per-domain", type=int, default=None) + parser.add_argument("--max-steps", type=int, default=int(os.environ.get("TAU2_MAX_STEPS", "100"))) + parser.add_argument("--num-samples", type=int, default=4) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--top-p", type=float, default=0.8) + parser.add_argument("--top-k", type=int, default=20) + parser.add_argument("--repetition-penalty", type=float, default=1.0) + parser.add_argument("--max-new-tokens", type=int, default=1200) + parser.add_argument("--user-model", default=os.environ.get("TAU2_USER_MODEL", "gpt-4.1-mini")) + parser.add_argument("--user-temperature", type=float, default=float(os.environ.get("TAU2_USER_TEMPERATURE", "0.7"))) + return parser + + +def _validate_args(args: argparse.Namespace, parser: argparse.ArgumentParser) -> None: + if args.num_samples < 1: + parser.error("--num-samples must be >= 1") + + +async def main_async() -> None: + parser = _build_arg_parser() + args = parser.parse_args() + _validate_args(args, parser) + + domains = _parse_csv(args.domains) + tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True) + + sampling_params: dict[str, Any] = { + "temperature": args.temperature, + "top_p": args.top_p, + "repetition_penalty": args.repetition_penalty, + "max_new_tokens": args.max_new_tokens, + "stop": [""], + "no_stop_trim": True, + } + if args.top_k > 0: + sampling_params["top_k"] = args.top_k + + client = SGLangClient(args.sglang_url) + try: + user_llm_args = _get_user_llm_args(temperature=args.user_temperature) + all_results: list[PassKResult] = [] + + for domain in domains: + task_ids = _load_tasks(domain, args.task_split) + if args.max_tasks_per_domain is not None: + task_ids = task_ids[: args.max_tasks_per_domain] + + logger.info(f"Evaluating domain={domain} split={args.task_split} tasks={len(task_ids)} k={args.num_samples}") + for i, task_id in enumerate(task_ids): + attempts: list[AttemptResult] = [] + for _ in range(args.num_samples): + attempts.append( + await _run_one_attempt( + client=client, + tokenizer=tokenizer, + domain=domain, + task_id=task_id, + sampling_params=sampling_params, + max_steps=args.max_steps, + user_llm=args.user_model, + user_llm_args=user_llm_args, + ) + ) + + best_idx = 0 + for j in range(1, len(attempts)): + a = attempts[j] + b = attempts[best_idx] + if a.success and not b.success: + best_idx = j + elif a.success == b.success and a.partial_score > b.partial_score: + best_idx = j + + pass_at_1 = 1.0 if attempts and attempts[0].success else 0.0 + pass_at_k = 1.0 if any(a.success for a in attempts) else 0.0 + best = attempts[best_idx] + all_results.append( + PassKResult( + domain=domain, + task_split=args.task_split, + task_index=i, + task_id=task_id, + num_samples=args.num_samples, + best_success=best.success, + best_reward=best.reward, + best_partial_score=best.partial_score, + best_sample_idx=best_idx, + attempts=[asdict(a) for a in attempts], + pass_at_1=pass_at_1, + pass_at_k=pass_at_k, + ) + ) + + by_domain: dict[str, list[PassKResult]] = {} + for r in all_results: + by_domain.setdefault(r.domain, []).append(r) + + report = { + "hf_checkpoint": args.hf_checkpoint, + "sglang_url": args.sglang_url, + "task_split": args.task_split, + "domains": domains, + "k": args.num_samples, + "metric_note": PASS_AT_K_NOTE, + "summary": _summarize(all_results, k=args.num_samples), + "by_domain": {d: _summarize(rs, k=args.num_samples) for d, rs in sorted(by_domain.items())}, + "results": [asdict(r) for r in all_results], + } + + os.makedirs(os.path.dirname(os.path.abspath(args.output)) or ".", exist_ok=True) + with open(args.output, "w") as f: + json.dump(report, f, indent=2) + + logger.info(f"Wrote {len(all_results)} results to {args.output}") + logger.info(f"Overall: {report['summary']}") + for d, s in report["by_domain"].items(): + logger.info(f"{d}: {s}") + finally: + await client.close() + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + asyncio.run(main_async()) + + +if __name__ == "__main__": + main() diff --git a/examples/tau-bench/tau2/prompting.py b/examples/tau-bench/tau2/prompting.py new file mode 100644 index 000000000..9973180f9 --- /dev/null +++ b/examples/tau-bench/tau2/prompting.py @@ -0,0 +1,216 @@ +"""Prompt utilities for tau2-bench (dual-control).""" + +from __future__ import annotations + +import json +import os +from typing import Any + + +USE_COMPRESSED = os.environ.get("TAU2_USE_COMPRESSED_PROMPTS", "0") != "0" + + +COMPRESSED_POLICY_TELECOM = """# Telecom Agent Policy + +## Capabilities +Help users with: technical support, overdue bill payment, line suspension, plan changes, data roaming, data refueling. + +## Core Rules +1. Authenticate customer first (by phone, customer_id, or name+DOB). +2. One tool call OR one response per turn, never both. +3. Confirm before any state-changing action. +4. Transfer to human only if request is outside your capabilities. + +## Action Ownership (CRITICAL) +You can directly perform: customer lookup, account queries, enable/disable roaming, resume lines, refuel data, process payments. + +For device diagnostics and settings, you must INSTRUCT THE USER to perform them: +- Ask user to toggle airplane mode on/off +- Ask user to check and reset APN settings +- Ask user to restart/reboot their device +- Ask user to reseat their SIM card +- Ask user to toggle mobile data on/off +- Ask user to run a speed test + +After instructing the user, wait for them to confirm they've done it, then proceed. + +## Key Procedures + +### Customer Lookup +- Look up by phone number, customer ID, or full name (requires DOB for verification). +- Get details for customer, line, device, bill, or plan as needed. + +### Line Suspension +- Lines suspend for: overdue bills OR expired contract. +- Can resume after bills paid, UNLESS contract expired. +- After resuming a line, ask user to reboot their device. + +### Overdue Bills +- Check bill status to confirm it is overdue. +- Only one bill can be AWAITING_PAYMENT at a time. +- Flow: send payment request -> user checks and accepts -> make payment -> verify PAID. + +### Data Issues +- Check data usage vs plan limit. +- If exceeded: offer plan change OR data refuel (max 2GB). +- If abroad: check and enable roaming if needed. + +### Technical Support +- First identify the customer and check their line/device status. +- For connectivity issues, instruct user step-by-step: + 1. Toggle airplane mode on, wait 10 seconds, toggle off + 2. If still not working, restart the device + 3. If still not working, check/reset APN settings + 4. If still not working, reseat SIM card +- After each step, ask user to confirm and check if issue is resolved. +- Only transfer to human after troubleshooting steps are exhausted. +""" + +COMPRESSED_POLICY_RETAIL = """# Retail Agent Policy + +## Capabilities +Help users: cancel/modify pending orders, return/exchange delivered orders, lookup order/product info. + +## Core Rules +1. Authenticate user first (by email OR name+zip). Required even if user provides user_id. +2. One tool call OR one response per turn, never both. +3. Confirm before any database update. +4. One user per conversation only. + +## Key Procedures + +### Order Status Rules +- **pending**: can cancel or modify (address, payment, items) +- **processed**: cannot modify +- **delivered**: can return or exchange +- **cancelled**: no actions possible + +### Cancel Pending Order +- Requires: order_id + reason ("no longer needed" OR "ordered by mistake" only). +- Refund timing: gift card = immediate, other = 5-7 business days. +- Always communicate the refund timeline to user. + +### Modify Pending Order +- Can change: address, payment method, item variants (same product type only). +- Item modification can only be done ONCE per order. +- New payment must cover any price difference. +- Communicate all changes and any price differences to user. + +### Return Delivered Order +- Collect: order_id, item_ids to return, payment method for refund. +- Refund must go to original payment OR existing gift card. +- Communicate the refund method and timing. + +### Exchange Delivered Order +- Same product type only (e.g., shirt to shirt). +- Collect ALL items to exchange before proceeding. +- User pays/receives price difference. +- Communicate the price difference clearly. +""" + +COMPRESSED_POLICY_AIRLINE = """# Airline Agent Policy + +## Capabilities +Help users: book, modify, cancel flight reservations; handle refunds and compensation. + +## Core Rules +1. Authenticate user first (by user_id, email, or name+DOB). +2. One tool call OR one response per turn, never both. +3. Confirm before any booking database update. +4. Transfer to human only if outside your capabilities. + +## Key Procedures + +### Booking Flights +- Collect: user_id, trip_type (one_way/round_trip), origin, destination, dates. +- Same cabin class across all flights in reservation. +- Max 5 passengers; collect name, DOB for each. +- Payment: max 1 travel certificate, 1 credit card, 3 gift cards (all from user profile). + +### Modifying Reservations +- Look up all user reservations to find the correct one. +- Can change: flights, passengers, cabin class, baggage. +- Change fees depend on membership level. +- Basic economy has more restrictions. +- Communicate all fees and changes clearly. + +### Cancellation & Refunds +- Verify the correct reservation before cancelling. +- Refund amount depends on: membership level, cabin class, timing. +- Travel certificates: non-refundable remainder. +- Gold members have maximum flexibility. +- Communicate the refund amount and method. + +### Membership Levels +- regular: standard fees and restrictions +- silver: reduced fees, some flexibility +- gold: waived fees, maximum flexibility + +### Cabin Classes +- basic_economy: most restricted +- economy: standard +- business: premium, most flexible +""" + + +def get_compressed_policy(domain: str) -> str: + """Return compressed policy for domain, or empty string if not available.""" + policies = { + "telecom": COMPRESSED_POLICY_TELECOM, + "retail": COMPRESSED_POLICY_RETAIL, + "airline": COMPRESSED_POLICY_AIRLINE, + } + return policies.get(domain, "") + + +def format_tools_json_schema(tools_openai: list[dict[str, Any]]) -> str: + tool_schemas = [] + for tool in tools_openai: + fn = tool.get("function", tool) + schema = { + "name": fn.get("name", "unknown"), + "description": fn.get("description", ""), + "parameters": fn.get("parameters", {"type": "object", "properties": {}}), + } + tool_schemas.append(schema) + return json.dumps(tool_schemas, indent=2) + + +def build_tau2_agent_system_prompt( + *, + domain: str, + policy: str, + tools_openai: list[dict[str, Any]], +) -> str: + if USE_COMPRESSED: + compressed = get_compressed_policy(domain) + if compressed: + policy = compressed + + tools_text = format_tools_json_schema(tools_openai) + return f"""## Output Format +Every turn: exactly ONE tool call in this format: + +{{"name": "tool_name", "arguments": {{"param": "value"}}}} + + +Special actions: +- respond: {{"name": "respond", "arguments": {{"content": "message"}}}} +- done: {{"name": "done", "arguments": {{}}}} + +--- + +You are a {domain} customer support agent. Complete the user's task following the policy below. + +{policy} + +## Available Tools +{tools_text} + +## Rules +- One tool call per turn (no plain text responses) +- Authenticate user before state changes +- Confirm before modifications +- Communicate all relevant details to the user +- Call done immediately when task is complete +""" diff --git a/examples/tau-bench/tau2/reward.py b/examples/tau-bench/tau2/reward.py new file mode 100644 index 000000000..bf85a09b1 --- /dev/null +++ b/examples/tau-bench/tau2/reward.py @@ -0,0 +1,196 @@ +"""Reward shaping for tau2-bench rollouts. + +Entry point: `--custom-reward-post-process-path reward.tau2_reward_post_process` +""" + +from __future__ import annotations + +import os +import threading +from collections import defaultdict + +import torch + +from env import PartialScoreWeights, compute_partial_score_from_reward_info, parse_reward_info +from slime.utils.types import Sample + + +CURRICULUM_SOLVED_THRESHOLD = 0.75 +CURRICULUM_HARD_THRESHOLD = 0.25 + + +class _Curriculum: + """Process-local curriculum tracker. + + Note: In distributed Ray settings, each worker maintains its own tracker + instance. This is a best-effort heuristic; disable via TAU2_USE_CURRICULUM=0 + for fully deterministic behavior across workers. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._attempts: dict[str, int] = defaultdict(int) + self._successes: dict[str, int] = defaultdict(int) + + def enabled(self) -> bool: + return os.environ.get("TAU2_USE_CURRICULUM", "1") == "1" + + def min_attempts(self) -> int: + return int(os.environ.get("TAU2_CURRICULUM_MIN_ATTEMPTS", "5")) + + def solved_weight(self) -> float: + return float(os.environ.get("TAU2_CURRICULUM_SOLVED_WEIGHT", "0.1")) + + def hard_weight(self) -> float: + return float(os.environ.get("TAU2_CURRICULUM_HARD_WEIGHT", "0.5")) + + def update(self, task_id: str, reward: float) -> None: + if not self.enabled() or not task_id: + return + with self._lock: + self._attempts[task_id] += 1 + if reward >= 1.0: + self._successes[task_id] += 1 + + def weight(self, task_id: str) -> float: + if not self.enabled() or not task_id: + return 1.0 + with self._lock: + attempts = self._attempts.get(task_id, 0) + if attempts < self.min_attempts(): + return 1.0 + rate = self._successes.get(task_id, 0) / attempts if attempts else 0.0 + if rate > CURRICULUM_SOLVED_THRESHOLD: + return self.solved_weight() + if rate < CURRICULUM_HARD_THRESHOLD: + return self.hard_weight() + return 1.0 + + +_curriculum = _Curriculum() + + +def _alpha(domain: str | None) -> float: + base = float(os.environ.get("TAU2_REWARD_ALPHA", "0.25")) + if os.environ.get("TAU2_DOMAIN_ADAPTIVE_ALPHA", "1") != "1" or not domain: + return base + # Telecom uses a higher multiplier to compensate for dual-control communication overhead. + mult = {"retail": 0.8, "airline": 1.0, "telecom": 1.6}.get(domain, 1.0) + return base * mult + + +def _partial_weights(domain: str | None) -> PartialScoreWeights: + action = float(os.environ.get("TAU2_PARTIAL_ACTION_WEIGHT", "0.5")) + communicate = float(os.environ.get("TAU2_PARTIAL_COMMUNICATE_WEIGHT", "0.15")) + env_assertion = float(os.environ.get("TAU2_PARTIAL_ENV_ASSERTION_WEIGHT", "0.35")) + db = float(os.environ.get("TAU2_PARTIAL_DB_WEIGHT", "0.0")) + + if domain == "telecom" and os.environ.get("TAU2_TELECOM_COMMUNICATION_BOOST", "1") == "1": + return PartialScoreWeights(action=0.35, communicate=0.35, env_assertion=0.30, db=0.0) + + return PartialScoreWeights(action=action, communicate=communicate, env_assertion=env_assertion, db=db) + + +def _flatten(samples: list[Sample] | list[list[Sample]]) -> list[Sample]: + if not samples: + return [] + if isinstance(samples[0], list): + return [s for group in samples for s in group] + return list(samples) + + +def _normalize_rewards( + rewards: list[float], + *, + valid_mask: list[float], + n_samples_per_prompt: int, + apply_std: bool, +) -> list[float]: + if not rewards: + return [] + if n_samples_per_prompt <= 0: + raise ValueError(f"n_samples_per_prompt must be >= 1 (got {n_samples_per_prompt})") + if len(rewards) % n_samples_per_prompt != 0: + raise ValueError( + "Reward count must be a multiple of n_samples_per_prompt " + f"(count={len(rewards)}, n_samples_per_prompt={n_samples_per_prompt})." + ) + + rewards_t = torch.tensor(rewards, dtype=torch.float).view(-1, n_samples_per_prompt) + mask_t = torch.tensor(valid_mask, dtype=torch.float).view(-1, n_samples_per_prompt) + + denom = mask_t.sum(dim=-1, keepdim=True).clamp(min=1.0) + mean = (rewards_t * mask_t).sum(dim=-1, keepdim=True) / denom + centered = (rewards_t - mean) * mask_t + + if apply_std: + var = (centered**2).sum(dim=-1, keepdim=True) / denom + centered = centered / (torch.sqrt(var) + 1e-6) + + return centered.flatten().tolist() + + +def tau2_reward_post_process(args, samples: list[Sample] | list[list[Sample]]) -> tuple[list[float], list[float]]: + flat = _flatten(samples) + + raw_rewards: list[float] = [] + shaped_rewards: list[float] = [] + sample_weights: list[float] = [] + valid_mask: list[float] = [] + + for sample in flat: + task_reward = float(sample.get_reward_value(args)) + raw_rewards.append(task_reward) + + metadata = sample.metadata or {} + domain = metadata.get("domain") + task_id = metadata.get("task_id") or metadata.get("tau2_task_id") or "" + + reward_info = parse_reward_info({"reward_info": metadata.get("reward_info")}) + partial_score, partial_components = compute_partial_score_from_reward_info( + reward_info, + weights=_partial_weights(domain), + normalize_over_present=True, + ) + + shaped = task_reward + _alpha(domain) * partial_score + is_valid = not sample.remove_sample + valid_mask.append(1.0 if is_valid else 0.0) + shaped_rewards.append(float(shaped) if is_valid else 0.0) + + _curriculum.update(task_id, task_reward) + w = _curriculum.weight(task_id) if is_valid else 0.0 + sample_weights.append(w) + + sample.metadata = { + **metadata, + "raw_reward": task_reward, + "partial_score": partial_score, + "partial_components": partial_components, + "shaped_reward": shaped, + "curriculum_weight": w, + } + + if ( + args.advantage_estimator in {"grpo", "gspo", "reinforce_plus_plus_baseline"} + and args.rewards_normalization + and shaped_rewards + ): + rewards = _normalize_rewards( + shaped_rewards, + valid_mask=valid_mask, + n_samples_per_prompt=args.n_samples_per_prompt, + apply_std=(args.advantage_estimator in {"grpo", "gspo"} and args.grpo_std_normalization), + ) + + if os.environ.get("TAU2_APPLY_CURRICULUM_WEIGHTS", "1") == "1": + weights = torch.tensor(sample_weights, dtype=torch.float).view(-1, args.n_samples_per_prompt) + rewards_t = torch.tensor(rewards, dtype=torch.float).view_as(weights) + rewards = (rewards_t * weights).flatten().tolist() + + return raw_rewards, rewards + + if os.environ.get("TAU2_APPLY_CURRICULUM_WEIGHTS", "1") == "1": + shaped_rewards = [r * sample_weights[i] for i, r in enumerate(shaped_rewards)] + + return raw_rewards, shaped_rewards diff --git a/examples/tau-bench/tau2/rollout.py b/examples/tau-bench/tau2/rollout.py new file mode 100644 index 000000000..b0248b682 --- /dev/null +++ b/examples/tau-bench/tau2/rollout.py @@ -0,0 +1,211 @@ +"""Tau2 (dual-control) rollout generation for slime. + +Used via `--custom-generate-function-path rollout.generate`. +""" + +from __future__ import annotations + +import os +from typing import Any + +from transformers import AutoTokenizer + +from slime.rollout.sglang_rollout import generate as sglang_generate +from slime.utils.mask_utils import MultiTurnLossMaskGenerator +from slime.utils.types import Sample + +from actions import env_action_from_parsed_action, followup_messages_for_observation, parse_action +from env import compute_partial_score_from_reward_info, parse_reward_info +from prompting import build_tau2_agent_system_prompt + +TOKENIZER = None +MASK_GENERATOR = None + + +def _get_tokenizer_and_mask_generator(args) -> tuple[Any, MultiTurnLossMaskGenerator]: + global TOKENIZER, MASK_GENERATOR + if TOKENIZER is None: + TOKENIZER = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True) + if MASK_GENERATOR is None: + MASK_GENERATOR = MultiTurnLossMaskGenerator(TOKENIZER, tokenizer_type=args.loss_mask_type) + return TOKENIZER, MASK_GENERATOR + + +def _get_tau2_user_sim_config() -> tuple[str, dict[str, Any]]: + user_model = os.environ.get("TAU2_USER_MODEL", "openai/Qwen/Qwen3-4B-Instruct-2507") + user_temperature = float(os.environ.get("TAU2_USER_TEMPERATURE", "0.7")) + user_api_base = os.environ.get("TAU2_USER_API_BASE", "http://127.0.0.1:30001/v1") + + user_llm_args: dict[str, Any] = {"temperature": user_temperature} + + if user_api_base: + user_llm_args["api_base"] = user_api_base + user_llm_args["api_key"] = "dummy-key-for-local-server" + user_llm_args["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}} + + return user_model, user_llm_args + + +def _get_tau2_max_steps(args) -> int: + return int(os.environ.get("TAU2_MAX_STEPS", "100")) + + +async def _generate_one_action( + args, + *, + messages: list[dict[str, str]], + sampling_params: dict[str, Any], + max_retries: int = 1, +) -> tuple[str, Any | None, str | None]: + stops = list(sampling_params.get("stop") or []) + if "" not in stops: + stops.append("") + sampling_params = {**sampling_params, "stop": stops, "no_stop_trim": True} + + working_messages = list(messages) + last_text = "" + last_error: str | None = None + + for attempt in range(max_retries + 1): + turn_sample = Sample(prompt=working_messages, metadata={}) + turn_sample = await sglang_generate(args, turn_sample, sampling_params.copy()) + if turn_sample.status == Sample.Status.ABORTED: + return "", None, "sglang_aborted" + + text = (turn_sample.response or "").strip() + last_text = text + try: + parsed = parse_action(text) + return text, parsed, None + except Exception as exc: + last_error = str(exc) + if attempt >= max_retries: + break + working_messages = working_messages + [ + {"role": "assistant", "content": text}, + { + "role": "user", + "content": "FORMAT ERROR. Re-output EXACTLY in the required format: " + '{"name": "...", "arguments": {...}}. One action only.', + }, + ] + sampling_params = {**sampling_params, "temperature": 0.0} + + return last_text, None, last_error + + +async def generate(args, sample: Sample, sampling_params: dict) -> Sample: + """Multi-turn rollout for tau2-bench (dual-control).""" + from tau2.gym.gym_agent import AgentGymEnv + + metadata = sample.metadata or {} + task_index = metadata.get("task_index", 0) + domain = metadata["domain"] + task_split = metadata.get("split", "train") + task_id = metadata["task_id"] + + user_llm, user_llm_args = _get_tau2_user_sim_config() + max_steps = _get_tau2_max_steps(args) + + env = AgentGymEnv( + domain=domain, + task_id=task_id, + max_steps=max_steps, + solo_mode=False, + user_llm=user_llm, + user_llm_args=user_llm_args, + all_messages_as_observation=False, + ) + + observation, info = env.reset() + tools = info.get("tools", []) + tools_openai = [t if isinstance(t, dict) else t.openai_schema for t in tools] + + policy = info.get("policy", "") + + system_prompt = build_tau2_agent_system_prompt(domain=domain, policy=policy, tools_openai=tools_openai) + + messages: list[dict[str, str]] = [{"role": "system", "content": system_prompt}] + messages.extend( + followup_messages_for_observation( + observation=observation, + last_action_call="(reset)", + last_action_was_tool=False, + ) + ) + + assistant_turn_texts: list[str] = [] + tool_sequence: list[str] = [] + reward = 0.0 + reward_info: dict[str, Any] = {} + info: dict[str, Any] = {} + + terminated = False + for _ in range(max_steps): + assistant_text, parsed, err = await _generate_one_action( + args, + messages=messages, + sampling_params=sampling_params, + max_retries=1, + ) + if assistant_text: + messages.append({"role": "assistant", "content": assistant_text}) + assistant_turn_texts.append(assistant_text) + + if parsed is None: + sample.remove_sample = True + break + + if parsed.name not in ("respond", "done"): + tool_sequence.append(parsed.name) + + env_action = env_action_from_parsed_action(parsed) + observation, reward, terminated, _truncated, info = env.step(env_action) + + if terminated: + reward_info = parse_reward_info(info) + break + + messages.extend( + followup_messages_for_observation( + observation=observation, + last_action_call=parsed.raw_action_call, + last_action_was_tool=(parsed.name != "respond"), + ) + ) + + sample.prompt = messages + sample.response = "\n\n".join(assistant_turn_texts).strip() + sample.reward = float(reward) + sample.status = Sample.Status.COMPLETED if terminated else Sample.Status.TRUNCATED + + reward_info = reward_info or parse_reward_info(info) + partial_score, partial_components = compute_partial_score_from_reward_info(reward_info) + + if sample.metadata is None: + sample.metadata = {} + sample.metadata.update( + { + "domain": domain, + "split": task_split, + "task_id": task_id, + "task_index": task_index, + "user_model": user_llm, + "reward_info": reward_info, + "partial_score": partial_score, + "partial_components": partial_components, + "tool_sequence": tool_sequence, + "terminated": terminated, + } + ) + + # Build tokens/loss-mask from the final multi-turn message list. + _, mask_gen = _get_tokenizer_and_mask_generator(args) + token_ids, full_loss_mask = mask_gen.get_loss_mask(messages) + response_length = mask_gen.get_response_lengths([full_loss_mask])[0] + + sample.tokens = token_ids + sample.response_length = response_length + sample.loss_mask = full_loss_mask[-response_length:] if response_length > 0 else [] + + return sample diff --git a/examples/tau-bench/tau2/run_grpo.sh b/examples/tau-bench/tau2/run_grpo.sh new file mode 100644 index 000000000..dd8c18749 --- /dev/null +++ b/examples/tau-bench/tau2/run_grpo.sh @@ -0,0 +1,196 @@ +#!/bin/bash + +# Tau2 GRPO training (tau2-bench). + +set -euo pipefail + +if [ "${TAU2_CLEANUP:-0}" = "1" ]; then + pkill -9 sglang || true + ray stop --force || true + pkill -9 ray || true +fi + +set -ex + +export PYTHONUNBUFFERED=1 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../../.." &>/dev/null && pwd)" +TAU_BENCH_OUT_DIR="${TAU_BENCH_OUT_DIR:-${SCRIPT_DIR}/../outputs}" +MEGATRON_LM_DIR="${MEGATRON_LM_DIR:-/root/Megatron-LM}" + +# ---- User-configurable paths ---- +HF_DIR="${HF_DIR:-${TAU_BENCH_OUT_DIR}/models/Qwen3-4B-Instruct-2507}" +SFT_CKPT_DIR="${SFT_CKPT_DIR:-${TAU_BENCH_OUT_DIR}/tau2/checkpoints/Qwen3-4B-tau2-sft1}" +REF_CKPT_DIR="${REF_CKPT_DIR:-${SFT_CKPT_DIR}}" +SAVE_DIR="${SAVE_DIR:-${TAU_BENCH_OUT_DIR}/tau2/checkpoints/Qwen3-4B-tau2-grpo-v1}" + +TAU2_TRAIN_TASKS_JSONL="${TAU2_TRAIN_TASKS_JSONL:-${TAU_BENCH_OUT_DIR}/tau2/tasks/tau2_train_all_tasks.jsonl}" + +NUM_GPUS="${NUM_GPUS:-4}" + +export TAU2_USER_MODEL="${TAU2_USER_MODEL:-openai/Qwen/Qwen3-4B-Instruct-2507}" +export TAU2_USER_API_BASE="${TAU2_USER_API_BASE:-http://127.0.0.1:30001/v1}" +export TAU2_USER_TEMPERATURE="${TAU2_USER_TEMPERATURE:-0.7}" +export TAU2_MAX_STEPS="${TAU2_MAX_STEPS:-100}" + +export TAU2_REWARD_ALPHA="${TAU2_REWARD_ALPHA:-0.25}" +export TAU2_DOMAIN_ADAPTIVE_ALPHA="${TAU2_DOMAIN_ADAPTIVE_ALPHA:-1}" +export TAU2_PARTIAL_ACTION_WEIGHT="${TAU2_PARTIAL_ACTION_WEIGHT:-0.5}" +export TAU2_PARTIAL_COMMUNICATE_WEIGHT="${TAU2_PARTIAL_COMMUNICATE_WEIGHT:-0.15}" +export TAU2_PARTIAL_ENV_ASSERTION_WEIGHT="${TAU2_PARTIAL_ENV_ASSERTION_WEIGHT:-0.35}" +export TAU2_PARTIAL_DB_WEIGHT="${TAU2_PARTIAL_DB_WEIGHT:-0.0}" + +export TAU2_TELECOM_COMMUNICATION_BOOST="${TAU2_TELECOM_COMMUNICATION_BOOST:-1}" + +export TAU2_USE_COMPRESSED_PROMPTS="${TAU2_USE_COMPRESSED_PROMPTS:-0}" + +export TAU2_USE_CURRICULUM="${TAU2_USE_CURRICULUM:-1}" +export TAU2_CURRICULUM_MIN_ATTEMPTS="${TAU2_CURRICULUM_MIN_ATTEMPTS:-5}" +export TAU2_CURRICULUM_SOLVED_WEIGHT="${TAU2_CURRICULUM_SOLVED_WEIGHT:-0.1}" +export TAU2_CURRICULUM_HARD_WEIGHT="${TAU2_CURRICULUM_HARD_WEIGHT:-0.5}" + +source "${SCRIPT_DIR}/../../../scripts/models/qwen3-4B-Instruct-2507.sh" + +CKPT_ARGS=( + --hf-checkpoint "${HF_DIR}" + --ref-load "${REF_CKPT_DIR}" + --load "${SFT_CKPT_DIR}" + --no-load-optim # Start optimizer fresh for GRPO + --save "${SAVE_DIR}" + --save-interval 10 +) + +ROLLOUT_ARGS=( + --prompt-data "${TAU2_TRAIN_TASKS_JSONL}" + --input-key text + --metadata-key metadata + --apply-chat-template + --rollout-shuffle + + # Optimized for 4xH100: better task coverage per iteration + # NOTE: rollout_batch_size * n_samples_per_prompt must be multiple of global_batch_size + --num-rollout 200 + --rollout-batch-size 16 + --n-samples-per-prompt 4 + --rollout-max-response-len 4096 + --rollout-temperature 0.7 + # Note: --rollout-stop removed; rollout.py ensures is in stops + + --global-batch-size 64 + --balance-data +) + +GRPO_ARGS=( + --advantage-estimator grpo + --eps-clip 0.2 + --eps-clip-high 0.28 + --entropy-coef 0.001 + --use-kl-loss + --kl-loss-coef 0.01 + --kl-loss-type low_var_kl +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 20480 + + # Gradient checkpointing for H100 memory efficiency + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.70 # Increased from 0.60 since user sim is on separate GPU + # If rollouts abort, uncomment: + # --no-offload-rollout +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +CUSTOM_ARGS=( + --custom-generate-function-path rollout.generate + --custom-reward-post-process-path reward.tau2_reward_post_process +) + +if [ -n "${WANDB_API_KEY:-}" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project "tau2-cookbook" + --wandb-group "grpo-qwen3-4b" + --wandb-exp-name "tau2-grpo-v1" + ) +else + echo "NOTE: WANDB_API_KEY not set, running without WandB logging" + WANDB_ARGS=() +fi + +export MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}" +RAY_DASHBOARD_HOST="${RAY_DASHBOARD_HOST:-127.0.0.1}" +ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" \ + --disable-usage-stats --dashboard-host="${RAY_DASHBOARD_HOST}" --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"working_dir\": \"${REPO_ROOT}\", + \"env_vars\": { + \"PYTHONPATH\": \"${MEGATRON_LM_DIR}:${SCRIPT_DIR}:${REPO_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"TAU2_DATA_DIR\": \"${TAU2_DATA_DIR:-${TAU_BENCH_OUT_DIR}/_external/tau2-bench/data}\", + \"TAU2_USER_MODEL\": \"${TAU2_USER_MODEL}\", + \"TAU2_USER_API_BASE\": \"${TAU2_USER_API_BASE:-}\", + \"TAU2_USER_TEMPERATURE\": \"${TAU2_USER_TEMPERATURE}\", + \"TAU2_MAX_STEPS\": \"${TAU2_MAX_STEPS}\", + \"TAU2_REWARD_ALPHA\": \"${TAU2_REWARD_ALPHA}\", + \"TAU2_DOMAIN_ADAPTIVE_ALPHA\": \"${TAU2_DOMAIN_ADAPTIVE_ALPHA}\", + \"TAU2_PARTIAL_ACTION_WEIGHT\": \"${TAU2_PARTIAL_ACTION_WEIGHT}\", + \"TAU2_PARTIAL_COMMUNICATE_WEIGHT\": \"${TAU2_PARTIAL_COMMUNICATE_WEIGHT}\", + \"TAU2_PARTIAL_ENV_ASSERTION_WEIGHT\": \"${TAU2_PARTIAL_ENV_ASSERTION_WEIGHT}\", + \"TAU2_PARTIAL_DB_WEIGHT\": \"${TAU2_PARTIAL_DB_WEIGHT}\", + \"TAU2_TELECOM_COMMUNICATION_BOOST\": \"${TAU2_TELECOM_COMMUNICATION_BOOST}\", + \"TAU2_USE_COMPRESSED_PROMPTS\": \"${TAU2_USE_COMPRESSED_PROMPTS}\", + \"TAU2_USE_CURRICULUM\": \"${TAU2_USE_CURRICULUM}\", + \"TAU2_CURRICULUM_MIN_ATTEMPTS\": \"${TAU2_CURRICULUM_MIN_ATTEMPTS}\", + \"TAU2_CURRICULUM_SOLVED_WEIGHT\": \"${TAU2_CURRICULUM_SOLVED_WEIGHT}\", + \"TAU2_CURRICULUM_HARD_WEIGHT\": \"${TAU2_CURRICULUM_HARD_WEIGHT}\", + \"WANDB_API_KEY\": \"${WANDB_API_KEY:-}\", + \"HF_TOKEN\": \"${HF_TOKEN:-}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node "${NUM_GPUS}" \ + --rollout-num-gpus "${NUM_GPUS}" \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${CUSTOM_ARGS[@]} \ + ${WANDB_ARGS[@]+"${WANDB_ARGS[@]}"} diff --git a/examples/tau-bench/tau2/run_sft.sh b/examples/tau-bench/tau2/run_sft.sh new file mode 100644 index 000000000..0440bae0f --- /dev/null +++ b/examples/tau-bench/tau2/run_sft.sh @@ -0,0 +1,128 @@ +#!/bin/bash + +# Tau2 SFT training (tau2-bench). + +set -euo pipefail + +if [ "${TAU2_CLEANUP:-0}" = "1" ]; then + pkill -9 sglang || true + ray stop --force || true + pkill -9 ray || true +fi + +set -ex + +export PYTHONUNBUFFERED=1 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../../.." &>/dev/null && pwd)" +TAU_BENCH_OUT_DIR="${TAU_BENCH_OUT_DIR:-${SCRIPT_DIR}/../outputs}" +MEGATRON_LM_DIR="${MEGATRON_LM_DIR:-/root/Megatron-LM}" + +HF_DIR="${HF_DIR:-${TAU_BENCH_OUT_DIR}/models/Qwen3-4B-Instruct-2507}" +TORCH_DIST_DIR="${TORCH_DIST_DIR:-${TAU_BENCH_OUT_DIR}/models/Qwen3-4B-Instruct-2507_torch_dist}" +TAU2_SFT_DATA_DIR="${TAU2_SFT_DATA_DIR:-${TAU_BENCH_OUT_DIR}/tau2/data/sft1}" +SFT_DATA_JSONL="${SFT_DATA_JSONL:-${TAU2_SFT_DATA_DIR}/tau2_sft_merged_v3_rft.jsonl}" +if [ ! -f "${SFT_DATA_JSONL}" ] && [ -f "${TAU2_SFT_DATA_DIR}/seed_sft_v3.jsonl" ]; then + SFT_DATA_JSONL="${TAU2_SFT_DATA_DIR}/seed_sft_v3.jsonl" +fi +SAVE_DIR="${SAVE_DIR:-${TAU_BENCH_OUT_DIR}/tau2/checkpoints/Qwen3-4B-tau2-sft1}" + +NUM_GPUS="${NUM_GPUS:-4}" + +source "${SCRIPT_DIR}/../../../scripts/models/qwen3-4B-Instruct-2507.sh" + +CKPT_ARGS=( + --hf-checkpoint "${HF_DIR}" + --load "${TORCH_DIST_DIR}" + --save "${SAVE_DIR}" + --save-interval 50 +) + +SFT_ARGS=( + --prompt-data "${SFT_DATA_JSONL}" + --input-key prompt + --apply-chat-template + --loss-mask-type qwen3 + --loss-type sft_loss + --calculate-per-token-loss + --disable-compute-advantages-and-returns + --rollout-function-path slime.rollout.sft_rollout.generate_rollout + + --num-epoch 2 + --rollout-batch-size 16 + --n-samples-per-prompt 1 + --rollout-shuffle + --rollout-max-response-len 4096 + --global-batch-size 16 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 12288 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-5 + --lr-decay-style cosine + --lr-warmup-fraction 0.05 + --weight-decay 0.01 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +if [ -n "${WANDB_API_KEY:-}" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project "tau2-cookbook" + --wandb-group "sft-qwen3-4b" + --wandb-exp-name "tau2-sft-v1" + ) +else + echo "NOTE: WANDB_API_KEY not set, running without WandB logging" + WANDB_ARGS=() +fi + +export MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}" +RAY_DASHBOARD_HOST="${RAY_DASHBOARD_HOST:-127.0.0.1}" +ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" \ + --disable-usage-stats --dashboard-host="${RAY_DASHBOARD_HOST}" --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"working_dir\": \"${REPO_ROOT}\", + \"env_vars\": { + \"PYTHONPATH\": \"${MEGATRON_LM_DIR}:${SCRIPT_DIR}:${REPO_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"WANDB_API_KEY\": \"${WANDB_API_KEY:-}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train_async.py \ + --debug-train-only \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node "${NUM_GPUS}" \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${SFT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${WANDB_ARGS[@]} diff --git a/examples/tau-bench/tau2/start_user_sim_server.sh b/examples/tau-bench/tau2/start_user_sim_server.sh new file mode 100755 index 000000000..bbfc5e7ab --- /dev/null +++ b/examples/tau-bench/tau2/start_user_sim_server.sh @@ -0,0 +1,26 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +TAU_BENCH_OUT_DIR="${TAU_BENCH_OUT_DIR:-${SCRIPT_DIR}/../outputs}" + +MODEL_DIR="${MODEL_DIR:-${TAU_BENCH_OUT_DIR}/models/Qwen3-4B-Instruct-2507}" +HOST="${HOST:-127.0.0.1}" +PORT="${PORT:-30001}" +# Keep these GPUs separate from training CUDA_VISIBLE_DEVICES. +GPUS="${GPUS:-2,3}" +TP="${TP:-2}" +MEM_FRACTION="${MEM_FRACTION:-0.85}" + +if [ ! -d "${MODEL_DIR}" ]; then + echo "Missing model directory: ${MODEL_DIR}" + echo "Download first (example): huggingface-cli download Qwen/Qwen3-4B-Instruct-2507 --local-dir \"${MODEL_DIR}\"" + exit 1 +fi + +CUDA_VISIBLE_DEVICES="${GPUS}" python3 -m sglang.launch_server \ + --model-path "${MODEL_DIR}" \ + --host "${HOST}" \ + --port "${PORT}" \ + --tp "${TP}" \ + --mem-fraction-static "${MEM_FRACTION}" diff --git a/examples/tau-bench/tau2/tasks.py b/examples/tau-bench/tau2/tasks.py new file mode 100644 index 000000000..c617a6819 --- /dev/null +++ b/examples/tau-bench/tau2/tasks.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +"""Preprocess tau2-bench tasks into JSONL index files for slime.""" + +from __future__ import annotations + +import argparse +import json +import os +from typing import Any + + +DEFAULT_DOMAINS = ("retail", "airline", "telecom") +DEFAULT_SPLITS = ("train", "test", "base") + + +def _parse_csv(value: str) -> list[str]: + items = [x.strip() for x in value.split(",")] + return [x for x in items if x] + + +def main() -> None: + parser = argparse.ArgumentParser(description="Preprocess tau2-bench tasks to JSONL for slime") + parser.add_argument( + "--local_dir", + required=True, + help="Output directory for `{domain}_{split}_tasks.jsonl` files", + ) + parser.add_argument( + "--domains", + default=",".join(DEFAULT_DOMAINS), + help=f"Comma-separated list of domains (default: {','.join(DEFAULT_DOMAINS)})", + ) + parser.add_argument( + "--splits", + default=",".join(DEFAULT_SPLITS), + help=f"Comma-separated list of task splits (default: {','.join(DEFAULT_SPLITS)})", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Max tasks per (domain, split) for smoke testing", + ) + args = parser.parse_args() + + from tau2.registry import registry + + local_dir = args.local_dir + os.makedirs(local_dir, exist_ok=True) + + domains = _parse_csv(args.domains) + splits = _parse_csv(args.splits) + + for split in splits: + all_rows: list[dict[str, Any]] = [] + for domain in domains: + tasks_loader = registry.get_tasks_loader(domain) + tasks = tasks_loader(split) + if args.limit is not None: + tasks = tasks[: args.limit] + + output_path = os.path.join(local_dir, f"{domain}_{split}_tasks.jsonl") + with open(output_path, "w") as f: + for i, task in enumerate(tasks): + row: dict[str, Any] = { + "text": [{"role": "user", "content": "task"}], + "metadata": { + "domain": domain, + "split": split, + "task_id": task.id, + "task_index": i, + "task": task.model_dump(), + }, + } + f.write(json.dumps(row) + "\n") + all_rows.append(row) + + print(f"Saved {len(tasks)} tasks to {output_path}") + + merged_path = os.path.join(local_dir, f"tau2_{split}_all_tasks.jsonl") + with open(merged_path, "w") as f: + for row in all_rows: + f.write(json.dumps(row) + "\n") + print(f"Saved {len(all_rows)} tasks to {merged_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/tau-bench/training_cookbook.md b/examples/tau-bench/training_cookbook.md new file mode 100644 index 000000000..0d66e0190 --- /dev/null +++ b/examples/tau-bench/training_cookbook.md @@ -0,0 +1,394 @@ +# Training Multi-Turn Tool-Use Agents: SFT → RFT → GRPO + +We report a 4B parameter model achieving **57.1% Pass@4** on tau2-bench (test split): **4× better than the base model** and competitive with models 6-60× larger. The model is faster, cheaper to run, and demonstrates that progressive training (SFT → rejection sampling → GRPO) works for complex, multi-turn tool-use tasks. + +This cookbook shows you how we did it. Everything is public and open source: [training data](https://huggingface.co/datasets/Jarrodbarnes/tau2-sft-seed-v3), [checkpoints](https://huggingface.co/Jarrodbarnes/Qwen3-4B-tau2-grpo-v1), and [code](https://github.com/THUDM/slime/tree/main/examples/tau-bench). + +![Tau2 pipeline overview](public/slime-pipeline-tau2.jpeg) + +*Figure 1: Three-stage training pipeline (SFT → rejection sampling/RFT → GRPO) for multi-turn tool-use agents.* + +## TLDR + +**Setup** (inside `slimerl/slime:latest` container): +```bash +export SLIME_ROOT="$(pwd)" TAU_BENCH_OUT_DIR="${SLIME_ROOT}/examples/tau-bench/outputs" +git clone https://github.com/sierra-research/tau2-bench.git "${TAU_BENCH_OUT_DIR}/_external/tau2-bench" +cd "${TAU_BENCH_OUT_DIR}/_external/tau2-bench" && git checkout 337326e62d8e0ca74c353b004a9c5d748e0ba914 && pip install -e . --no-deps && cd "${SLIME_ROOT}" +export TAU2_DATA_DIR="${TAU_BENCH_OUT_DIR}/_external/tau2-bench/data" +pip install gymnasium addict deepdiff fs langfuse plotly pydantic-argparse redis ruff scikit-learn seaborn tenacity watchdog "litellm==1.65.0" +cp examples/tau-bench/tau2/.env.template examples/tau-bench/tau2/.env # ADD OPENAI_API_KEY +set -a && source examples/tau-bench/tau2/.env && set +a +``` + +**Evaluate** (uses published checkpoint, ~2h on 2xH100): +```bash +# Terminal 1: Policy server (use `--tp 1` on single-GPU) +CUDA_VISIBLE_DEVICES=0,1 python3 -m sglang.launch_server \ + --model-path Jarrodbarnes/Qwen3-4B-tau2-grpo-v1 \ + --host 0.0.0.0 --port 30000 --tp 2 --mem-fraction-static 0.70 + +# Terminal 2: Run evaluation (uses GPT-4.1-mini as user simulator; requires OPENAI_API_KEY) +python3 examples/tau-bench/tau2/eval.py \ + --hf-checkpoint Jarrodbarnes/Qwen3-4B-tau2-grpo-v1 \ + --sglang-url http://127.0.0.1:30000/generate \ + --domains airline,retail,telecom --task-split test --num-samples 4 \ + --temperature 0.8 --top-p 1.0 --top-k 20 \ + --output "${TAU_BENCH_OUT_DIR}/tau2/eval/eval_pass4.json" +``` + +To train from scratch, see [Train from Scratch](#train-from-scratch-optional). + +--- + +## Contents + +- [TLDR](#tldr) +- [Why Tau2-Bench?](#why-tau2-bench) +- [Performance snapshot](#performance-snapshot) +- [Before You Start](#before-you-start) +- [Setup (Tau2)](#setup-tau2) +- [Resources](#resources) +- [Methodology (why this works)](#methodology-why-this-works) +- [Implementation Details](#implementation-details) +- [Quickstart: Reproduce Pass@4](#quickstart-reproduce-pass4) +- [Train from Scratch (Optional)](#train-from-scratch-optional) +- [Smoke tests (documented)](#smoke-tests-documented) +- [Troubleshooting](#troubleshooting) + +## Why Tau2-Bench? + +Tau2-bench ([paper](https://arxiv.org/pdf/2506.07982), [repo](https://github.com/sierra-research/tau2-bench)) tests multi-turn agents in realistic scenarios: airline bookings, retail purchases, and telecom troubleshooting. Unlike simpler benchmarks, it requires agents to maintain protocol correctness across dozens of turns while managing complex tool schemas. + +The telecom domain is particularly challenging. It uses **dual-control**, meaning diagnostic actions are user-only. The agent must *instruct* the user to perform diagnostics rather than executing them directly. This mirrors real customer support workflows and pushes difficulty into communication strategy. + +## Performance snapshot + +Complete performance comparison (test split; Pass@4 is the headline metric): + +| Stage | Overall | Airline | Retail | Telecom | +|------------------------------|---------|---------|--------|---------| +| Baseline (Qwen3-4B-Instruct) | 14.3% | 5.0% | 16.0% | 20.0% | +| SFT | 8.57% | 5.0% | 20.0% | 0.0% | +| SFT1 (RFT) | 27.0% | 20.0% | 50.0% | 7.5% | +| GRPO (Pass@1, greedy) | 32.9% | 15.0% | 76.0% | 4.0% | +| GRPO (Pass@4, temp=0.8, **reported**) | 57.1% | 50.0% | 76.0% | 44.0% | +| Delta (Pass@4 vs Baseline) | +42.8% | +45.0% | +60.0% | +24.0% | + +**What worked:** +- **Progressive training compounds**: Baseline → SFT+RFT (27%) → GRPO (32.9%) → Pass@4 (57.1%, reported). Each stage builds on the last. +- **Pass@K matters for RL**: Multi-sampling at inference (Pass@4) gains +24.2 percentage points over greedy decoding. RL models benefit more from exploration than prompted baselines. +- **Domain-specific gains**: Retail (76%) and airline (50%) saw massive improvements. Telecom (44%), constrained by dual-control complexity, still improved 2.2× over baseline. + +[WandB runs (public, SFT + GRPO v1 only: b7d80rfe, e3jgp9aj) →](https://wandb.ai/jbarnes850-near-protocol/tau2-cookbook) + +![Tau2 performance comparison](public/performance-chart.jpeg) + +*Figure 2: Qwen3-4B with progressive training (57.1% Pass@4, reported) achieves competitive performance against models 6-60× larger. Stacked bar shows contribution from SFT+RFT and GRPO stages.* + +**Local reproduction (Dec 28, 2025)** using the eval command below and full policies (no compressed prompt), with reported sampling settings (`top_p=1.0`): + +| Metric | Overall | Airline | Retail | Telecom | +|--------|---------|---------|--------|---------| +| Pass@1 | 36.0% | 20.0% | 50.0% | 30.0% | +| Pass@4 | 57.0% | 30.0% | 82.5% | 45.0% | + +Config: `Jarrodbarnes/Qwen3-4B-tau2-grpo-v1`, `tau2-bench` commit `337326e62d8e0ca74c353b004a9c5d748e0ba914`, `TAU2_USE_COMPRESSED_PROMPTS=0`, `TAU2_MAX_STEPS=100`, `TAU2_USER_MODEL=gpt-4.1-mini`, `TAU2_USER_TEMPERATURE=0.7`, `temperature=0.8`, `top_p=1.0`, `top_k=20`, `num_samples=4`. + +Reported Pass@4 settings: `TAU2_MAX_STEPS=100`, `TAU2_USER_TEMPERATURE=0.7`, `temperature=0.8`, `top_p=1.0`, `top_k=20`, `num_samples=4`. + +## Before You Start + +All scripts use `slimerl/slime:latest` and assume you're in the repo root. If you're not already inside the container, start it first: + +```bash +docker run --gpus all --rm -it -v "$(pwd)":/workspace/slime -w /workspace/slime slimerl/slime:latest +``` + +Everything outputs to `TAU_BENCH_OUT_DIR` (defaults to `examples/tau-bench/outputs`): + +```bash +export SLIME_ROOT="$(pwd)" +export TAU_BENCH_OUT_DIR="${SLIME_ROOT}/examples/tau-bench/outputs" +``` + +Training is stochastic. You'll get comparable results, not identical ones. Checkpoints and datasets live on Hugging Face; local runs write to gitignored `outputs/`. + +## Setup (Tau2) + +This assumes you are running inside `slimerl/slime:latest` and are in the slime repo root. + +### 0) Install tau2-bench (official) + +```bash +mkdir -p "${TAU_BENCH_OUT_DIR}/_external" +git clone https://github.com/sierra-research/tau2-bench.git "${TAU_BENCH_OUT_DIR}/_external/tau2-bench" +cd "${TAU_BENCH_OUT_DIR}/_external/tau2-bench" +git checkout 337326e62d8e0ca74c353b004a9c5d748e0ba914 +# Avoid dependency conflicts with sglang inside slimerl/slime:latest. +pip install -e . --no-deps +export TAU2_DATA_DIR="${TAU_BENCH_OUT_DIR}/_external/tau2-bench/data" +cd "${SLIME_ROOT}" +``` + +### 1) Python deps (minimal) + +Install the tau2-bench runtime deps explicitly (pin `litellm` to avoid upgrading `openai`/`grpcio`): + +```bash +pip install gymnasium addict deepdiff fs langfuse plotly pydantic-argparse redis ruff \ + scikit-learn seaborn tenacity watchdog "litellm==1.65.0" +``` + +Do not run `pip install -e .` without `--no-deps`; it will downgrade `grpcio` and upgrade `openai`, breaking `sglang` in the base image. + +Optional (recommended for experiment logging): `wandb`, `weave`. + +### 2) API keys and environment + +Create `examples/tau-bench/tau2/.env` from the template and source it: + +```bash +cp examples/tau-bench/tau2/.env.template examples/tau-bench/tau2/.env +set -a && source examples/tau-bench/tau2/.env && set +a +``` + +## Resources + +**Models** (public on Hugging Face): +- [Qwen3-4B-Instruct-2507](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507) - Base model +- [Qwen3-4B-tau2-sft1](https://huggingface.co/Jarrodbarnes/Qwen3-4B-tau2-sft1) - After SFT+RFT +- [Qwen3-4B-tau2-grpo-v1](https://huggingface.co/Jarrodbarnes/Qwen3-4B-tau2-grpo-v1) - Final GRPO checkpoint + +**Dataset** (public): [tau2-sft-seed-v3](https://huggingface.co/datasets/Jarrodbarnes/tau2-sft-seed-v3) - Filtered trajectories from rejection sampling + +**Training logs**: [WandB project](https://wandb.ai/jbarnes850-near-protocol/tau2-cookbook) - Public SFT + GRPO v1 runs only (`b7d80rfe`, `e3jgp9aj`) + +## Methodology (why this works) + +### The Problem: Credit Assignment in Multi-Turn Tool-Use + +Consider a telecom troubleshooting task where the agent must guide a user through 20+ turns of diagnostics before solving their MMS issue. At step 15, the agent asks the user to grant app permissions (a critical action that enables MMS). But the final reward (success/failure) only arrives at step 20. + +**How does the model know step 15 mattered?** + +This is the credit assignment problem. Standard outcome-based rewards (0 for failure, 1 for success) provide essentially zero gradient across the 19 intermediate steps. The model sees: +- Steps 1-19: no signal +- Step 20: success or failure + +For a prompted model, this is catastrophic. Without seeing thousands of examples of this exact interaction pattern, it cannot learn which intermediate actions lead to success. Early SFT attempts on tau2-bench showed this clearly: models achieved 8.57% on the test split (*worse* than the unprompted baseline). + +### Stage 1: SFT Warm-Start (Teaching Protocol) + +Supervised fine-tuning addresses the protocol learning problem. Before a model can *optimize* tool-use, it must first understand: + +1. **Turn structure**: One action per turn, wait for environment response +2. **Tool schemas**: 30+ tools across domains with complex argument structures +3. **Dual-control coordination**: In telecom, the agent coaches users through diagnostics rather than executing them + +Example trajectory from SFT data: +``` +Agent: get_customer_by_phone(phone_number="555-123-2002") +Env: {customer_id: "C1001", ...} +Agent: get_details_by_id(id="L1002") +Env: {line_id: "L1002", status: "Active", ...} +Agent: "Please toggle airplane mode ON, wait 10 seconds, then OFF..." +User: "Done. Still no data." +Agent: "Now open Settings > Apps > Messaging and check permissions..." +``` + +Without SFT, RL training thrashes. The model doesn't know the rules of the game. With SFT, we achieve 27% on test (after rejection filtering), establishing a foundation for exploration. + +### Stage 2: Rejection Sampling (RFT) - Concentrating Success Patterns + +After SFT, the model can complete tasks but inconsistently. A critical insight from recent research ([Statistical Rejection Sampling Improves Preference Optimization](https://openreview.net/forum?id=xbjSwwrQOe)): sampling multiple on-policy rollouts and keeping only successes concentrates the training distribution on viable strategies. + +Our rejection sampling process: +1. Sample 4-8 attempts per training task at temperature 0.8 +2. Keep trajectories where `reward >= 1.0` (true successes only) +3. For tasks with no successes, keep the highest `partial_score` trajectory if ≥ 0.6 +4. Retrain SFT on this filtered dataset + +This serves two purposes: +- **Exploration**: High temperature discovers diverse solution paths +- **Quality gates**: Hard filters prevent training on broken strategies + +The published [tau2-sft-seed-v3](https://huggingface.co/datasets/Jarrodbarnes/tau2-sft-seed-v3) dataset is the result of this filtering with a 25% success rate during RFT. + +⚠️ **Limitation**: Rejection sampling requires an SFT policy that can occasionally succeed. On new domains where SFT achieves <5%, you may need teacher demonstrations or curriculum learning first. + +### Stage 3: GRPO + Turn-Level Reward Shaping + +GRPO solves the credit assignment problem through two mechanisms: + +**1. Group-based advantage estimation** + +For each prompt, GRPO samples K trajectories (K=4 in our setup), scores them, and trains the model to increase probability of high-reward actions relative to the group average: + +``` +advantage_k = (reward_k - mean(rewards)) / std(rewards) +loss = -mean(log_prob * advantage) +``` + +This is *relative* optimization. The model learns "this action was better than my other attempts" rather than "this action is objectively good." For multi-turn tasks with many valid paths, this is exactly what we want. + +**2. Dense reward shaping from partial scores** + +Tau2-bench provides `reward_info` with turn-level evaluation: +- `action_checks`: Did the agent call expected tools? +- `communicate_checks`: Did the agent mention required information? +- `env_assertions`: Are environment states correct? + +We extract a `partial_score` from these signals: + +```python +partial_score = 0.5 * (correct_actions / total_expected) + + 0.35 * (env_assertions_met / total_assertions) + + 0.15 * (communication_checks / total_checks) +``` + +The final shaped reward becomes: +```python +shaped_reward = task_reward + alpha * partial_score +``` + +Default `alpha` is 0.25 (domain-adaptive scaling applies a 1.6× multiplier in telecom). +When `TAU2_TELECOM_COMMUNICATION_BOOST=1`, telecom uses 0.35/0.35/0.30 (action/communication/env) weights. + +This provides gradient at every turn, not just at task completion. Research on [turn-level credit assignment](https://arxiv.org/html/2505.11821v1) shows this is critical for multi-turn learning. Trajectory-level rewards fail to distinguish which *turns* contributed to success. + +⚠️ **Watch for reward hacking**: We observed that adding partial credit for "taking more turns" caused the model to repeat tool calls indefinitely. Dense rewards must align with true task objectives. + +### Why This Pipeline Works: Empirical Evidence + +Recent research validates the SFT→RFT→GRPO progression: + +1. **SFT establishes foundation**: Models learn reasoning patterns and task structure ([On the Generalization of SFT](https://arxiv.org/pdf/2508.05629)) +2. **RFT enables exploration**: Statistical rejection sampling improves policy estimation over pure SFT ([Statistical Rejection Sampling](https://openreview.net/forum?id=xbjSwwrQOe)) +3. **GRPO optimizes efficiency**: Group comparisons stabilize learning without critic overhead ([Two-Stage SFT+GRPO Pipeline](https://www.emergentmind.com/topics/two-stage-sft-grpo-training-pipeline)) + +Hybrid RL branching (SFT→RFT→GRPO) reaches maximum SFT performance with only 55% of the compute while pushing the Pareto frontier on both accuracy and efficiency. + +For tau2-bench specifically, the progression shows: +- Baseline: 14.3% (no task understanding) +- SFT+RFT: 27.0% (protocol learned, inconsistent execution) +- GRPO Pass@1: 32.9% (optimized for single best path) +- GRPO Pass@4: 57.1% (reported; robust across multiple sampling attempts) + +The 24.2 percentage point gain from Pass@1 to Pass@4 demonstrates that RL-trained models benefit significantly from inference-time exploration. They've learned multiple viable strategies rather than overfitting to a single path. + +## Implementation Details + +**Dual-control (telecom)**: Diagnostic actions are user-only. The agent instructs rather than executes: +``` +Agent: "Please toggle airplane mode ON, wait 10 seconds, then OFF. Tell me what changes." +User: "Done. Still no data." +``` + +**Function calling**: Qwen3 uses native format `{...}`. Include `` in stop sequences. + +**Chat templates**: Training on multi-turn conversations requires `--apply-chat-template` flag. + +**User simulator**: Training uses a local instruct model on port 30001 (`TAU2_USER_API_BASE=http://127.0.0.1:30001/v1`). Evaluation defaults to GPT-4.1-mini (OpenAI); set `TAU2_USER_MODEL=gpt-4.1-2025-04-14` if you want the larger model. + +## Quickstart: Reproduce Pass@4 + +Download the [GRPO checkpoint](https://huggingface.co/Jarrodbarnes/Qwen3-4B-tau2-grpo-v1), start the policy server, and run evaluation: + +**1. Policy model (port 30000)** (use `--tp 1` on single-GPU): +```bash +CUDA_VISIBLE_DEVICES=0,1 python3 -m sglang.launch_server \ + --model-path Jarrodbarnes/Qwen3-4B-tau2-grpo-v1 \ + --host 0.0.0.0 --port 30000 --tp 2 --mem-fraction-static 0.70 +``` + +**2. Run evaluation** (uses GPT-4.1-mini as user simulator; requires `OPENAI_API_KEY`): +```bash +python3 examples/tau-bench/tau2/eval.py \ + --hf-checkpoint Jarrodbarnes/Qwen3-4B-tau2-grpo-v1 \ + --sglang-url http://127.0.0.1:30000/generate \ + --domains airline,retail,telecom --task-split test --num-samples 4 \ + --temperature 0.8 --top-p 1.0 --top-k 20 \ + --output "${TAU_BENCH_OUT_DIR}/tau2/eval/eval_pass4.json" +``` + +This takes ~2 hours on 2×H100. Results: Pass@1 and Pass@4 metrics across all domains. + +The script outputs both Pass@1 and Pass@4. Results are stochastic; see the local reproduction table above for a concrete run and config. + +Note: `eval.py` reports pass@k (any success among k attempts). The official tau2-bench leaderboard uses pass^k; use tau2-bench metrics if you need leaderboard-comparable numbers. + +To run without external API keys, start the local user simulator and set: +```bash +export TAU2_USER_API_BASE=http://127.0.0.1:30001/v1 +export TAU2_USER_MODEL=openai/Qwen/Qwen3-4B-Instruct-2507 +``` + +## Train from Scratch (Optional) + +We publish the [SFT checkpoint](https://huggingface.co/Jarrodbarnes/Qwen3-4B-tau2-sft1) and [GRPO checkpoint](https://huggingface.co/Jarrodbarnes/Qwen3-4B-tau2-grpo-v1) (public; login only needed for uploads). To train from scratch: + +### Prerequisites + +**1. Download base model and SFT training data**: +```bash +huggingface-cli download Qwen/Qwen3-4B-Instruct-2507 --local-dir "${TAU_BENCH_OUT_DIR}/models/Qwen3-4B-Instruct-2507" +mkdir -p "${TAU_BENCH_OUT_DIR}/tau2/data/sft1" +huggingface-cli download Jarrodbarnes/tau2-sft-seed-v3 --local-dir "${TAU_BENCH_OUT_DIR}/tau2/data/sft1" --repo-type dataset +export TAU2_SFT_DATA_DIR="${TAU_BENCH_OUT_DIR}/tau2/data/sft1" +export SFT_DATA_JSONL="${TAU2_SFT_DATA_DIR}/tau2_sft_merged_v3_rft.jsonl" +``` + +**2. Convert to Megatron format**: +```bash +source scripts/models/qwen3-4B-Instruct-2507.sh +python3 tools/convert_hf_to_torch_dist.py \ + --hf-checkpoint "${TAU_BENCH_OUT_DIR}/models/Qwen3-4B-Instruct-2507" \ + --save "${TAU_BENCH_OUT_DIR}/models/Qwen3-4B-Instruct-2507_torch_dist" \ + ${MODEL_ARGS[@]} +``` + +### Stage 1: SFT + +Run supervised fine-tuning on the filtered RFT trajectories: +```bash +bash examples/tau-bench/tau2/run_sft.sh +``` +For a smaller debug run, set `SFT_DATA_JSONL="${TAU2_SFT_DATA_DIR}/seed_sft_v3.jsonl"`. + +### Stage 2: GRPO + +**Generate task indices**: +```bash +python3 examples/tau-bench/tau2/tasks.py \ + --local_dir "${TAU_BENCH_OUT_DIR}/tau2/tasks" \ + --domains airline,retail,telecom --splits train +``` + +**Start user simulator** (separate terminal, keep GPUs distinct from training): +```bash +GPUS=2,3 bash examples/tau-bench/tau2/start_user_sim_server.sh +``` + +**Run GRPO training** (example for 4 GPUs total): +```bash +CUDA_VISIBLE_DEVICES=0,1 NUM_GPUS=2 bash examples/tau-bench/tau2/run_grpo.sh +``` + +Adjust `GPUS`, `CUDA_VISIBLE_DEVICES`, and `NUM_GPUS` for your machine to avoid overlap. + +Training takes ~2 hours on 8×H100s. [Reference logs (SFT + GRPO v1 only)](https://wandb.ai/jbarnes850-near-protocol/tau2-cookbook). + +## Smoke tests (documented) + +- Import: `python3 -c "from tau2.gym.gym_agent import AgentGymEnv; print('ok')"` +- Task index: run `examples/tau-bench/tau2/tasks.py` with `--limit 1` on `train,test` +- Prompt formatting: ensure `--apply-chat-template` is passed for multi-turn training +- Tiny eval sanity: run `eval.py` with `--max-tasks-per-domain 1` + +## Troubleshooting + +- **SGLang abort/OOM**: reduce `--mem-fraction-static`, reduce `--max-tokens-per-gpu`, reduce `--rollout-batch-size`. +- **Ray working directory issues**: the provided scripts submit Ray jobs with `working_dir` set to the slime repo root and `PYTHONPATH` set explicitly; avoid running from random directories. +- **Ray dashboard exposure**: `run_grpo.sh` binds the dashboard to `127.0.0.1` by default. If you override `RAY_DASHBOARD_HOST`, avoid exposing it on shared networks. +- **Telecom is slow / low Pass@K**: dual-control pushes difficulty into communication. Inspect failures for (a) tool ownership violations, (b) premature `done`, (c) missing follow-up questions after user diagnostics. diff --git a/tests/test_tau2_actions.py b/tests/test_tau2_actions.py new file mode 100644 index 000000000..4a4142a2c --- /dev/null +++ b/tests/test_tau2_actions.py @@ -0,0 +1,40 @@ +import pathlib +import sys + +import pytest + +ROOT = pathlib.Path(__file__).resolve().parents[1] +TAU2_DIR = ROOT / "examples" / "tau-bench" / "tau2" +sys.path.insert(0, str(TAU2_DIR)) + +from actions import parse_action + + +def test_parse_action_rejects_multiple_tool_calls(): + text = ( + '{"name": "respond", "arguments": {"content": "hi"}}' + "\n" + '{"name": "done", "arguments": {}}' + ) + try: + parse_action(text) + except ValueError as exc: + assert "Multiple " in str(exc) + else: + raise AssertionError("Expected ValueError for multiple tool calls") + + +def test_parse_action_handles_nested_arguments(): + text = ( + '{"name": "book", "arguments": {"passengers": [{"name": "Ada", "age": 30}]}}' + ) + action = parse_action(text) + assert action.arguments["passengers"][0]["name"] == "Ada" + + +def test_parse_action_rejects_extra_text(): + text = ( + 'oops {"name": "respond", "arguments": {"content": "hi"}}' + ) + with pytest.raises(ValueError, match="Unexpected text"): + parse_action(text) diff --git a/tests/test_tau2_eval.py b/tests/test_tau2_eval.py new file mode 100644 index 000000000..26288ebcc --- /dev/null +++ b/tests/test_tau2_eval.py @@ -0,0 +1,58 @@ +import pathlib +import sys +import types + +import pytest + +ROOT = pathlib.Path(__file__).resolve().parents[1] +TAU2_DIR = ROOT / "examples" / "tau-bench" / "tau2" +sys.path.insert(0, str(TAU2_DIR)) + +httpx = types.ModuleType("httpx") + + +class _DummyTimeout: + def __init__(self, *args, **kwargs) -> None: + pass + + +class _DummyAsyncClient: + def __init__(self, *args, **kwargs) -> None: + pass + + +httpx.Timeout = _DummyTimeout +httpx.AsyncClient = _DummyAsyncClient +sys.modules["httpx"] = httpx + +transformers = types.ModuleType("transformers") + + +class _DummyAutoTokenizer: + @classmethod + def from_pretrained(cls, *args, **kwargs): + return None + + +transformers.AutoTokenizer = _DummyAutoTokenizer +sys.modules["transformers"] = transformers + +import eval as tau2_eval + + +def test_eval_num_samples_guard(): + parser = tau2_eval._build_arg_parser() + args = parser.parse_args( + [ + "--hf-checkpoint", + "dummy", + "--sglang-url", + "http://localhost:30000/generate", + "--output", + "eval.json", + "--num-samples", + "0", + ] + ) + with pytest.raises(SystemExit): + tau2_eval._validate_args(args, parser) diff --git a/tests/test_tau2_reward.py b/tests/test_tau2_reward.py new file mode 100644 index 000000000..730aef32f --- /dev/null +++ b/tests/test_tau2_reward.py @@ -0,0 +1,23 @@ +import pathlib +import sys + +import pytest + +ROOT = pathlib.Path(__file__).resolve().parents[1] +TAU2_DIR = ROOT / "examples" / "tau-bench" / "tau2" +sys.path.insert(0, str(ROOT)) +sys.path.insert(0, str(TAU2_DIR)) + +reward = pytest.importorskip("reward") + + +def test_normalize_rewards_masks_removed_samples(): + rewards = [1.0, 2.0, 3.0, 4.0] + valid_mask = [1.0, 1.0, 0.0, 1.0] + normalized = reward._normalize_rewards( + rewards, + valid_mask=valid_mask, + n_samples_per_prompt=2, + apply_std=False, + ) + assert pytest.approx(normalized) == [-0.5, 0.5, 0.0, 0.0]