diff --git a/.gitignore b/.gitignore index 6da6dff..cd1ddda 100644 --- a/.gitignore +++ b/.gitignore @@ -128,3 +128,5 @@ dmypy.json # For local testing playground/ + +.vscode/ diff --git a/jupyter_ai_router/extension.py b/jupyter_ai_router/extension.py index d666f38..f88711f 100644 --- a/jupyter_ai_router/extension.py +++ b/jupyter_ai_router/extension.py @@ -1,8 +1,11 @@ from __future__ import annotations +import asyncio +import json from typing import TYPE_CHECKING import time from jupyter_events import EventLogger from jupyter_server.extension.application import ExtensionApp +from jupyter_ydoc.ybasedoc import YBaseDoc from jupyter_ai_router.handlers import RouteHandler @@ -41,6 +44,18 @@ class RouterExtension(ExtensionApp): router: MessageRouter + @property + def event_loop(self) -> asyncio.AbstractEventLoop: + """ + Returns a reference to the asyncio event loop. + """ + return asyncio.get_event_loop_policy().get_event_loop() + + @property + def fileid_manager(self): + return self.serverapp.web_app.settings["file_id_manager"] + + def initialize_settings(self): """Initialize router settings and event listeners.""" start = time.time() @@ -59,10 +74,32 @@ def initialize_settings(self): self.event_logger.add_listener( schema_id=JUPYTER_COLLABORATION_EVENTS_URI, listener=self._on_chat_event ) - + elapsed = time.time() - start self.log.info(f"Initialized RouterExtension in {elapsed:.2f}s") + def _get_global_awareness(self): + # TODO: make this compatible with jcollab + jcollab_api = self.serverapp.web_app.settings["jupyter_server_ydoc"] + yroom_manager = jcollab_api.yroom_manager + yroom = yroom_manager.get_room("JupyterLab:globalAwareness") + return yroom.get_awareness() + + async def _room_id_from_path(self, path: str) -> str | None: + """Returns room_id from document path""" + # TODO: Make this compatible with jcollab + yroom_manager = self.serverapp.web_app.settings["yroom_manager"] + for room_id in yroom_manager._rooms_by_id: + if room_id == "JupyterLab:globalAwareness": + continue + ydoc = await self._get_doc(room_id) + state = ydoc.awareness.get_local_state() + file_id = state["file_id"] + ydoc_path = self.fileid_manager.get_path(file_id) + if ydoc_path == path: + print(f"Found match in path {path}") + return room_id + async def _on_chat_event( self, logger: EventLogger, schema_id: str, data: dict ) -> None: @@ -87,6 +124,84 @@ async def _on_chat_event( # Connect chat to router self.router.connect_chat(room_id, ychat) + async def _on_notebook_event( + self, logger: EventLogger, schema_id: str, data: dict + ) -> None: + """Handle notebook room events and connect new chats to router.""" + # Only handle notebook room initialization events + if not ( + data["room"].startswith("json:notebook:") + and data["action"] == "initialize" + and data["msg"] == "Room initialized" + ): + return + + room_id = data["room"] + self.log.info(f"New notebook room detected: {room_id}") + + # Get YDoc document for the room + ydoc = await self._get_doc(room_id) + if ydoc is None: + self.log.error(f"Failed to get YDoc for room {room_id}") + return + + # Connect notebook to router + self.router.connect_notebook(room_id, ydoc) + + async def _get_doc(self, room_id: str) -> YBaseDoc | None: + """ + Get YDoc instance for a room ID. + + Dispatches to either `_get_doc_jcollab()` or `_get_doc_jsd()` based on + whether `jupyter_server_documents` is installed. + """ + + if JSD_PRESENT: + return await self._get_doc_jsd(room_id) + else: + return await self._get_doc_jcollab(room_id) + + async def _get_doc_jcollab(self, room_id: str) -> YBaseDoc | None: + """ + Method used to retrieve the `YDoc` instance for a given room when + `jupyter_server_documents` **is not** installed. + """ + if not self.serverapp: + return None + + try: + collaboration = self.serverapp.web_app.settings["jupyter_server_ydoc"] + document = await collaboration.get_document(room_id=room_id, copy=False) + return document + except Exception as e: + self.log.error(f"Error getting ydoc for {room_id}: {e}") + return None + + async def _get_doc_jsd(self, room_id: str) -> YBaseDoc | None: + """ + Method used to retrieve the `YDoc` instance for a given room when + `jupyter_server_documents` **is** installed. + + This method uniquely attaches a callback which is fired whenever the + `YDoc` is reset. + """ + if not self.serverapp: + return None + + try: + jcollab_api = self.serverapp.web_app.settings["jupyter_server_ydoc"] + yroom_manager = jcollab_api.yroom_manager + yroom = yroom_manager.get_room(room_id) + + def _on_ydoc_reset(new_ydoc: YBaseDoc): + self.router._on_notebook_reset(room_id, new_ydoc) + + ydoc = await yroom.get_jupyter_ydoc(on_reset=_on_ydoc_reset) + return ydoc + except Exception as e: + self.log.error(f"Error getting ydoc for {room_id}: {e}") + return None + async def _get_chat(self, room_id: str) -> YChat | None: """ Get YChat instance for a room ID. diff --git a/jupyter_ai_router/router.py b/jupyter_ai_router/router.py index 7e1c35a..cb6a744 100644 --- a/jupyter_ai_router/router.py +++ b/jupyter_ai_router/router.py @@ -7,13 +7,15 @@ - Manages lifecycle and cleanup """ +import time from typing import Any, Callable, Dict, List, TYPE_CHECKING from functools import partial import re from dataclasses import replace from jupyterlab_chat.models import Message -from pycrdt import ArrayEvent +from pycrdt import ArrayEvent, TextEvent, MapEvent from traitlets.config import LoggingConfigurable +from jupyter_ydoc.ybasedoc import YBaseDoc if TYPE_CHECKING: from jupyterlab_chat.ychat import YChat @@ -21,6 +23,41 @@ from .utils import get_first_word +class UserTracker: + """Tracks a user's observers and current document.""" + def __init__(self, username: str): + self.username = username + self.current_document: str | None = None + self.observer_ids: List[int] = [] + self.room_states: Dict[str, Dict] = {} # room_id -> awareness_state + + +class RoomTracker: + """Tracks a room's observation state and metadata.""" + def __init__(self, room_id: str, ydoc: "YBaseDoc"): + self.room_id = room_id + self.ydoc = ydoc + self.last_edit_time = 0.0 + self.last_trigger_time = 0.0 + self.active_users: set[str] = set() + self.subscribers = { + "awareness": None, + "notebook": None + } + + def start_observing(self, awareness_cb, notebook_cb): + """Start observing this room's awareness and notebook changes.""" + self.subscribers["awareness"] = self.ydoc.awareness.observe(awareness_cb) + self.subscribers["notebook"] = self.ydoc._ycells.observe_deep(notebook_cb) + + def stop_observing(self): + """Stop observing this room and clean up subscriptions.""" + if self.subscribers["awareness"]: + self.ydoc.awareness.unobserve(self.subscribers["awareness"]) + if self.subscribers["notebook"]: + self.ydoc._ycells.unobserve(self.subscribers["notebook"]) + + def matches_pattern(word: str, pattern: str) -> bool: """ Check if a word matches a regex pattern. @@ -56,6 +93,18 @@ def __init__(self, *args, **kwargs): self.slash_cmd_observers: Dict[str, Dict[str, List[Callable[[str, str, Message], Any]]]] = {} self.chat_msg_observers: Dict[str, List[Callable[[str, Message], Any]]] = {} self.chat_reset_observers: List[Callable[[str, "YChat"], Any]] = [] + + # Notebook observers + self.notebook_reset_observers: List[Callable[[str, YBaseDoc], Any]] = [] + + # Simplified tracking with just 2 main structures + self.users: Dict[str, UserTracker] = {} # username -> user_info + self.rooms: Dict[str, RoomTracker] = {} # room_id -> room_info + + # Global observer management + self.observer_counter = 0 + self.trigger_cooldown = 1 # seconds between triggers + self._observer_callbacks: Dict[int, Dict] = {} # observer_id -> callback_info # Active chat rooms self.active_chats: Dict[str, "YChat"] = {} @@ -63,6 +112,91 @@ def __init__(self, *args, **kwargs): # Root observers for keeping track of incoming messages self.message_observers: Dict[str, Callable] = {} + # Global awareness observer subscriber ID for cleanup + self._global_awareness_subscriber_id = None + + self.event_loop.create_task(self._start_observing_global_awareness()) + + async def _room_id_from_path(self, path: str) -> str | None: + room_id = await self.parent._room_id_from_path(path) + return room_id + + async def _get_doc(self, room_id: str) -> YBaseDoc | None: + doc = await self.parent._get_doc(room_id) + return doc + + def _get_global_awareness(self): + awareness = self.parent._get_global_awareness() + return awareness + + @property + def event_loop(self): + """Return the event loop from parent.""" + return self.parent.event_loop + + async def _start_observing_global_awareness(self): + awareness = self._get_global_awareness() + self._global_awareness_subscriber_id = awareness.observe(self._on_global_awareness_change) + + async def _start_observing_room(self, room_id: str, username: str): + """Start observing a room's document and awareness changes.""" + if room_id in self.rooms: + # Already observing, just add user + self.rooms[room_id].active_users.add(username) + self.log.info(f"Added user {username} to existing room observation {room_id}") + return + + # Get the document for this room + ydoc = await self._get_doc(room_id) + if not ydoc: + self.log.error(f"Could not get document for room {room_id}") + return + + # Create room tracker + room = RoomTracker(room_id, ydoc) + room.active_users.add(username) + + # Set up observers + awareness_cb = partial(self._on_awareness_change, room_id, ydoc) + notebook_change_cb = partial(self._on_notebook_change, room_id) + + # Start observing using the RoomTracker + room.start_observing(awareness_cb, notebook_change_cb) + + # Track this room + self.rooms[room_id] = room + + self.log.info(f"Started observing room {room_id} for user {username}") + + def _stop_observing_room(self, room_id: str): + """Stop observing a room when no users need it.""" + if room_id not in self.rooms: + return + + room = self.rooms[room_id] + + # Remove observers using the RoomTracker + try: + room.stop_observing() + self.log.info(f"Stopped observing awareness and notebook changes for room_id: {room_id}") + except Exception as e: + self.log.warning(f"Error stopping observation for room {room_id}: {e}") + + del self.rooms[room_id] + self.log.info(f"Stopped observing room {room_id}") + + def _maybe_stop_observing_room(self, room_id: str, username: str): + """Stop observing room if this was the last user using it.""" + if room_id not in self.rooms: + return + + room = self.rooms[room_id] + room.active_users.discard(username) + + if not room.active_users: + self._stop_observing_room(room_id) + self.log.info(f"Stopped observing room {room_id} - no more users") + def observe_chat_init(self, callback: Callable[[str, "YChat"], Any]) -> None: """ Register a callback for when new chats are initialized. @@ -123,6 +257,286 @@ def observe_chat_msg( self.chat_msg_observers[room_id].append(callback) self.log.info("Registered message callback") + + def observe_notebook_activity( + self, username: str, callback: Callable[[Dict], Any] + ) -> int: + """ + Register a callback for notebook activity for a specific user. + Returns observer_id for unregistering. + """ + observer_id = self.observer_counter + self.observer_counter += 1 + + # Create or get user tracker + if username not in self.users: + self.users[username] = UserTracker(username) + + user = self.users[username] + user.observer_ids.append(observer_id) + + # Store observer callback (we still need this for the actual callback) + self._observer_callbacks[observer_id] = { + "username": username, + "callback": callback + } + + # Check user's current notebook and start observing it + self.event_loop.create_task(self._observe_active_notebook(username)) + + self.log.info(f"Registered notebook activity observer {observer_id} for user {username}") + return observer_id + + def unobserve_notebook_activity(self, observer_id: int) -> bool: + """Remove a notebook activity observer by ID.""" + if observer_id not in self._observer_callbacks: + return False + + observer_info = self._observer_callbacks[observer_id] + username = observer_info["username"] + + # Remove from tracking + del self._observer_callbacks[observer_id] + + if username in self.users: + user = self.users[username] + if observer_id in user.observer_ids: + user.observer_ids.remove(observer_id) + + # If user has no more observers, clean up + if not user.observer_ids: + del self.users[username] + + # Check if we can stop observing any rooms + self._cleanup_unused_room_observers() + + self.log.info(f"Unregistered observer {observer_id}") + return True + + async def _observe_active_notebook(self, username: str): + """Finds user's active notebook and starts observing it.""" + + try: + awareness = self._get_global_awareness() + for _, state in awareness.states.items(): + state_username = state.get("user", {}).get("username", "") + if state_username != username: + continue + + active_doc = state.get("current") + if not (active_doc and active_doc.startswith("notebook:")): + continue + + # Update user's current document + if username in self.users: + self.users[username].current_document = active_doc + + path = active_doc.split(":", maxsplit=1)[1] + room_id = await self._room_id_from_path(path) + + self.log.info(f"Got room_id {room_id} in _observe_active_notebook") + + if room_id: + await self._start_observing_room(room_id, username) + self.log.info(f"Started observing {room_id} for newly registered user {username}") + break + except Exception as e: + self.log.error(f"Error checking {username}'s current notebook: {e}") + + def _cleanup_unused_room_observers(self): + """Remove room observers that have no active users.""" + rooms_to_remove = [] + + for room_id, room in self.rooms.items(): + # Check if any of this room's users still have active observers + active_users = set() + for username in room.active_users: + if username in self.users and self.users[username].observer_ids: + active_users.add(username) + + if not active_users: + rooms_to_remove.append(room_id) + else: + # Update the room's user list to only active users + room.active_users = active_users + + # Remove unused rooms + for room_id in rooms_to_remove: + self._stop_observing_room(room_id) + + def _on_global_awareness_change(self, topic, updates): + """ + Handle global awareness changes to track client document switching. + """ + + _, room = updates + + if isinstance(room, str): + return + + awareness = room.get_awareness() + for _, state in awareness.states.items(): + username = state.get("user", {}).get("username", None) + if not (username and username in self.users): + continue + + active_doc = state.get("current") + + if not (active_doc and active_doc.startswith("notebook:")): + continue + + user = self.users[username] + prev_doc = user.current_document + + # Check if user switched to a different document + if prev_doc != active_doc: + self.log.info( + f"User {username} switched from {prev_doc} to {active_doc}" + ) + + # Update stored current document + user.current_document = active_doc + + self.event_loop.create_task( + self._handle_user_document_switch(username, active_doc, prev_doc) + ) + + async def _handle_user_document_switch( + self, username: str, current_doc: str, prev_doc: str | None + ): + """Handles user switching documents, unobserves current doc room, + and registers new observers for the room that becomes active.""" + try: + # Only handle users we have observers for + if username not in self.users: + return + + path = current_doc.split(":", maxsplit=1)[1] + # Convert document path to room_id (async) + room_id = await self._room_id_from_path(path) + + if room_id: + + # Start observing new room + await self._start_observing_room(room_id, username) + + # Initialize user awareness state for new room if needed + user = self.users[username] + if room_id not in user.room_states: + user.room_states[room_id] = { + "active_cell": None, + "notebook_path": current_doc, + "last_check": 0 + } + self.log.info(f"Initialized tracking for user {username} in room {room_id}") + + # Stop observing old room if needed + if prev_doc: + old_path = prev_doc.split(":", maxsplit=1)[1] + old_room_id = await self._room_id_from_path(old_path) + if old_room_id: + self._maybe_stop_observing_room(old_room_id, username) + + except Exception as e: + self.log.error(f"Error handling document switch for user {username}: {e}") + + def _on_notebook_change(self, room_id: str, events): + """Handle notebook document changes and log event details.""" + + for event in events: + if isinstance(event, MapEvent): + self.log.info(f"Keys: {event.keys}") + else: + self.log.info(f"Change type: {event.delta}") + self.log.info(f"Target: {event.target}") + self.log.info(f"Event Path: {event.path}") + + + + # Save the timestamp that a change was made indicating notebook has changed + current_time = time.time() + if room_id in self.rooms: + self.rooms[room_id].last_edit_time = current_time + self.log.info(f"Notebook cells changed in {room_id} at {current_time}") + + def _on_awareness_change(self, room_id: str, ydoc: YBaseDoc, topic, updates): + """Handle awareness changes for notebook activity tracking.""" + + awareness_states = ydoc.awareness.states + current_time = time.time() + + # Get room tracker + if room_id not in self.rooms: + return + + room = self.rooms[room_id] + + # Extract username from each client's state + for _, state in awareness_states.items(): + username = state.get("user", {}).get("username", None) + + if not (username and username in self.users): + continue + + user = self.users[username] + active_cell = state.get("activeCellId") + notebook_path = state.get("notebookPath") + + if not active_cell: + continue + + # Get previous state for this user in this room + prev_state = user.room_states.get(room_id, {}) + prev_active_cell = prev_state.get("active_cell") + prev_check = prev_state.get("last_check", 0) + + # Skip if this was the first change + if not prev_active_cell: + user.room_states[room_id] = { + "active_cell": active_cell, + "notebook_path": notebook_path, + "last_check": current_time + } + continue + + # If active cell changed and there were new edits + if prev_active_cell != active_cell and room.last_edit_time > prev_check: + # Check if enough time has passed since last trigger + if current_time - room.last_trigger_time >= self.trigger_cooldown: + room.last_trigger_time = current_time + self._notify_notebook_activity_observers( + username=username, + prev_active_cell=prev_active_cell, + notebook_path=notebook_path + ) + + # Update stored state for this user + user.room_states[room_id] = { + "active_cell": active_cell, + "notebook_path": notebook_path, + "last_check": current_time + } + + def _notify_notebook_activity_observers( + self, username: str, prev_active_cell: str, notebook_path: str + ) -> None: + """Notify all notebook activity observers.""" + + if username not in self.users: + return + + user = self.users[username] + observer_ids = user.observer_ids + + for observer_id in observer_ids: + if observer_id in self._observer_callbacks: + callback = self._observer_callbacks[observer_id]["callback"] + try: + callback(username, prev_active_cell, notebook_path) + except Exception as e: + self.log.error(f"Notebook activity observer error for {username}: {e}") + + def connect_chat(self, room_id: str, ychat: "YChat") -> None: """ Connect a new chat session to the router. @@ -252,24 +666,88 @@ def _on_chat_reset(self, room_id, ychat: "YChat") -> None: installed. """ self.log.warning(f"Detected `YChat` document reset in room '{room_id}'.") + self.active_chats[room_id] = ychat for callback in self.chat_reset_observers: try: callback(room_id, ychat) except Exception as e: self.log.error(f"Reset chat observer error for {room_id}: {e}") + def _on_notebook_reset(self, room_id, ydoc: YBaseDoc) -> None: + """ + Method to call when the YDoc undergoes a document reset, e.g. when the + `.ipynb` file is modified directly on disk. + + NOTE: Document resets will only occur when `jupyter_server_documents` is + installed. + """ + self.log.warning(f"Detected `YDoc` document reset in room '{room_id}'.") + for callback in self.notebook_reset_observers: + try: + callback(room_id, ydoc) + except Exception as e: + self.log.error(f"Reset notebook observer error for {room_id}: {e}") + + def _cleanup_rooms(self) -> None: + """Clean up all room trackers and their subscriptions.""" + room_ids = list(self.rooms.keys()) + for room_id in room_ids: + room_tracker = self.rooms[room_id] + try: + room_tracker.stop_observing() + self.log.debug(f"Cleaned up room tracker for {room_id}") + except Exception as e: + self.log.warning(f"Failed to clean up room tracker {room_id}: {e}") + self.rooms.clear() + + def _cleanup_awareness_observers(self) -> None: + """Clean up global and local awareness observers.""" + # Clean up global awareness observer + if self._global_awareness_subscriber_id is not None: + try: + awareness = self._get_global_awareness() + awareness.unobserve(self._global_awareness_subscriber_id) + self._global_awareness_subscriber_id = None + self.log.debug("Cleaned up global awareness observer") + except Exception as e: + self.log.warning(f"Failed to clean up global awareness observer: {e}") + + # Clean up notebook activity observers + observer_ids = list(self._observer_callbacks.keys()) + for observer_id in observer_ids: + try: + # Observer callbacks are already disconnected via room cleanup + # Just need to clear the registry + del self._observer_callbacks[observer_id] + except Exception as e: + self.log.warning(f"Failed to clean up observer {observer_id}: {e}") + + # Reset observer counter + self.observer_counter = 0 + def cleanup(self) -> None: """Clean up router resources.""" self.log.info("Cleaning up MessageRouter...") + # Clean up room trackers and their subscriptions + self._cleanup_rooms() + + # Clean up user trackers + self.users.clear() + + # Clean up awareness observers (global and local) + self._cleanup_awareness_observers() + # Disconnect all chats room_ids = list(self.active_chats.keys()) for room_id in room_ids: self.disconnect_chat(room_id) - # Clear callbacks + # Clear all observer callback lists self.chat_init_observers.clear() self.slash_cmd_observers.clear() self.chat_msg_observers.clear() + self.chat_reset_observers.clear() + self.notebook_reset_observers.clear() self.log.info("MessageRouter cleanup complete") diff --git a/pyproject.toml b/pyproject.toml index 1a1f8e1..0b069e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,8 @@ classifiers = [ dependencies = [ "jupyter_server>=2.4.0,<3", "jupyterlab-chat>=0.17.0", - "jupyter-collaboration>=4.0.0" + "jupyter-collaboration>=4.0.0", + "jupyterlab-notebook-awareness>=0.2.0" ] dynamic = ["version", "description", "authors", "urls", "keywords"]