Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Add other bedrock models #432

Merged
merged 33 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f786e76
add langchain_integration folder;support lazy load model and chains
Oct 17, 2024
0e9cc92
modify
Oct 18, 2024
7d34b82
modify
Oct 18, 2024
b68322d
Merge branch 'dev' of https://github.com/aws-samples/Intelli-Agent in…
Oct 18, 2024
52cdcc8
refactor: unified tool to adapt to langchian's tool
Oct 24, 2024
d62f3a9
refactor: modify module import
Oct 24, 2024
e3f063e
refactor: add llm chain tool_calling_api and it's prompt template
Oct 24, 2024
a4c99df
complete the tool refactor, next to test
Oct 29, 2024
541c3e7
fix: one rag tool should respect to one index; refactor: add register…
Oct 31, 2024
ffb1f43
move langchain_integration into common_logic
Nov 2, 2024
4401f38
modify agent prompt; add python_repl tool; adapt to pydantic v2
Nov 2, 2024
e6e3066
remove lambda invoke in intention stage
Nov 3, 2024
6577e26
move retrievers to common_logic
Nov 3, 2024
d9b6851
move functions to __functions
Nov 5, 2024
cdb3e15
add lambda tool test
Nov 6, 2024
e526644
remove functions layer
Nov 7, 2024
bcc3742
fix bug in streaming
Nov 7, 2024
de41e80
add new model in ui
Nov 7, 2024
c94f7b1
modify logger, fix bug about inaccurate filename output
Nov 7, 2024
406a72f
remove llm_generate_utils
Nov 7, 2024
2442434
add CLAUDE_3_5_SONNET_V2 and CLAUDE_3_5_HAIKU models
Nov 7, 2024
882e17a
modify online requirements
Nov 7, 2024
57f5fdc
add enable_prefill parameter;optimize prompt
Nov 7, 2024
ec997ba
modify PythonREPL, fix bug running on lambda
Nov 8, 2024
6a45433
add llama-3.2
Nov 12, 2024
5171b31
merge from dev
Nov 12, 2024
5b622d7
modify agent prompt; add new intent logic
Nov 12, 2024
0e89072
debug intention logic
Nov 12, 2024
0686dc3
add sso example
Nov 12, 2024
75f194c
modify .viperlightignore
Nov 13, 2024
8f8df04
remove __functions __llm_generate_utils
Nov 13, 2024
7aa5eec
modify according to the pr comments
Nov 13, 2024
6976c5a
modify glue job requirements
Nov 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .viperlightignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ api_test/test_data/*
api_test/gen-report-lambda.py
source/portal/src/utils/const.ts
source/lambda/online/lambda_main/test/main_local_test_retail.py
source/lambda/online/lambda_main/test/main_local_test_common.py
source/lambda/online/functions/retail_tools/lambda_product_information_search/product_information_search.py
source/lambda/job/test/prepare_data.py
README.md
Expand Down
40 changes: 20 additions & 20 deletions source/infrastructure/lib/chat/chat-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export class ChatStack extends NestedStack implements ChatStackOutputs {
private lambdaOnlineAgent: Function;
private lambdaOnlineLLMGenerate: Function;
private chatbotTableName: string;
private lambdaOnlineFunctions: Function;
// private lambdaOnlineFunctions: Function;

constructor(scope: Construct, id: string, props: ChatStackProps) {
super(scope, id);
Expand Down Expand Up @@ -282,23 +282,23 @@ export class ChatStack extends NestedStack implements ChatStackOutputs {
this.lambdaOnlineLLMGenerate.addToRolePolicy(this.iamHelper.dynamodbStatement);


const lambdaOnlineFunctions = new LambdaFunction(this, "lambdaOnlineFunctions", {
runtime: Runtime.PYTHON_3_12,
handler: "lambda_tools.lambda_handler",
code: Code.fromAsset(
join(__dirname, "../../../lambda/online/functions/functions_utils"),
),
memorySize: 4096,
vpc: vpc,
securityGroups: securityGroups,
layers: [apiLambdaOnlineSourceLayer, apiLambdaJobSourceLayer],
environment: {
CHATBOT_TABLE: props.sharedConstructOutputs.chatbotTable.tableName,
INDEX_TABLE: this.indexTableName,
MODEL_TABLE: this.modelTableName,
},
});
this.lambdaOnlineFunctions = lambdaOnlineFunctions.function;
// const lambdaOnlineFunctions = new LambdaFunction(this, "lambdaOnlineFunctions", {
// runtime: Runtime.PYTHON_3_12,
// handler: "lambda_tools.lambda_handler",
// code: Code.fromAsset(
// join(__dirname, "../../../lambda/online/functions/functions_utils"),
// ),
// memorySize: 4096,
// vpc: vpc,
// securityGroups: securityGroups,
// layers: [apiLambdaOnlineSourceLayer, apiLambdaJobSourceLayer],
// environment: {
// CHATBOT_TABLE: props.sharedConstructOutputs.chatbotTable.tableName,
// INDEX_TABLE: this.indexTableName,
// MODEL_TABLE: this.modelTableName,
// },
// });
// this.lambdaOnlineFunctions = lambdaOnlineFunctions.function;

this.lambdaOnlineQueryPreprocess.grantInvoke(this.lambdaOnlineMain);

Expand All @@ -310,8 +310,8 @@ export class ChatStack extends NestedStack implements ChatStackOutputs {
this.lambdaOnlineLLMGenerate.grantInvoke(this.lambdaOnlineQueryPreprocess);
this.lambdaOnlineLLMGenerate.grantInvoke(this.lambdaOnlineAgent);

this.lambdaOnlineFunctions.grantInvoke(this.lambdaOnlineMain);
this.lambdaOnlineFunctions.grantInvoke(this.lambdaOnlineIntentionDetection);
// this.lambdaOnlineFunctions.grantInvoke(this.lambdaOnlineMain);
// this.lambdaOnlineFunctions.grantInvoke(this.lambdaOnlineIntentionDetection);

if (props.config.chat.amazonConnect.enabled) {
new ConnectConstruct(this, "connect-construct", {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac
"--PORTAL_BUCKET": this.uiPortalBucketName,
"--CHATBOT_TABLE": props.sharedConstructOutputs.chatbotTable.tableName,
"--additional-python-modules":
"langchain==0.1.11,beautifulsoup4==4.12.2,requests-aws4auth==1.2.3,boto3==1.28.84,openai==0.28.1,pyOpenSSL==23.3.0,tenacity==8.2.3,markdownify==0.11.6,mammoth==1.6.0,chardet==5.2.0,python-docx==1.1.0,nltk==3.8.1,pdfminer.six==20221105,smart-open==7.0.4,lxml==5.2.2,pandas==2.1.2,openpyxl==3.1.5,xlrd==2.0.1",
"langchain==0.1.11,beautifulsoup4==4.12.2,requests-aws4auth==1.2.3,boto3==1.28.84,openai==0.28.1,pyOpenSSL==23.3.0,tenacity==8.2.3,markdownify==0.11.6,mammoth==1.6.0,chardet==5.2.0,python-docx==1.1.0,nltk==3.8.1,pdfminer.six==20221105,smart-open==7.0.4,lxml==5.2.2,pandas==2.1.2,openpyxl==3.1.5,xlrd==2.0.1,langchain_community==0.3.5",
// Add multiple extra python files
"--extra-py-files": extraPythonFilesList
},
Expand Down
15 changes: 8 additions & 7 deletions source/lambda/job/dep/llm_bot_dep/sm_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
import io
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint
from langchain.embeddings import SagemakerEndpointEmbeddings, BedrockEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain_community.llms import SagemakerEndpoint
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does glue job's requirements.txt need to be update as well?
langchain -> langchain_community

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, need to add "langchain_community" to it's requirements.txt

from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
from langchain_community.embeddings import SagemakerEndpointEmbeddings,BedrockEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.utils import enforce_stop_tokens
from langchain_community.llms.utils import enforce_stop_tokens
from typing import Dict, List, Optional, Any,Iterator
from langchain_core.outputs import GenerationChunk
import boto3
Expand Down Expand Up @@ -234,12 +235,12 @@ def transform_output(self, output: bytes) -> str:
function. See `boto3`_. docs for more info.
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>
"""
content_type = "application/json"
accepts = "application/json"
content_type: str = "application/json"
accepts: str = "application/json"
class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
extra = Extra.forbid.value

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand Down
3 changes: 2 additions & 1 deletion source/lambda/job/dep/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ opensearch-py==2.6.0
lxml==5.2.2
pandas==2.1.2
openpyxl==3.1.5
xlrd==2.0.1
xlrd==2.0.1
langchain_community==0.3.5
Binary file removed source/lambda/online/common_entry_agent_workflow.png
Binary file not shown.
Binary file modified source/lambda/online/common_entry_workflow.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 14 additions & 5 deletions source/lambda/online/common_logic/common_utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,19 @@ class LLMTaskType(ConstantBase):
HYDE_TYPE = "hyde"
CONVERSATION_SUMMARY_TYPE = "conversation_summary"
RETAIL_CONVERSATION_SUMMARY_TYPE = "retail_conversation_summary"

MKT_CONVERSATION_SUMMARY_TYPE = "mkt_conversation_summary"
MKT_QUERY_REWRITE_TYPE = "mkt_query_rewrite"
STEPBACK_PROMPTING_TYPE = "stepback_prompting"
TOOL_CALLING = "tool_calling"
TOOL_CALLING_XML = "tool_calling_xml"
TOOL_CALLING_API = "tool_calling_api"
RETAIL_TOOL_CALLING = "retail_tool_calling"
RAG = "rag"
MTK_RAG = "mkt_rag"
CHAT = 'chat'
AUTO_EVALUATION = "auto_evaluation"



class MessageType(ConstantBase):
HUMAN_MESSAGE_TYPE = 'human'
AI_MESSAGE_TYPE = 'ai'
Expand Down Expand Up @@ -126,19 +128,26 @@ class LLMModelType(ConstantBase):
CLAUDE_2 = "anthropic.claude-v2"
CLAUDE_21 = "anthropic.claude-v2:1"
CLAUDE_3_HAIKU = "anthropic.claude-3-haiku-20240307-v1:0"
CLAUDE_3_5_HAIKU = "anthropic.claude-3-5-haiku-20241022-v1:0"
CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"
CLAUDE_3_5_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0"
CLAUDE_3_5_SONNET_V2 = "anthropic.claude-3-5-sonnet-20241022-v2:0"
MIXTRAL_8X7B_INSTRUCT = "mistral.mixtral-8x7b-instruct-v0:1"
BAICHUAN2_13B_CHAT = "Baichuan2-13B-Chat-4bits"
INTERNLM2_CHAT_7B = "internlm2-chat-7b"
INTERNLM2_CHAT_20B = "internlm2-chat-20b"
GLM_4_9B_CHAT = "glm-4-9b-chat"
CHATGPT_35_TURBO = "gpt-3.5-turbo-0125"
CHATGPT_35_TURBO_0125 = "gpt-3.5-turbo-0125"
CHATGPT_4_TURBO = "gpt-4-turbo"
CHATGPT_4O = "gpt-4o"
QWEN2INSTRUCT7B = "qwen2-7B-instruct"
QWEN2INSTRUCT72B = "qwen2-72B-instruct"
QWEN15INSTRUCT32B = "qwen1_5-32B-instruct"
LLAMA3_1_70B_INSTRUCT = "meta.llama3-1-70b-instruct-v1:0"
LLAMA3_2_90B_INSTRUCT = "us.meta.llama3-2-90b-instruct-v1:0"
MISTRAL_LARGE_2407 = "mistral.mistral-large-2407-v1:0"
COHERE_COMMAND_R_PLUS = "cohere.command-r-plus-v1:0"



class EmbeddingModelType(ConstantBase):
Expand Down Expand Up @@ -170,13 +179,13 @@ class IndexTag(Enum):

@unique
class KBType(Enum):
AOS = "aos"

AOS = "aos"

GUIDE_INTENTION_NOT_FOUND = "Intention not found, please add intentions first when using agent mode, refer to https://amzn-chn.feishu.cn/docx/HlxvduJYgoOz8CxITxXc43XWn8e"
INDEX_DESC = "Answer question based on search result"


class Threshold(ConstantBase):
QQ_IN_RAG_CONTEXT = 0.5
INTENTION_ALL_KNOWLEDGE_RETRIEVAL = 0.4

Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@
import importlib
import json
import time
import os
from typing import Any, Dict, Optional, Callable, Union
import threading

import requests
from common_logic.common_utils.constant import StreamMessageType
from common_logic.common_utils.logger_utils import get_logger
from common_logic.common_utils.websocket_utils import is_websocket_request, send_to_ws_client
from langchain.pydantic_v1 import BaseModel, Field, root_validator
from pydantic import BaseModel, Field, model_validator


from .exceptions import LambdaInvokeError

logger = get_logger("lambda_invoke_utils")

# thread_local = threading.local()
thread_local = threading.local()
CURRENT_STATE = None

__FUNC_NAME_MAP = {
"query_preprocess": "Preprocess for Multi-round Conversation",
Expand All @@ -26,6 +31,38 @@
"llm_direct_results_generation": "LLM Response"
}


class StateContext:

def __init__(self,state):
self.state=state

@classmethod
def get_current_state(cls):
# print("thread id",threading.get_ident(),'parent id',threading.)
# state = getattr(thread_local,'state',None)
state = CURRENT_STATE
assert state is not None,"There is not a valid state in current context"
return state

@classmethod
def set_current_state(cls, state):
global CURRENT_STATE
assert CURRENT_STATE is None, "Parallel node executions are not alowed"
CURRENT_STATE = state

@classmethod
def clear_state(cls):
global CURRENT_STATE
CURRENT_STATE = None

def __enter__(self):
self.set_current_state(self.state)

def __exit__(self, exc_type, exc_val, exc_tb):
self.clear_state()


class LAMBDA_INVOKE_MODE(enum.Enum):
LAMBDA = "lambda"
LOCAL = "local"
Expand Down Expand Up @@ -55,26 +92,24 @@ class LambdaInvoker(BaseModel):
region_name: str = None
credentials_profile_name: Optional[str] = Field(default=None, exclude=True)

@root_validator()
@model_validator(mode="before")
def validate_environment(cls, values: Dict):
if values.get("client") is not None:
return values
try:
import boto3

try:
if values["credentials_profile_name"] is not None:
if values.get("credentials_profile_name") is not None:
session = boto3.Session(
profile_name=values["credentials_profile_name"]
)
else:
# use default credentials
session = boto3.Session()

values["client"] = session.client(
"lambda", region_name=values["region_name"]
"lambda",
region_name=values.get("region_name",os.environ['AWS_REGION'])
)

except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
Expand All @@ -97,8 +132,9 @@ def invoke_with_lambda(self, lambda_name: str, event_body: dict):
)
response_body = invoke_response["Payload"]
response_str = response_body.read().decode()

response_body = json.loads(response_str)
if "body" in response_body:
response_body = json.loads(response_body['body'])

if "errorType" in response_body:
error = (
Expand All @@ -108,7 +144,6 @@ def invoke_with_lambda(self, lambda_name: str, event_body: dict):
+ f"{response_body['errorType']}: {response_body['errorMessage']}"
)
raise LambdaInvokeError(error)

return response_body

def invoke_with_local(
Expand Down Expand Up @@ -285,7 +320,10 @@ def wrapper(state: Dict[str, Any]) -> Dict[str, Any]:
current_stream_use, ws_connection_id, enable_trace)
state['trace_infos'].append(
f"Enter: {func.__name__}, time: {time.time()}")
output = func(state)

with StateContext(state):
output = func(state)

current_monitor_infos = output.get(monitor_key, None)
if current_monitor_infos is not None:
send_trace(f"\n\n {current_monitor_infos}",
Expand Down
30 changes: 17 additions & 13 deletions source/lambda/online/common_logic/common_utils/logger_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import logging
import threading
import os
Expand All @@ -10,15 +9,13 @@
logger_lock = threading.Lock()


def cloud_print_wrapper(fn):
@wraps(fn)
def _inner(msg, *args, **kwargs):
class CloudStreamHandler(logging.StreamHandler):
def emit(self, record):
from common_logic.common_utils.lambda_invoke_utils import is_running_local
if not is_running_local:
# enable multiline as one message in cloudwatch
msg = msg.replace("\n", "\r")
return fn(msg, *args, **kwargs)
return _inner
record.msg = record.msg.replace("\n", "\r")
return super().emit(record)


class Logger:
Expand All @@ -37,16 +34,11 @@ def _get_logger(
logger = logging.getLogger(name)
logger.propagate = 0
# Create a handler
c_handler = logging.StreamHandler()
c_handler = CloudStreamHandler()
formatter = logging.Formatter(format, datefmt=datefmt)
c_handler.setFormatter(formatter)
logger.addHandler(c_handler)
logger.setLevel(level)
logger.info = cloud_print_wrapper(logger.info)
logger.error = cloud_print_wrapper(logger.error)
logger.warning = cloud_print_wrapper(logger.warning)
logger.critical = cloud_print_wrapper(logger.critical)
logger.debug = cloud_print_wrapper(logger.debug)
cls.logger_map[name] = logger
return logger

Expand All @@ -72,3 +64,15 @@ def print_llm_messages(msg, logger=logger):
"ENABLE_PRINT_MESSAGES", 'True').lower() in ('true', '1', 't')
if enable_print_messages:
logger.info(msg)


def llm_messages_print_decorator(fn):
@wraps(fn)
def _inner(*args, **kwargs):
if args:
print_llm_messages(args)
if kwargs:
print_llm_messages(kwargs)
return fn(*args, **kwargs)
return _inner

Loading
Loading