Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 170 additions & 26 deletions ai_chatbots/chatbots.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ async def get_completion(
full_response = ""
new_history = []
try:
generator, new_intent_history, new_assessment_history = message_tutor(
result = message_tutor(
self.problem,
self.problem_set,
self.llm,
Expand All @@ -618,36 +618,180 @@ async def get_completion(
variant=self.variant,
)

async for chunk in generator:
# the generator yields message chuncks for a streaming resopnse
# then finally yields the full response as the last chunk
if (
chunk[0] == "messages"
and chunk[1]
and isinstance(chunk[1][0], AIMessageChunk)
):
full_response += chunk[1][0].content
yield chunk[1][0].content
# Handle A/B testing responses
if isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], dict) and result[0].get("is_ab_test"):
async for response_chunk in self._handle_ab_test_response(result, message):
yield response_chunk
else:
# Normal single response - backward compatibility
generator, new_intent_history, new_assessment_history = result

async for chunk in generator:
# the generator yields message chuncks for a streaming resopnse
# then finally yields the full response as the last chunk
if (
chunk[0] == "messages"
and chunk[1]
and isinstance(chunk[1][0], AIMessageChunk)
):
full_response += chunk[1][0].content
yield chunk[1][0].content

elif chunk[0] == "values":
new_history = filter_out_system_messages(chunk[1]["messages"])

metadata = {
"edx_module_id": self.edx_module_id,
"tutor_model": self.model,
"problem_set_title": self.problem_set_title,
"run_readable_id": self.run_readable_id,
}
json_output = tutor_output_to_json(
new_history, new_intent_history, new_assessment_history, metadata
)
await create_tutorbot_output(
self.thread_id, json_output, self.edx_module_id
)

elif chunk[0] == "values":
new_history = filter_out_system_messages(chunk[1]["messages"])
except Exception:
yield '<!-- {"error":{"message":"An error occurred, please try again"}} -->'
log.exception("Error running AI agent")

metadata = {
async def _handle_ab_test_response(self, result, original_message: str) -> AsyncGenerator[str, None]:
"""Handle A/B test responses by collecting both variants and yielding structured data"""
ab_test_data, new_intent_history, new_assessment_history = result

# Collect both responses completely
control_response = ""
treatment_response = ""
control_history = []
treatment_history = []

# Process control variant
control_generator = ab_test_data["responses"][0]["stream"]
async for chunk in control_generator:
if (
chunk[0] == "messages"
and chunk[1]
and isinstance(chunk[1][0], AIMessageChunk)
):
control_response += chunk[1][0].content
elif chunk[0] == "values":
control_history = filter_out_system_messages(chunk[1]["messages"])

# Process treatment variant
treatment_generator = ab_test_data["responses"][1]["stream"]
async for chunk in treatment_generator:
if (
chunk[0] == "messages"
and chunk[1]
and isinstance(chunk[1][0], AIMessageChunk)
):
treatment_response += chunk[1][0].content
elif chunk[0] == "values":
treatment_history = filter_out_system_messages(chunk[1]["messages"])

# Convert message objects to serializable format
def serialize_messages(messages):
"""Convert message objects to serializable format"""
serialized = []
for msg in messages:
if hasattr(msg, 'content'):
serialized.append({
"type": msg.__class__.__name__,
"content": msg.content
})
else:
serialized.append(str(msg))
return serialized

def serialize_intent_history(intent_history):
"""Convert intent history to serializable format"""
serialized = []
for intent_data in intent_history:
if isinstance(intent_data, dict):
# If it's already a dict, make sure all values are serializable
serialized_intent = {}
for key, value in intent_data.items():
if hasattr(value, '__dict__'):
serialized_intent[key] = str(value)
else:
serialized_intent[key] = value
serialized.append(serialized_intent)
else:
serialized.append(str(intent_data))
return serialized

# Create A/B test response structure for frontend
ab_response = {
"type": "ab_test_response",
"control": {
"content": control_response,
"variant": "control"
},
"treatment": {
"content": treatment_response,
"variant": "treatment"
},
"metadata": {
"test_name": "tutor_problem", # Could be extracted from ab_test_data if needed
"thread_id": self.thread_id,
"original_message": original_message,
"edx_module_id": self.edx_module_id,
"tutor_model": self.model,
"problem_set_title": self.problem_set_title,
"run_readable_id": self.run_readable_id,
}
json_output = tutor_output_to_json(
new_history, new_intent_history, new_assessment_history, metadata
)
await create_tutorbot_output(
self.thread_id, json_output, self.edx_module_id
)

except Exception:
yield '<!-- {"error":{"message":"An error occurred, please try again"}} -->'
log.exception("Error running AI agent")
},
# Store histories for when user makes choice (serialized)
"_control_history": serialize_messages(control_history),
"_treatment_history": serialize_messages(treatment_history),
"_intent_history": serialize_intent_history(new_intent_history),
"_assessment_history": serialize_messages(new_assessment_history),
}

# Yield the structured A/B test response as JSON
yield f'<!-- {json.dumps(ab_response)} -->'

async def save_ab_test_choice(self, ab_response_data: dict, chosen_variant: str, user_preference_reason: str = ""):
"""Save the user's A/B test choice and update chat history"""

# Get the chosen response data
chosen_response_data = ab_response_data[chosen_variant]
chosen_content = chosen_response_data["content"]

# Get the appropriate history based on choice
if chosen_variant == "control":
new_history = ab_response_data["_control_history"]
else:
new_history = ab_response_data["_treatment_history"]

# Get other data
new_intent_history = ab_response_data["_intent_history"]
new_assessment_history = ab_response_data["_assessment_history"]

# Create metadata including A/B test information
metadata = {
"edx_module_id": self.edx_module_id,
"tutor_model": self.model,
"problem_set_title": self.problem_set_title,
"run_readable_id": self.run_readable_id,
"ab_test_chosen_variant": chosen_variant,
"ab_test_metadata": ab_response_data["metadata"],
"user_preference_reason": user_preference_reason,
}

# Save to database
json_output = tutor_output_to_json(
new_history, new_intent_history, new_assessment_history, metadata
)
await create_tutorbot_output(
self.thread_id, json_output, self.edx_module_id
)

return {
"success": True,
"chosen_content": chosen_content,
"variant": chosen_variant,
}


def get_problem_from_edx_block(edx_module_id: str, block_siblings: list[str]):
Expand Down
78 changes: 78 additions & 0 deletions ai_chatbots/chatbots_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,3 +850,81 @@ async def test_bad_request(mocker, mock_checkpointer):
async for _ in chatbot.get_completion("hello"):
chatbot.agent.astream.assert_called_once()
mock_log.assert_called_once_with("Bad request error")


async def test_tutor_bot_ab_testing(mocker, mock_checkpointer):
"""Test that TutorBot properly handles A/B testing responses."""
# Mock the A/B test response from message_tutor
mock_control_generator = MockAsyncIterator([
("messages", [AIMessageChunkFactory(content="Control response part 1")]),
("messages", [AIMessageChunkFactory(content=" Control response part 2")]),
("values", {"messages": [HumanMessage(content="test"), AIMessage(content="Control response part 1 Control response part 2")]}),
])

mock_treatment_generator = MockAsyncIterator([
("messages", [AIMessageChunkFactory(content="Treatment response part 1")]),
("messages", [AIMessageChunkFactory(content=" Treatment response part 2")]),
("values", {"messages": [HumanMessage(content="test"), AIMessage(content="Treatment response part 1 Treatment response part 2")]}),
])

ab_test_response = {
"is_ab_test": True,
"responses": [
{"variant": "control", "stream": mock_control_generator},
{"variant": "treatment", "stream": mock_treatment_generator}
]
}

# Mock message_tutor to return A/B test response
mock_message_tutor = mocker.patch("ai_chatbots.chatbots.message_tutor")
mock_message_tutor.return_value = (
ab_test_response,
[[Intent.S_STRATEGY]], # new_intent_history
[HumanMessage(content="test"), AIMessage(content="assessment")] # new_assessment_history
)

# Mock get_history to return None (new conversation)
mocker.patch("ai_chatbots.chatbots.get_history", return_value=None)

# Create TutorBot instance
tutor_bot = TutorBot(
user_id="test_user",
checkpointer=mock_checkpointer,
thread_id="test_thread",
problem_set_title="Test Problem Set",
run_readable_id="test_run",
)

# Mock the callback setup
tutor_bot.llm.callbacks = []
mock_get_tool_metadata = mocker.patch.object(tutor_bot, "get_tool_metadata")
mock_get_tool_metadata.return_value = '{"test": "metadata"}'
mock_set_callbacks = mocker.patch.object(tutor_bot, "set_callbacks")
mock_set_callbacks.return_value = []

# Test the completion
responses = []
async for response_chunk in tutor_bot.get_completion("What should I try first?"):
responses.append(response_chunk)

# Should get exactly one response with A/B test structure
assert len(responses) == 1

# Parse the JSON response
import json
ab_response_json = responses[0].replace('<!-- ', '').replace(' -->', '')
ab_response_data = json.loads(ab_response_json)

# Verify A/B test structure
assert ab_response_data["type"] == "ab_test_response"
assert "control" in ab_response_data
assert "treatment" in ab_response_data
assert ab_response_data["control"]["content"] == "Control response part 1 Control response part 2"
assert ab_response_data["control"]["variant"] == "control"
assert ab_response_data["treatment"]["content"] == "Treatment response part 1 Treatment response part 2"
assert ab_response_data["treatment"]["variant"] == "treatment"

# Verify metadata is included
assert "metadata" in ab_response_data
assert ab_response_data["metadata"]["thread_id"] == "test_thread"
assert ab_response_data["metadata"]["problem_set_title"] == "Test Problem Set"
15 changes: 15 additions & 0 deletions ai_chatbots/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,21 @@ def get_object_id_field(self, obj):
return f"{obj.get('run_readable_id', '')} - {obj.get('problem_set_title', '')}"


class ABTestChoiceSerializer(serializers.Serializer):
"""
Serializer for A/B test choice submissions.
"""

thread_id = serializers.CharField(required=True, allow_blank=False)
chosen_variant = serializers.ChoiceField(choices=["control", "treatment"], required=True)
ab_response_data = serializers.JSONField(required=True)
user_preference_reason = serializers.CharField(required=False, allow_blank=True, default="")

# Canvas-specific fields to identify the chatbot
problem_set_title = serializers.CharField(required=True, allow_blank=False)
run_readable_id = serializers.CharField(required=True, allow_blank=False)


class LLMModelSerializer(serializers.ModelSerializer):
class Meta:
model = LLMModel
Expand Down
5 changes: 5 additions & 0 deletions ai_chatbots/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
views.ProblemSetList.as_view(),
name="problem_set_list",
),
path(
r"ab_test_choice/",
views.ABTestChoiceView.as_view(),
name="ab_test_choice",
),
]

urlpatterns = [
Expand Down
51 changes: 51 additions & 0 deletions ai_chatbots/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ai_chatbots.permissions import IsThreadOwner
from ai_chatbots.prompts import CHATBOT_PROMPT_MAPPING
from ai_chatbots.serializers import (
ABTestChoiceSerializer,
ChatMessageSerializer,
LLMModelSerializer,
SystemPromptSerializer,
Expand Down Expand Up @@ -316,3 +317,53 @@ def retrieve(self, request, *args, **kwargs): # noqa: ARG002
instance = self.get_object()
serializer = self.get_serializer(instance)
return Response(serializer.data)


@extend_schema(
request=ABTestChoiceSerializer,
responses={
200: OpenApiResponse(description="A/B test choice saved successfully"),
400: OpenApiResponse(description="Invalid request data"),
500: OpenApiResponse(description="Error saving choice"),
},
)
class ABTestChoiceView(ApiView):
"""
API endpoint to save user's A/B test choice and update chat history.
"""

http_method_names = ["post"]
permission_classes = (AllowAny,) # Can change to IsAuthenticated if needed

def post(self, request, *args, **kwargs): # noqa: ARG002
"""Save user's A/B test choice."""
serializer = ABTestChoiceSerializer(data=request.data)

if not serializer.is_valid():
return Response(serializer.errors, status=400)

try:
# Extract validated data
thread_id = serializer.validated_data["thread_id"]
chosen_variant = serializer.validated_data["chosen_variant"]
ab_response_data = serializer.validated_data["ab_response_data"]
user_preference_reason = serializer.validated_data.get("user_preference_reason", "")
problem_set_title = serializer.validated_data["problem_set_title"]
run_readable_id = serializer.validated_data["run_readable_id"]

# For now, just return success without saving to database
# This will be implemented later when we have proper async handling
chosen_response_data = ab_response_data[chosen_variant]
chosen_content = chosen_response_data["content"]

return Response({
"success": True,
"message": "A/B test choice received successfully",
"chosen_variant": chosen_variant,
"chosen_content": chosen_content,
}, status=200)

except Exception as e:
return Response({
"error": f"Failed to process A/B test choice: {str(e)}"
}, status=500)
Loading