-
Notifications
You must be signed in to change notification settings - Fork 3
/
chatbot.py
150 lines (116 loc) · 4.24 KB
/
chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import datetime
from dotenv import load_dotenv
import sys
import streamlit as st
from llama_index.llms import OpenAI as LlamaOpenAI
import openai
from system_prompt import system_prompt
from llama_index.tools import QueryEngineTool, ToolMetadata
from llama_index.agent import OpenAIAgent
from llama_index import SimpleDirectoryReader
from llama_index import Document
sys.path.append("utils")
# importing utils
from sentence_window_retrieval import build_sentence_window_index, get_sentence_window_query_engine
from automerging_retrieval import build_automerging_index, get_automerging_query_engine
from trulens_recorder import load_trulens, get_tru
load_dotenv()
st.set_page_config(
page_title="Chat with Sample Agent",
page_icon="",
layout="centered",
initial_sidebar_state="auto",
menu_items=None,
)
# Load the OpenAI key from the environment variable
openai_key = os.getenv("OPENAI_API_KEY")
if not openai_key:
st.error("No OpenAI key found. Please set the OPENAI_API_KEY environment variable.")
openai.api_key = openai_key
llm = LlamaOpenAI(model="gpt-4", temperature=0.1, system_prompt=system_prompt)
@st.cache_data
def load_data():
print("loading documents")
documents = SimpleDirectoryReader(
input_files=["./data/eBook-How-to-Build-a-Career-in-AI.pdf"]
).load_data()
document = Document(text="\n\n".join([doc.text for doc in documents]))
return document, documents
document, documents = load_data()
# import advanced RAG techniques
@st.cache_resource
def load_sentence_retrieval():
print("loading sentence retrieval")
sentence_index = build_sentence_window_index(
document,
llm,
embed_model="local:BAAI/bge-small-en-v1.5",
save_dir="sentence_index"
)
sentence_window_engine = get_sentence_window_query_engine(sentence_index)
app_id = "Sentence Retrieval"
return sentence_window_engine, app_id
@st.cache_resource
def load_automerging_retrieval():
print("loading automerging retrieval")
automerging_index = build_automerging_index(
documents,
llm,
embed_model="local:BAAI/bge-small-en-v1.5",
save_dir="merging_index"
)
automerging_engine = get_automerging_query_engine(automerging_index)
app_id = "Automerging Retrieval"
return automerging_engine, app_id
# Pick which retrieval method to use
query_engine, app_id = load_automerging_retrieval() #load_sentence_retrieval()
# Load the trulens recorder and object for dashboard
tru_recorder = load_trulens(query_engine, app_id)
tru = get_tru()
tools = [
QueryEngineTool(
query_engine=query_engine,
metadata=ToolMetadata(
name="query_engine_tool",
description="Query the supplied documents",
),
),
]
agent = OpenAIAgent.from_tools(
llm=llm,
tools=tools,
system_prompt=system_prompt,
)
# CHAT
st.title("Sample Agent")
if "messages" not in st.session_state.keys(): # Initialize the chat messages history
st.session_state.messages = [
{
"role": "assistant",
"content": "Hello there. How can I help you today?",
},
]
if st.button("Launch Dashboard"):
tru.run_dashboard()
if "chat_engine" not in st.session_state.keys(): # Initialize the chat engine
st.session_state.chat_engine: OpenAIAgent = agent
if prompt := st.chat_input(" "):
st.session_state.messages.append({"role": "user", "content": prompt})
for message in st.session_state.messages: # Display the prior chat messages
with st.chat_message(message["role"]):
st.write(message["content"])
# If last message is not from assistant, generate a new response
if st.session_state.messages[-1]["role"] != "assistant":
with st.spinner("The agent is thinking..."):
with tru_recorder as recording:
# Use query_engine to process the prompt
vector_response = st.session_state.chat_engine(prompt)
with st.chat_message("assistant"):
response_string = vector_response.response
# Display the full response in the chat
st.markdown(response_string)
# Append the response to the session state messages
st.session_state.messages.append(
{"role": "assistant", "content": response_string}
)