1
1
import json
2
2
import boto3
3
3
import os
4
- from requests_aws4auth import AWS4Auth
5
-
6
- from langchain_aws import BedrockLLM
7
- from langchain_community .chat_models import BedrockChat
8
- from langchain .memory import ConversationBufferMemory
9
- from langchain_core .prompts import PromptTemplate
10
- from langchain .chains import ConversationalRetrievalChain
11
- from langchain_community .embeddings import BedrockEmbeddings
12
- from langchain_community .chat_message_histories import DynamoDBChatMessageHistory
13
- from opensearchpy import OpenSearch , RequestsHttpConnection
14
- from langchain_community .vectorstores import OpenSearchVectorSearch
15
-
16
- if __name__ == "__main__" :
17
- main ()
18
-
19
- def main ():
20
-
21
- #prompt = event["prompt"]
22
- prompt = os .environ ['prompt' ]
23
- #bedrock_model_id = event["bedrock_model_id"]
24
- bedrock_model_id = os .environ ['bedrock_model_id' ]
25
- #model_kwargs = event["model_kwargs"]
26
- #model_kwargs = os.environ['model_kwargs']
27
- #metadata = event["metadata"]
28
- metadata = os .environ ['metadata' ]
29
- #memory_window = event["memory_window"]
30
- memory_window = os .environ ['memory_window' ]
31
- #session_id = event["session_id"]
32
- session_id = os .environ ['session_id' ]
33
- region = os .environ ['AWS_REGION' ]
34
-
35
- model_kwargs_json = {
36
- "temperature" : 1.0 ,
37
- "top_p" : 1.0 ,
38
- "top_k" : 500
39
- }
40
-
41
- model_kwargs = json .dumps (model_kwargs_json )
42
-
43
- if "temperature" in model_kwargs and (model_kwargs ["temperature" ] < 0 or model_kwargs ["temperature" ] > 1 ):
44
- return {
45
- 'statusCode' : 400 ,
46
- 'body' : "Invalid input. temperature value must be between 0 and 1."
47
- }
48
- if "top_p" in model_kwargs and (model_kwargs ["top_p" ] < 0 or model_kwargs ["top_p" ] > 1 ):
49
- return {
50
- 'statusCode' : 400 ,
51
- 'body' : "Invalid input. top_p value must be between 0 and 1."
52
- }
53
-
54
- # Check if top_k is between 0 and 1
55
- if "top_k" in model_kwargs and (model_kwargs ["top_k" ] < 0 or model_kwargs ["top_k" ] > 500 ):
56
- return {
57
- 'statusCode' : 400 ,
58
- 'body' : "Invalid input. top_k value must be between 0 and 500."
59
- }
60
-
61
- os_host = os .environ ['aoss_host' ]
62
- if not os_host :
63
- return {
64
- 'statusCode' : 400 ,
65
- 'body' : "Invalid input. os_host is empty."
4
+ import textwrap
5
+ from fastapi import FastAPI , HTTPException
6
+ from pydantic import BaseModel
7
+ from io import StringIO
8
+ import sys
9
+
10
+ # Initialize FastAPI app
11
+ app = FastAPI ()
12
+
13
+
14
+ # Boto3 client for AWS Bedrock
15
+ bedrock_agent_runtime = boto3 .client (service_name = "bedrock-agent-runtime" , region_name = 'us-east-1' )
16
+
17
+ class QueryRequest (BaseModel ):
18
+ prompt : str
19
+ kbId : str
20
+
21
+ def print_ww (* args , width : int = 100 , ** kwargs ):
22
+ """Wraps and prints output text"""
23
+ buffer = StringIO ()
24
+ _stdout = sys .stdout
25
+ try :
26
+ sys .stdout = buffer
27
+ print (* args , ** kwargs )
28
+ output = buffer .getvalue ()
29
+ finally :
30
+ sys .stdout = _stdout
31
+ for line in output .splitlines ():
32
+ print ("\n " .join (textwrap .wrap (line , width = width )))
33
+
34
+ def retrieve_and_generate (query : str , kb_id : str ):
35
+ """Retrieves response from AWS Bedrock"""
36
+ response = bedrock_agent_runtime .retrieve_and_generate (
37
+ input = {'text' : query },
38
+ retrieveAndGenerateConfiguration = {
39
+ 'type' : 'KNOWLEDGE_BASE' ,
40
+ 'knowledgeBaseConfiguration' : {
41
+ 'knowledgeBaseId' : 'PNMTFQRPDF' ,
42
+ 'modelArn' : 'arn:aws:bedrock:us-east-1::foundation-model/amazon.titan-text-premier-v1:0'
43
+ }
66
44
}
67
-
68
- #region = os.environ.get('AWS_REGION', 'us-east-1') # Default to us-east-1 if AWS_REGION is not set
69
-
70
- # TODO implement
71
- conversation = init_conversationchain (session_id , region , bedrock_model_id ,model_kwargs , metadata , memory_window , os_host )
72
- response = conversation ({"question" : prompt })
73
-
74
- generated_text = response ["answer" ]
75
- doc_url = json .loads ('[]' )
76
-
77
- if len (response ['source_documents' ]) != 0 :
78
- for doc in response ['source_documents' ]:
79
- doc_url .append (doc .metadata ['source' ])
80
- print (generated_text )
81
- print (doc_url )
82
-
83
- return {
84
- 'statusCode' : 200 ,
85
- 'body' : {"question" : prompt .strip (), "answer" : generated_text .strip (), "documents" : doc_url }
86
- }
87
-
88
-
89
- def init_conversationchain (session_id ,region , bedrock_model_id , model_kwargs , metadata , memory_window , host ) -> ConversationalRetrievalChain :
90
- bedrock_embedding_model_id = "amazon.titan-embed-text-v2:0"
91
-
92
- bedrock_client = boto3 .client (service_name = 'bedrock-runtime' , region_name = region )
93
- bedrock_embeddings = BedrockEmbeddings (model_id = bedrock_embedding_model_id ,
94
- client = bedrock_client )
95
-
96
- service = 'aoss'
97
- credentials = boto3 .Session ().get_credentials ()
98
- awsauth = AWS4Auth (credentials .access_key , credentials .secret_key ,
99
- region , service , session_token = credentials .token )
100
-
101
- new_db = OpenSearchVectorSearch (
102
- index_name = "fsxnragvector-index" ,
103
- embedding_function = bedrock_embeddings ,
104
- opensearch_url = f'{ host } :443' ,
105
- http_auth = awsauth ,
106
- use_ssl = True ,
107
- verify_certs = True ,
108
- connection_class = RequestsHttpConnection
109
- )
110
-
111
- prompt_template = """Human: This is a friendly conversation between a human and an AI.
112
- The AI is talkative and provides specific details from its context but limits it to 240 tokens.
113
- If the AI does not know the answer to a question, it truthfully says it
114
- does not know.
115
-
116
- Assistant: OK, got it, I'll be a talkative truthful AI assistant.
117
-
118
- Human: Here are a few documents in <documents> tags:
119
- <documents>
120
- {context}
121
- </documents>
122
- Based on the above documents, provide a detailed answer for, {question}
123
- Answer "don't know" if not present in the document.
124
-
125
- Assistant:
126
- """
127
-
128
- PROMPT = PromptTemplate (
129
- template = prompt_template , input_variables = ["question" , "context" ]
130
- )
131
-
132
- condense_qa_template = """{chat_history}
133
- Human:
134
- Given the previous conversation and a follow up question below, rephrase the follow up question
135
- to be a standalone question.
136
-
137
- Follow Up Question: {question}
138
- Standalone Question:
139
-
140
- Assistant:"""
141
- standalone_question_prompt = PromptTemplate .from_template (condense_qa_template )
142
-
143
- everyone_acl = 'S-1-1-0'
144
- if metadata == "NA" :
145
- retriever = new_db .as_retriever (search_kwargs = {"filter" : [{"term" : {"metadata.acl.allowed" : everyone_acl }}]})
146
- else :
147
- # retriever = new_db.as_retriever(search_kwargs={"filter": [{"term": {"metadata.year": metadata}}]})
148
- retriever = new_db .as_retriever (search_kwargs = {"filter" : [{"terms" : {"metadata.acl.allowed" : [everyone_acl ,metadata ]}}]})
149
-
150
- llm = BedrockChat (
151
- model_id = bedrock_model_id ,
152
- model_kwargs = model_kwargs ,
153
- streaming = True
154
- )
155
-
156
- msg_history = DynamoDBChatMessageHistory (table_name = 'SessionTable' , session_id = session_id , boto3_session = boto3 .Session (region_name = region ))
157
-
158
- memory = ConversationBufferMemory (
159
- memory_key = "chat_history" ,
160
- chat_memory = msg_history ,
161
- return_messages = True ,
162
- output_key = "answer" )
163
-
164
- conversation = ConversationalRetrievalChain .from_llm (
165
- llm = llm ,
166
- retriever = retriever ,
167
- condense_question_prompt = standalone_question_prompt ,
168
- return_source_documents = True ,
169
- verbose = True ,
170
- memory = memory ,
171
- combine_docs_chain_kwargs = {"prompt" :PROMPT },
172
45
)
173
-
174
- return conversation
46
+ return response ['output' ]['text' ]
47
+
48
+ @app .post ("/query" )
49
+ async def handle_query (request : QueryRequest ):
50
+ """API endpoint to process user queries"""
51
+ try :
52
+ response_text = retrieve_and_generate (request .prompt , request .kbId )
53
+ return {"response" : response_text }
54
+ except Exception as e :
55
+ raise HTTPException (status_code = 500 , detail = str (e ))
56
+
57
+ # Test changes
0 commit comments