|
2 | 2 |
|
3 | 3 | import json
|
4 | 4 | import logging
|
5 |
| -from typing import Any, Optional, cast |
| 5 | +from typing import Any, Optional, Union, cast |
6 | 6 | from uuid import UUID, uuid4
|
7 | 7 |
|
8 | 8 | import litellm
|
| 9 | +from channels.db import database_sync_to_async |
9 | 10 | from django.conf import settings
|
| 11 | +from django.db import transaction |
10 | 12 | from langchain_core.language_models import LanguageModelLike
|
11 | 13 | from langchain_core.messages import (
|
12 | 14 | AIMessage,
|
|
35 | 37 | from pydantic import BaseModel
|
36 | 38 | from typing_extensions import TypedDict
|
37 | 39 |
|
| 40 | +from ai_chatbots.models import DjangoCheckpoint, TutorBotOutput, UserChatSession |
| 41 | +from main.utils import now_in_utc |
| 42 | + |
38 | 43 | log = logging.getLogger(__name__)
|
39 | 44 |
|
40 | 45 |
|
@@ -586,3 +591,210 @@ def on_llm_end(
|
586 | 591 | super().on_llm_end(
|
587 | 592 | response, run_id=run_id, parent_run_id=parent_run_id, **kwargs
|
588 | 593 | )
|
| 594 | + |
| 595 | + |
| 596 | +@database_sync_to_async |
| 597 | +def query_tutorbot_output(thread_id: str) -> Optional[TutorBotOutput]: |
| 598 | + """Return the latest TutorBotOutput for a given thread_id""" |
| 599 | + return TutorBotOutput.objects.filter(thread_id=thread_id).last() |
| 600 | + |
| 601 | + |
| 602 | +@database_sync_to_async |
| 603 | +def create_tutorbot_output_and_checkpoints( |
| 604 | + thread_id: str, chat_json: Union[str, dict], edx_module_id: Optional[str] |
| 605 | +) -> tuple[TutorBotOutput, list[DjangoCheckpoint]]: |
| 606 | + """Atomically create both TutorBotOutput and DjangoCheckpoint objects""" |
| 607 | + with transaction.atomic(): |
| 608 | + # Get the previous TutorBotOutput to compare messages |
| 609 | + previous_output = ( |
| 610 | + TutorBotOutput.objects.filter(thread_id=thread_id).order_by("-id").first() |
| 611 | + ) |
| 612 | + previous_chat_json = previous_output.chat_json if previous_output else None |
| 613 | + |
| 614 | + # Create TutorBotOutput |
| 615 | + tutorbot_output = TutorBotOutput.objects.create( |
| 616 | + thread_id=thread_id, |
| 617 | + chat_json=chat_json, |
| 618 | + edx_module_id=edx_module_id or "", |
| 619 | + ) |
| 620 | + |
| 621 | + checkpoints = create_tutor_checkpoints(thread_id, chat_json, previous_chat_json) |
| 622 | + |
| 623 | + return tutorbot_output, checkpoints |
| 624 | + |
| 625 | + |
| 626 | +def _should_create_checkpoint(msg: dict) -> bool: |
| 627 | + """Determine if a message should have a checkpoint created for it.""" |
| 628 | + # Skip ToolMessage type or tool_calls |
| 629 | + return not (msg.get("type") == "ToolMessage" or msg.get("tool_calls")) |
| 630 | + |
| 631 | + |
| 632 | +def _identify_new_messages( |
| 633 | + filtered_messages: list[dict], previous_chat_json: Optional[Union[str, dict]] |
| 634 | +) -> list[dict]: |
| 635 | + """Identify which messages are new by comparing with previous chat data.""" |
| 636 | + if not previous_chat_json: |
| 637 | + return filtered_messages |
| 638 | + |
| 639 | + previous_chat_data = ( |
| 640 | + json.loads(previous_chat_json) |
| 641 | + if isinstance(previous_chat_json, str) |
| 642 | + else previous_chat_json |
| 643 | + ) |
| 644 | + previous_messages = previous_chat_data.get("chat_history", []) |
| 645 | + |
| 646 | + # Get set of existing message IDs from previous chat |
| 647 | + existing_message_ids = { |
| 648 | + msg.get("id") |
| 649 | + for msg in previous_messages |
| 650 | + if _should_create_checkpoint(msg) and msg.get("id") |
| 651 | + } |
| 652 | + |
| 653 | + # Find messages with IDs that don't exist in previous chat |
| 654 | + return [ |
| 655 | + msg for msg in filtered_messages if msg.get("id") not in existing_message_ids |
| 656 | + ] |
| 657 | + |
| 658 | + |
| 659 | +def _create_langchain_message(message: dict) -> dict: |
| 660 | + """Create a message in LangChain format.""" |
| 661 | + message_id = str(uuid4()) |
| 662 | + return { |
| 663 | + "id": ["langchain", "schema", "messages", message["type"]], |
| 664 | + "lc": 1, |
| 665 | + "type": "constructor", |
| 666 | + "kwargs": { |
| 667 | + "id": message_id, |
| 668 | + "type": message["type"].lower().replace("message", ""), |
| 669 | + "content": message["content"], |
| 670 | + }, |
| 671 | + } |
| 672 | + |
| 673 | + |
| 674 | +def _create_checkpoint_data(checkpoint_id: str, step: int, chat_data: dict) -> dict: |
| 675 | + """Create the checkpoint data structure.""" |
| 676 | + return { |
| 677 | + "v": 4, |
| 678 | + "id": checkpoint_id, |
| 679 | + "ts": now_in_utc().isoformat(), |
| 680 | + "pending_sends": [], |
| 681 | + "versions_seen": { |
| 682 | + "__input__": {}, |
| 683 | + "__start__": {"__start__": step + 1} if step >= 0 else {}, |
| 684 | + }, |
| 685 | + "channel_values": { |
| 686 | + "messages": chat_data.get("chat_history", []), |
| 687 | + # Preserve tutor-specific data |
| 688 | + "intent_history": chat_data.get("intent_history"), |
| 689 | + "assessment_history": chat_data.get("assessment_history"), |
| 690 | + # Include metadata for reference |
| 691 | + "tutor_metadata": chat_data.get("metadata", {}), |
| 692 | + # Add other channel values that might be needed |
| 693 | + "branch:to:pre_model_hook": None, |
| 694 | + }, |
| 695 | + "channel_versions": {"messages": len(chat_data.get("messages", []))}, |
| 696 | + } |
| 697 | + |
| 698 | + |
| 699 | +def _create_checkpoint_metadata( |
| 700 | + tutor_meta: dict, message: dict, step: int, thread_id: str |
| 701 | +) -> dict: |
| 702 | + """Create metadata for the checkpoint based on message type.""" |
| 703 | + source = ( |
| 704 | + "input" if message.get("kwargs", {}).get("type") == "HumanMessage" else "loop" |
| 705 | + ) |
| 706 | + writes = {"__start__": {"messages": [message], **tutor_meta}} |
| 707 | + |
| 708 | + return { |
| 709 | + "step": step, |
| 710 | + "source": source, |
| 711 | + "writes": writes, |
| 712 | + "parents": {}, |
| 713 | + "thread_id": thread_id, |
| 714 | + } |
| 715 | + |
| 716 | + |
| 717 | +def create_tutor_checkpoints( |
| 718 | + thread_id: str, |
| 719 | + chat_json: Union[str, dict], |
| 720 | + previous_chat_json: Optional[Union[str, dict]] = None, |
| 721 | +) -> list[DjangoCheckpoint]: |
| 722 | + """Create DjangoCheckpoint records from tutor chat data (synchronous)""" |
| 723 | + # Get the associated session |
| 724 | + try: |
| 725 | + session = UserChatSession.objects.get(thread_id=thread_id) |
| 726 | + except UserChatSession.DoesNotExist: |
| 727 | + return [] |
| 728 | + |
| 729 | + # Parse and validate chat data |
| 730 | + chat_data = json.loads(chat_json) if isinstance(chat_json, str) else chat_json |
| 731 | + messages = chat_data.get("chat_history", []) |
| 732 | + if not messages: |
| 733 | + return [] |
| 734 | + |
| 735 | + # Filter out ToolMessage types and AI messages with tool_calls |
| 736 | + filtered_messages = [msg for msg in messages if _should_create_checkpoint(msg)] |
| 737 | + if not filtered_messages: |
| 738 | + return [] |
| 739 | + |
| 740 | + # Get previous checkpoint if any |
| 741 | + latest_checkpoint = ( |
| 742 | + DjangoCheckpoint.objects.filter( |
| 743 | + thread_id=thread_id, |
| 744 | + checkpoint__channel_values__tutor_metadata__isnull=False, |
| 745 | + ) |
| 746 | + .only("checkpoint_id") |
| 747 | + .order_by("-id") |
| 748 | + .first() |
| 749 | + ) |
| 750 | + parent_checkpoint_id = ( |
| 751 | + latest_checkpoint.checkpoint_id if latest_checkpoint else None |
| 752 | + ) |
| 753 | + |
| 754 | + # Determine new messages by comparing message IDs |
| 755 | + new_messages = _identify_new_messages(filtered_messages, previous_chat_json) |
| 756 | + if not new_messages: |
| 757 | + return [] # No new messages to checkpoint |
| 758 | + |
| 759 | + # Calculate starting step based on length of previous chat history |
| 760 | + previous_messages = ( |
| 761 | + json.loads(previous_chat_json).get("chat_history", []) |
| 762 | + if previous_chat_json |
| 763 | + else [] |
| 764 | + ) |
| 765 | + step = len(previous_messages) |
| 766 | + checkpoints_created = [] |
| 767 | + |
| 768 | + # Create checkpoints only for the NEW messages |
| 769 | + for message in new_messages: |
| 770 | + checkpoint_id = str(uuid4()) |
| 771 | + |
| 772 | + # Create checkpoint data structure |
| 773 | + checkpoint_data = _create_checkpoint_data(checkpoint_id, step, chat_data) |
| 774 | + |
| 775 | + # Create message with LangChain format and add to cumulative history |
| 776 | + langchain_message = _create_langchain_message(message) |
| 777 | + |
| 778 | + # Create metadata for this step |
| 779 | + metadata = _create_checkpoint_metadata( |
| 780 | + chat_data.get("metadata", {}), langchain_message, step, thread_id |
| 781 | + ) |
| 782 | + |
| 783 | + # Create and save the checkpoint |
| 784 | + checkpoint, _ = DjangoCheckpoint.objects.update_or_create( |
| 785 | + session=session, |
| 786 | + thread_id=thread_id, |
| 787 | + checkpoint_id=checkpoint_id, |
| 788 | + defaults={ |
| 789 | + "checkpoint_ns": "", |
| 790 | + "parent_checkpoint_id": parent_checkpoint_id, |
| 791 | + "type": "msgpack", |
| 792 | + "checkpoint": checkpoint_data, |
| 793 | + "metadata": metadata, |
| 794 | + }, |
| 795 | + ) |
| 796 | + parent_checkpoint_id = checkpoint_id |
| 797 | + checkpoints_created.append(checkpoint) |
| 798 | + step += 1 |
| 799 | + |
| 800 | + return checkpoints_created |
0 commit comments