Skip to content

Commit 8891e81

Browse files
committed
Migrate to nextgen-kernels-api message ID encoding and fix execution state sync
- Replace message_cache with extract_src_id/extract_channel utilities - Remove obsolete test_kernel_message_cache.py - Fix awareness sync to send execution states to reconnecting clients - Add cell_msg_ids tracking for re-execution detection
1 parent 3b7dabc commit 8891e81

File tree

3 files changed

+63
-325
lines changed

3 files changed

+63
-325
lines changed

jupyter_server_documents/kernel_client.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import typing as t
1313

1414
from nextgen_kernels_api.services.kernels.client import JupyterServerKernelClient
15+
from nextgen_kernels_api.services.kernels.message_utils import extract_src_id, extract_channel
1516
from traitlets import Instance, Set, Type, default
1617

1718
from jupyter_server_documents.outputs import OutputProcessor
@@ -41,6 +42,9 @@ def _default_output_processor(self) -> OutputProcessor:
4142
def __init__(self, *args, **kwargs):
4243
super().__init__(*args, **kwargs)
4344

45+
# Track last message ID per cell to detect re-executions
46+
self._cell_msg_ids: dict[str, str] = {}
47+
4448
# Register listener for document-related messages
4549
# Combines state updates and outputs to share deserialization logic
4650
self.add_listener(
@@ -92,18 +96,19 @@ async def _handle_document_messages(self, channel_name: str, msg: list[bytes]):
9296
self.log.debug(f"Skipping message that can't be deserialized: {e}")
9397
return
9498

95-
# Extract parent message context for cell ID lookup
99+
# Extract parent message context for cell ID and channel lookup
100+
# Cell ID and channel are now encoded directly in the parent msg_id
96101
parent_msg_id = dmsg.get("parent_header", {}).get("msg_id")
97-
parent_msg_data = self.message_cache.get(parent_msg_id) if parent_msg_id else None
98-
cell_id = parent_msg_data.get("cell_id") if parent_msg_data else None
102+
cell_id = extract_src_id(parent_msg_id) if parent_msg_id else None
103+
parent_channel = extract_channel(parent_msg_id) if parent_msg_id else None
99104

100105
# Dispatch to appropriate handler
101106
msg_type = dmsg.get("msg_type")
102107
match msg_type:
103108
case "kernel_info_reply":
104109
await self._handle_kernel_info_reply(dmsg)
105110
case "status":
106-
await self._handle_status_message(dmsg, parent_msg_data, cell_id)
111+
await self._handle_status_message(dmsg, parent_channel, cell_id)
107112
case "execute_input":
108113
await self._handle_execute_input(dmsg, cell_id)
109114
case "stream" | "display_data" | "execute_result" | "error" | "update_display_data" | "clear_output":
@@ -124,7 +129,7 @@ async def _handle_kernel_info_reply(self, msg: dict):
124129
self.log.warning(f"Failed to update language info for yroom: {e}")
125130

126131
async def _handle_status_message(
127-
self, dmsg: dict, parent_msg_data: dict | None, cell_id: str | None
132+
self, dmsg: dict, parent_channel: str | None, cell_id: str | None
128133
):
129134
"""Update kernel and cell execution states from status messages.
130135
@@ -135,20 +140,14 @@ async def _handle_status_message(
135140
execution_state = content.get("execution_state")
136141

137142
for yroom in self._yrooms:
138-
awareness = yroom.get_awareness()
139-
if awareness is None:
140-
continue
141-
142143
# Update document-level kernel status if this is a top-level status message
143-
if parent_msg_data and parent_msg_data.get("channel") == "shell":
144-
awareness.set_local_state_field(
145-
"kernel", {"execution_state": execution_state}
146-
)
144+
# (i.e., parent message came from shell channel)
145+
if parent_channel == "shell":
146+
yroom.set_kernel_execution_state(execution_state)
147147

148148
# Update cell execution state for persistence and awareness
149149
if cell_id:
150150
yroom.set_cell_execution_state(cell_id, execution_state)
151-
yroom.set_cell_awareness_state(cell_id, execution_state)
152151
break
153152

154153
async def _handle_execute_input(self, dmsg: dict, cell_id: str | None):
@@ -205,10 +204,14 @@ def handle_incoming_message(self, channel_name: str, msg: list[bytes]):
205204

206205
if cell_id:
207206
# Clear outputs if this is a re-execution of the same cell
208-
existing = self.message_cache.get(cell_id=cell_id)
209-
if existing and existing["msg_id"] != msg_id:
207+
# (different msg_id for the same cell_id)
208+
last_msg_id = self._cell_msg_ids.get(cell_id)
209+
if last_msg_id and last_msg_id != msg_id:
210210
asyncio.create_task(self.output_processor.clear_cell_outputs(cell_id))
211211

212+
# Track this message ID for the cell
213+
self._cell_msg_ids[cell_id] = msg_id
214+
212215
# Set awareness state immediately for queued cells
213216
if msg_type == "execute_request" and channel_name == "shell":
214217
for yroom in self._yrooms:

jupyter_server_documents/rooms/yroom.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -405,38 +405,41 @@ def get_awareness(self, on_reset: Callable[[pycrdt.Awareness], Any] | None = Non
405405
self._on_reset_callbacks['awareness'].append(on_reset)
406406
return self._awareness
407407

408-
def get_cell_execution_states(self) -> dict:
409-
"""
410-
Returns the persistent cell execution states for this room.
411-
These states survive client disconnections but are not saved to disk.
412-
"""
413-
if not hasattr(self, '_cell_execution_states'):
414-
self._cell_execution_states: dict[str, str] = {}
415-
return self._cell_execution_states
416-
417408
def set_cell_execution_state(self, cell_id: str, execution_state: str) -> None:
418409
"""
419-
Sets the execution state for a specific cell.
420-
This state persists across client disconnections.
410+
Sets the execution state for a specific cell in the awareness system.
411+
This provides real-time updates to all connected clients and persists
412+
while the server is running (survives client reconnections).
421413
"""
422-
if not hasattr(self, '_cell_execution_states'):
423-
self._cell_execution_states = {}
424-
self._cell_execution_states[cell_id] = execution_state
414+
awareness = self.get_awareness()
415+
if awareness is None:
416+
return
417+
418+
local_state = awareness.get_local_state()
419+
if local_state is not None:
420+
cell_states = local_state.get("cell_execution_states", {})
421+
else:
422+
cell_states = {}
423+
424+
cell_states[cell_id] = execution_state
425+
awareness.set_local_state_field("cell_execution_states", cell_states)
425426

426427
def set_cell_awareness_state(self, cell_id: str, execution_state: str) -> None:
427428
"""
428-
Sets the execution state for a specific cell in the awareness system.
429+
Alias for set_cell_execution_state for backward compatibility.
430+
"""
431+
self.set_cell_execution_state(cell_id, execution_state)
432+
433+
def set_kernel_execution_state(self, execution_state: str) -> None:
434+
"""
435+
Sets the kernel execution state in awareness.
429436
This provides real-time updates to all connected clients.
430437
"""
431438
awareness = self.get_awareness()
432439
if awareness is not None:
433-
local_state = awareness.get_local_state()
434-
if local_state is not None:
435-
cell_states = local_state.get("cell_execution_states", {})
436-
else:
437-
cell_states = {}
438-
cell_states[cell_id] = execution_state
439-
awareness.set_local_state_field("cell_execution_states", cell_states)
440+
awareness.set_local_state_field(
441+
"kernel", {"execution_state": execution_state}
442+
)
440443

441444
def add_message(self, client_id: str, message: bytes) -> None:
442445
"""
@@ -540,6 +543,7 @@ def handle_sync_step1(self, client_id: str, message: bytes) -> None:
540543
- Computing a SyncStep2 reply,
541544
- Sending the reply to the client over WS, and
542545
- Sending a new SyncStep1 message immediately after.
546+
- Sending awareness state to the new client.
543547
"""
544548
# Mark client as desynced
545549
new_client = self.clients.get(client_id)
@@ -586,6 +590,22 @@ def handle_sync_step1(self, client_id: str, message: bytes) -> None:
586590
)
587591
self.log.exception(e)
588592

593+
# Send current awareness state to the new client
594+
try:
595+
# Get all awareness client IDs and broadcast the current state
596+
all_client_ids = list(self._awareness._states.keys())
597+
if all_client_ids:
598+
awareness_update = self._awareness.encode_awareness_update(all_client_ids)
599+
awareness_message = pycrdt.create_awareness_message(awareness_update)
600+
assert isinstance(new_client.websocket, WebSocketHandler)
601+
new_client.websocket.write_message(awareness_message, binary=True)
602+
except Exception as e:
603+
self.log.error(
604+
f"An exception occurred when sending awareness to "
605+
f"newly-synced client '{new_client.id}':"
606+
)
607+
self.log.exception(e)
608+
589609

590610
def handle_sync_step2(self, client_id: str, message: bytes) -> None:
591611
"""
@@ -791,8 +811,8 @@ def _on_awareness_update(self, type: str, changes: tuple[dict[str, Any], Any]) -
791811
Arguments:
792812
type: The change type.
793813
changes: The awareness changes.
794-
"""
795-
814+
"""
815+
796816
self.log.debug(f"awareness update, type={type}, changes={changes}, changes[1]={changes[1]}, meta={self._awareness.meta}, ydoc.clientid={self._ydoc.client_id}, roomId={self.room_id}")
797817
updated_clients = [v for value in changes[0].values() for v in value]
798818
self.log.debug(f"awareness update, updated_clients={updated_clients}")

0 commit comments

Comments
 (0)