Skip to content
Open
204 changes: 104 additions & 100 deletions validator/modules/llm_judge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from loguru import logger
from huggingface_hub import HfApi
from typing import List, Dict, Any
from validator.modules.llm_judge.prompt import get_prompt
from validator.modules.llm_judge.prompt import get_prompt,template_str
from validator.modules.llm_judge.utils import download_file
from validator.modules.llm_judge.constant import SUPPORTED_BASE_MODELS
from validator.exceptions import LLMJudgeException, InvalidModelParametersException
from validator.modules.llm_judge.template import template_dict
from peft import PeftModel
from jinja2 import Environment
from transformers import AutoTokenizer, AutoModelForCausalLM
from validator.modules.base import (
BaseValidationModule,
Expand All @@ -37,7 +38,7 @@
class LLMJudgeConfig(BaseConfig):
gen_batch_size: int = 1
eval_batch_size: int = 16
gen_temperature: float = 0.1
gen_temperature: float = 0.7


class LLMJudgeMetrics(BaseMetrics):
Expand Down Expand Up @@ -144,7 +145,6 @@ def _load_model(self, repo_id: str, revision: str = "main", max_params: int = No
model_kwargs = dict(
trust_remote_code=True,
torch_dtype=compute_dtype,
use_cache=False,
device_map="auto",
)
if is_lora:
Expand All @@ -157,6 +157,18 @@ def _load_model(self, repo_id: str, revision: str = "main", max_params: int = No
with open("judge/adapter_config.json", "r") as f:
adapter_config = json.load(f)
base_model = adapter_config["base_model_name_or_path"]
if base_model in SUPPORTED_BASE_MODELS:
logger.info(
f"LoRA's base model '{base_model}' is in SUPPORTED_BASE_MODELS. "
f"Using it for tokenizer."
)
else:
logger.error(
f"LoRA's base model '{base_model}' is not in SUPPORTED_BASE_MODELS. "
f"Marking assignment as failed."
)
raise

self.hf_tokenizer = AutoTokenizer.from_pretrained(
base_model, trust_remote_code=True, use_fast=True, padding_side="left"
)
Expand Down Expand Up @@ -189,88 +201,6 @@ def _load_model(self, repo_id: str, revision: str = "main", max_params: int = No
f"Model parameters {total} exceed limit {max_params}"
)

def _construct_conversation_template(
self,
conversation: List[Dict[str, str]],
base_model: str,
) -> str:
try:
if base_model not in template_dict:
logger.info(f"Template {base_model} not found, using default")
base_model = "default"

template = template_dict[base_model]

conversation_parts = []

# Validate conversation structure
if not isinstance(conversation, dict):
raise LLMJudgeException(
f"Conversation must be a dict, got {type(conversation)}"
)

if "conversations" not in conversation:
raise LLMJudgeException(
f"Conversation dict must have 'conversations' key"
)

if not conversation["conversations"]:
raise LLMJudgeException(f"Conversation 'conversations' list is empty")

# Use provided system_text or fall back to template default
if template.system_format:
system_prompt = (
conversation["system"] if "system" in conversation else None
)
system_content = (
system_prompt if system_prompt else "You are a helpful assistant."
)
if system_content:
formatted_system = template.system_format.format(
content=system_content
)
conversation_parts.append(formatted_system)

# Multi-turn conversation: format each message according to template
for msg in conversation["conversations"]:
if (
not isinstance(msg, dict)
or "role" not in msg
or "content" not in msg
):
logger.warning(f"Skipping invalid message: {msg}")
continue

if msg["role"] == "user":
user_text = template.user_format.format(
content=msg["content"],
stop_token=self.hf_tokenizer.eos_token,
)
conversation_parts.append(user_text)
elif msg["role"] == "assistant":
assistant_text = template.assistant_format.format(
content=msg["content"],
stop_token=self.hf_tokenizer.eos_token,
)
conversation_parts.append(assistant_text)

conversation_format = "".join(conversation_parts)

if not conversation_format.strip():
logger.error(
f"Empty template generated. Template: {base_model}, Conversation: {conversation}, Parts: {conversation_parts}"
)
raise LLMJudgeException(
f"Generated conversation template is empty after formatting"
)

except Exception as e:
raise LLMJudgeException(
f"Failed to construct conversation template: {e}"
) from e

return conversation_format

def _generate_response(
self,
context_length: int,
Expand All @@ -294,9 +224,28 @@ def _generate_response(
# Apply chat template with fallback
batch_conversation_templates = []
for conversation in batch_conversations:
template = self._construct_conversation_template(
conversation,
base_model=base_model,

messages = []
if "system" in conversation:
messages.append({
"role": "system",
"content": conversation["system"]
})

messages += conversation["conversations"]
tools_for_template = conversation.get("tools", None)
try:
if isinstance(tools_for_template, str):
tools_for_template = json.loads(tools_for_template)
except Exception:
# leave tools_for_template as-is if parsing fails
pass
template = self.hf_tokenizer.apply_chat_template(
messages,
tools=tools_for_template,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)

# Validate template is not empty
Expand Down Expand Up @@ -353,10 +302,10 @@ def _generate_response(
outputs = self.hf_model.generate(
**model_inputs,
max_new_tokens=max_length,
temperature=self.config.gen_temperature,
temperature=self.config.gen_temperature, # Non thinking-General 0.7 ,Reasoning 1
do_sample=True,
top_p=0.95, # Nucleus sampling for stability
top_k=50, # Limit vocabulary for stability
top_p=0.8, # Non thinking-General 0.8 ,Reasoning 0.95
top_k=20, #
pad_token_id=self.hf_tokenizer.eos_token_id,
eos_token_id=self.hf_tokenizer.eos_token_id,
)
Expand Down Expand Up @@ -608,29 +557,79 @@ def _load_jsonl_conversations(
conversation_to_process = []
reference_response = None
tools_info = None

pending_tool_call_ids: list[str] = []
tool_call_counter = 0
if "conversations" in json_data:
conversations = json_data["conversations"]
if isinstance(conversations, list) and conversations:
# Filter valid messages
for msg in conversations:
role = msg.get("role", "")
content = msg.get("content", "").strip()
if (
role in ["user", "assistant", "function_call"]
and content
):
if not content:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1. Reference extraction can be misaligned with the actual prompt

You extract the reference using:

last_msg = conversations[-1]

but you truncate the processed messages with:

conversation_to_process = conversation_to_process[:-1]

Problem:

  • conversations = raw input
  • conversation_to_process = filtered + transformed version

If any messages were:

  • skipped (empty content),
  • transformed (e.g. function_callassistant),
  • or failed parsing,

then the “last raw message” may not match the “last processed message”.

You may remove one message from the prompt, but use a different one as the reference.


2. function_call reference is raw JSON string (may not match your evaluation goal)

When the last message is a function_call, you do:

reference_response = last_msg["content"]

This gives you something like:

{"name":"get_weather","arguments":{"city":"Toronto"}}

So your reference is:

  • a raw JSON string, not
  • a structured tool call, nor
  • a natural language answer

This is only correct if your evaluation expects:

  • exact string match of the function call JSON

Otherwise it may be inconsistent with:

  • how your template represents tool calls (tool_calls structure)
  • or how your model outputs them

3. Tool call ↔ observation matching is simplified (not robust)

You assign each observation to the most recent tool call:

for prev_msg in reversed(conversation_to_process):
    if prev_msg.get("role") == "assistant" and prev_msg.get("tool_calls"):
        tool_call_id = prev_msg["tool_calls"][0]["id"]
        break

This assumes:

  • one tool call at a time
  • one observation per call
  • strictly sequential flow

Works fine for simple ReAct-style traces like:

assistant → tool_call
tool → observation
assistant → next step

But breaks or becomes ambiguous if:

  • multiple tool calls in one assistant message
  • multiple observations
  • parallel or interleaved calls

continue
if role == "function_call":
tool_call_counter += 1
tool_call_id = f"call_{tool_call_counter}"

try:
call_data = json.loads(content)
tool_call_msg = {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": tool_call_id,
"type": "function",
"function": {
"name": call_data.get("name", ""),
"arguments": call_data.get("arguments", {})
},
}
],
}
conversation_to_process.append(tool_call_msg)
except (json.JSONDecodeError, KeyError):
tool_call_id = None
conversation_to_process.append(
{"role": "assistant", "content": content}
)
if tool_call_id:
pending_tool_call_ids.append(tool_call_id)

elif role == "observation":
if pending_tool_call_ids:
tool_call_id = pending_tool_call_ids.pop(0)
else:
tool_call_id = "call_unknown"

conversation_to_process.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": content,
}
)
elif role in ["user", "assistant"]:
conversation_to_process.append(
{"role": role, "content": content}
)

# Extract reference response (last assistant or function_call message)
reference_response = None

if conversation_to_process:
last_msg = conversation_to_process[-1]
if last_msg["role"] in ["assistant", "function_call"]:
last_msg = conversations[-1]
if last_msg["role"] in ["assistant"]:
reference_response = last_msg["content"]
conversation_to_process = conversation_to_process[:-1]
elif last_msg["role"] in ["function_call"]:
env = Environment(trim_blocks=True, lstrip_blocks=True)
conversation_template = env.from_string(template_str)
reference_response = conversation_template.render(
messages=[conversation_to_process[-1]], trim_blocks=True,
lstrip_blocks=True)
conversation_to_process = conversation_to_process[:-1]

# Extract tools information if available (for function_call evaluation)
if "tools" in json_data:
Expand All @@ -651,6 +650,8 @@ def _load_jsonl_conversations(
continue

input_conversations_data["conversations"] = conversation_to_process
if tools_info is not None:
input_conversations_data["tools"] = tools_info

input_conversations.append(
{
Expand Down Expand Up @@ -916,7 +917,10 @@ def validate(self, data: LLMJudgeInputData, **kwargs) -> LLMJudgeMetrics:
self._load_model(data.hg_repo_id, data.revision, data.max_params)
except InvalidModelParametersException as e:
# lowest possible reward for invalid model parameters
logger.info(f"Invalid model parameters: {e}")
logger.error(f"Invalid model parameters: {e}")
return LLMJudgeMetrics(score=LOWEST_POSSIBLE_SCORE)
except Exception as e:
logger.error(f"Exception when load model: {e}")
return LLMJudgeMetrics(score=LOWEST_POSSIBLE_SCORE)

# Stage 1: Generate all responses
Expand Down
12 changes: 12 additions & 0 deletions validator/modules/llm_judge/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SUPPORTED_BASE_MODELS = [
# qwen3.5
"Qwen/Qwen3.5-0.8B",
"Qwen/Qwen3.5-0.8B-Base",
"Qwen/Qwen3.5-2B",
"Qwen/Qwen3.5-2B-Base",
"Qwen/Qwen3.5-4B",
"Qwen/Qwen3.5-4B-Base",
"Qwen/Qwen3.5-9B",
"Qwen/Qwen3.5-9B-Base",
"Qwen/Qwen3.5-27B",
]
6 changes: 3 additions & 3 deletions validator/modules/llm_judge/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ dependencies:
- openai>=1.0.0 # OpenAI API client
- httpx # HTTP client for OpenAI requests
- pydantic>=2.0.0 # Data validation and parsing
- transformers==4.49.0 # HuggingFace transformers library
- transformers==5.3.0 # HuggingFace transformers library
- torch>=1.13.1 # PyTorch for model inference
- accelerate>=0.27.2 # For efficient model loading
- loguru>=0.6.0 # Logging library
- huggingface-hub==0.29.1
- huggingface-hub==1.5.0
- tenacity
- peft>=0.10.0,<0.18.0
- peft==0.18.1
- python-dotenv # Load environment variables from .env file
42 changes: 42 additions & 0 deletions validator/modules/llm_judge/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,45 @@ def function_call_ref_eval_prompt(
tools=Tools,
assistant_response=assistant_response,
)

template_str= """{% for message in messages %}

{% if message.role == "system" %}
<system>
{{ message.content }}
</system>

{% elif message.role == "user" %}
<user>
{{ message.content }}
</user>

{% elif message.role == "assistant" %}

{% if message.tool_calls %}
<tool_call>
{% for tool in message.tool_calls %}
<function={{ tool.function.name }}>
{% set args = tool.function.arguments %}
{% if args is string %}
{% set args = args | from_json %}
{% endif %}
{% for key, value in args.items() %}
<parameter={{ key }}>{{ value }}</parameter>
{% endfor %}
</function>
{% endfor %}
</tool_call>
{% else %}
<assistant>
{{ message.content }}
</assistant>
{% endif %}
{% elif message.role == "tool" %}
<tool_response>
{{ message.content }}
</tool_response>

{% endif %}

{% endfor %}"""
Loading