Skip to content

Commit 9258762

Browse files
committed
refactor code
1 parent 7d61b8f commit 9258762

File tree

3 files changed

+60
-177
lines changed

3 files changed

+60
-177
lines changed

eksbedrock/bedrockrag/Dockerfile

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ RUN apt-get update && apt-get install -y \
99

1010
COPY ./ ./
1111

12-
RUN pip3 install -r requirements.txt
12+
RUN pip3 install --no-cache-dir -r requirements.txt && \
13+
pip3 install --upgrade boto3 botocore
1314

1415
EXPOSE 8080
1516

16-
ENTRYPOINT ["python3", "bedrockrag.py"]
17+
CMD ["uvicorn", "bedrockrag:app", "--host", "0.0.0.0", "--port", "8080"]
18+

eksbedrock/bedrockrag/bedrockrag.py

+52-169
Original file line numberDiff line numberDiff line change
@@ -1,174 +1,57 @@
11
import json
22
import boto3
33
import os
4-
from requests_aws4auth import AWS4Auth
5-
6-
from langchain_aws import BedrockLLM
7-
from langchain_community.chat_models import BedrockChat
8-
from langchain.memory import ConversationBufferMemory
9-
from langchain_core.prompts import PromptTemplate
10-
from langchain.chains import ConversationalRetrievalChain
11-
from langchain_community.embeddings import BedrockEmbeddings
12-
from langchain_community.chat_message_histories import DynamoDBChatMessageHistory
13-
from opensearchpy import OpenSearch, RequestsHttpConnection
14-
from langchain_community.vectorstores import OpenSearchVectorSearch
15-
16-
if __name__ == "__main__":
17-
main()
18-
19-
def main():
20-
21-
#prompt = event["prompt"]
22-
prompt = os.environ['prompt']
23-
#bedrock_model_id = event["bedrock_model_id"]
24-
bedrock_model_id = os.environ['bedrock_model_id']
25-
#model_kwargs = event["model_kwargs"]
26-
#model_kwargs = os.environ['model_kwargs']
27-
#metadata = event["metadata"]
28-
metadata = os.environ['metadata']
29-
#memory_window = event["memory_window"]
30-
memory_window = os.environ['memory_window']
31-
#session_id = event["session_id"]
32-
session_id = os.environ['session_id']
33-
region = os.environ['AWS_REGION']
34-
35-
model_kwargs_json = {
36-
"temperature": 1.0,
37-
"top_p": 1.0,
38-
"top_k": 500
39-
}
40-
41-
model_kwargs = json.dumps(model_kwargs_json)
42-
43-
if "temperature" in model_kwargs and (model_kwargs["temperature"] < 0 or model_kwargs["temperature"] > 1):
44-
return {
45-
'statusCode': 400,
46-
'body': "Invalid input. temperature value must be between 0 and 1."
47-
}
48-
if "top_p" in model_kwargs and (model_kwargs["top_p"] < 0 or model_kwargs["top_p"] > 1):
49-
return {
50-
'statusCode': 400,
51-
'body': "Invalid input. top_p value must be between 0 and 1."
52-
}
53-
54-
# Check if top_k is between 0 and 1
55-
if "top_k" in model_kwargs and (model_kwargs["top_k"] < 0 or model_kwargs["top_k"] > 500):
56-
return {
57-
'statusCode': 400,
58-
'body': "Invalid input. top_k value must be between 0 and 500."
59-
}
60-
61-
os_host = os.environ['aoss_host']
62-
if not os_host:
63-
return {
64-
'statusCode': 400,
65-
'body': "Invalid input. os_host is empty."
4+
import textwrap
5+
from fastapi import FastAPI, HTTPException
6+
from pydantic import BaseModel
7+
from io import StringIO
8+
import sys
9+
10+
# Initialize FastAPI app
11+
app = FastAPI()
12+
13+
14+
# Boto3 client for AWS Bedrock
15+
bedrock_agent_runtime = boto3.client(service_name="bedrock-agent-runtime", region_name='us-east-1')
16+
17+
class QueryRequest(BaseModel):
18+
prompt: str
19+
kbId: str
20+
21+
def print_ww(*args, width: int = 100, **kwargs):
22+
"""Wraps and prints output text"""
23+
buffer = StringIO()
24+
_stdout = sys.stdout
25+
try:
26+
sys.stdout = buffer
27+
print(*args, **kwargs)
28+
output = buffer.getvalue()
29+
finally:
30+
sys.stdout = _stdout
31+
for line in output.splitlines():
32+
print("\n".join(textwrap.wrap(line, width=width)))
33+
34+
def retrieve_and_generate(query: str, kb_id: str):
35+
"""Retrieves response from AWS Bedrock"""
36+
response = bedrock_agent_runtime.retrieve_and_generate(
37+
input={'text': query},
38+
retrieveAndGenerateConfiguration={
39+
'type': 'KNOWLEDGE_BASE',
40+
'knowledgeBaseConfiguration': {
41+
'knowledgeBaseId': 'PNMTFQRPDF',
42+
'modelArn': 'arn:aws:bedrock:us-east-1::foundation-model/amazon.titan-text-premier-v1:0'
43+
}
6644
}
67-
68-
#region = os.environ.get('AWS_REGION', 'us-east-1') # Default to us-east-1 if AWS_REGION is not set
69-
70-
# TODO implement
71-
conversation = init_conversationchain(session_id, region, bedrock_model_id,model_kwargs, metadata, memory_window, os_host)
72-
response = conversation({"question": prompt})
73-
74-
generated_text = response["answer"]
75-
doc_url = json.loads('[]')
76-
77-
if len(response['source_documents']) != 0:
78-
for doc in response['source_documents']:
79-
doc_url.append(doc.metadata['source'])
80-
print(generated_text)
81-
print(doc_url)
82-
83-
return {
84-
'statusCode': 200,
85-
'body': {"question": prompt.strip(), "answer": generated_text.strip(), "documents": doc_url}
86-
}
87-
88-
89-
def init_conversationchain(session_id,region, bedrock_model_id, model_kwargs, metadata, memory_window, host) -> ConversationalRetrievalChain:
90-
bedrock_embedding_model_id = "amazon.titan-embed-text-v2:0"
91-
92-
bedrock_client = boto3.client(service_name='bedrock-runtime', region_name=region)
93-
bedrock_embeddings = BedrockEmbeddings(model_id=bedrock_embedding_model_id,
94-
client=bedrock_client)
95-
96-
service = 'aoss'
97-
credentials = boto3.Session().get_credentials()
98-
awsauth = AWS4Auth(credentials.access_key, credentials.secret_key,
99-
region, service, session_token=credentials.token)
100-
101-
new_db = OpenSearchVectorSearch(
102-
index_name="fsxnragvector-index",
103-
embedding_function=bedrock_embeddings,
104-
opensearch_url=f'{host}:443',
105-
http_auth=awsauth,
106-
use_ssl=True,
107-
verify_certs=True,
108-
connection_class=RequestsHttpConnection
109-
)
110-
111-
prompt_template = """Human: This is a friendly conversation between a human and an AI.
112-
The AI is talkative and provides specific details from its context but limits it to 240 tokens.
113-
If the AI does not know the answer to a question, it truthfully says it
114-
does not know.
115-
116-
Assistant: OK, got it, I'll be a talkative truthful AI assistant.
117-
118-
Human: Here are a few documents in <documents> tags:
119-
<documents>
120-
{context}
121-
</documents>
122-
Based on the above documents, provide a detailed answer for, {question}
123-
Answer "don't know" if not present in the document.
124-
125-
Assistant:
126-
"""
127-
128-
PROMPT = PromptTemplate(
129-
template=prompt_template, input_variables=["question", "context"]
130-
)
131-
132-
condense_qa_template = """{chat_history}
133-
Human:
134-
Given the previous conversation and a follow up question below, rephrase the follow up question
135-
to be a standalone question.
136-
137-
Follow Up Question: {question}
138-
Standalone Question:
139-
140-
Assistant:"""
141-
standalone_question_prompt = PromptTemplate.from_template(condense_qa_template)
142-
143-
everyone_acl = 'S-1-1-0'
144-
if metadata == "NA":
145-
retriever = new_db.as_retriever(search_kwargs={"filter": [{"term": {"metadata.acl.allowed": everyone_acl}}]})
146-
else:
147-
# retriever = new_db.as_retriever(search_kwargs={"filter": [{"term": {"metadata.year": metadata}}]})
148-
retriever = new_db.as_retriever(search_kwargs={"filter": [{"terms": {"metadata.acl.allowed": [everyone_acl,metadata]}}]})
149-
150-
llm = BedrockChat(
151-
model_id=bedrock_model_id,
152-
model_kwargs=model_kwargs,
153-
streaming=True
154-
)
155-
156-
msg_history = DynamoDBChatMessageHistory(table_name='SessionTable', session_id=session_id, boto3_session=boto3.Session(region_name=region))
157-
158-
memory = ConversationBufferMemory(
159-
memory_key="chat_history",
160-
chat_memory=msg_history,
161-
return_messages=True,
162-
output_key="answer")
163-
164-
conversation = ConversationalRetrievalChain.from_llm(
165-
llm=llm,
166-
retriever=retriever,
167-
condense_question_prompt=standalone_question_prompt,
168-
return_source_documents=True,
169-
verbose=True,
170-
memory=memory,
171-
combine_docs_chain_kwargs={"prompt":PROMPT},
17245
)
173-
174-
return conversation
46+
return response['output']['text']
47+
48+
@app.post("/query")
49+
async def handle_query(request: QueryRequest):
50+
"""API endpoint to process user queries"""
51+
try:
52+
response_text = retrieve_and_generate(request.prompt, request.kbId)
53+
return {"response": response_text}
54+
except Exception as e:
55+
raise HTTPException(status_code=500, detail=str(e))
56+
57+
# Test changes
+4-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
boto3
2-
requests_aws4auth
3-
langchain_aws
4-
langchain
5-
langchain_community
6-
opensearch-py
1+
fastapi==0.100.0 # Web framework for the API
2+
uvicorn==0.22.0 # ASGI server to run FastAPI
3+
pydantic==1.10.0 # Data validation and serialization
4+
requests==2.31.0 # HTTP client (if needed for API calls)

0 commit comments

Comments
 (0)