forked from AntoniZap/IBM-Chatbot
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
287 lines (244 loc) · 9.36 KB
/
app.py
File metadata and controls
287 lines (244 loc) · 9.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import os.path
# Document loading and the link
from langchain_core.runnables import RunnablePassthrough
from langchain_community.retrievers import BM25Retriever
# ✨AI✨
from langchain.memory import ChatMessageHistory
from langchain_core.messages import AIMessage, SystemMessage
# Our own stuff
from csv_to_langchain import CSVLoader
from local import resolve
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, Iterator, List, Optional, Union
from dataclasses import dataclass, field
import datetime
from dotenv import load_dotenv, find_dotenv, dotenv_values
from agent.sql import AggregationRAG, LLMUnreliableException
from db import get_db
from llm import get_llm, get_raw_llm
@dataclass
class PendingInferenceComplete:
data: object
timestamp: str = field(default_factory=lambda: str(datetime.datetime.now()))
@dataclass
class PendingResponseChoice:
answers: Dict[str, str]
"""
A mapping from llm names → answers
"""
def query_chain(retriever):
return (lambda params: params["messages"][-1].content) | retriever
def set_state(next_state):
"""sets the next state of the LLM
Args:
next_state (PendingInferenceComplete or PendingResponseChoice
or None): the next state the LLM will enter
"""
global state
state = next_state
print(f"Transitioned to next state {state}")
def get_state():
"""Gets the current state of the LLM
Returns:
PendingInferenceComplete or PendingResponseChoice or None: the current state of the LLM
"""
global state
return state
@dataclass
class Closure:
llm: str
def __call__(self, part, whole):
"""A callable that takes an LLM update delta as the first argument
(e.g. an individual new word generated by chatGPT) and the full expression
after the update as the second argument and publishes the update to the UI via a websocket.
Args:
part (dict): LLM update delta
whole (dict): full expression
"""
print(f"got message from `{self.llm}`: {part}")
socketio.emit("socket", {
"llm": self.llm,
"answer": whole.get("answer", "")
})
def _get_data(messages, llm_choices, sql=False):
"""_get_data takes in the messages from the chat message history
and your choice of LLMs and returns the answer to the query and the
sources behind the answer
Args:
messages (array): the messages from chat message history from memory
llm_choices (String): ai21, chatGPT, llama
Returns:
list, database: the answer and sources about a specific query
"""
jobs = []
pool = ThreadPoolExecutor(4)
question = messages[-1].content
global filename
context = get_db(filename).as_retriever(k=1).invoke(question)
context_source = RunnablePassthrough.assign(context=lambda _: context)
for llm_choice in llm_choices:
llm = get_raw_llm(llm_choice)
chain = context_source | get_llm(llm_choice)
print(f"Submitting task for `{llm_choice}`")
job = pool.submit(infer, messages, llm, chain, Closure(llm_choice), sql=sql)
jobs.append((llm_choice, job))
answers = []
for llm, job in jobs:
answer = job.result()
answers.append({ **answer, "llm": llm })
return answers, context
def infer(messages, llm, chain, callback, sql=False):
""" Uses the last element in the messages array as a question and
generates a response using the provided LLM chain and the review database.
The chain should accept a 'messages' argument into which the messages will be inserted.
If possible, LLM output will be streamed and for each new update the callback will be invoked
using the new tokens and the full output at the point of the update.
See the Closure class for an example.
Args:
messages (array): the messages from chat message history
chain (dict): the source for the context for the answers
callback (function): calls the Closure class with llm_choice
Returns:
dict: a response to the provided query
"""
if sql:
print("Running pre-inference step")
try:
agg = AggregationRAG(llm, verbose=True, notify_cb=lambda event: callback({}, { "answer" : event }))
result = agg.answer(messages[-1])
if result is not None:
full = { **result.__dict__, "type" : "tabular" }
return full
else:
print("got empty result, but things were otherwise okay")
except LLMUnreliableException as e:
print(f"LLM not reliable: {e}")
else:
print("SQL Aggregation tool disabled")
print("Starting inference for LLM")
payload = {
"messages": [
*messages,
SystemMessage(content="No tabular output could be generated. Use the sources provided to answer the question. Note that these sources are limited.")
]
}
full = None
for item in chain.stream(payload):
if full is None:
full = item
else:
full += item
if callback is not None:
callback(item, full)
return { "type" : "regular", "answer" : full["answer"] }
set_state(None)
memory = ChatMessageHistory()
def refresh_environment_vars():
"""Loads the environment variables from the .env file into the user's environment variables
"""
for env_key, env_value in dotenv_values('.env').items():
os.environ[env_key] = env_value
if __name__ == "__main__":
from flask import Flask, request, jsonify
from flask_cors import CORS
from flask_socketio import SocketIO
from os import listdir
app = Flask(__name__)
socketio = SocketIO(app, cors_allowed_origins="*")
CORS(app)
path_to_env = find_dotenv()
if path_to_env:
load_dotenv(path_to_env)
else:
open('.env', 'w')
load_dotenv('.env')
refresh_environment_vars()
global filename
filename = "_Datafiniti_Amazon_Consumer_Reviews_of_Amazon_Products.csv"
@app.route('/message', methods=['POST'])
def get_data():
"""get_data returns the answers and sources for a query
Returns:
Response: A json response of answers and corresponding sources in json format
"""
global state
if state is not None:
return jsonify(
error="Already running inference",
state=str(state)
), 400
data = request.json
llm_choices = data.get('llms') or []
if len(llm_choices) == 0:
return jsonify(answers={})
message = data.get('message')
sql = data.get('sql')
set_state(PendingInferenceComplete(data=data))
memory.add_user_message(message)
answers, sources = _get_data(memory.messages, llm_choices, sql=sql)
set_state(PendingResponseChoice(answers={ answer["llm"]: answer for answer in answers }))
return jsonify({
"answers": answers,
"sources" : [ { "pageContent" : source.page_content,
"title" : source.metadata["title"],
"rating" : source.metadata["rating"],
"productName" : source.metadata["name"] }
for source in sources ]
})
@app.route('/selectAnswer', methods=['POST'])
def select_response():
"""Adds the selected answer to memory
Returns:
Response: a json response object
"""
data = request.json
global state
if type(state) is not PendingResponseChoice:
return jsonify(400)
chosen_answer = data.get('llm')
global memory
chosen_answer = state.answers[chosen_answer.lower()]
if chosen_answer["type"] == "tabular":
memory.add_ai_message("[This question was answered in a tabular format, which was presented in the UI]")
else:
memory.add_ai_message(chosen_answer["answer"])
set_state(None)
return jsonify(200)
@app.route('/setFile', methods=['POST'])
def set_file():
data = request.json
global filename
filename = data.get('file')
print('Updated filename to: "' + filename + '"')
return jsonify(200)
@app.route('/files', methods=['POST'])
def get_files():
files = []
for file in listdir("."):
if file.endswith(".csv"):
print(file)
files.append(file)
return jsonify({"files" : files})
@app.route('/config', methods=['POST'])
def config_llm():
"""configures the llms by reading and writing into the .env file
Returns:
Response: a json response object
"""
data = request.json
global state
llm = data.get('llm')
llm_key = data.get('llm_key')
with open('.env', 'r') as read_dotenv:
lines = read_dotenv.readlines()
lines_to_write = [f'{llm}={llm_key}']
with open('.env', 'w') as write_dotenv:
for line in lines:
if not (line.startswith(llm) or line.isspace()):
lines_to_write.append(f'\n{line}')
write_dotenv.writelines(lines_to_write)
load_dotenv(path_to_env)
refresh_environment_vars()
set_state(None)
return jsonify(200)
app.run(port=5000)