Skip to content

Commit 223406a

Browse files
zkleb-aaiAssemblyAI
andauthored
chore: sync sdk code with DeepLearning repo (#178)
Co-authored-by: AssemblyAI <engineering.sdk@assemblyai.com>
1 parent 98c9ea8 commit 223406a

4 files changed

Lines changed: 204 additions & 14 deletions

File tree

assemblyai/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.64.0"
1+
__version__ = "0.64.1"

assemblyai/streaming/v3/client.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,10 @@ def connect(self, params: StreamingParameters) -> None:
206206
logger.debug("Connected to WebSocket server")
207207

208208
def disconnect(self, terminate: bool = False) -> None:
209-
if terminate and not self._stop_event.is_set():
209+
# Enqueue Terminate even when stop is already set: `_write_message`
210+
# bypasses the stop gate for TerminateSession so the frame still
211+
# reaches the server when the write thread is alive.
212+
if terminate:
210213
self._write_queue.put(TerminateSession())
211214

212215
self._stop_event.set()
@@ -341,7 +344,10 @@ def _handle_message(self, message: EventMessage) -> None:
341344
event_type = StreamingEvents[message.type]
342345

343346
for handler in self._handlers[event_type]:
344-
handler(self, message)
347+
try:
348+
handler(self, message)
349+
except Exception:
350+
logger.exception("on_%s handler raised", event_type.name.lower())
345351

346352
def _parse_message(self, data: Dict[str, Any]) -> Optional[EventMessage]:
347353
if "type" in data:
@@ -385,7 +391,10 @@ def _handle_warning(self, warning: WarningEvent):
385391
"Streaming warning (code=%s): %s", warning.warning_code, warning.warning
386392
)
387393
for handler in self._handlers[StreamingEvents.Warning]:
388-
handler(self, warning)
394+
try:
395+
handler(self, warning)
396+
except Exception:
397+
logger.exception("on_warning handler raised")
389398

390399
def _report_server_error(self, error: ErrorEvent) -> None:
391400
self._server_error_reported = True
@@ -395,6 +404,13 @@ def _report_server_error(self, error: ErrorEvent) -> None:
395404
)
396405
logger.error("Streaming error: %s (code=%s)", error.error, error.error_code)
397406
self._dispatch_error(streaming_error)
407+
# Tear down locally so a server that sends Error without a trailing
408+
# close frame doesn't leave the read loop spinning in recv(timeout=1)
409+
# forever. `_close_websocket` is idempotent; if the trailing close
410+
# does arrive, `_report_connection_closed` will dedup via
411+
# `_server_error_reported`.
412+
self._close_websocket()
413+
self._stop_event.set()
398414

399415
def _report_connection_closed(
400416
self,

assemblyai/streaming/v3/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class StreamingSessionParameters(BaseModel):
105105
keyterms_prompt: Optional[List[str]] = None
106106
filter_profanity: Optional[bool] = None
107107
prompt: Optional[str] = None
108+
interruption_delay: Optional[int] = None
108109

109110

110111
class Encoding(str, Enum):

tests/unit/test_streaming.py

Lines changed: 183 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,37 @@ def mocked_websocket_connect(
650650
assert "continuous_partials=True" in actual_url
651651

652652

653+
def test_client_connect_with_interruption_delay(mocker: MockFixture):
654+
# Given: client + interruption_delay=500 (U3-Pro early-partial override)
655+
actual_url = None
656+
657+
def mocked_websocket_connect(
658+
url: str, additional_headers: dict, open_timeout: float
659+
):
660+
nonlocal actual_url
661+
actual_url = url
662+
663+
mocker.patch(
664+
"assemblyai.streaming.v3.client.websocket_connect",
665+
new=mocked_websocket_connect,
666+
)
667+
_disable_rw_threads(mocker)
668+
client = StreamingClient(
669+
StreamingClientOptions(api_key="test", api_host="api.example.com")
670+
)
671+
params = StreamingParameters(
672+
sample_rate=16000,
673+
speech_model=SpeechModel.u3_rt_pro,
674+
interruption_delay=500,
675+
)
676+
677+
# When: connect
678+
client.connect(params)
679+
680+
# Then: parameter reaches the URL
681+
assert "interruption_delay=500" in actual_url
682+
683+
653684
def test_customer_support_audio_capture_warns_when_enabled(
654685
mocker: MockFixture, caplog: pytest.LogCaptureFixture
655686
):
@@ -986,7 +1017,10 @@ def on_error(self_, err):
9861017
seed_chunks=[b"\x00" * 320] * 50,
9871018
)
9881019

989-
# Then: exactly one on_error with the rich server-error content.
1020+
# Then: exactly one on_error with the rich server-error content. The
1021+
# local websocket has been closed (by _report_server_error). Whether the
1022+
# trailing close-frame race produces an additional "Connection closed"
1023+
# log depends on scheduling, but dedup ensures no second on_error fires.
9901024
assert len(received) == 1, (
9911025
f"expected exactly 1 error, got {len(received)}: {received}"
9921026
)
@@ -1001,19 +1035,10 @@ def on_error(self_, err):
10011035
for rec in caplog.records
10021036
if "Streaming error" in rec.message and "4001" in rec.message
10031037
]
1004-
close_logs = [
1005-
rec
1006-
for rec in caplog.records
1007-
if "Connection closed" in rec.message and "4001" in rec.message
1008-
]
10091038
assert len(error_logs) == 1, (
10101039
f"expected exactly 1 Streaming-error log, got {len(error_logs)}"
10111040
)
10121041
assert error_logs[0].levelno == logging.ERROR
1013-
assert len(close_logs) == 1, (
1014-
f"expected exactly 1 Connection-closed log, got {len(close_logs)}"
1015-
)
1016-
assert close_logs[0].levelno == logging.ERROR
10171042

10181043
client.disconnect(terminate=True)
10191044

@@ -1204,3 +1229,151 @@ def test_write_thread_close_is_drained_by_read_thread(mocker: MockFixture):
12041229
assert received[0].code == 1011
12051230

12061231
client.disconnect()
1232+
1233+
1234+
def test_server_error_without_trailing_close_exits_read_loop(mocker: MockFixture):
1235+
# Given: server sends an Error frame and then nothing (no close). Without
1236+
# _report_server_error setting _stop_event, the read loop would call
1237+
# recv(timeout=1) forever after dispatching the error.
1238+
error_json = json.dumps(
1239+
{"type": "Error", "error": "Server boom", "error_code": 5001}
1240+
)
1241+
fake_ws = _FakeWebSocket(recv_script=[error_json])
1242+
mocker.patch(
1243+
"assemblyai.streaming.v3.client.websocket_connect",
1244+
return_value=fake_ws,
1245+
)
1246+
received = []
1247+
client = StreamingClient(
1248+
StreamingClientOptions(api_key="test", api_host="api.example.com")
1249+
)
1250+
client.on(StreamingEvents.Error, lambda c, e: received.append(e))
1251+
1252+
# When: connect and let the read thread dispatch the Error
1253+
_connect_and_wait(client, _default_params())
1254+
1255+
# Then: error was dispatched once and the read thread exited despite the
1256+
# absence of a trailing close frame.
1257+
assert len(received) == 1
1258+
assert received[0].code == 5001
1259+
assert client._stop_event.is_set()
1260+
assert not client._read_thread.is_alive()
1261+
assert not client._write_thread.is_alive()
1262+
1263+
client.disconnect(terminate=True)
1264+
1265+
1266+
def test_disconnect_terminate_enqueues_when_stop_already_set(mocker: MockFixture):
1267+
# Given: a client whose _stop_event is already set (e.g. after a server
1268+
# error invoked _report_server_error). Threads were never started, so the
1269+
# only observable side-effect of disconnect(terminate=True) is the queue.
1270+
fake_ws = _FakeWebSocket(recv_script=[])
1271+
mocker.patch(
1272+
"assemblyai.streaming.v3.client.websocket_connect",
1273+
return_value=fake_ws,
1274+
)
1275+
client = StreamingClient(
1276+
StreamingClientOptions(api_key="test", api_host="api.example.com")
1277+
)
1278+
client._websocket = fake_ws
1279+
client._stop_event.set()
1280+
1281+
# When: disconnect(terminate=True) runs after stop is already set
1282+
client.disconnect(terminate=True)
1283+
1284+
# Then: TerminateSession was enqueued unconditionally; the disconnect-side
1285+
# guard no longer silently swallows the terminate intent.
1286+
assert client._write_queue.qsize() == 1
1287+
msg = client._write_queue.get_nowait()
1288+
assert isinstance(msg, TerminateSession)
1289+
1290+
1291+
def test_message_handler_exception_does_not_kill_read_thread(mocker: MockFixture):
1292+
# Given: a Turn handler that raises, followed by a Termination event. If
1293+
# the exception escapes _handle_message, the read thread dies before
1294+
# processing the Termination event.
1295+
turn_json = json.dumps(
1296+
{
1297+
"type": "Turn",
1298+
"turn_order": 1,
1299+
"turn_is_formatted": True,
1300+
"end_of_turn": True,
1301+
"transcript": "hi",
1302+
"end_of_turn_confidence": 0.9,
1303+
"words": [],
1304+
}
1305+
)
1306+
termination_json = json.dumps(
1307+
{
1308+
"type": "Termination",
1309+
"audio_duration_seconds": 1,
1310+
"session_duration_seconds": 1,
1311+
}
1312+
)
1313+
fake_ws = _FakeWebSocket(recv_script=[turn_json, termination_json])
1314+
mocker.patch(
1315+
"assemblyai.streaming.v3.client.websocket_connect",
1316+
return_value=fake_ws,
1317+
)
1318+
turns = []
1319+
terminations = []
1320+
1321+
def bad_turn_handler(self_, msg):
1322+
turns.append(msg)
1323+
raise RuntimeError("boom")
1324+
1325+
client = StreamingClient(
1326+
StreamingClientOptions(api_key="test", api_host="api.example.com")
1327+
)
1328+
client.on(StreamingEvents.Turn, bad_turn_handler)
1329+
client.on(StreamingEvents.Termination, lambda c, e: terminations.append(e))
1330+
1331+
# When: connect; the read thread processes the Turn (handler raises) then
1332+
# the Termination (which sets _stop_event and exits the loop)
1333+
_connect_and_wait(client, _default_params())
1334+
1335+
# Then: read thread survived the raising handler and processed Termination.
1336+
assert len(turns) == 1
1337+
assert len(terminations) == 1
1338+
assert client._stop_event.is_set()
1339+
assert not client._read_thread.is_alive()
1340+
1341+
client.disconnect()
1342+
1343+
1344+
def test_warning_handler_exception_does_not_kill_read_thread(mocker: MockFixture):
1345+
# Given: a Warning handler that raises, followed by a clean close.
1346+
warning_json = json.dumps(
1347+
{"type": "Warning", "warning": "session ending soon", "warning_code": 1234}
1348+
)
1349+
clean_close = ConnectionClosed(rcvd=Close(1000, "session ended"), sent=None)
1350+
fake_ws = _FakeWebSocket(recv_script=[warning_json, clean_close])
1351+
mocker.patch(
1352+
"assemblyai.streaming.v3.client.websocket_connect",
1353+
return_value=fake_ws,
1354+
)
1355+
warnings_received = []
1356+
errors_received = []
1357+
1358+
def bad_warning_handler(self_, w):
1359+
warnings_received.append(w)
1360+
raise RuntimeError("boom")
1361+
1362+
client = StreamingClient(
1363+
StreamingClientOptions(api_key="test", api_host="api.example.com")
1364+
)
1365+
client.on(StreamingEvents.Warning, bad_warning_handler)
1366+
client.on(StreamingEvents.Error, lambda c, e: errors_received.append(e))
1367+
1368+
# When: connect; the read thread processes the warning (handler raises)
1369+
# then the clean close
1370+
_connect_and_wait(client, _default_params())
1371+
1372+
# Then: warning was delivered, read thread survived, clean close completed
1373+
# without dispatching an error.
1374+
assert len(warnings_received) == 1
1375+
assert errors_received == []
1376+
assert client._stop_event.is_set()
1377+
assert not client._read_thread.is_alive()
1378+
1379+
client.disconnect()

0 commit comments

Comments
 (0)