Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions examples/tau-bench/episode_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# examples/tau-bench/episode_logger.py
from __future__ import annotations

import hashlib
import json
import os
import time
from dataclasses import dataclass
from typing import Any


def _truncate(s: str | None, max_chars: int = 8000) -> str | None:
if s is None:
return None
if len(s) <= max_chars:
return s
return s[:max_chars] + f"\n...[truncated {len(s)-max_chars} chars]"


def _sha256_text(s: str) -> str:
return hashlib.sha256(s.encode("utf-8", errors="ignore")).hexdigest()


@dataclass
class EpisodeLogger:
log_dir: str
run_meta: dict[str, Any]

def __post_init__(self):
os.makedirs(self.log_dir, exist_ok=True)
self._jsonl_path = os.path.join(self.log_dir, "episode.jsonl")
self._summary_path = os.path.join(self.log_dir, "summary.json")
self._run_meta_path = os.path.join(self.log_dir, "run_meta.json")

with open(self._run_meta_path, "w", encoding="utf-8") as f:
json.dump(self.run_meta, f, ensure_ascii=False, indent=2)

def log_step(self, rec: dict[str, Any]):
# Controlling field size: Truncate long text and append a hash.
for k in [
"assistant_raw",
"assistant",
"user_text",
"observation",
"tool_result",
"env_state",
"normal_text",
"tool_parse_error",
"error",
]:
if k in rec and isinstance(rec[k], str):
rec[k + "_hash"] = _sha256_text(rec[k])
rec[k] = _truncate(rec[k])

rec["ts"] = time.time()
with open(self._jsonl_path, "a", encoding="utf-8") as f:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")

def finalize(self, summary: dict[str, Any]):
with open(self._summary_path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
36 changes: 36 additions & 0 deletions examples/tau-bench/generate_with_tau.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@

import logging
import os
import time
import uuid
from typing import Any

from episode_logger import EpisodeLogger
from tau_bench.envs import get_env
from tau_bench.types import RunConfig
from trainable_agents import InteractionResult, Status, agent_factory
Expand All @@ -29,6 +32,7 @@
"model_provider": "auto_router", # Unused, required
"model": "qwen3-4b", # Unused, required
"user_model_provider": "gemini",
# "user_model_provider": "stub",
}
# Replace with your actual API key for user sim
GEMINI_API_KEY = "NONE"
Expand Down Expand Up @@ -97,6 +101,10 @@ def res_to_sample(res: InteractionResult, task_index: int) -> Sample:
return sample


def _default_run_root() -> str:
return os.environ.get("TAU_RUN_DIR", os.path.join(os.getcwd(), "runs", "tau1"))


async def generate(args: dict[str, Any], sample: Sample, sampling_params: dict) -> Sample:
"""
Generate a complete agent-environment interaction trajectory for tau-bench.
Expand All @@ -121,6 +129,25 @@ async def generate(args: dict[str, Any], sample: Sample, sampling_params: dict)

# Extract task index from sample prompt
task_index = int(sample.prompt)

run_root = _default_run_root()
run_id = time.strftime("%Y%m%d_%H%M%S") + f"_{os.getpid()}_{uuid.uuid4().hex[:8]}"
episode_dir = os.path.join(run_root, run_id, f"task_{task_index:06d}")

run_meta = {
"run_id": run_id,
"task_index": task_index,
"tau_config": TAU_CONFIGS,
"sampling_params": {
k: sampling_params.get(k)
for k in ["temperature", "top_p", "top_k", "max_new_tokens"]
if k in sampling_params
},
"pid": os.getpid(),
}
os.makedirs(episode_dir, exist_ok=True)
ep_logger = EpisodeLogger(log_dir=episode_dir, run_meta=run_meta)

logger.info(f"Starting agent-environment interaction for task {task_index}")

# Initialize tau-bench environment
Expand All @@ -140,6 +167,7 @@ async def generate(args: dict[str, Any], sample: Sample, sampling_params: dict)
config=tau_config,
rollout_args=args,
sampling_params=sampling_params,
episode_logger=ep_logger,
)

# Execute agent-environment interaction
Expand All @@ -149,5 +177,13 @@ async def generate(args: dict[str, Any], sample: Sample, sampling_params: dict)
# Convert to slime Sample format
result_sample = res_to_sample(interaction_result, task_index)

ep_logger.finalize(
{
"status": str(interaction_result.status),
"reward": interaction_result.reward,
"response_length": getattr(interaction_result, "response_length", None),
}
)

logger.info(f"Finished agent-environment interaction for task {task_index}")
return result_sample
78 changes: 70 additions & 8 deletions examples/tau-bench/openai_tool_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@
logger = logging.getLogger(__name__)


def _parse_tools_compat(parse_tools_fn, response: str, tools_info, parser_type: str | None):
"""
Compatible wrapper for parse_tools() across versions.

Some versions: parse_tools(response, tools_info)
Some versions: parse_tools(response, tools_info, parser_type)
"""
try:
# try 3-arg signature first
return parse_tools_fn(response, tools_info, parser_type)
except TypeError:
# fallback to 2-arg signature
return parse_tools_fn(response, tools_info)


@dataclass
class OpenAIToolCall:
"""OpenAI format tool call structure"""
Expand Down Expand Up @@ -62,20 +77,67 @@ def parse_response_to_openai_format(self, response: str) -> dict[str, Any]:
Exception: Thrown when parsing fails
"""
try:
# Use existing parser to parse tool calls
parsed = parse_tools(response, self.tools_info, self.parser_type)

# Extract parsing results
normal_text = parsed["normal_text"]
calls = parsed["calls"]
parsed = _parse_tools_compat(parse_tools, response, self.tools_info, self.parser_type)

# --- normalize parsed output ---
normal_text = None
calls = []

# Case A: dict output
if isinstance(parsed, dict):
# common keys
if "normal_text" in parsed:
normal_text = parsed.get("normal_text")
elif "text" in parsed:
normal_text = parsed.get("text")
elif "content" in parsed:
normal_text = parsed.get("content")
elif "assistant" in parsed:
normal_text = parsed.get("assistant")
else:
# fallback: if no known key, just use raw response
normal_text = response

# tool calls keys variants
if "calls" in parsed and isinstance(parsed.get("calls"), list):
calls = parsed.get("calls", [])
elif "tool_calls" in parsed and isinstance(parsed.get("tool_calls"), list):
calls = parsed.get("tool_calls", [])
elif "function_calls" in parsed and isinstance(parsed.get("function_calls"), list):
calls = parsed.get("function_calls", [])
else:
calls = []

# Case B: tuple/list output (e.g., (normal_text, calls))
elif isinstance(parsed, (list, tuple)):
if len(parsed) >= 1:
normal_text = parsed[0]
if len(parsed) >= 2 and isinstance(parsed[1], list):
calls = parsed[1]
if normal_text is None:
normal_text = response

# Case C: unknown type
else:
normal_text = response
calls = []

# Ensure types
if normal_text is None:
normal_text = response
if not isinstance(calls, list):
calls = []

# Convert to OpenAI format
openai_message = self._convert_to_openai_message(normal_text, calls)

return {"openai_message": openai_message, "parsed_result": parsed, "success": True}
# For downstream code, we still provide parsed_result in the "tau expected" schema
parsed_result = {"normal_text": normal_text, "calls": calls}

return {"openai_message": openai_message, "parsed_result": parsed_result, "success": True}

except Exception as e:
logger.warning(f"Parsing failed with error: {str(e)}")
logger.warning(f"Parsing failed with error: {type(e).__name__}: {str(e)}")
return {"openai_message": None, "parsed_result": None, "success": False, "error": str(e)}

def _convert_to_openai_message(self, normal_text: str, calls: list[dict[str, Any]]) -> OpenAIAssistantMessage:
Expand Down
82 changes: 57 additions & 25 deletions examples/tau-bench/sglang_tool_parser.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,63 @@
import logging
from typing import Any

from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import Function, Tool
logger = logging.getLogger(__name__)

# We make sglang an OPTIONAL dependency.
# On many dev machines (e.g., macOS / CPU-only / Python 3.13), importing sglang.srt can
# pull in heavy deps like triton. For Tau-Bench integration we should not hard-require it.

def parse_tools(response: str, tools: list[dict[str, Any]], parser: str = "qwen25"):
_SGLANG_AVAILABLE = False
_FunctionCallParser = None
_Function = None
_Tool = None

try:
# These imports may trigger triton dependency in some sglang versions.
from sglang.srt.function_call.function_call_parser import FunctionCallParser as _FunctionCallParser
from sglang.srt.managers.io_struct import Function as _Function
from sglang.srt.managers.io_struct import Tool as _Tool

_SGLANG_AVAILABLE = True
except Exception as e:
logger.warning(f"sglang tool parser unavailable (optional). Falling back to no-tool parsing. Error: {e}")


def parse_tools(tools_info: list[dict[str, Any]], text: str) -> dict[str, Any]:
"""
This function mimics the function call parser API from
https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py#L952
But running locally
Parse tool calls from model output text.

Returns a dict compatible with openai_tool_adapter expectations:
{
"success": bool,
"error": str | None,
"parsed_result": {"normal_text": str, "calls": list[dict]}
}
"""
tools_list = [
Tool(
function=Function(
name=tool["function"]["name"],
description=tool["function"]["description"],
parameters=tool["function"]["parameters"],
),
type=tool["type"],
)
for tool in tools
]
parser = FunctionCallParser(tools=tools_list, tool_call_parser=parser)

normal_text, calls = parser.parse_non_stream(response)

return {
"normal_text": normal_text,
"calls": [call.model_dump() for call in calls], # Convert pydantic objects to dictionaries
}
# Fallback: treat everything as normal text; no tool calls.
if not _SGLANG_AVAILABLE:
return {
"success": True,
"error": None,
"parsed_result": {"normal_text": text, "calls": []},
}

# If sglang parser is available, use it.
try:
tools = [_Tool(**t) for t in tools_info]
functions = [_Function.from_tool(t) for t in tools] # depending on sglang version
parser = _FunctionCallParser(functions)

normal_text, calls = parser.parse(text)
# calls should be list of dict: [{"name":..., "parameters": "...json..."}]
return {
"success": True,
"error": None,
"parsed_result": {"normal_text": normal_text, "calls": calls or []},
}
except Exception as e:
return {
"success": False,
"error": repr(e),
"parsed_result": {"normal_text": text, "calls": []},
}
Loading