Skip to content

Commit bb75d6a

Browse files
authored
Standardize checkpoints for all bots (#301)
1 parent 3891bc8 commit bb75d6a

File tree

14 files changed

+1141
-53
lines changed

14 files changed

+1141
-53
lines changed

ai_chatbots/admin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class UserChatSessionAdmin(admin.ModelAdmin):
1313
list_filter = ("agent", "user")
1414
search_fields = ("title", "thread_id")
1515
ordering = ("-updated_on",)
16-
readonly_fields = ("agent", "thread_id", "created_on", "updated_on")
16+
readonly_fields = ("agent", "thread_id", "created_on", "updated_on", "user")
1717

1818

1919
@admin.register(LLMModel)

ai_chatbots/api.py

Lines changed: 213 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
import json
44
import logging
5-
from typing import Any, Optional, cast
5+
from typing import Any, Optional, Union, cast
66
from uuid import UUID, uuid4
77

88
import litellm
9+
from channels.db import database_sync_to_async
910
from django.conf import settings
11+
from django.db import transaction
1012
from langchain_core.language_models import LanguageModelLike
1113
from langchain_core.messages import (
1214
AIMessage,
@@ -35,6 +37,9 @@
3537
from pydantic import BaseModel
3638
from typing_extensions import TypedDict
3739

40+
from ai_chatbots.models import DjangoCheckpoint, TutorBotOutput, UserChatSession
41+
from main.utils import now_in_utc
42+
3843
log = logging.getLogger(__name__)
3944

4045

@@ -586,3 +591,210 @@ def on_llm_end(
586591
super().on_llm_end(
587592
response, run_id=run_id, parent_run_id=parent_run_id, **kwargs
588593
)
594+
595+
596+
@database_sync_to_async
597+
def query_tutorbot_output(thread_id: str) -> Optional[TutorBotOutput]:
598+
"""Return the latest TutorBotOutput for a given thread_id"""
599+
return TutorBotOutput.objects.filter(thread_id=thread_id).last()
600+
601+
602+
@database_sync_to_async
603+
def create_tutorbot_output_and_checkpoints(
604+
thread_id: str, chat_json: Union[str, dict], edx_module_id: Optional[str]
605+
) -> tuple[TutorBotOutput, list[DjangoCheckpoint]]:
606+
"""Atomically create both TutorBotOutput and DjangoCheckpoint objects"""
607+
with transaction.atomic():
608+
# Get the previous TutorBotOutput to compare messages
609+
previous_output = (
610+
TutorBotOutput.objects.filter(thread_id=thread_id).order_by("-id").first()
611+
)
612+
previous_chat_json = previous_output.chat_json if previous_output else None
613+
614+
# Create TutorBotOutput
615+
tutorbot_output = TutorBotOutput.objects.create(
616+
thread_id=thread_id,
617+
chat_json=chat_json,
618+
edx_module_id=edx_module_id or "",
619+
)
620+
621+
checkpoints = create_tutor_checkpoints(thread_id, chat_json, previous_chat_json)
622+
623+
return tutorbot_output, checkpoints
624+
625+
626+
def _should_create_checkpoint(msg: dict) -> bool:
627+
"""Determine if a message should have a checkpoint created for it."""
628+
# Skip ToolMessage type or tool_calls
629+
return not (msg.get("type") == "ToolMessage" or msg.get("tool_calls"))
630+
631+
632+
def _identify_new_messages(
633+
filtered_messages: list[dict], previous_chat_json: Optional[Union[str, dict]]
634+
) -> list[dict]:
635+
"""Identify which messages are new by comparing with previous chat data."""
636+
if not previous_chat_json:
637+
return filtered_messages
638+
639+
previous_chat_data = (
640+
json.loads(previous_chat_json)
641+
if isinstance(previous_chat_json, str)
642+
else previous_chat_json
643+
)
644+
previous_messages = previous_chat_data.get("chat_history", [])
645+
646+
# Get set of existing message IDs from previous chat
647+
existing_message_ids = {
648+
msg.get("id")
649+
for msg in previous_messages
650+
if _should_create_checkpoint(msg) and msg.get("id")
651+
}
652+
653+
# Find messages with IDs that don't exist in previous chat
654+
return [
655+
msg for msg in filtered_messages if msg.get("id") not in existing_message_ids
656+
]
657+
658+
659+
def _create_langchain_message(message: dict) -> dict:
660+
"""Create a message in LangChain format."""
661+
message_id = str(uuid4())
662+
return {
663+
"id": ["langchain", "schema", "messages", message["type"]],
664+
"lc": 1,
665+
"type": "constructor",
666+
"kwargs": {
667+
"id": message_id,
668+
"type": message["type"].lower().replace("message", ""),
669+
"content": message["content"],
670+
},
671+
}
672+
673+
674+
def _create_checkpoint_data(checkpoint_id: str, step: int, chat_data: dict) -> dict:
675+
"""Create the checkpoint data structure."""
676+
return {
677+
"v": 4,
678+
"id": checkpoint_id,
679+
"ts": now_in_utc().isoformat(),
680+
"pending_sends": [],
681+
"versions_seen": {
682+
"__input__": {},
683+
"__start__": {"__start__": step + 1} if step >= 0 else {},
684+
},
685+
"channel_values": {
686+
"messages": chat_data.get("chat_history", []),
687+
# Preserve tutor-specific data
688+
"intent_history": chat_data.get("intent_history"),
689+
"assessment_history": chat_data.get("assessment_history"),
690+
# Include metadata for reference
691+
"tutor_metadata": chat_data.get("metadata", {}),
692+
# Add other channel values that might be needed
693+
"branch:to:pre_model_hook": None,
694+
},
695+
"channel_versions": {"messages": len(chat_data.get("messages", []))},
696+
}
697+
698+
699+
def _create_checkpoint_metadata(
700+
tutor_meta: dict, message: dict, step: int, thread_id: str
701+
) -> dict:
702+
"""Create metadata for the checkpoint based on message type."""
703+
source = (
704+
"input" if message.get("kwargs", {}).get("type") == "HumanMessage" else "loop"
705+
)
706+
writes = {"__start__": {"messages": [message], **tutor_meta}}
707+
708+
return {
709+
"step": step,
710+
"source": source,
711+
"writes": writes,
712+
"parents": {},
713+
"thread_id": thread_id,
714+
}
715+
716+
717+
def create_tutor_checkpoints(
718+
thread_id: str,
719+
chat_json: Union[str, dict],
720+
previous_chat_json: Optional[Union[str, dict]] = None,
721+
) -> list[DjangoCheckpoint]:
722+
"""Create DjangoCheckpoint records from tutor chat data (synchronous)"""
723+
# Get the associated session
724+
try:
725+
session = UserChatSession.objects.get(thread_id=thread_id)
726+
except UserChatSession.DoesNotExist:
727+
return []
728+
729+
# Parse and validate chat data
730+
chat_data = json.loads(chat_json) if isinstance(chat_json, str) else chat_json
731+
messages = chat_data.get("chat_history", [])
732+
if not messages:
733+
return []
734+
735+
# Filter out ToolMessage types and AI messages with tool_calls
736+
filtered_messages = [msg for msg in messages if _should_create_checkpoint(msg)]
737+
if not filtered_messages:
738+
return []
739+
740+
# Get previous checkpoint if any
741+
latest_checkpoint = (
742+
DjangoCheckpoint.objects.filter(
743+
thread_id=thread_id,
744+
checkpoint__channel_values__tutor_metadata__isnull=False,
745+
)
746+
.only("checkpoint_id")
747+
.order_by("-id")
748+
.first()
749+
)
750+
parent_checkpoint_id = (
751+
latest_checkpoint.checkpoint_id if latest_checkpoint else None
752+
)
753+
754+
# Determine new messages by comparing message IDs
755+
new_messages = _identify_new_messages(filtered_messages, previous_chat_json)
756+
if not new_messages:
757+
return [] # No new messages to checkpoint
758+
759+
# Calculate starting step based on length of previous chat history
760+
previous_messages = (
761+
json.loads(previous_chat_json).get("chat_history", [])
762+
if previous_chat_json
763+
else []
764+
)
765+
step = len(previous_messages)
766+
checkpoints_created = []
767+
768+
# Create checkpoints only for the NEW messages
769+
for message in new_messages:
770+
checkpoint_id = str(uuid4())
771+
772+
# Create checkpoint data structure
773+
checkpoint_data = _create_checkpoint_data(checkpoint_id, step, chat_data)
774+
775+
# Create message with LangChain format and add to cumulative history
776+
langchain_message = _create_langchain_message(message)
777+
778+
# Create metadata for this step
779+
metadata = _create_checkpoint_metadata(
780+
chat_data.get("metadata", {}), langchain_message, step, thread_id
781+
)
782+
783+
# Create and save the checkpoint
784+
checkpoint, _ = DjangoCheckpoint.objects.update_or_create(
785+
session=session,
786+
thread_id=thread_id,
787+
checkpoint_id=checkpoint_id,
788+
defaults={
789+
"checkpoint_ns": "",
790+
"parent_checkpoint_id": parent_checkpoint_id,
791+
"type": "msgpack",
792+
"checkpoint": checkpoint_data,
793+
"metadata": metadata,
794+
},
795+
)
796+
parent_checkpoint_id = checkpoint_id
797+
checkpoints_created.append(checkpoint)
798+
step += 1
799+
800+
return checkpoints_created

0 commit comments

Comments
 (0)