Skip to content

Commit

Permalink
Pusher with ping fast (#1964)
Browse files Browse the repository at this point in the history
  • Loading branch information
beastoin authored Mar 9, 2025
2 parents dc4ff48 + 81668bd commit da37feb
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 38 deletions.
52 changes: 24 additions & 28 deletions backend/routers/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,29 @@ async def _websocket_util_trigger(
websocket_active = True
websocket_close_code = 1000

# heart beat
async def send_heartbeat():
print("pusher send_heartbeat", uid)
nonlocal websocket_active
nonlocal websocket_close_code
try:
while websocket_active:
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_json({"type": "ping"})
else:
break
await asyncio.sleep(10)
except WebSocketDisconnect:
print("WebSocket disconnected")
except Exception as e:
print(f'Heartbeat error: {e}')
websocket_close_code = 1011
finally:
websocket_active = False

# start heart beat
heartbeat_task = asyncio.create_task(send_heartbeat())

loop = asyncio.get_event_loop()

# audio bytes
Expand All @@ -50,11 +73,6 @@ async def receive_audio_bytes():
header_type = struct.unpack('<I', data[:4])[0]

# Transcript
if header_type == 100:
segments = json.loads(bytes(data[4:]).decode("utf-8"))
asyncio.run_coroutine_threadsafe(trigger_realtime_integrations(uid, segments, None), loop)
asyncio.run_coroutine_threadsafe(realtime_transcript_webhook(uid, segments), loop)
continue
if header_type == 102:
res = json.loads(bytes(data[4:]).decode("utf-8"))
segments = res.get('segments')
Expand Down Expand Up @@ -87,30 +105,8 @@ async def receive_audio_bytes():
finally:
websocket_active = False

# heart beat
async def send_heartbeat():
nonlocal websocket_active
nonlocal websocket_close_code
try:
while websocket_active:
await asyncio.sleep(20)
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_json({"type": "ping"})
else:
break
except WebSocketDisconnect:
print("WebSocket disconnected")
except Exception as e:
print(f'Heartbeat error: {e}')
websocket_close_code = 1011
finally:
websocket_active = False

try:
receive_task = asyncio.create_task(
receive_audio_bytes()
)
heartbeat_task = asyncio.create_task(send_heartbeat())
receive_task = asyncio.create_task(receive_audio_bytes())
await asyncio.gather(receive_task, heartbeat_task)

except Exception as e:
Expand Down
11 changes: 6 additions & 5 deletions backend/routers/transcribe_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ async def transcript_consume():
print(f"Pusher transcripts Connection closed: {e}", uid)
transcript_ws = None
pusher_connected = False
await reconnect()
await connect()
except Exception as e:
print(f"Pusher transcripts failed: {e}", uid)

Expand Down Expand Up @@ -446,19 +446,19 @@ async def audio_bytes_consume():
print(f"Pusher audio_bytes Connection closed: {e}", uid)
audio_bytes_ws = None
pusher_connected = False
await reconnect()
await connect()
except Exception as e:
print(f"Pusher audio_bytes failed: {e}", uid)

async def reconnect():
async def connect():
nonlocal pusher_connected
nonlocal pusher_connect_lock
async with pusher_connect_lock:
if pusher_connected:
return
await connect()
await _connect()

async def connect():
async def _connect():
nonlocal pusher_ws
nonlocal transcript_ws
nonlocal audio_bytes_ws
Expand Down Expand Up @@ -486,6 +486,7 @@ async def close(code: int = 1000):
transcript_consume = None
audio_bytes_send = None
audio_bytes_consume = None

pusher_connect, pusher_close, \
transcript_send, transcript_consume, \
audio_bytes_send, audio_bytes_consume = create_pusher_task_handler()
Expand Down
3 changes: 2 additions & 1 deletion backend/utils/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,15 @@ def _single(app: App):


async def trigger_realtime_integrations(uid: str, segments: list[dict], memory_id: str | None):
print("trigger_realtime_integrations", uid)
"""REALTIME STREAMING"""
# TODO: don't retrieve token before knowing if to notify
token = notification_db.get_token_only(uid)
_trigger_realtime_integrations(uid, token, segments, memory_id)


async def trigger_realtime_audio_bytes(uid: str, sample_rate: int, data: bytearray):
print("trigger_realtime_audio_bytes", uid)
"""REALTIME AUDIO STREAMING"""
_trigger_realtime_audio_bytes(uid, sample_rate, data)

Expand Down Expand Up @@ -364,7 +366,6 @@ def _single(app: App):
return results



def _trigger_realtime_integrations(uid: str, token: str, segments: List[dict], memory_id: str | None) -> dict:
apps: List[App] = get_available_apps(uid)
filtered_apps = [
Expand Down
8 changes: 4 additions & 4 deletions backend/utils/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
PusherAPI = os.getenv('HOSTED_PUSHER_API_URL')

async def connect_to_trigger_pusher(uid: str, sample_rate: int = 8000, retries: int = 3):
print("connect_to_trigger_pusher")
print("connect_to_trigger_pusher", uid)
for attempt in range(retries):
try:
return await _connect_to_trigger_pusher(uid, sample_rate)
except Exception as error:
print(f'An error occurred: {error}')
print(f'An error occurred: {error}', uid)
if attempt == retries - 1:
raise
backoff_delay = calculate_backoff_with_jitter(attempt)
print(f"Waiting {backoff_delay:.0f}ms before next retry...")
print(f"Waiting {backoff_delay:.0f}ms before next retry...", uid)
await asyncio.sleep(backoff_delay / 1000)

raise Exception(f'Could not open socket: All retry attempts failed.')
raise Exception(f'Could not open socket: All retry attempts failed.', uid)

async def _connect_to_trigger_pusher(uid: str, sample_rate: int = 8000):
try:
Expand Down
2 changes: 2 additions & 0 deletions backend/utils/webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def day_summary_webhook(uid, summary: str):


async def realtime_transcript_webhook(uid, segments: List[dict]):
print("realtime_transcript_webhook", uid)
toggled = user_webhook_status_db(uid, WebhookType.realtime_transcript)
if toggled:
webhook_url = get_user_webhook_db(uid, WebhookType.realtime_transcript)
Expand Down Expand Up @@ -107,6 +108,7 @@ def get_audio_bytes_webhook_seconds(uid: str):


async def send_audio_bytes_developer_webhook(uid: str, sample_rate: int, data: bytearray):
print("send_audio_bytes_developer_webhook", uid)
# TODO: add a lock, send shorter segments, validate regex.
toggled = user_webhook_status_db(uid, WebhookType.audio_bytes)
if toggled:
Expand Down

0 comments on commit da37feb

Please sign in to comment.