Skip to content

Commit cdfb782

Browse files
authored
Improve and Fix: response parser, invalid response and others (stitionai#522)
* chore: minor updates * Add: send live inference time to frontend * add: timeout in settings * Improve: response parsing, add temperature to models * patches close stitionai#510, stitionai#507, stitionai#502, stitionai#468
1 parent 75df1c6 commit cdfb782

22 files changed

+245
-180
lines changed

app.dockerfile

Lines changed: 0 additions & 29 deletions
This file was deleted.

devika.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,7 @@ def real_time_logs():
186186
@route_logger(logger)
187187
def set_settings():
188188
data = request.json
189-
print("Data: ", data)
190-
config.config.update(data)
191-
config.save_config()
189+
config.update_config(data)
192190
return jsonify({"message": "Settings updated"})
193191

194192

sample.config.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,6 @@ OPENAI = "https://api.openai.com/v1"
2626
[LOGGING]
2727
LOG_REST_API = "true"
2828
LOG_PROMPTS = "false"
29+
30+
[TIMEOUT]
31+
INFERENCE = 60

src/agents/action/action.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from jinja2 import Environment, BaseLoader
44

5-
from src.services.utils import retry_wrapper
5+
from src.services.utils import retry_wrapper, validate_responses
66
from src.config import Config
77
from src.llm import LLM
88

@@ -24,17 +24,8 @@ def render(
2424
conversation=conversation
2525
)
2626

27+
@validate_responses
2728
def validate_response(self, response: str):
28-
response = response.strip().replace("```json", "```")
29-
30-
if response.startswith("```") and response.endswith("```"):
31-
response = response[3:-3].strip()
32-
33-
try:
34-
response = json.loads(response)
35-
except Exception as _:
36-
return False
37-
3829
if "response" not in response and "action" not in response:
3930
return False
4031
else:

src/agents/answer/answer.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from jinja2 import Environment, BaseLoader
44

5-
from src.services.utils import retry_wrapper
5+
from src.services.utils import retry_wrapper, validate_responses
66
from src.config import Config
77
from src.llm import LLM
88

@@ -25,17 +25,8 @@ def render(
2525
code_markdown=code_markdown
2626
)
2727

28+
@validate_responses
2829
def validate_response(self, response: str):
29-
response = response.strip().replace("```json", "```")
30-
31-
if response.startswith("```") and response.endswith("```"):
32-
response = response[3:-3].strip()
33-
34-
try:
35-
response = json.loads(response)
36-
except Exception as _:
37-
return False
38-
3930
if "response" not in response:
4031
return False
4132
else:

src/agents/decision/decision.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from jinja2 import Environment, BaseLoader
44

5-
from src.services.utils import retry_wrapper
5+
from src.services.utils import retry_wrapper, validate_responses
66
from src.llm import LLM
77

88
PROMPT = open("src/agents/decision/prompt.jinja2").read().strip()
@@ -16,17 +16,8 @@ def render(self, prompt: str) -> str:
1616
template = env.from_string(PROMPT)
1717
return template.render(prompt=prompt)
1818

19+
@validate_responses
1920
def validate_response(self, response: str):
20-
response = response.strip().replace("```json", "```")
21-
22-
if response.startswith("```") and response.endswith("```"):
23-
response = response[3:-3].strip()
24-
25-
try:
26-
response = json.loads(response)
27-
except Exception as _:
28-
return False
29-
3021
for item in response:
3122
if "function" not in item or "args" not in item or "reply" not in item:
3223
return False

src/agents/internal_monologue/internal_monologue.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from jinja2 import Environment, BaseLoader
44

55
from src.llm import LLM
6-
from src.services.utils import retry_wrapper
6+
from src.services.utils import retry_wrapper, validate_responses
77

88
PROMPT = open("src/agents/internal_monologue/prompt.jinja2").read().strip()
99

@@ -16,19 +16,10 @@ def render(self, current_prompt: str) -> str:
1616
template = env.from_string(PROMPT)
1717
return template.render(current_prompt=current_prompt)
1818

19+
@validate_responses
1920
def validate_response(self, response: str):
20-
response = response.strip().replace("```json", "```")
21-
22-
if response.startswith("```") and response.endswith("```"):
23-
response = response[3:-3].strip()
24-
25-
try:
26-
response = json.loads(response)
27-
except Exception as _:
28-
return False
29-
30-
response = {k.replace("\\", ""): v for k, v in response.items()}
31-
21+
print('-------------------> ', response)
22+
print("####", type(response))
3223
if "internal_monologue" not in response:
3324
return False
3425
else:

src/agents/researcher/prompt.jinja2

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@ Only respond in the following JSON format:
1111

1212
```
1313
{
14-
"queries": [
15-
"<QUERY 1>",
16-
"<QUERY 2>"
17-
],
18-
"ask_user": "<ASK INPUT FROM USER>"
14+
"queries": ["<QUERY 1>", "<QUERY 2>", "<QUERY 3>", ... ],
15+
"ask_user": "<ASK INPUT FROM USER IF REQUIRED, OTHERWISE LEAVE EMPTY STRING>"
16+
}
17+
```
18+
Example =>
19+
```
20+
{
21+
"queries": ["How to do Bing Search via API in Python", "Claude API Documentation Python"],
22+
"ask_user": "Can you please provide API Keys for Claude, OpenAI, and Firebase?"
1923
}
2024
```
2125

2226
Keywords for Search Query: {{ contextual_keywords }}
2327

24-
Example "queries": ["How to do Bing Search via API in Python", "Claude API Documentation Python"]
25-
Example "ask_user": "Can you please provide API Keys for Claude, OpenAI, and Firebase?"
2628

2729
Rules:
2830
- Only search for a maximum of 3 queries.
@@ -33,13 +35,6 @@ Rules:
3335
- Do not search for basic queries, only search for advanced and specific queries. You are allowed to leave the "queries" field empty if no search queries are needed for the step.
3436
- DO NOT EVER SEARCH FOR BASIC QUERIES. ONLY SEARCH FOR ADVANCED QUERIES.
3537
- YOU ARE ALLOWED TO LEAVE THE "queries" FIELD EMPTY IF NO SEARCH QUERIES ARE NEEDED FOR THE STEP.
36-
37-
Remember to only make search queries for resources that might require external information (like Documentation or a Blog or an Article). If the information is already known to you or commonly known, there is no need to search for it.
38-
39-
The `queries` key and the `ask_user` key can be empty list and string respectively if no search queries or user input are needed for the step. Try to keep the number of search queries to a minimum to save context window. One query per subject.
40-
41-
Only search for documentation or articles that are relevant to the task at hand. Do not search for general information.
42-
43-
Try to include contextual keywords into your search queries, adding relevant keywords and phrases to make the search queries as specific as possible.
38+
- you only have to return one JSON object with the queries and ask_user fields. You can't return multiple JSON objects.
4439

4540
Only the provided JSON response format is accepted. Any other response format will be rejected.

src/agents/researcher/researcher.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from jinja2 import Environment, BaseLoader
55

66
from src.llm import LLM
7-
from src.services.utils import retry_wrapper
7+
from src.services.utils import retry_wrapper, validate_responses
88
from src.browser.search import BingSearch
99

1010
PROMPT = open("src/agents/researcher/prompt.jinja2").read().strip()
@@ -23,17 +23,8 @@ def render(self, step_by_step_plan: str, contextual_keywords: str) -> str:
2323
contextual_keywords=contextual_keywords
2424
)
2525

26+
@validate_responses
2627
def validate_response(self, response: str) -> dict | bool:
27-
response = response.strip().replace("```json", "```")
28-
29-
if response.startswith("```") and response.endswith("```"):
30-
response = response[3:-3].strip()
31-
try:
32-
response = json.loads(response)
33-
except Exception as _:
34-
return False
35-
36-
response = {k.replace("\\", ""): v for k, v in response.items()}
3728

3829
if "queries" not in response and "ask_user" not in response:
3930
return False

src/agents/runner/runner.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from src.llm import LLM
1111
from src.state import AgentState
1212
from src.project import ProjectManager
13-
from src.services.utils import retry_wrapper
13+
from src.services.utils import retry_wrapper, validate_responses
1414

1515
PROMPT = open("src/agents/runner/prompt.jinja2", "r").read().strip()
1616
RERUNNER_PROMPT = open("src/agents/runner/rerunner.jinja2", "r").read().strip()
@@ -52,37 +52,15 @@ def render_rerunner(
5252
error=error
5353
)
5454

55+
@validate_responses
5556
def validate_response(self, response: str):
56-
response = response.strip().replace("```json", "```")
57-
58-
if response.startswith("```") and response.endswith("```"):
59-
response = response[3:-3].strip()
60-
61-
try:
62-
response = json.loads(response)
63-
except Exception as _:
64-
return False
65-
6657
if "commands" not in response:
6758
return False
6859
else:
6960
return response["commands"]
70-
61+
62+
@validate_responses
7163
def validate_rerunner_response(self, response: str):
72-
response = response.strip().replace("```json", "```")
73-
74-
if response.startswith("```") and response.endswith("```"):
75-
response = response[3:-3].strip()
76-
77-
print(response)
78-
79-
try:
80-
response = json.loads(response)
81-
except Exception as _:
82-
return False
83-
84-
print(response)
85-
8664
if "action" not in response and "response" not in response:
8765
return False
8866
else:

src/config.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def get_logging_rest_api(self):
104104

105105
def get_logging_prompts(self):
106106
return self.config["LOGGING"]["LOG_PROMPTS"] == "true"
107+
108+
def get_timeout_inference(self):
109+
return self.config["TIMEOUT"]["INFERENCE"]
107110

108111
def set_bing_api_key(self, key):
109112
self.config["API_KEYS"]["BING"] = key
@@ -157,30 +160,6 @@ def set_netlify_api_key(self, key):
157160
self.config["API_KEYS"]["NETLIFY"] = key
158161
self.save_config()
159162

160-
def set_sqlite_db(self, db):
161-
self.config["STORAGE"]["SQLITE_DB"] = db
162-
self.save_config()
163-
164-
def set_screenshots_dir(self, dir):
165-
self.config["STORAGE"]["SCREENSHOTS_DIR"] = dir
166-
self.save_config()
167-
168-
def set_pdfs_dir(self, dir):
169-
self.config["STORAGE"]["PDFS_DIR"] = dir
170-
self.save_config()
171-
172-
def set_projects_dir(self, dir):
173-
self.config["STORAGE"]["PROJECTS_DIR"] = dir
174-
self.save_config()
175-
176-
def set_logs_dir(self, dir):
177-
self.config["STORAGE"]["LOGS_DIR"] = dir
178-
self.save_config()
179-
180-
def set_repos_dir(self, dir):
181-
self.config["STORAGE"]["REPOS_DIR"] = dir
182-
self.save_config()
183-
184163
def set_logging_rest_api(self, value):
185164
self.config["LOGGING"]["LOG_REST_API"] = "true" if value else "false"
186165
self.save_config()
@@ -189,6 +168,21 @@ def set_logging_prompts(self, value):
189168
self.config["LOGGING"]["LOG_PROMPTS"] = "true" if value else "false"
190169
self.save_config()
191170

171+
def set_timeout_inference(self, value):
172+
self.config["TIMEOUT"]["INFERENCE"] = value
173+
self.save_config()
174+
192175
def save_config(self):
193176
with open("config.toml", "w") as f:
194177
toml.dump(self.config, f)
178+
179+
def update_config(self, data):
180+
for key, value in data.items():
181+
if key in self.config:
182+
with open("config.toml", "r+") as f:
183+
config = toml.load(f)
184+
for sub_key, sub_value in value.items():
185+
self.config[key][sub_key] = sub_value
186+
config[key][sub_key] = sub_value
187+
f.seek(0)
188+
toml.dump(config, f)

src/llm/claude_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def inference(self, model_id: str, prompt: str) -> str:
2020
}
2121
],
2222
model=model_id,
23+
temperature=0
2324
)
2425

2526
return message.content[0].text

src/llm/gemini_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ def __init__(self):
1010
genai.configure(api_key=api_key)
1111

1212
def inference(self, model_id: str, prompt: str) -> str:
13-
model = genai.GenerativeModel(model_id)
13+
config = genai.GenerationConfig(temperature=0)
14+
model = genai.GenerativeModel(model_id, generation_config=config)
1415
# Set safety settings for the request
1516
safety_settings = {
1617
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,

src/llm/groq_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def inference(self, model_id: str, prompt: str) -> str:
1818
}
1919
],
2020
model=model_id,
21+
temperature=0
2122
)
2223

2324
return chat_completion.choices[0].message.content

0 commit comments

Comments
 (0)