From d1b0c6eccd8332f5144b8b9227fe9f9338c4e8af Mon Sep 17 00:00:00 2001 From: Fengzdadi <453788063@qq.com> Date: Fri, 19 Dec 2025 18:35:17 -0500 Subject: [PATCH] tau-bench: add offline stub user + tool parser compat + logging fixes --- examples/tau-bench/episode_logger.py | 61 +++++++ examples/tau-bench/generate_with_tau.py | 36 ++++ examples/tau-bench/openai_tool_adapter.py | 78 ++++++++- examples/tau-bench/sglang_tool_parser.py | 82 ++++++--- examples/tau-bench/trainable_agents.py | 197 +++++++++++++++++----- 5 files changed, 378 insertions(+), 76 deletions(-) create mode 100644 examples/tau-bench/episode_logger.py diff --git a/examples/tau-bench/episode_logger.py b/examples/tau-bench/episode_logger.py new file mode 100644 index 000000000..c5f4e7c52 --- /dev/null +++ b/examples/tau-bench/episode_logger.py @@ -0,0 +1,61 @@ +# examples/tau-bench/episode_logger.py +from __future__ import annotations + +import hashlib +import json +import os +import time +from dataclasses import dataclass +from typing import Any + + +def _truncate(s: str | None, max_chars: int = 8000) -> str | None: + if s is None: + return None + if len(s) <= max_chars: + return s + return s[:max_chars] + f"\n...[truncated {len(s)-max_chars} chars]" + + +def _sha256_text(s: str) -> str: + return hashlib.sha256(s.encode("utf-8", errors="ignore")).hexdigest() + + +@dataclass +class EpisodeLogger: + log_dir: str + run_meta: dict[str, Any] + + def __post_init__(self): + os.makedirs(self.log_dir, exist_ok=True) + self._jsonl_path = os.path.join(self.log_dir, "episode.jsonl") + self._summary_path = os.path.join(self.log_dir, "summary.json") + self._run_meta_path = os.path.join(self.log_dir, "run_meta.json") + + with open(self._run_meta_path, "w", encoding="utf-8") as f: + json.dump(self.run_meta, f, ensure_ascii=False, indent=2) + + def log_step(self, rec: dict[str, Any]): + # Controlling field size: Truncate long text and append a hash. + for k in [ + "assistant_raw", + "assistant", + "user_text", + "observation", + "tool_result", + "env_state", + "normal_text", + "tool_parse_error", + "error", + ]: + if k in rec and isinstance(rec[k], str): + rec[k + "_hash"] = _sha256_text(rec[k]) + rec[k] = _truncate(rec[k]) + + rec["ts"] = time.time() + with open(self._jsonl_path, "a", encoding="utf-8") as f: + f.write(json.dumps(rec, ensure_ascii=False) + "\n") + + def finalize(self, summary: dict[str, Any]): + with open(self._summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, ensure_ascii=False, indent=2) diff --git a/examples/tau-bench/generate_with_tau.py b/examples/tau-bench/generate_with_tau.py index ce80c8d99..c37d2b667 100644 --- a/examples/tau-bench/generate_with_tau.py +++ b/examples/tau-bench/generate_with_tau.py @@ -8,8 +8,11 @@ import logging import os +import time +import uuid from typing import Any +from episode_logger import EpisodeLogger from tau_bench.envs import get_env from tau_bench.types import RunConfig from trainable_agents import InteractionResult, Status, agent_factory @@ -29,6 +32,7 @@ "model_provider": "auto_router", # Unused, required "model": "qwen3-4b", # Unused, required "user_model_provider": "gemini", + # "user_model_provider": "stub", } # Replace with your actual API key for user sim GEMINI_API_KEY = "NONE" @@ -97,6 +101,10 @@ def res_to_sample(res: InteractionResult, task_index: int) -> Sample: return sample +def _default_run_root() -> str: + return os.environ.get("TAU_RUN_DIR", os.path.join(os.getcwd(), "runs", "tau1")) + + async def generate(args: dict[str, Any], sample: Sample, sampling_params: dict) -> Sample: """ Generate a complete agent-environment interaction trajectory for tau-bench. @@ -121,6 +129,25 @@ async def generate(args: dict[str, Any], sample: Sample, sampling_params: dict) # Extract task index from sample prompt task_index = int(sample.prompt) + + run_root = _default_run_root() + run_id = time.strftime("%Y%m%d_%H%M%S") + f"_{os.getpid()}_{uuid.uuid4().hex[:8]}" + episode_dir = os.path.join(run_root, run_id, f"task_{task_index:06d}") + + run_meta = { + "run_id": run_id, + "task_index": task_index, + "tau_config": TAU_CONFIGS, + "sampling_params": { + k: sampling_params.get(k) + for k in ["temperature", "top_p", "top_k", "max_new_tokens"] + if k in sampling_params + }, + "pid": os.getpid(), + } + os.makedirs(episode_dir, exist_ok=True) + ep_logger = EpisodeLogger(log_dir=episode_dir, run_meta=run_meta) + logger.info(f"Starting agent-environment interaction for task {task_index}") # Initialize tau-bench environment @@ -140,6 +167,7 @@ async def generate(args: dict[str, Any], sample: Sample, sampling_params: dict) config=tau_config, rollout_args=args, sampling_params=sampling_params, + episode_logger=ep_logger, ) # Execute agent-environment interaction @@ -149,5 +177,13 @@ async def generate(args: dict[str, Any], sample: Sample, sampling_params: dict) # Convert to slime Sample format result_sample = res_to_sample(interaction_result, task_index) + ep_logger.finalize( + { + "status": str(interaction_result.status), + "reward": interaction_result.reward, + "response_length": getattr(interaction_result, "response_length", None), + } + ) + logger.info(f"Finished agent-environment interaction for task {task_index}") return result_sample diff --git a/examples/tau-bench/openai_tool_adapter.py b/examples/tau-bench/openai_tool_adapter.py index b580f28b2..79ab3b0db 100644 --- a/examples/tau-bench/openai_tool_adapter.py +++ b/examples/tau-bench/openai_tool_adapter.py @@ -11,6 +11,21 @@ logger = logging.getLogger(__name__) +def _parse_tools_compat(parse_tools_fn, response: str, tools_info, parser_type: str | None): + """ + Compatible wrapper for parse_tools() across versions. + + Some versions: parse_tools(response, tools_info) + Some versions: parse_tools(response, tools_info, parser_type) + """ + try: + # try 3-arg signature first + return parse_tools_fn(response, tools_info, parser_type) + except TypeError: + # fallback to 2-arg signature + return parse_tools_fn(response, tools_info) + + @dataclass class OpenAIToolCall: """OpenAI format tool call structure""" @@ -62,20 +77,67 @@ def parse_response_to_openai_format(self, response: str) -> dict[str, Any]: Exception: Thrown when parsing fails """ try: - # Use existing parser to parse tool calls - parsed = parse_tools(response, self.tools_info, self.parser_type) - - # Extract parsing results - normal_text = parsed["normal_text"] - calls = parsed["calls"] + parsed = _parse_tools_compat(parse_tools, response, self.tools_info, self.parser_type) + + # --- normalize parsed output --- + normal_text = None + calls = [] + + # Case A: dict output + if isinstance(parsed, dict): + # common keys + if "normal_text" in parsed: + normal_text = parsed.get("normal_text") + elif "text" in parsed: + normal_text = parsed.get("text") + elif "content" in parsed: + normal_text = parsed.get("content") + elif "assistant" in parsed: + normal_text = parsed.get("assistant") + else: + # fallback: if no known key, just use raw response + normal_text = response + + # tool calls keys variants + if "calls" in parsed and isinstance(parsed.get("calls"), list): + calls = parsed.get("calls", []) + elif "tool_calls" in parsed and isinstance(parsed.get("tool_calls"), list): + calls = parsed.get("tool_calls", []) + elif "function_calls" in parsed and isinstance(parsed.get("function_calls"), list): + calls = parsed.get("function_calls", []) + else: + calls = [] + + # Case B: tuple/list output (e.g., (normal_text, calls)) + elif isinstance(parsed, (list, tuple)): + if len(parsed) >= 1: + normal_text = parsed[0] + if len(parsed) >= 2 and isinstance(parsed[1], list): + calls = parsed[1] + if normal_text is None: + normal_text = response + + # Case C: unknown type + else: + normal_text = response + calls = [] + + # Ensure types + if normal_text is None: + normal_text = response + if not isinstance(calls, list): + calls = [] # Convert to OpenAI format openai_message = self._convert_to_openai_message(normal_text, calls) - return {"openai_message": openai_message, "parsed_result": parsed, "success": True} + # For downstream code, we still provide parsed_result in the "tau expected" schema + parsed_result = {"normal_text": normal_text, "calls": calls} + + return {"openai_message": openai_message, "parsed_result": parsed_result, "success": True} except Exception as e: - logger.warning(f"Parsing failed with error: {str(e)}") + logger.warning(f"Parsing failed with error: {type(e).__name__}: {str(e)}") return {"openai_message": None, "parsed_result": None, "success": False, "error": str(e)} def _convert_to_openai_message(self, normal_text: str, calls: list[dict[str, Any]]) -> OpenAIAssistantMessage: diff --git a/examples/tau-bench/sglang_tool_parser.py b/examples/tau-bench/sglang_tool_parser.py index ea4f380ba..bea45546e 100644 --- a/examples/tau-bench/sglang_tool_parser.py +++ b/examples/tau-bench/sglang_tool_parser.py @@ -1,31 +1,63 @@ +import logging from typing import Any -from sglang.srt.function_call.function_call_parser import FunctionCallParser -from sglang.srt.managers.io_struct import Function, Tool +logger = logging.getLogger(__name__) +# We make sglang an OPTIONAL dependency. +# On many dev machines (e.g., macOS / CPU-only / Python 3.13), importing sglang.srt can +# pull in heavy deps like triton. For Tau-Bench integration we should not hard-require it. -def parse_tools(response: str, tools: list[dict[str, Any]], parser: str = "qwen25"): +_SGLANG_AVAILABLE = False +_FunctionCallParser = None +_Function = None +_Tool = None + +try: + # These imports may trigger triton dependency in some sglang versions. + from sglang.srt.function_call.function_call_parser import FunctionCallParser as _FunctionCallParser + from sglang.srt.managers.io_struct import Function as _Function + from sglang.srt.managers.io_struct import Tool as _Tool + + _SGLANG_AVAILABLE = True +except Exception as e: + logger.warning(f"sglang tool parser unavailable (optional). Falling back to no-tool parsing. Error: {e}") + + +def parse_tools(tools_info: list[dict[str, Any]], text: str) -> dict[str, Any]: """ - This function mimics the function call parser API from - https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py#L952 - But running locally + Parse tool calls from model output text. + + Returns a dict compatible with openai_tool_adapter expectations: + { + "success": bool, + "error": str | None, + "parsed_result": {"normal_text": str, "calls": list[dict]} + } """ - tools_list = [ - Tool( - function=Function( - name=tool["function"]["name"], - description=tool["function"]["description"], - parameters=tool["function"]["parameters"], - ), - type=tool["type"], - ) - for tool in tools - ] - parser = FunctionCallParser(tools=tools_list, tool_call_parser=parser) - - normal_text, calls = parser.parse_non_stream(response) - - return { - "normal_text": normal_text, - "calls": [call.model_dump() for call in calls], # Convert pydantic objects to dictionaries - } + # Fallback: treat everything as normal text; no tool calls. + if not _SGLANG_AVAILABLE: + return { + "success": True, + "error": None, + "parsed_result": {"normal_text": text, "calls": []}, + } + + # If sglang parser is available, use it. + try: + tools = [_Tool(**t) for t in tools_info] + functions = [_Function.from_tool(t) for t in tools] # depending on sglang version + parser = _FunctionCallParser(functions) + + normal_text, calls = parser.parse(text) + # calls should be list of dict: [{"name":..., "parameters": "...json..."}] + return { + "success": True, + "error": None, + "parsed_result": {"normal_text": normal_text, "calls": calls or []}, + } + except Exception as e: + return { + "success": False, + "error": repr(e), + "parsed_result": {"normal_text": text, "calls": []}, + } diff --git a/examples/tau-bench/trainable_agents.py b/examples/tau-bench/trainable_agents.py index 5cfc3f43d..89d1013a0 100644 --- a/examples/tau-bench/trainable_agents.py +++ b/examples/tau-bench/trainable_agents.py @@ -42,16 +42,42 @@ def call_to_action_sglang(calls: list[Any], text_response: str) -> Action: """ # Default action if no action was found. action = Action(name=RESPOND_ACTION_NAME, kwargs={"content": text_response}) - if calls: - if len(calls) > 1: - logger.debug("Multiple tool calls identified, only taking first.") - tool_call = calls[0] - params = json.loads(tool_call["parameters"]) - if not isinstance(params, dict): - logger.warning(f"{params} does not follow dict structure for action") + if not calls: + return action + if len(calls) > 1: + logger.debug("Multiple tool calls identified, only taking first.") + + tool_call = calls[0] + tool_name = tool_call.get("name") + + # tau-bench adapter sometimes uses "parameters" (stringified json) + # but be defensive: could be dict, could be "arguments", could be missing. + raw_params = tool_call.get("parameters", None) + if raw_params is None: + raw_params = tool_call.get("arguments", None) + + params: Any = {} + try: + if isinstance(raw_params, dict): + params = raw_params + elif isinstance(raw_params, str): + raw_params_str = raw_params.strip() + params = json.loads(raw_params_str) if raw_params_str else {} else: - action = Action(name=tool_call["name"], kwargs=params) - return action + params = {} + except Exception as e: + logger.warning(f"Failed to parse tool params: {e}; raw_params={raw_params!r}") + return action + + if not isinstance(params, dict): + logger.warning(f"Tool params is not a dict: {params!r}") + return action + + if not tool_name: + logger.warning(f"Tool call missing name: {tool_call!r}") + return action + + return Action(name=tool_name, kwargs=params) TOOL_INSTRUCTION = ( @@ -201,6 +227,7 @@ async def asolve( # Initialize environment and state state = GenerateState(rollout_args) url = f"http://{rollout_args.sglang_router_ip}:" f"{rollout_args.sglang_router_port}/generate" + step_id = 0 # Get initial environment state obs, info = self._initialize_environment(env, task_index) @@ -210,13 +237,15 @@ async def asolve( prompt_text, prompt_token_ids = self._prepare_prompt_tokens(state, messages) # Initialize tracking variables - loss_masks = [] - response_token_ids = [] + loss_masks: list[int] = [] + response_token_ids: list[int] = [] total_reward = 0.0 # Initialize result res = InteractionResult(prompt=prompt_text, reward=0, messages=[], info={}) + env_response = None + # Multi-turn interaction loop for _ in range(max_num_steps): # Prepare payload for sglang @@ -229,44 +258,97 @@ async def asolve( # Send request to sglang server output = await self._call_llm(url, payload) + finish_reason = output.get("meta_info", {}).get("finish_reason", {}) + finish_type = finish_reason.get("type") # Check for abort - if output["meta_info"]["finish_reason"]["type"] == "abort": + if finish_type == "abort": + if getattr(self, "episode_logger", None): + self.episode_logger.log_step( + { + "step_id": step_id, + "phase": "llm_output", + "finish_type": finish_type, + "assistant_raw": output.get("text", ""), + "messages_len": len(messages), + } + ) res.status = Status.ABORTED return self._build_final_result( res, total_reward, info, messages, loss_masks, prompt_token_ids, response_token_ids ) - response = output["text"] + raw_response = output.get("text", "") + response = raw_response # Remove end of conversation token if present if response.endswith("<|im_end|>"): response = response[:-10] - # Parse tool calls using OpenAI adapter - logger.debug(f"Using OpenAI adapter to parse response: {response[:100]}...") + if getattr(self, "episode_logger", None): + self.episode_logger.log_step( + { + "step_id": step_id, + "phase": "llm_output", + "finish_type": finish_type, + "assistant_raw": raw_response, + "assistant": response, + "text_input_chars": len(text_input) if isinstance(text_input, str) else None, + "messages_len": len(messages), + } + ) + try: openai_result = self._parse_tool(response) - logger.debug(f"OpenAI adapter result: success={openai_result['success']}") - - if not openai_result["success"]: - logger.warning(f"OpenAI adapter failed: {openai_result['error']}") - logger.warning( - f"rollout response: {response} can not be parsed into " f"tool calls {openai_result['error']}" - ) + parse_ok = bool(openai_result.get("success", False)) + + if not parse_ok: + if getattr(self, "episode_logger", None): + self.episode_logger.log_step( + { + "step_id": step_id, + "phase": "tool_parse", + "tool_parse_ok": False, + "tool_parse_error": openai_result.get("error"), + "assistant": response, + } + ) res.status = Status.ABORTED return self._build_final_result( res, total_reward, info, messages, loss_masks, prompt_token_ids, response_token_ids ) - # Extract parsed results - parsed = openai_result["parsed_result"] - logger.debug( - f"Successfully parsed - normal_text: '{parsed['normal_text']}', " f"calls: {parsed['calls']}" - ) + parsed = openai_result.get("parsed_result") or {} + if getattr(self, "episode_logger", None): + self.episode_logger.log_step( + { + "step_id": step_id, + "phase": "tool_parse", + "tool_parse_ok": True, + "normal_text": parsed.get("normal_text"), + "tool_calls": parsed.get("calls"), + } + ) + + try: + logger.debug( + f"Successfully parsed - normal_text: '{parsed.get('normal_text')}', calls: {parsed.get('calls')}" + ) + except Exception: + pass except Exception as e: logger.warning(f"Exception in OpenAI adapter: {e}") - logger.warning(f"rollout response: {response} can not be parsed into " f"tool calls {e}") + logger.warning(f"rollout response: {response} can not be parsed into tool calls") + if getattr(self, "episode_logger", None): + self.episode_logger.log_step( + { + "step_id": step_id, + "phase": "tool_parse", + "tool_parse_ok": False, + "tool_parse_error": repr(e), + "assistant_raw": response, + } + ) res.status = Status.ABORTED return self._build_final_result( res, total_reward, info, messages, loss_masks, prompt_token_ids, response_token_ids @@ -280,33 +362,53 @@ async def asolve( # Execute action in environment agent_content, calls = parsed["normal_text"], parsed["calls"] - logger.debug(f"Creating action from - content: '{agent_content}', " f"calls: {calls}") action = call_to_action_sglang(calls, agent_content) - logger.debug(f"Created action: {action}") try: env_response = await self._execute_tool(env, action) + + if getattr(self, "episode_logger", None): + info_dict = None + try: + info_dict = env_response.info.model_dump() + except Exception: + info_dict = None + + self.episode_logger.log_step( + { + "step_id": step_id, + "phase": "env_step", + "env_step_ok": True, + "action_name": action.name, + "action_kwargs": getattr(action, "kwargs", None), + "reward": env_response.reward, + "done": env_response.done, + "observation": env_response.observation, + "info_keys": list(info_dict.keys()) if isinstance(info_dict, dict) else None, + } + ) + except Exception as e: - logger.warning("Environment step failed, this is usually related to " "the User simulation call.") - logger.warning(f"Error: {e}") + if getattr(self, "episode_logger", None): + self.episode_logger.log_step( + { + "step_id": step_id, + "phase": "env_step", + "env_step_ok": False, + "error": repr(e), + "action_name": action.name, + "action_kwargs": getattr(action, "kwargs", None), + } + ) res.status = Status.ABORTED return self._build_final_result( res, total_reward, info, messages, loss_masks, prompt_token_ids, response_token_ids ) - logger.debug(f"Environment response: reward={env_response.reward}, " f"done={env_response.done}") - # Update message history based on action type if action.name != RESPOND_ACTION_NAME: - messages.append( - { - "role": "tool", - "name": action.name, - "content": env_response.observation, - } - ) + messages.append({"role": "tool", "name": action.name, "content": env_response.observation}) else: - # Direct response from user messages.append({"role": "user", "content": env_response.observation}) # Update token tracking @@ -321,10 +423,15 @@ async def asolve( # Check if done if env_response.done: res.status = Status.COMPLETED + step_id += 1 break + step_id += 1 + # Handle truncation - if not env_response.done: + if env_response is None: + res.status = Status.ABORTED + elif not env_response.done: res.status = Status.TRUNCATED return self._build_final_result( @@ -427,6 +534,7 @@ def __init__( temperature: float = 0.0, rollout_args: dict[str, Any] | None = None, sampling_params: dict[str, Any] | None = None, + episode_logger=None, ): # Initialize the parent ToolCallingAgent super().__init__( @@ -451,6 +559,7 @@ def __init__( } # Initialize OpenAI adapter self.openai_adapter = create_openai_adapter(tools_info=self.tools_info, parser_type="qwen25") + self.episode_logger = episode_logger def agent_factory( @@ -459,6 +568,7 @@ def agent_factory( config: RunConfig, rollout_args: dict[str, Any] | None = None, sampling_params: dict[str, Any] | None = None, + episode_logger=None, ) -> Agent: if config.agent_strategy == "tool-calling": return TrainableToolCallingAgent( @@ -469,6 +579,7 @@ def agent_factory( temperature=config.temperature, rollout_args=rollout_args, sampling_params=sampling_params, + episode_logger=episode_logger, ) else: raise NotImplementedError(f"Unsupported agent strategy: {config.agent_strategy}")