Skip to content

Commit 4e06714

Browse files
authored
Fix: deep research use case (#493)
1 parent 18c8d25 commit 4e06714

File tree

8 files changed

+275
-81
lines changed

8 files changed

+275
-81
lines changed

.changeset/chilly-bats-smile.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"create-llama": patch
3+
---
4+
5+
Fix the error: Unable to view file sources due to CORS.

templates/components/agents/python/deep_research/app/workflows/agents.py

+33-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ class AnalysisDecision(BaseModel):
1616
description="Whether to continue research, write a report, or cancel the research after several retries"
1717
)
1818
research_questions: Optional[List[str]] = Field(
19-
description="Questions to research if continuing research. Maximum 3 questions. Set to null or empty if writing a report.",
19+
description="""
20+
If the decision is to research, provide a list of questions to research that related to the user request.
21+
Maximum 3 questions. Set to null or empty if writing a report or cancel the research.
22+
""",
2023
default_factory=list,
2124
)
2225
cancel_reason: Optional[str] = Field(
@@ -29,32 +32,53 @@ async def plan_research(
2932
memory: SimpleComposableMemory,
3033
context_nodes: List[Node],
3134
user_request: str,
35+
total_questions: int,
3236
) -> AnalysisDecision:
33-
analyze_prompt = PromptTemplate(
34-
"""
37+
analyze_prompt = """
3538
You are a professor who is guiding a researcher to research a specific request/problem.
3639
Your task is to decide on a research plan for the researcher.
40+
3741
The possible actions are:
3842
+ Provide a list of questions for the researcher to investigate, with the purpose of clarifying the request.
3943
+ Write a report if the researcher has already gathered enough research on the topic and can resolve the initial request.
4044
+ Cancel the research if most of the answers from researchers indicate there is insufficient information to research the request. Do not attempt more than 3 research iterations or too many questions.
45+
4146
The workflow should be:
4247
+ Always begin by providing some initial questions for the researcher to investigate.
4348
+ Analyze the provided answers against the initial topic/request. If the answers are insufficient to resolve the initial request, provide additional questions for the researcher to investigate.
4449
+ If the answers are sufficient to resolve the initial request, instruct the researcher to write a report.
45-
<User request>
46-
{user_request}
47-
</User request>
4850
51+
Here are the context:
4952
<Collected information>
5053
{context_str}
5154
</Collected information>
5255
5356
<Conversation context>
5457
{conversation_context}
5558
</Conversation context>
59+
60+
{enhanced_prompt}
61+
62+
Now, provide your decision in the required format for this user request:
63+
<User request>
64+
{user_request}
65+
</User request>
5666
"""
57-
)
67+
# Manually craft the prompt to avoid LLM hallucination
68+
enhanced_prompt = ""
69+
if total_questions == 0:
70+
# Avoid writing a report without any research context
71+
enhanced_prompt = """
72+
73+
The student has no questions to research. Let start by asking some questions.
74+
"""
75+
elif total_questions > 6:
76+
# Avoid asking too many questions (when the data is not ready for writing a report)
77+
enhanced_prompt = f"""
78+
79+
The student has researched {total_questions} questions. Should cancel the research if the context is not enough to write a report.
80+
"""
81+
5882
conversation_context = "\n".join(
5983
[f"{message.role}: {message.content}" for message in memory.get_all()]
6084
)
@@ -63,10 +87,11 @@ async def plan_research(
6387
)
6488
res = await Settings.llm.astructured_predict(
6589
output_cls=AnalysisDecision,
66-
prompt=analyze_prompt,
90+
prompt=PromptTemplate(template=analyze_prompt),
6791
user_request=user_request,
6892
context_str=context_str,
6993
conversation_context=conversation_context,
94+
enhanced_prompt=enhanced_prompt,
7095
)
7196
return res
7297

templates/components/agents/python/deep_research/app/workflows/deep_research.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,11 @@ def __init__(
8989
)
9090

9191
@step
92-
def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
92+
async def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
9393
"""
9494
Initiate the workflow: memory, tools, agent
9595
"""
96+
await ctx.set("total_questions", 0)
9697
self.user_request = ev.get("input")
9798
self.memory.put_messages(
9899
messages=[
@@ -132,9 +133,7 @@ def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
132133
nodes=nodes,
133134
)
134135
)
135-
return PlanResearchEvent(
136-
context_nodes=self.context_nodes,
137-
)
136+
return PlanResearchEvent()
138137

139138
@step
140139
async def analyze(
@@ -153,10 +152,12 @@ async def analyze(
153152
},
154153
)
155154
)
155+
total_questions = await ctx.get("total_questions")
156156
res = await plan_research(
157157
memory=self.memory,
158158
context_nodes=self.context_nodes,
159159
user_request=self.user_request,
160+
total_questions=total_questions,
160161
)
161162
if res.decision == "cancel":
162163
ctx.write_event_to_stream(
@@ -172,6 +173,22 @@ async def analyze(
172173
result=res.cancel_reason,
173174
)
174175
elif res.decision == "write":
176+
# Writing a report without any research context is not allowed.
177+
# It's a LLM hallucination.
178+
if total_questions == 0:
179+
ctx.write_event_to_stream(
180+
DataEvent(
181+
type="deep_research_event",
182+
data={
183+
"event": "analyze",
184+
"state": "done",
185+
},
186+
)
187+
)
188+
return StopEvent(
189+
result="Sorry, I have a problem when analyzing the retrieved information. Please try again.",
190+
)
191+
175192
self.memory.put(
176193
message=ChatMessage(
177194
role=MessageRole.ASSISTANT,
@@ -180,7 +197,11 @@ async def analyze(
180197
)
181198
ctx.send_event(ReportEvent())
182199
else:
183-
await ctx.set("n_questions", len(res.research_questions))
200+
total_questions += len(res.research_questions)
201+
await ctx.set("total_questions", total_questions) # For tracking
202+
await ctx.set(
203+
"waiting_questions", len(res.research_questions)
204+
) # For waiting questions to be answered
184205
self.memory.put(
185206
message=ChatMessage(
186207
role=MessageRole.ASSISTANT,
@@ -270,7 +291,7 @@ async def collect_answers(
270291
"""
271292
Collect answers to all questions
272293
"""
273-
num_questions = await ctx.get("n_questions")
294+
num_questions = await ctx.get("waiting_questions")
274295
results = ctx.collect_events(
275296
ev,
276297
expected=[CollectAnswersEvent] * num_questions,
@@ -284,7 +305,7 @@ async def collect_answers(
284305
content=f"<Question>{result.question}</Question>\n<Answer>{result.answer}</Answer>",
285306
)
286307
)
287-
await ctx.set("n_questions", 0)
308+
await ctx.set("waiting_questions", 0)
288309
self.memory.put(
289310
message=ChatMessage(
290311
role=MessageRole.ASSISTANT,

templates/types/streaming/fastapi/main.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
# flake8: noqa: E402
2-
from app.config import DATA_DIR, STATIC_DIR
32
from dotenv import load_dotenv
43

4+
from app.config import DATA_DIR, STATIC_DIR
5+
56
load_dotenv()
67

78
import logging
89
import os
910

1011
import uvicorn
12+
from fastapi import FastAPI
13+
from fastapi.middleware.cors import CORSMiddleware
14+
from fastapi.responses import RedirectResponse
15+
from fastapi.staticfiles import StaticFiles
16+
1117
from app.api.routers import api_router
1218
from app.middlewares.frontend import FrontendProxyMiddleware
1319
from app.observability import init_observability
1420
from app.settings import init_settings
15-
from fastapi import FastAPI
16-
from fastapi.responses import RedirectResponse
17-
from fastapi.staticfiles import StaticFiles
1821

1922
servers = []
2023
app_name = os.getenv("FLY_APP_NAME")
@@ -28,6 +31,16 @@
2831
environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
2932
logger = logging.getLogger("uvicorn")
3033

34+
# Add CORS middleware for development
35+
if environment == "dev":
36+
app.add_middleware(
37+
CORSMiddleware,
38+
allow_origin_regex="http://localhost:\d+|http://0\.0\.0\.0:\d+",
39+
allow_credentials=True,
40+
allow_methods=["*"],
41+
allow_headers=["*"],
42+
)
43+
3144

3245
def mount_static_files(directory, path, html=False):
3346
if os.path.exists(directory):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"use client";
2+
3+
import * as AccordionPrimitive from "@radix-ui/react-accordion";
4+
import { ChevronDown } from "lucide-react";
5+
import * as React from "react";
6+
import { cn } from "./lib/utils";
7+
8+
const Accordion = AccordionPrimitive.Root;
9+
10+
const AccordionItem = React.forwardRef<
11+
React.ElementRef<typeof AccordionPrimitive.Item>,
12+
React.ComponentPropsWithoutRef<typeof AccordionPrimitive.Item>
13+
>(({ className, ...props }, ref) => (
14+
<AccordionPrimitive.Item
15+
ref={ref}
16+
className={cn("border-b", className)}
17+
{...props}
18+
/>
19+
));
20+
AccordionItem.displayName = "AccordionItem";
21+
22+
const AccordionTrigger = React.forwardRef<
23+
React.ElementRef<typeof AccordionPrimitive.Trigger>,
24+
React.ComponentPropsWithoutRef<typeof AccordionPrimitive.Trigger>
25+
>(({ className, children, ...props }, ref) => (
26+
<AccordionPrimitive.Header className="flex">
27+
<AccordionPrimitive.Trigger
28+
ref={ref}
29+
className={cn(
30+
"flex flex-1 items-center justify-between py-4 text-sm font-medium transition-all hover:underline text-left [&[data-state=open]>svg]:rotate-180",
31+
className,
32+
)}
33+
{...props}
34+
>
35+
{children}
36+
<ChevronDown className="h-4 w-4 shrink-0 text-neutral-500 transition-transform duration-200 dark:text-neutral-400" />
37+
</AccordionPrimitive.Trigger>
38+
</AccordionPrimitive.Header>
39+
));
40+
AccordionTrigger.displayName = AccordionPrimitive.Trigger.displayName;
41+
42+
const AccordionContent = React.forwardRef<
43+
React.ElementRef<typeof AccordionPrimitive.Content>,
44+
React.ComponentPropsWithoutRef<typeof AccordionPrimitive.Content>
45+
>(({ className, children, ...props }, ref) => (
46+
<AccordionPrimitive.Content
47+
ref={ref}
48+
className="overflow-hidden text-sm data-[state=closed]:animate-accordion-up data-[state=open]:animate-accordion-down"
49+
{...props}
50+
>
51+
<div className={cn("pb-4 pt-0", className)}>{children}</div>
52+
</AccordionPrimitive.Content>
53+
));
54+
AccordionContent.displayName = AccordionPrimitive.Content.displayName;
55+
56+
export { Accordion, AccordionContent, AccordionItem, AccordionTrigger };
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import * as React from "react";
2+
import { cn } from "./lib/utils";
3+
4+
const Card = React.forwardRef<
5+
HTMLDivElement,
6+
React.HTMLAttributes<HTMLDivElement>
7+
>(({ className, ...props }, ref) => (
8+
<div
9+
ref={ref}
10+
className={cn(
11+
"rounded-xl border border-neutral-200 bg-white text-neutral-950 shadow dark:border-neutral-800 dark:bg-neutral-950 dark:text-neutral-50",
12+
className,
13+
)}
14+
{...props}
15+
/>
16+
));
17+
Card.displayName = "Card";
18+
19+
const CardHeader = React.forwardRef<
20+
HTMLDivElement,
21+
React.HTMLAttributes<HTMLDivElement>
22+
>(({ className, ...props }, ref) => (
23+
<div
24+
ref={ref}
25+
className={cn("flex flex-col space-y-1.5 p-6", className)}
26+
{...props}
27+
/>
28+
));
29+
CardHeader.displayName = "CardHeader";
30+
31+
const CardTitle = React.forwardRef<
32+
HTMLDivElement,
33+
React.HTMLAttributes<HTMLDivElement>
34+
>(({ className, ...props }, ref) => (
35+
<div
36+
ref={ref}
37+
className={cn("font-semibold leading-none tracking-tight", className)}
38+
{...props}
39+
/>
40+
));
41+
CardTitle.displayName = "CardTitle";
42+
43+
const CardDescription = React.forwardRef<
44+
HTMLDivElement,
45+
React.HTMLAttributes<HTMLDivElement>
46+
>(({ className, ...props }, ref) => (
47+
<div
48+
ref={ref}
49+
className={cn("text-sm text-neutral-500 dark:text-neutral-400", className)}
50+
{...props}
51+
/>
52+
));
53+
CardDescription.displayName = "CardDescription";
54+
55+
const CardContent = React.forwardRef<
56+
HTMLDivElement,
57+
React.HTMLAttributes<HTMLDivElement>
58+
>(({ className, ...props }, ref) => (
59+
<div ref={ref} className={cn("p-6 pt-0", className)} {...props} />
60+
));
61+
CardContent.displayName = "CardContent";
62+
63+
const CardFooter = React.forwardRef<
64+
HTMLDivElement,
65+
React.HTMLAttributes<HTMLDivElement>
66+
>(({ className, ...props }, ref) => (
67+
<div
68+
ref={ref}
69+
className={cn("flex items-center p-6 pt-0", className)}
70+
{...props}
71+
/>
72+
));
73+
CardFooter.displayName = "CardFooter";
74+
75+
export {
76+
Card,
77+
CardContent,
78+
CardDescription,
79+
CardFooter,
80+
CardHeader,
81+
CardTitle,
82+
};

0 commit comments

Comments
 (0)