Skip to content

Commit c436389

Browse files
authored
Realtime: one guardrail trip event per response (#1458)
There was a problem with the current implementation, where for a single repsonse, we might have many different guardrails fire. We should have at most one per response.
1 parent a9b8ab3 commit c436389

File tree

2 files changed

+71
-13
lines changed

2 files changed

+71
-13
lines changed

src/agents/realtime/session.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898
self._stored_exception: Exception | None = None
9999

100100
# Guardrails state tracking
101-
self._interrupted_by_guardrail = False
101+
self._interrupted_response_ids: set[str] = set()
102102
self._item_transcripts: dict[str, str] = {} # item_id -> accumulated transcript
103103
self._item_guardrail_run_counts: dict[str, int] = {} # item_id -> run count
104104
self._debounce_text_length = self._run_config.get("guardrails_settings", {}).get(
@@ -242,7 +242,8 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
242242

243243
if current_length >= next_run_threshold:
244244
self._item_guardrail_run_counts[item_id] += 1
245-
self._enqueue_guardrail_task(self._item_transcripts[item_id])
245+
# Pass response_id so we can ensure only a single interrupt per response
246+
self._enqueue_guardrail_task(self._item_transcripts[item_id], event.response_id)
246247
elif event.type == "item_updated":
247248
is_new = not any(item.item_id == event.item.item_id for item in self._history)
248249
self._history = self._get_new_history(self._history, event.item)
@@ -274,7 +275,6 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
274275
# Clear guardrail state for next turn
275276
self._item_transcripts.clear()
276277
self._item_guardrail_run_counts.clear()
277-
self._interrupted_by_guardrail = False
278278

279279
await self._put_event(
280280
RealtimeAgentEndEvent(
@@ -442,7 +442,7 @@ def _get_new_history(
442442
# Otherwise, add it to the end
443443
return old_history + [event]
444444

445-
async def _run_output_guardrails(self, text: str) -> bool:
445+
async def _run_output_guardrails(self, text: str, response_id: str) -> bool:
446446
"""Run output guardrails on the given text. Returns True if any guardrail was triggered."""
447447
combined_guardrails = self._current_agent.output_guardrails + self._run_config.get(
448448
"output_guardrails", []
@@ -455,7 +455,8 @@ async def _run_output_guardrails(self, text: str) -> bool:
455455
output_guardrails.append(guardrail)
456456
seen_ids.add(guardrail_id)
457457

458-
if not output_guardrails or self._interrupted_by_guardrail:
458+
# If we've already interrupted this response, skip
459+
if not output_guardrails or response_id in self._interrupted_response_ids:
459460
return False
460461

461462
triggered_results = []
@@ -475,8 +476,12 @@ async def _run_output_guardrails(self, text: str) -> bool:
475476
continue
476477

477478
if triggered_results:
478-
# Mark as interrupted to prevent multiple interrupts
479-
self._interrupted_by_guardrail = True
479+
# Double-check: bail if already interrupted for this response
480+
if response_id in self._interrupted_response_ids:
481+
return False
482+
483+
# Mark as interrupted immediately (before any awaits) to minimize race window
484+
self._interrupted_response_ids.add(response_id)
480485

481486
# Emit guardrail tripped event
482487
await self._put_event(
@@ -502,10 +507,10 @@ async def _run_output_guardrails(self, text: str) -> bool:
502507

503508
return False
504509

505-
def _enqueue_guardrail_task(self, text: str) -> None:
510+
def _enqueue_guardrail_task(self, text: str, response_id: str) -> None:
506511
# Runs the guardrails in a separate task to avoid blocking the main loop
507512

508-
task = asyncio.create_task(self._run_output_guardrails(text))
513+
task = asyncio.create_task(self._run_output_guardrails(text, response_id))
509514
self._guardrail_tasks.add(task)
510515

511516
# Add callback to remove completed tasks and handle exceptions

tests/realtime/test_session.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,6 @@ async def test_transcript_delta_triggers_guardrail_at_threshold(
10501050
await self._wait_for_guardrail_tasks(session)
10511051

10521052
# Should have triggered guardrail and interrupted
1053-
assert session._interrupted_by_guardrail is True
10541053
assert mock_model.interrupts_called == 1
10551054
assert len(mock_model.sent_messages) == 1
10561055
assert "triggered_guardrail" in mock_model.sent_messages[0]
@@ -1187,14 +1186,12 @@ async def test_turn_ended_clears_guardrail_state(
11871186
# Wait for async guardrail tasks to complete
11881187
await self._wait_for_guardrail_tasks(session)
11891188

1190-
assert session._interrupted_by_guardrail is True
11911189
assert len(session._item_transcripts) == 1
11921190

11931191
# End turn
11941192
await session.on_event(RealtimeModelTurnEndedEvent())
11951193

11961194
# State should be cleared
1197-
assert session._interrupted_by_guardrail is False
11981195
assert len(session._item_transcripts) == 0
11991196
assert len(session._item_guardrail_run_counts) == 0
12001197

@@ -1259,7 +1256,6 @@ async def test_agent_output_guardrails_triggered(self, mock_model, triggered_gua
12591256
await session.on_event(transcript_event)
12601257
await self._wait_for_guardrail_tasks(session)
12611258

1262-
assert session._interrupted_by_guardrail is True
12631259
assert mock_model.interrupts_called == 1
12641260
assert len(mock_model.sent_messages) == 1
12651261
assert "triggered_guardrail" in mock_model.sent_messages[0]
@@ -1272,6 +1268,63 @@ async def test_agent_output_guardrails_triggered(self, mock_model, triggered_gua
12721268
assert len(guardrail_events) == 1
12731269
assert guardrail_events[0].message == "this is more than ten characters"
12741270

1271+
@pytest.mark.asyncio
1272+
async def test_concurrent_guardrail_tasks_interrupt_once_per_response(self, mock_model):
1273+
"""Even if multiple guardrail tasks trigger concurrently for the same response_id,
1274+
only the first should interrupt and send a message."""
1275+
import asyncio
1276+
1277+
# Barrier to release both guardrail tasks at the same time
1278+
start_event = asyncio.Event()
1279+
1280+
async def async_trigger_guardrail(context, agent, output):
1281+
await start_event.wait()
1282+
return GuardrailFunctionOutput(
1283+
output_info={"reason": "concurrent"}, tripwire_triggered=True
1284+
)
1285+
1286+
concurrent_guardrail = OutputGuardrail(
1287+
guardrail_function=async_trigger_guardrail, name="concurrent_trigger"
1288+
)
1289+
1290+
run_config: RealtimeRunConfig = {
1291+
"output_guardrails": [concurrent_guardrail],
1292+
"guardrails_settings": {"debounce_text_length": 5},
1293+
}
1294+
1295+
# Use a minimal agent (guardrails from run_config)
1296+
agent = RealtimeAgent(name="agent")
1297+
session = RealtimeSession(mock_model, agent, None, run_config=run_config)
1298+
1299+
# Two deltas for same item and response to enqueue two guardrail tasks
1300+
await session.on_event(
1301+
RealtimeModelTranscriptDeltaEvent(
1302+
item_id="item_1", delta="12345", response_id="resp_same"
1303+
)
1304+
)
1305+
await session.on_event(
1306+
RealtimeModelTranscriptDeltaEvent(
1307+
item_id="item_1", delta="67890", response_id="resp_same"
1308+
)
1309+
)
1310+
1311+
# Wait until both tasks are enqueued
1312+
for _ in range(50):
1313+
if len(session._guardrail_tasks) >= 2:
1314+
break
1315+
await asyncio.sleep(0.01)
1316+
1317+
# Release both tasks concurrently
1318+
start_event.set()
1319+
1320+
# Wait for completion
1321+
if session._guardrail_tasks:
1322+
await asyncio.gather(*session._guardrail_tasks, return_exceptions=True)
1323+
1324+
# Only one interrupt and one message should be sent
1325+
assert mock_model.interrupts_called == 1
1326+
assert len(mock_model.sent_messages) == 1
1327+
12751328

12761329
class TestModelSettingsIntegration:
12771330
"""Test suite for model settings integration in RealtimeSession."""

0 commit comments

Comments
 (0)