diff --git a/source/infrastructure/lib/knowledge-base/knowledge-base-stack.ts b/source/infrastructure/lib/knowledge-base/knowledge-base-stack.ts index d630cd94..5bcafac5 100644 --- a/source/infrastructure/lib/knowledge-base/knowledge-base-stack.ts +++ b/source/infrastructure/lib/knowledge-base/knowledge-base-stack.ts @@ -138,6 +138,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac private createKnowledgeBaseJob(props: any) { + const deployRegion = props.config.deployRegion; const connection = new glue.Connection(this, "GlueJobConnection", { type: glue.ConnectionType.NETWORK, subnet: props.sharedConstructOutputs.vpc.privateSubnets[0], @@ -159,10 +160,8 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac notificationLambda.addToRolePolicy(this.iamHelper.logStatement); notificationLambda.addToRolePolicy(this.dynamodbStatement); - // If this.region is cn-north-1 or cn-northwest-1, use the glue-job-script-cn.py - const glueJobScript = "glue-job-script.py"; - // Assemble the extra python files list using _S3Bucket.s3UrlForObject("llm_bot_dep-0.1.0-py3-none-any.whl") and _S3Bucket.s3UrlForObject("nougat_ocr-0.1.17-py3-none-any.whl") and convert to string + // Assemble the extra python files list using _S3Bucket.s3UrlForObject("llm_bot_dep-0.1.0-py3-none-any.whl") const extraPythonFilesList = [ this.glueLibS3Bucket.s3UrlForObject("llm_bot_dep-0.1.0-py3-none-any.whl"), ].join(","); @@ -202,36 +201,42 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac glueRole.addToPolicy(this.dynamodbStatement); glueRole.addToPolicy(this.iamHelper.dynamodbStatement); + const glueJobDefaultArguments: { [key: string]: string } = { + "--AOS_ENDPOINT": this.aosDomainEndpoint, + "--REGION": deployRegion, + "--ETL_MODEL_ENDPOINT": props.modelConstructOutputs.defaultKnowledgeBaseModelName, + "--RES_BUCKET": this.glueResultBucket.bucketName, + "--ETL_OBJECT_TABLE": this.etlObjTableName || "-", + "--PORTAL_BUCKET": this.uiPortalBucketName, + "--CHATBOT_TABLE": props.sharedConstructOutputs.chatbotTable.tableName, + "--additional-python-modules": + "langchain==0.3.7,beautifulsoup4==4.12.2,requests-aws4auth==1.2.3,boto3==1.35.98,openai==0.28.1,pyOpenSSL==23.3.0,tenacity==8.2.3,markdownify==0.11.6,mammoth==1.6.0,chardet==5.2.0,python-docx==1.1.0,pdfminer.six==20221105,smart-open==7.0.4,opensearch-py==2.2.0,lxml==5.2.2,pandas==2.1.2,openpyxl==3.1.5,xlrd==2.0.1,langchain_community==0.3.5,pillow==10.0.1,tiktoken==0.8.0", + // Add multiple extra python files + "--extra-py-files": extraPythonFilesList, + } + + // Set China-specific PyPI mirror for China regions + if (deployRegion === "cn-north-1" || deployRegion === "cn-northwest-1") { + glueJobDefaultArguments["--python-modules-installer-option"] = "-i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple"; + } + // Create glue job to process files specified in s3 bucket and prefix - const glueJob = new glue.Job(this, "PythonShellJob", { - executable: glue.JobExecutable.pythonShell({ - glueVersion: glue.GlueVersion.V3_0, - pythonVersion: glue.PythonVersion.THREE_NINE, + const glueJob = new glue.Job(this, "PythonEtlJob", { + executable: glue.JobExecutable.pythonEtl({ + glueVersion: glue.GlueVersion.V4_0, + pythonVersion: glue.PythonVersion.THREE, script: glue.Code.fromAsset( - join(__dirname, "../../../lambda/job", glueJobScript), + join(__dirname, "../../../lambda/job/glue-job-script.py"), ), }), // Worker Type is not supported for Job Command pythonshell and Both workerType and workerCount must be set - // workerType: glue.WorkerType.G_2X, - // workerCount: 2, + workerType: glue.WorkerType.G_1X, + workerCount: 2, maxConcurrentRuns: 200, maxRetries: 1, connections: [connection], - maxCapacity: 1, role: glueRole, - defaultArguments: { - "--AOS_ENDPOINT": this.aosDomainEndpoint, - "--REGION": process.env.CDK_DEFAULT_REGION || "-", - "--ETL_MODEL_ENDPOINT": props.modelConstructOutputs.defaultKnowledgeBaseModelName, - "--RES_BUCKET": this.glueResultBucket.bucketName, - "--ETL_OBJECT_TABLE": this.etlObjTableName || "-", - "--PORTAL_BUCKET": this.uiPortalBucketName, - "--CHATBOT_TABLE": props.sharedConstructOutputs.chatbotTable.tableName, - "--additional-python-modules": - "langchain==0.3.7,beautifulsoup4==4.12.2,requests-aws4auth==1.2.3,boto3==1.35.98,openai==0.28.1,pyOpenSSL==23.3.0,tenacity==8.2.3,markdownify==0.11.6,mammoth==1.6.0,chardet==5.2.0,python-docx==1.1.0,nltk==3.9.1,pdfminer.six==20221105,smart-open==7.0.4,opensearch-py==2.2.0,lxml==5.2.2,pandas==2.1.2,openpyxl==3.1.5,xlrd==2.0.1,langchain_community==0.3.5,pillow==10.0.1,tiktoken==0.8.0", - // Add multiple extra python files - "--extra-py-files": extraPythonFilesList - }, + defaultArguments: glueJobDefaultArguments, }); // Create SNS topic and subscription to notify when glue job is completed @@ -308,7 +313,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac "--TABLE_ITEM_ID.$": "$.tableItemId", "--QA_ENHANCEMENT.$": "$.qaEnhance", "--REGION": process.env.CDK_DEFAULT_REGION || "-", - "--BEDROCK_REGION": props.config.chat.bedrockRegion, + "--BEDROCK_REGION": props.config.chat.bedrockRegion || "-", "--MODEL_TABLE": props.sharedConstructOutputs.modelTable.tableName, "--RES_BUCKET": this.glueResultBucket.bucketName, "--S3_BUCKET.$": "$.s3Bucket", diff --git a/source/infrastructure/package.json b/source/infrastructure/package.json index dab6a416..93ff798f 100644 --- a/source/infrastructure/package.json +++ b/source/infrastructure/package.json @@ -6,7 +6,7 @@ "clobber": "npx projen clobber", "compile": "npx projen compile", "default": "npx projen default", - "deploy": "npx cdk deploy --all", + "deploy": "npx cdk deploy --all --require-approval never", "destroy": "npx projen destroy", "diff": "npx projen diff", "eject": "npx projen eject", diff --git a/source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl b/source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl index 588b7d59..265cddfd 100644 Binary files a/source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl and b/source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl differ diff --git a/source/lambda/job/dep/llm_bot_dep/enhance_utils.py b/source/lambda/job/dep/llm_bot_dep/enhance_utils.py index 303265dd..9f3a7d43 100644 --- a/source/lambda/job/dep/llm_bot_dep/enhance_utils.py +++ b/source/lambda/job/dep/llm_bot_dep/enhance_utils.py @@ -5,7 +5,6 @@ from typing import Dict, List import boto3 -import nltk import openai from langchain.docstore.document import Document @@ -127,7 +126,9 @@ def EnhanceWithClaude( answer = line qa_content = f"{question}\n{answer}" enhanced_prompt_list.append( - Document(page_content=qa_content, metadata=document.metadata) + Document( + page_content=qa_content, metadata=document.metadata + ) ) question = "" answer = "" @@ -137,7 +138,11 @@ def EnhanceWithClaude( return enhanced_prompt_list def EnhanceWithOpenAI( - self, prompt: str, solution_title: str, document: Document, zh: bool = True + self, + prompt: str, + solution_title: str, + document: Document, + zh: bool = True, ) -> List[Dict[str, str]]: """ Enhances a given prompt with additional information and performs a chat completion using OpenAI's GPT-3.5 Turbo model. @@ -165,7 +170,10 @@ def EnhanceWithOpenAI( # error and retry handling for openai api due to request cap limit try: response = openai.ChatCompletion.create( - model="gpt-3.5-turbo", messages=messages, temperature=0, max_tokens=2048 + model="gpt-3.5-turbo", + messages=messages, + temperature=0, + max_tokens=2048, ) except Exception as e: logger.error("OpenAI API request failed: {}".format(e)) @@ -208,7 +216,9 @@ def SplitDocumentByTokenNum( - List[Document]: A list of documents, each containing a slice of the original document. """ # Get the token number of input paragraph - tokens = nltk.word_tokenize(document.page_content) + # tokens = nltk.word_tokenize(document.page_content) + # TODO: Currently Disable tokenization for now + tokens = [] metadata = document.metadata if "content_type" in metadata: metadata["content_type"] = "qa" diff --git a/source/lambda/job/dep/llm_bot_dep/sm_utils.py b/source/lambda/job/dep/llm_bot_dep/sm_utils.py index 910b9eec..faba8d7b 100644 --- a/source/lambda/job/dep/llm_bot_dep/sm_utils.py +++ b/source/lambda/job/dep/llm_bot_dep/sm_utils.py @@ -11,6 +11,7 @@ BedrockEmbeddings, SagemakerEndpointEmbeddings, ) +from langchain_community.embeddings.openai import OpenAIEmbeddings from langchain_community.embeddings.sagemaker_endpoint import ( EmbeddingsContentHandler, ) @@ -21,14 +22,11 @@ from langchain_core.messages import BaseMessage from langchain_core.outputs import GenerationChunk from langchain_core.pydantic_v1 import Extra, root_validator -from langchain_community.embeddings.openai import OpenAIEmbeddings logger = logging.getLogger() logger.setLevel(logging.INFO) session = boto3.session.Session() -secret_manager_client = session.client( - service_name="secretsmanager" -) +secret_manager_client = session.client(service_name="secretsmanager") def get_model_details(group_name: str, chatbot_id: str, table_name: str): @@ -48,14 +46,13 @@ def get_model_details(group_name: str, chatbot_id: str, table_name: str): try: response = table.get_item( - Key={ - "groupName": group_name, - "modelId": model_id - } + Key={"groupName": group_name, "modelId": model_id} ) if "Item" not in response: - raise Exception(f"No model found for group {group_name} and model ID {model_id}") + raise Exception( + f"No model found for group {group_name} and model ID {model_id}" + ) return response["Item"] except Exception as e: @@ -512,6 +509,7 @@ def SagemakerEndpointVectorOrCross( ) return genericModel(prompt=prompt, stop=stop, **kwargs) + def getCustomEmbeddings( endpoint_name: str, region_name: str, @@ -519,7 +517,7 @@ def getCustomEmbeddings( model_type: str, group_name: str, chatbot_id: str, - model_table: str + model_table: str, ) -> SagemakerEndpointEmbeddings: embeddings = None model_details = get_model_details(group_name, chatbot_id, model_table) @@ -531,8 +529,10 @@ def getCustomEmbeddings( if model_provider not in ["Bedrock API", "OpenAI API"]: # Use local models client = boto3.client("sagemaker-runtime", region_name=region_name) - bedrock_client = boto3.client("bedrock-runtime", region_name=bedrock_region) if model_type == "bedrock": + bedrock_client = boto3.client( + "bedrock-runtime", region_name=bedrock_region + ) content_handler = BedrockEmbeddings() embeddings = BedrockEmbeddings( client=bedrock_client, @@ -567,15 +567,17 @@ def getCustomEmbeddings( embeddings = OpenAIEmbeddings( model=endpoint_name, api_key=get_secret_value(api_key_arn), - base_url=base_url + base_url=base_url, ) elif model_provider == "OpenAI API": embeddings = OpenAIEmbeddings( model=endpoint_name, api_key=get_secret_value(api_key_arn), - base_url=base_url + base_url=base_url, ) else: - raise ValueError(f"Unsupported API inference provider: {model_provider}") + raise ValueError( + f"Unsupported API inference provider: {model_provider}" + ) return embeddings diff --git a/source/lambda/job/dep/llm_bot_dep/splitter_utils.py b/source/lambda/job/dep/llm_bot_dep/splitter_utils.py index 9f9fb0e7..277ecfa7 100644 --- a/source/lambda/job/dep/llm_bot_dep/splitter_utils.py +++ b/source/lambda/job/dep/llm_bot_dep/splitter_utils.py @@ -16,63 +16,63 @@ logger.setLevel(logging.INFO) -def _make_spacy_pipeline_for_splitting(pipeline: str) -> Any: # avoid importing spacy - try: - import spacy - except ImportError: - raise ImportError("Spacy is not installed, please install it with `pip install spacy`.") - if pipeline == "sentencizer": - from spacy.lang.en import English +# def _make_spacy_pipeline_for_splitting(pipeline: str) -> Any: # avoid importing spacy +# try: +# import spacy +# except ImportError: +# raise ImportError("Spacy is not installed, please install it with `pip install spacy`.") +# if pipeline == "sentencizer": +# from spacy.lang.en import English - sentencizer = English() - sentencizer.add_pipe("sentencizer") - else: - sentencizer = spacy.load(pipeline, exclude=["ner", "tagger"]) - return sentencizer +# sentencizer = English() +# sentencizer.add_pipe("sentencizer") +# else: +# sentencizer = spacy.load(pipeline, exclude=["ner", "tagger"]) +# return sentencizer -class NLTKTextSplitter(TextSplitter): - """Splitting text using NLTK package.""" +# class NLTKTextSplitter(TextSplitter): +# """Splitting text using NLTK package.""" - def __init__(self, separator: str = "\n\n", language: str = "english", **kwargs: Any) -> None: - """Initialize the NLTK splitter.""" - super().__init__(**kwargs) - try: - from nltk.tokenize import sent_tokenize +# def __init__(self, separator: str = "\n\n", language: str = "english", **kwargs: Any) -> None: +# """Initialize the NLTK splitter.""" +# super().__init__(**kwargs) +# try: +# from nltk.tokenize import sent_tokenize - self._tokenizer = sent_tokenize - except ImportError: - raise ImportError("NLTK is not installed, please install it with `pip install nltk`.") - self._separator = separator - self._language = language +# self._tokenizer = sent_tokenize +# except ImportError: +# raise ImportError("NLTK is not installed, please install it with `pip install nltk`.") +# self._separator = separator +# self._language = language - def split_text(self, text: str) -> List[str]: - """Split incoming text and return chunks.""" - # First we naively split the large input into a bunch of smaller ones. - splits = self._tokenizer(text, language=self._language) - return self._merge_splits(splits, self._separator) +# def split_text(self, text: str) -> List[str]: +# """Split incoming text and return chunks.""" +# # First we naively split the large input into a bunch of smaller ones. +# splits = self._tokenizer(text, language=self._language) +# return self._merge_splits(splits, self._separator) -class SpacyTextSplitter(TextSplitter): - """Splitting text using Spacy package. +# class SpacyTextSplitter(TextSplitter): +# """Splitting text using Spacy package. - Per default, Spacy's `en_core_web_sm` model is used. For a faster, but - potentially less accurate splitting, you can use `pipeline='sentencizer'`. - """ +# Per default, Spacy's `en_core_web_sm` model is used. For a faster, but +# potentially less accurate splitting, you can use `pipeline='sentencizer'`. +# """ - def __init__( - self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any - ) -> None: - """Initialize the spacy text splitter.""" - super().__init__(**kwargs) - self._tokenizer = _make_spacy_pipeline_for_splitting(pipeline) - self._separator = separator +# def __init__( +# self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any +# ) -> None: +# """Initialize the spacy text splitter.""" +# super().__init__(**kwargs) +# self._tokenizer = _make_spacy_pipeline_for_splitting(pipeline) +# self._separator = separator - def split_text(self, text: str) -> List[str]: - """Split incoming text and return chunks.""" - splits = (s.text for s in self._tokenizer(text).sents) - return self._merge_splits(splits, self._separator) +# def split_text(self, text: str) -> List[str]: +# """Split incoming text and return chunks.""" +# splits = (s.text for s in self._tokenizer(text).sents) +# return self._merge_splits(splits, self._separator) def find_parent(headers: dict, level: int): @@ -130,7 +130,11 @@ def find_child(headers: dict, header_id: str): level = headers[header_id]["level"] for id, header in headers.items(): - if header["level"] == level + 1 and id not in children and header["parent"] == header_id: + if ( + header["level"] == level + 1 + and id not in children + and header["parent"] == header_id + ): children.append(id) return children @@ -181,7 +185,9 @@ def extract_headings(md_content: str): for header_obj in headers: headers[header_obj]["child"] = find_child(headers, header_obj) - headers[header_obj]["next"] = find_next_with_same_level(headers, header_obj) + headers[header_obj]["next"] = find_next_with_same_level( + headers, header_obj + ) return headers, id_index_dict @@ -232,7 +238,10 @@ def _set_chunk_id( else: # Move one step to get the next chunk_id same_heading_dict[current_heading] += 1 - if len(id_index_dict[current_heading]) > same_heading_dict[current_heading]: + if ( + len(id_index_dict[current_heading]) + > same_heading_dict[current_heading] + ): metadata["chunk_id"] = id_index_dict[current_heading][ same_heading_dict[current_heading] ] @@ -240,7 +249,9 @@ def _set_chunk_id( id_prefix = str(uuid.uuid4())[:8] metadata["chunk_id"] = f"$0-{id_prefix}" - def _get_current_heading_list(self, current_heading, current_heading_level_map): + def _get_current_heading_list( + self, current_heading, current_heading_level_map + ): try: title_symble_count = 0 for char in current_heading: @@ -263,9 +274,13 @@ def _get_current_heading_list(self, current_heading, current_heading_level_map): def split_text(self, text: Document) -> List[Document]: if self.res_bucket is not None: - save_content_to_s3(s3, text, self.res_bucket, SplittingType.BEFORE.value) + save_content_to_s3( + s3, text, self.res_bucket, SplittingType.BEFORE.value + ) else: - logger.warning("No resource bucket is defined, skip saving content into S3 bucket") + logger.warning( + "No resource bucket is defined, skip saving content into S3 bucket" + ) lines = text.page_content.strip().split("\n") chunks = [] @@ -275,7 +290,9 @@ def split_text(self, text: Document) -> List[Document]: inside_figure = False have_figure = False figure_metadata = [] - heading_hierarchy, id_index_dict = extract_headings(text.page_content.strip()) + heading_hierarchy, id_index_dict = extract_headings( + text.page_content.strip() + ) if len(lines) > 0: current_heading = lines[0] @@ -299,7 +316,10 @@ def split_text(self, text: Document) -> List[Document]: try: self._set_chunk_id( - id_index_dict, current_heading, metadata, same_heading_dict + id_index_dict, + current_heading, + metadata, + same_heading_dict, ) except KeyError: logger.info( @@ -308,7 +328,9 @@ def split_text(self, text: Document) -> List[Document]: id_prefix = str(uuid.uuid4())[:8] metadata["chunk_id"] = f"$0-{id_prefix}" if metadata["chunk_id"] in heading_hierarchy: - metadata["heading_hierarchy"] = heading_hierarchy[metadata["chunk_id"]] + metadata["heading_hierarchy"] = heading_hierarchy[ + metadata["chunk_id"] + ] page_content = "\n".join(current_chunk_content) metadata["complete_heading"] = current_heading_list if have_figure: @@ -340,9 +362,9 @@ def split_text(self, text: Document) -> List[Document]: figure_description = xml_node.find(FigureNode.DESCRIPTION.value) figure_value = xml_node.find(FigureNode.VALUE.value) figure_s3_link = xml_node.findtext(FigureNode.LINK.value) - chunk_figure_content = etree.tostring(figure_description, encoding="utf-8").decode( - "utf-8" - ) + chunk_figure_content = etree.tostring( + figure_description, encoding="utf-8" + ).decode("utf-8") if figure_value is not None: chunk_figure_content += "\n" + etree.tostring( figure_value, encoding="utf-8" @@ -370,13 +392,17 @@ def split_text(self, text: Document) -> List[Document]: ) current_heading = current_heading.replace("#", "").strip() try: - self._set_chunk_id(id_index_dict, current_heading, metadata, same_heading_dict) + self._set_chunk_id( + id_index_dict, current_heading, metadata, same_heading_dict + ) except KeyError: logger.info(f"No standard heading found") id_prefix = str(uuid.uuid4())[:8] metadata["chunk_id"] = f"$0-{id_prefix}" if metadata["chunk_id"] in heading_hierarchy: - metadata["heading_hierarchy"] = heading_hierarchy[metadata["chunk_id"]] + metadata["heading_hierarchy"] = heading_hierarchy[ + metadata["chunk_id"] + ] page_content = "\n".join(current_chunk_content) metadata["complete_heading"] = current_heading_list if have_figure: diff --git a/source/lambda/job/dep/setup.py b/source/lambda/job/dep/setup.py index 0de856a3..93f1e138 100644 --- a/source/lambda/job/dep/setup.py +++ b/source/lambda/job/dep/setup.py @@ -6,7 +6,10 @@ version="0.1.0", packages=find_packages(exclude=[]), package_data={ - "": ["*.txt", "*.json"], # include all .txt and .json files in any package + "": [ + "*.txt", + "*.json", + ], # include all .txt and .json files in any package # Or if you want to be more specific: # "your_package_name": ["data/*.txt", "config/*.json"] }, @@ -23,10 +26,9 @@ "mammoth==1.6.0", "chardet==5.2.0", "python-docx==1.1.0", - "nltk==3.9.1", "pdfminer.six==20221105", "smart-open==7.0.4", "pillow==10.0.1", - "tiktoken==0.8.0" + "tiktoken==0.8.0", ], ) diff --git a/source/lambda/job/glue-job-script.py b/source/lambda/job/glue-job-script.py index bf990650..c3d00f38 100644 --- a/source/lambda/job/glue-job-script.py +++ b/source/lambda/job/glue-job-script.py @@ -9,7 +9,6 @@ import boto3 import chardet -import nltk from langchain.docstore.document import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import OpenSearchVectorSearch @@ -23,73 +22,39 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) -try: - from awsglue.utils import getResolvedOptions - - args = getResolvedOptions( - sys.argv, - [ - "AOS_ENDPOINT", - "BATCH_FILE_NUMBER", - "BATCH_INDICE", - "DOCUMENT_LANGUAGE", - "EMBEDDING_MODEL_ENDPOINT", - "ETL_MODEL_ENDPOINT", - "JOB_NAME", - "OFFLINE", - "ETL_OBJECT_TABLE", - "TABLE_ITEM_ID", - "QA_ENHANCEMENT", - "REGION", - "RES_BUCKET", - "S3_BUCKET", - "S3_PREFIX", - "CHATBOT_ID", - "INDEX_ID", - "EMBEDDING_MODEL_TYPE", - "CHATBOT_TABLE", - "INDEX_TYPE", - "OPERATION_TYPE", - "PORTAL_BUCKET", - "BEDROCK_REGION", - "MODEL_TABLE", - "GROUP_NAME", - ], - ) -except Exception as e: - logger.warning("Running locally") - import argparse - - parser = argparse.ArgumentParser(description="local ingestion parameters") - parser.add_argument("--offline", type=bool, default=True) - parser.add_argument("--batch_indice", type=int, default=0) - parser.add_argument("--batch_file_number", type=int, default=1000) - parser.add_argument("--document_language", type=str, default="zh") - parser.add_argument("--embedding_model_endpoint", type=str, required=True) - parser.add_argument("--table_item_id", type=str, default="x") - parser.add_argument("--qa_enhancement", type=str, default=False) - parser.add_argument("--s3_bucket", type=str, required=True) - parser.add_argument("--s3_prefix", type=str, required=True) - parser.add_argument("--chatbot_id", type=str, required=True) - parser.add_argument("--index_id", type=str, required=True) - parser.add_argument("--embedding_model_type", type=str, required=True) - parser.add_argument("--index_type", type=str, required=True) - parser.add_argument("--operation_type", type=str, default="create") - command_line_args = parser.parse_args() - sys.path.append("dep") - command_line_args_dict = vars(command_line_args) - args = {} - for key in command_line_args_dict.keys(): - args[key.upper()] = command_line_args_dict[key] - args["AOS_ENDPOINT"] = os.environ["AOS_ENDPOINT"] - args["CHATBOT_TABLE"] = os.environ["CHATBOT_TABLE_NAME"] - args["ETL_OBJECT_TABLE"] = os.environ["ETL_OBJECT_TABLE_NAME"] - args["ETL_MODEL_ENDPOINT"] = os.environ["ETL_ENDPOINT"] - args["RES_BUCKET"] = os.environ["RES_BUCKET"] - args["REGION"] = os.environ["REGION"] - args["BEDROCK_REGION"] = os.environ["BEDROCK_REGION"] - args["CROSS_ACCOUNT_BEDROCK_KEY"] = os.environ["CROSS_ACCOUNT_BEDROCK_KEY"] - args["PORTAL_BUCKET"] = os.environ.get("PORTAL_BUCKET", None) + +from awsglue.utils import getResolvedOptions + +args = getResolvedOptions( + sys.argv, + [ + "AOS_ENDPOINT", + "BATCH_FILE_NUMBER", + "BATCH_INDICE", + "DOCUMENT_LANGUAGE", + "EMBEDDING_MODEL_ENDPOINT", + "ETL_MODEL_ENDPOINT", + "JOB_NAME", + "OFFLINE", + "ETL_OBJECT_TABLE", + "TABLE_ITEM_ID", + "QA_ENHANCEMENT", + "REGION", + "RES_BUCKET", + "S3_BUCKET", + "S3_PREFIX", + "CHATBOT_ID", + "INDEX_ID", + "EMBEDDING_MODEL_TYPE", + "CHATBOT_TABLE", + "INDEX_TYPE", + "OPERATION_TYPE", + "PORTAL_BUCKET", + "BEDROCK_REGION", + "MODEL_TABLE", + "GROUP_NAME", + ], +) from llm_bot_dep import sm_utils from llm_bot_dep.constant import SplittingType @@ -99,7 +64,6 @@ # Adaption to allow nougat to run in AWS Glue with writable /tmp os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" os.environ["NOUGAT_CHECKPOINT"] = "/tmp/nougat_checkpoint" -os.environ["NLTK_DATA"] = "/tmp/nltk_data" # Parse arguments if "BATCH_INDICE" not in args: @@ -145,8 +109,6 @@ credentials = boto3.Session().get_credentials() MAX_OS_DOCS_PER_PUT = 8 -nltk.data.path.append("/tmp/nltk_data") - def get_aws_auth(): try: @@ -598,7 +560,9 @@ def ingestion_pipeline( SplittingType.SEMANTIC.value, ) - gen_chunk_flag = False if file_type in ["csv", "xlsx", "xls"] else True + gen_chunk_flag = ( + False if file_type in ["csv", "xlsx", "xls"] else True + ) batches = batch_chunk_processor.batch_generator(res, gen_chunk_flag) for batch in batches: @@ -744,7 +708,7 @@ def main(): model_type=embedding_model_type, group_name=group_name, chatbot_id=chatbot_id, - model_table=model_table + model_table=model_table, ) aws_auth = get_aws_auth() docsearch = OpenSearchVectorSearch( @@ -789,12 +753,4 @@ def main(): if __name__ == "__main__": logger.info("boto3 version: %s", boto3.__version__) - # Set the NLTK data path to the /tmp directory for AWS Glue jobs - nltk.data.path.append("/tmp") - # List of NLTK packages to download - nltk_packages = ["words", "punkt"] - # Download the required NLTK packages to /tmp - for package in nltk_packages: - # Download the package to /tmp/nltk_data - nltk.download(package, download_dir="/tmp/nltk_data") main()