diff --git a/source/model/etl/code/aikits_utils.py b/source/model/etl/code/aikits_utils.py deleted file mode 100644 index 5ba1f5dc2..000000000 --- a/source/model/etl/code/aikits_utils.py +++ /dev/null @@ -1,50 +0,0 @@ -from io import BytesIO -import boto3 -import base64 -import numpy as np -from PIL import Image -import cv2 -try: - import urllib.request as urllib2 - from urllib.parse import urlparse -except ImportError: - import urllib2 - from urlparse import urlparse - -def readimg(body, keys=None): - if keys is None: - keys = body.keys() - inputs = dict() - for key in keys: - try: - if key.startswith('url'): # url形式 - if body[key].startswith('http'): # http url - image_string = urllib2.urlopen(body[key]).read() - elif body[key].startswith('s3'): # s3 key - o = urlparse(body[key]) - bucket = o.netloc - path = o.path.lstrip('/') - s3 = boto3.resource('s3') - img_obj = s3.Object(bucket, path) - image_string = img_obj.get()['Body'].read() - else: - raise - elif key.startswith('img'): # base64形式 - image_string = base64.b64decode(body[key]) - else: - raise - inputs[key] = np.array(Image.open(BytesIO(image_string)).convert('RGB'))[:, :, :3] - except: - inputs[key] = None - return inputs - -def lambda_return(statusCode, body): - return { - 'statusCode': statusCode, - 'headers': { - 'Access-Control-Allow-Headers': '*', - 'Access-Control-Allow-Origin': '*', - 'Access-Control-Allow-Methods': '*' - }, - 'body': body - } \ No newline at end of file diff --git a/source/model/etl/code/config.py b/source/model/etl/code/config.py new file mode 100644 index 000000000..d155ac9b9 --- /dev/null +++ b/source/model/etl/code/config.py @@ -0,0 +1,73 @@ +class ModelConfig: + """模型配置管理类""" + + LANG_CONFIGS = { + 'ch': { + 'det': { + 'model': 'det_cn.onnx', + 'postprocess': { + 'name': 'DBPostProcess', + 'thresh': 0.1, + 'box_thresh': 0.1, + 'max_candidates': 1000, + 'unclip_ratio': 1.5, + 'use_dilation': False, + 'score_mode': 'fast', + 'box_type': 'quad' + } + }, + 'rec': { + 'model': 'rec_ch.onnx', + 'dict_path': 'ppocr_keys_v1.txt', + 'postprocess': { + 'name': 'CTCLabelDecode', + 'character_type': 'ch', + 'use_space_char': True + } + } + }, + 'en': { + 'det': { + 'model': 'det_en.onnx', + 'postprocess': { + 'name': 'DBPostProcess', + 'thresh': 0.1, + 'box_thresh': 0.1, + 'max_candidates': 1000, + 'unclip_ratio': 1.5, + 'use_dilation': False, + 'score_mode': 'fast', + 'box_type': 'quad' + } + }, + 'rec': { + 'model': 'rec_en.onnx', + 'dict_path': 'en_dict.txt', + 'postprocess': { + 'name': 'CTCLabelDecode', + 'use_space_char': True + } + } + } + } + + @classmethod + def get_model_path(cls, lang, model_type): + """获取模型路径""" + return os.path.join( + os.environ['MODEL_PATH'], + cls.LANG_CONFIGS[lang][model_type]['model'] + ) + + @classmethod + def get_dict_path(cls, lang): + """获取字典路径""" + return os.path.join( + os.environ['MODEL_PATH'], + cls.LANG_CONFIGS[lang]['rec']['dict_path'] + ) + + @classmethod + def get_postprocess_config(cls, lang, model_type): + """获取后处理配置""" + return cls.LANG_CONFIGS[lang][model_type]['postprocess'] \ No newline at end of file diff --git a/source/model/etl/code/figure_llm.py b/source/model/etl/code/figure_llm.py index 078b21640..ec117aaba 100644 --- a/source/model/etl/code/figure_llm.py +++ b/source/model/etl/code/figure_llm.py @@ -5,15 +5,73 @@ import io import base64 import json +import os +import openai + +# Add logger configuration +logger = logging.getLogger(__name__) class figureUnderstand(): - def __init__(self): - self.bedrock_runtime = boto3.client(service_name='bedrock-runtime') + def __init__(self, model_provider="bedrock", api_secret_name=None): + self.model_provider = model_provider + if model_provider == "bedrock": + self.bedrock_runtime = boto3.client(service_name='bedrock-runtime') + elif model_provider == "openai": + self.openai_api_key = self._get_api_key(api_secret_name) + if not self.openai_api_key: + raise ValueError("Failed to retrieve OpenAI API key from Secrets Manager") + openai.api_key = self.openai_api_key + # Set OpenAI base URL from environment variable if provided + base_url = os.environ.get("OPENAI_API_BASE") + if base_url: + openai.base_url = base_url + else: + raise ValueError("Unsupported model provider. Choose 'bedrock' or 'openai'") + self.mermaid_prompt = json.load(open('prompt/mermaid.json', 'r')) - def invoke_llm(self, img, prompt, prefix="", stop=""): + + def _get_api_key(self, api_secret_name): + """ + Get the API key from AWS Secrets Manager. + Args: + api_secret_name (str): The name of the secret in AWS Secrets Manager containing the API key. + Returns: + str: The API key. + """ + if not api_secret_name: + raise ValueError("api_secret_name must be provided when using OpenAI") + + try: + secrets_client = boto3.client("secretsmanager") + secret_response = secrets_client.get_secret_value( + SecretId=api_secret_name + ) + if "SecretString" in secret_response: + secret_data = json.loads(secret_response["SecretString"]) + api_key = secret_data.get("api_key") + logger.info( + f"Successfully retrieved API key from secret: {api_secret_name}" + ) + return api_key + except Exception as e: + logger.error(f"Error retrieving secret {api_secret_name}: {str(e)}") + raise + return None + + def _image_to_base64(self, img): + """Convert PIL Image to base64 encoded string""" image_stream = io.BytesIO() img.save(image_stream, format="JPEG") - base64_encoded = base64.b64encode(image_stream.getvalue()).decode('utf-8') + return base64.b64encode(image_stream.getvalue()).decode('utf-8') + + def invoke_llm(self, img, prompt, prefix="", stop=""): + if self.model_provider == "bedrock": + return self._invoke_bedrock(img, prompt, prefix, stop) + elif self.model_provider == "openai": + return self._invoke_openai(img, prompt, prefix, stop) + + def _invoke_bedrock(self, img, prompt, prefix="", stop=""): + base64_encoded = self._image_to_base64(img) messages = [ { "role": "user", @@ -34,7 +92,7 @@ def invoke_llm(self, img, prompt, prefix="", stop=""): }, {"role": "assistant", "content": prefix}, ] - model_id = "anthropic.claude-3-sonnet-20240229-v1:0" + model_id = "anthropic.claude-3-5-sonnet-20241022-v2:0" body = json.dumps({ "anthropic_version": "bedrock-2023-05-31", "max_tokens": 4096, @@ -46,51 +104,67 @@ def invoke_llm(self, img, prompt, prefix="", stop=""): response_body = json.loads(response.get('body').read()) result = prefix + response_body['content'][0]['text'] + stop return result + + def _invoke_openai(self, img, prompt, prefix="", stop=""): + base64_encoded = self._image_to_base64(img) + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_encoded}" + } + } + ] + }, + { + "role": "assistant", + "content": prefix + } + ] + + response = openai.chat.completions.create( + model="gpt-4-vision-preview", + messages=messages, + max_tokens=4096, + stop=[stop] if stop else None + ) + + result = prefix + response.choices[0].message.content + stop + return result + def get_classification(self, img): with open('prompt/figure_classification.txt') as f: figure_classification_prompt = f.read() output = self.invoke_llm(img, figure_classification_prompt) return output + def get_chart(self, img, context, tag): - prompt = '''您是文档阅读专家。您的任务是将图片中的图表转换成Markdown格式。以下是说明: -1. 找到图片中的图表。 -2. 仔细观察图表,了解其中包含的结构和数据。 -3. 使用标签中的上下文信息来帮助你更好地理解和描述这张图表。上下文中的{tag}就是指该图表。 -4. 按照以下指南将图表数据转换成 Markdown 表格格式: - - 使用 | 字符分隔列 - - 使用 --- 行表示标题行 - - 确保表格格式正确并对齐 - - 对于不确定的数字,请根据图片估算。 -5. 仔细检查您的 Markdown 表格是否准确地反映了图表图像中的数据。 -6. 在 xml 标签中仅返回 Markdown,不含其他文本。 - - -{context} - -请将你的描述写在xml标签之间。 -'''.strip() - output = self.invoke_llm(img, prompt) + with open('prompt/chart.txt') as f: + chart_prompt = f.read() + output = self.invoke_llm(img, chart_prompt.format(context=context, tag=tag)) return output - def get_description(self, img, context, tag): - prompt = ''' -你是一位资深的图像分析专家。你的任务是仔细观察给出的插图,并按照以下步骤进行: -1. 清晰地描述这张图片中显示的内容细节。如果图片中包含任何文字,请确保在描述中准确无误地包含这些文字。 -2. 使用标签中的上下文信息来帮助你更好地理解和描述这张图片。上下文中的{tag}就是指该插图。 -3. 将你的描述写在标签之间。 - -{context} - -请将你的描述写在xml标签之间。 -'''.strip() - output = self.invoke_llm(img, prompt.format(context=context, tag=tag)) + def get_description(self, img, context, tag): + with open('prompt/description.txt') as f: + description_prompt = f.read() + output = self.invoke_llm(img, description_prompt.format(context=context, tag=tag)) return f'![{output}]()' + def get_mermaid(self, img, classification): with open('prompt/mermaid_template.txt') as f: mermaid_prompt = f.read() prompt = mermaid_prompt.format(diagram_type=classification, diagram_example=self.mermaid_prompt[classification]) output = self.invoke_llm(img, prompt, prefix='', stop='') return output + def parse_result(self, llm_output, tag): try: pattern = fr"<{tag}>(.*?)" @@ -98,6 +172,7 @@ def parse_result(self, llm_output, tag): except: output = llm_output.replace(f"<{tag}>", '').replace(f"", '') return output + def __call__(self, img, context, tag, s3_link): classification = self.get_classification(img) classification = self.parse_result(classification, 'output') diff --git a/source/model/etl/code/gpu_config.py b/source/model/etl/code/gpu_config.py new file mode 100644 index 000000000..02b1df65e --- /dev/null +++ b/source/model/etl/code/gpu_config.py @@ -0,0 +1,13 @@ +import GPUtil + +def get_provider_config(): + if len(GPUtil.getGPUs()): + provider = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "HEURISTIC"}), "CPUExecutionProvider"] + rec_batch_num = 6 + layout_model = 'layout.onnx' + else: + provider = ["CPUExecutionProvider"] + rec_batch_num = 1 + layout_model = 'layout_s.onnx' + + return provider, rec_batch_num, layout_model \ No newline at end of file diff --git a/source/model/etl/code/imaug/__init__.py b/source/model/etl/code/imaug/__init__.py index 461a24738..419b5c999 100644 --- a/source/model/etl/code/imaug/__init__.py +++ b/source/model/etl/code/imaug/__init__.py @@ -4,6 +4,10 @@ from __future__ import unicode_literals from .operators import * from .table_ops import * +from .preprocess import preprocess + +__all__ = ["preprocess"] + def transform(data, ops=None): """ transform """ if ops is None: diff --git a/source/model/etl/code/imaug/preprocess.py b/source/model/etl/code/imaug/preprocess.py new file mode 100644 index 000000000..1a55e087e --- /dev/null +++ b/source/model/etl/code/imaug/preprocess.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +import cv2 +import numpy as np + +__all__ = ["preprocess"] + +def preprocess(img, input_size, swap=(2, 0, 1)): + """Preprocess image for model input. + + Args: + img: Input image + input_size: Target size (height, width) + swap: Channel swap order + + Returns: + Preprocessed image and scale ratio + """ + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r \ No newline at end of file diff --git a/source/model/etl/code/layout.py b/source/model/etl/code/layout.py index 7b4f8d091..ab67ac8b5 100644 --- a/source/model/etl/code/layout.py +++ b/source/model/etl/code/layout.py @@ -17,33 +17,35 @@ import numpy as np import time -from utils import preprocess, multiclass_nms, postprocess +from imaug import preprocess +from postprocess import multiclass_nms, postprocess import onnxruntime -import GPUtil -if len(GPUtil.getGPUs()): - provider = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "HEURISTIC"}), "CPUExecutionProvider"] - model = 'layout.onnx' -else: - provider = ["CPUExecutionProvider"] - model = 'layout_s.onnx' +from gpu_config import get_provider_config +from model_config import LAYOUT_CONFIG + +provider, _, layout_model = get_provider_config() class LayoutPredictor(object): def __init__(self): - self.ort_session = onnxruntime.InferenceSession(os.path.join(os.environ['MODEL_PATH'], model), providers=provider) - #_ = self.ort_session.run(['output'], {'images': np.zeros((1,3,640,640), dtype='float32')})[0] - self.categorys = ['text', 'title', 'figure', 'table'] + self.ort_session = onnxruntime.InferenceSession(os.path.join(os.environ['MODEL_PATH'], layout_model), providers=provider) + self.categorys = LAYOUT_CONFIG['categories'] + self.nms_thr = LAYOUT_CONFIG['nms_threshold'] + self.score_thr = LAYOUT_CONFIG['score_threshold'] + self.image_size = LAYOUT_CONFIG['image_size'] + self.aspect_ratio_threshold = LAYOUT_CONFIG['aspect_ratio_threshold'] + def __call__(self, img): ori_im = img.copy() starttime = time.time() h,w,_ = img.shape h_ori, w_ori, _ = img.shape - if max(h_ori, w_ori)/min(h_ori, w_ori)>2: - s = 640/min(h_ori, w_ori) + if max(h_ori, w_ori)/min(h_ori, w_ori) > self.aspect_ratio_threshold: + s = self.image_size/min(h_ori, w_ori) h_new = int((h_ori*s)//32*32) w_new = int((w_ori*s)//32*32) h, w = (h_new, w_new) else: - h, w = (640, 640) + h, w = (self.image_size, self.image_size) image, ratio = preprocess(img, (h, w)) res = self.ort_session.run(['output'], {'images': image[np.newaxis,:]})[0] predictions = postprocess(res, (h, w), p6=False)[0] @@ -58,7 +60,7 @@ def __call__(self, img): boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. boxes_xyxy /= ratio - dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.15, score_thr=0.3) + dets = multiclass_nms(boxes_xyxy, scores, nms_thr=self.nms_thr, score_thr=self.score_thr) if dets is None: return [], time.time() - starttime scores = dets[:, 4] diff --git a/source/model/etl/code/main.py b/source/model/etl/code/main.py index 94c1860c7..23e51beca 100644 --- a/source/model/etl/code/main.py +++ b/source/model/etl/code/main.py @@ -6,6 +6,7 @@ import re import subprocess from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed from ocr import TextSystem from table import TableSystem @@ -182,6 +183,14 @@ def remove_symbols(text): return cleaned_text +def process_figure(figure_data, doc, figure_idx): + k, v = figure_data + image, region_text = v[0], v[1] if v[1] is not None else '' + start_pos = doc.index(k) + context = doc[max(start_pos-200, 0): min(start_pos+200, len(doc))] + return k, figure_understand(image, context, k, s3_link=f'{figure_idx:05d}.jpg') + + def structure_predict(file_path: Path, lang: str, auto_dpi, figure_rec) -> str: """ Extracts structured information from images in the given file path and returns a formatted document. @@ -193,8 +202,6 @@ def structure_predict(file_path: Path, lang: str, auto_dpi, figure_rec) -> str: str: The formatted document containing the extracted information. """ - # img_list, flag_gif, flag_pdf are returned from check_and_read - #img_list, _, _ = check_and_read(file_path) all_res = [] for index, img in enumerate(check_and_read(file_path)): @@ -260,15 +267,37 @@ def structure_predict(file_path: Path, lang: str, auto_dpi, figure_rec) -> str: doc += "\n\n" doc = re.sub("\n{2,}", "\n\n", doc.strip()) images = {} - for figure_idx, (k,v) in enumerate(figure.items()): - images[f'{figure_idx:05d}.jpg'] = v[0] - region_text = v[1] if not v[1] is None else '' - if figure_rec: - start_pos = doc.index(k) - context = doc[max(start_pos-200, 0): min(start_pos+200, len(doc))] - doc = doc.replace(k, figure_understand(v[0], context, k, s3_link=f'{figure_idx:05d}.jpg')) - else: + start_time = time.time() + if figure_rec: + # Process figures concurrently using ThreadPoolExecutor + replacements = {} + max_workers = int(os.environ.get('FIGURE_UNDERSTANDING_MAX_WORKERS', 4)) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_figure = { + executor.submit(process_figure, (k,v), doc, figure_idx): (figure_idx, k, v) + for figure_idx, (k,v) in enumerate(figure.items()) + } + + for future in as_completed(future_to_figure): + figure_idx = future_to_figure[future][0] + images[f'{figure_idx:05d}.jpg'] = future_to_figure[future][2][0] + try: + k, replacement = future.result() + replacements[k] = replacement + except Exception as e: + logger.error(f"Error processing figure {figure_idx}: {str(e)}") + + # Apply all replacements after concurrent processing + for k, replacement in replacements.items(): + doc = doc.replace(k, replacement) + else: + for figure_idx, (k,v) in enumerate(figure.items()): + images[f'{figure_idx:05d}.jpg'] = v[0] + region_text = v[1] if not v[1] is None else '' doc = doc.replace(k, f"\n
\n{figure_idx:05d}.jpg\nocr\n\n{region_text}\n\n
\n") + + end_time = time.time() + logger.info(f"Figure processing time: {end_time - start_time:.2f} seconds") doc = re.sub("\n{2,}", "\n\n", doc.strip()) return doc, images diff --git a/source/model/etl/code/model_config.py b/source/model/etl/code/model_config.py new file mode 100644 index 000000000..67107685d --- /dev/null +++ b/source/model/etl/code/model_config.py @@ -0,0 +1,78 @@ +MODEL_CONFIGS = { + 'ch': { + 'det': 'det_cn.onnx', + 'rec': 'rec_ch.onnx', + 'dict_path': 'ppocr_keys_v1.txt', + 'character_type': 'ch', + 'use_space_char': True + }, + 'en': { + 'det': 'det_en.onnx', + 'rec': 'rec_en.onnx', + 'dict_path': 'en_dict.txt', + 'character_type': 'en', + 'use_space_char': True + }, + 'multi': { + 'det': 'det_cn.onnx', + 'rec': 'rec_multi_large.onnx', + 'dict_path': 'keys_en_chs_cht_vi_ja_ko.txt', + 'character_type': 'ch', + 'use_space_char': True + } +} + +LAYOUT_CONFIG = { + 'categories': ['text', 'title', 'figure', 'table'], + 'nms_threshold': 0.15, + 'score_threshold': 0.3, + 'image_size': 640, # Default image processing size + 'aspect_ratio_threshold': 2 # Max ratio threshold for special handling +} + +TABLE_CONFIG = { + 'model': { + 'name': 'table_sim.onnx', + 'session_options': { + 'intra_op_num_threads': 8, + 'execution_mode': 'ORT_SEQUENTIAL', + 'optimization_level': 'ORT_ENABLE_ALL' + } + }, + 'preprocess': [ + { + 'ResizeTableImage': { + 'max_len': 488 + } + }, + { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, + { + 'PaddingTableImage': { + 'size': [488, 488] + } + }, + { + 'ToCHWImage': None + }, + { + 'KeepKeys': { + 'keep_keys': ['image', 'shape'] + } + } + ], + 'postprocess': { + 'name': 'TableLabelDecode', + 'dict_path': 'table_structure_dict_ch.txt', + 'merge_no_span_structure': True + }, + 'table_match': { + 'filter_ocr_result': True + } +} \ No newline at end of file diff --git a/source/model/etl/code/ocr.py b/source/model/etl/code/ocr.py index 7a8326c88..61d783469 100644 --- a/source/model/etl/code/ocr.py +++ b/source/model/etl/code/ocr.py @@ -9,17 +9,14 @@ import cv2 from imaug import create_operators, transform from postprocess import build_post_process -import GPUtil -if len(GPUtil.getGPUs()): - provider = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "HEURISTIC"}), "CPUExecutionProvider"] - rec_batch_num = 6 -else: - provider = ["CPUExecutionProvider"] - rec_batch_num = 1 +from gpu_config import get_provider_config +from model_config import MODEL_CONFIGS + +provider, rec_batch_num, _ = get_provider_config() class TextClassifier(): def __init__(self): - self.weights_path = os.environ['MODEL_PATH'] + 'classifier.onnx' + self.weights_path = os.path.join(os.environ['MODEL_PATH'], 'classifier.onnx') self.cls_image_shape = [3, 48, 192] self.cls_batch_num = 30 @@ -43,7 +40,6 @@ def resize_norm_img(self, img): else: resized_w = int(math.ceil(imgH * ratio)) resized_image = np.array(Image.fromarray(img).resize((resized_w, imgH))) - #resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = resized_image.astype('float32') if self.cls_image_shape[0] == 1: resized_image = resized_image / 255 @@ -95,15 +91,8 @@ def __call__(self, img_list): class TextDetector(): def __init__(self, lang): - - if lang=='ch': - modelName = 'det_cn.onnx' - elif lang=='en': - modelName = 'det_en.onnx' - else: - modelName = 'det_cn.onnx' - - self.weights_path = os.environ['MODEL_PATH'] + modelName + model_config = MODEL_CONFIGS[lang] + self.weights_path = os.path.join(os.environ['MODEL_PATH'], model_config['det']) self.det_algorithm = 'DB' self.use_zero_copy_run = False @@ -209,30 +198,15 @@ def __call__(self, img, scale=None): class TextRecognizer(): def __init__(self, lang='ch'): - if lang=='ch': - modelName = 'rec_ch.onnx' - postprocess_params = { - 'name': 'CTCLabelDecode', - "character_type": 'ch', - "character_dict_path": os.environ['MODEL_PATH'] + 'ppocr_keys_v1.txt', - "use_space_char": True - } - elif lang=='en': - modelName = 'rec_en.onnx' - postprocess_params = { - 'name': 'CTCLabelDecode', - "character_dict_path": os.environ['MODEL_PATH'] + 'en_dict.txt', - "use_space_char": True - } - else: - modelName = 'rec_multi_large.onnx' - postprocess_params = { - 'name': 'CTCLabelDecode', - "character_type": 'ch', - "character_dict_path": os.environ['MODEL_PATH'] + 'keys_en_chs_cht_vi_ja_ko.txt', - "use_space_char": True - } - self.weights_path = os.environ['MODEL_PATH'] + modelName + model_config = MODEL_CONFIGS[lang] + self.weights_path = os.path.join(os.environ['MODEL_PATH'], model_config['rec']) + + postprocess_params = { + 'name': 'CTCLabelDecode', + "character_type": model_config['character_type'], + "character_dict_path": os.path.join(os.environ['MODEL_PATH'], model_config['dict_path']), + "use_space_char": model_config['use_space_char'] + } self.limited_max_width = 1280 self.limited_min_width = 16 @@ -243,7 +217,6 @@ def __init__(self, lang='ch'): self.use_zero_copy_run = False self.postprocess_op = build_post_process(postprocess_params) - self.ort_session = onnxruntime.InferenceSession(self.weights_path, providers=provider) def resize_norm_img(self, img, max_wh_ratio): diff --git a/source/model/etl/code/postprocess/__init__.py b/source/model/etl/code/postprocess/__init__.py index 935708e11..ff0895548 100644 --- a/source/model/etl/code/postprocess/__init__.py +++ b/source/model/etl/code/postprocess/__init__.py @@ -5,7 +5,13 @@ import copy -__all__ = ['build_post_process'] +from .nms import nms, multiclass_nms, postprocess + +__all__ = [ + "nms", + "multiclass_nms", + "postprocess" +] def build_post_process(config, global_config=None): diff --git a/source/model/etl/code/postprocess/nms.py b/source/model/etl/code/postprocess/nms.py new file mode 100644 index 000000000..29b70b7ec --- /dev/null +++ b/source/model/etl/code/postprocess/nms.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +import numpy as np + +__all__ = ["nms", "multiclass_nms", "postprocess"] + +def nms(boxes, scores, nms_thr): + """Single class NMS implemented in Numpy.""" + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= nms_thr)[0] + order = order[inds + 1] + + return keep + +def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True): + """Multiclass NMS implemented in Numpy""" + if class_agnostic: + nms_method = multiclass_nms_class_agnostic + else: + nms_method = multiclass_nms_class_aware + return nms_method(boxes, scores, nms_thr, score_thr) + +def multiclass_nms_class_aware(boxes, scores, nms_thr, score_thr): + """Multiclass NMS implemented in Numpy. Class-aware version.""" + final_dets = [] + num_classes = scores.shape[1] + for cls_ind in range(num_classes): + cls_scores = scores[:, cls_ind] + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + continue + else: + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + keep = nms(valid_boxes, valid_scores, nms_thr) + if len(keep) > 0: + cls_inds = np.ones((len(keep), 1)) * cls_ind + dets = np.concatenate( + [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 + ) + final_dets.append(dets) + if len(final_dets) == 0: + return None + return np.concatenate(final_dets, 0) + +def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr): + """Multiclass NMS implemented in Numpy. Class-agnostic version.""" + cls_inds = scores.argmax(1) + cls_scores = scores[np.arange(len(cls_inds)), cls_inds] + + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + return None + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + valid_cls_inds = cls_inds[valid_score_mask] + keep = nms(valid_boxes, valid_scores, nms_thr) + if keep: + dets = np.concatenate( + [valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]], 1 + ) + return dets + +def postprocess(outputs, img_size, p6=False): + """Post-process model outputs.""" + grids = [] + expanded_strides = [] + + if not p6: + strides = [8, 16, 32] + else: + strides = [8, 16, 32, 64] + + hsizes = [img_size[0] // stride for stride in strides] + wsizes = [img_size[1] // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + + return outputs \ No newline at end of file diff --git a/source/model/etl/code/prompt/chart.txt b/source/model/etl/code/prompt/chart.txt new file mode 100644 index 000000000..8de485baf --- /dev/null +++ b/source/model/etl/code/prompt/chart.txt @@ -0,0 +1,16 @@ +您是文档阅读专家。您的任务是将图片中的图表转换成Markdown格式。以下是说明: +1. 找到图片中的图表。 +2. 仔细观察图表,了解其中包含的结构和数据。 +3. 使用标签中的上下文信息来帮助你更好地理解和描述这张图表。上下文中的{tag}就是指该图表。 +4. 按照以下指南将图表数据转换成 Markdown 表格格式: + - 使用 | 字符分隔列 + - 使用 --- 行表示标题行 + - 确保表格格式正确并对齐 + - 对于不确定的数字,请根据图片估算。 +5. 仔细检查您的 Markdown 表格是否准确地反映了图表图像中的数据。 +6. 在 xml 标签中仅返回 Markdown,不含其他文本。 + + +{context} + +请将你的描述写在xml标签之间。 \ No newline at end of file diff --git a/source/model/etl/code/prompt/description.txt b/source/model/etl/code/prompt/description.txt new file mode 100644 index 000000000..22fa1a862 --- /dev/null +++ b/source/model/etl/code/prompt/description.txt @@ -0,0 +1,8 @@ +你是一位资深的图像分析专家。你的任务是仔细观察给出的插图,并按照以下步骤进行: +1. 清晰地描述这张图片中显示的内容细节。如果图片中包含任何文字,请确保在描述中准确无误地包含这些文字。 +2. 使用标签中的上下文信息来帮助你更好地理解和描述这张图片。上下文中的{tag}就是指该插图。 +3. 将你的描述写在标签之间。 + +{context} + +请将你的描述写在xml标签之间。 \ No newline at end of file diff --git a/source/model/etl/code/prompt/mermaid_template.txt b/source/model/etl/code/prompt/mermaid_template.txt index eae16c000..fcfc9d1e3 100644 --- a/source/model/etl/code/prompt/mermaid_template.txt +++ b/source/model/etl/code/prompt/mermaid_template.txt @@ -19,7 +19,6 @@ Your response should be structured as follows: [Mermaid chart code representing the workflow] - Below are example of {diagram_type} mermaid templates. The example include a detailed description of the workflow along with Mermaid chart codes. Use the example as reference when analyzing the workflow diagram in the image. diff --git a/source/model/etl/code/requirements.txt b/source/model/etl/code/requirements.txt index 695227c67..4fd7ee918 100644 --- a/source/model/etl/code/requirements.txt +++ b/source/model/etl/code/requirements.txt @@ -8,4 +8,5 @@ PyMuPDF<1.21.0 markdownify flask gevent -GPUtil \ No newline at end of file +GPUtil +openai>=1.0.0 \ No newline at end of file diff --git a/source/model/etl/code/sm_predictor.py b/source/model/etl/code/sm_predictor.py index 5cbacb949..00e3bc26d 100644 --- a/source/model/etl/code/sm_predictor.py +++ b/source/model/etl/code/sm_predictor.py @@ -1,4 +1,3 @@ -from aikits_utils import lambda_return from main import process_pdf_pipeline from gevent import pywsgi import flask @@ -9,7 +8,7 @@ def handler(event, context): if 'body' not in event: - return lambda_return(400, 'invalid param') + return create_response('invalid param', 400) try: if isinstance(event['body'], str): body = json.loads(event['body']) @@ -17,14 +16,30 @@ def handler(event, context): body = event['body'] if 's3_bucket' not in body or 'object_key' not in body: - return lambda_return(400, 'Must specify the `s3_bucket` and `object_key` for the file') + return create_response('Must specify the `s3_bucket` and `object_key` for the file', 400) except: - return lambda_return(400, 'invalid param') + return create_response('invalid param', 400) output = process_pdf_pipeline(body) + return create_response(json.dumps(output), 200) - return lambda_return(200, json.dumps(output)) + +def create_response(body, status_code=200): + """Create a Flask response with CORS headers. + + Args: + body: Response body + status_code: HTTP status code (default: 200) + + Returns: + Flask Response object + """ + response = flask.make_response(body, status_code) + response.headers['Access-Control-Allow-Headers'] = '*' + response.headers['Access-Control-Allow-Origin'] = '*' + response.headers['Access-Control-Allow-Methods'] = '*' + return response @app.route('/ping', methods=['GET']) @@ -48,14 +63,10 @@ def transformation(): if flask.request.content_type == 'application/json': request_body = flask.request.data.decode('utf-8') body = json.loads(request_body) - req = handler({'body': body}, None) - return flask.Response( - response=req['body'], - status=req['statusCode'], mimetype='application/json') + response = handler({'body': body}, None) + return response else: - return flask.Response( - response='Only supports application/json data', - status=415, mimetype='application/json') + return create_response('Only supports application/json data', 415) server = pywsgi.WSGIServer(('0.0.0.0', 8080), app) diff --git a/source/model/etl/code/table.py b/source/model/etl/code/table.py index c6e891088..a7fb0a788 100644 --- a/source/model/etl/code/table.py +++ b/source/model/etl/code/table.py @@ -7,12 +7,16 @@ import numpy as np import os import onnxruntime as ort +from model_config import TABLE_CONFIG -sess_options = ort.SessionOptions() +def get_session_options(): + sess_options = ort.SessionOptions() + config = TABLE_CONFIG['model']['session_options'] + sess_options.intra_op_num_threads = config['intra_op_num_threads'] + sess_options.execution_mode = getattr(ort.ExecutionMode, config['execution_mode']) + sess_options.graph_optimization_level = getattr(ort.GraphOptimizationLevel, config['optimization_level']) + return sess_options -sess_options.intra_op_num_threads = 8 -sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL -sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL def sorted_boxes(dt_boxes): """ Sort text boxes in order from top to bottom, left to right @@ -33,24 +37,24 @@ def sorted_boxes(dt_boxes): _boxes[i] = _boxes[i + 1] _boxes[i + 1] = tmp return _boxes + class TableStructurer(object): def __init__(self): - self.use_onnx = True #args.use_onnx - pre_process_list = [{'ResizeTableImage': {'max_len': 488}}, {'NormalizeImage': {'std': [0.229, 0.224, 0.225], 'mean': [0.485, 0.456, 0.406], 'scale': '1./255.', 'order': 'hwc'}}, {'PaddingTableImage': {'size': [488, 488]}}, {'ToCHWImage': None}, {'KeepKeys': {'keep_keys': ['image', 'shape']}}] - - postprocess_params = { - 'name': 'TableLabelDecode', - "character_dict_path": os.environ['MODEL_PATH'] + 'table_structure_dict_ch.txt', - 'merge_no_span_structure': True - } - self.preprocess_op = create_operators(pre_process_list) + self.use_onnx = True + self.preprocess_op = create_operators(TABLE_CONFIG['preprocess']) + + postprocess_params = dict(TABLE_CONFIG['postprocess']) + postprocess_params['character_dict_path'] = os.path.join(os.environ['MODEL_PATH'], postprocess_params.pop('dict_path')) self.postprocess_op = build_post_process(postprocess_params) - sess = ort.InferenceSession(os.environ['MODEL_PATH'] + 'table_sim.onnx', providers=['CPUExecutionProvider']) #, sess_options=sess_options, providers=[("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"})] + sess = ort.InferenceSession( + os.path.join(os.environ['MODEL_PATH'], TABLE_CONFIG['model']['name']), + providers=['CPUExecutionProvider'], + sess_options=get_session_options() + ) _ = sess.run(None, {'x': np.zeros((1, 3, 488, 488), dtype='float32')}) self.predictor, self.input_tensor, self.output_tensors, self.config = sess, sess.get_inputs()[0], None, None - def __call__(self, img): starttime = time.time() ori_im = img.copy() @@ -78,6 +82,7 @@ def __call__(self, img): ] + structure_str_list + ['', '', ''] elapse = time.time() - starttime return (structure_str_list, bbox_list), elapse + def expand(pix, det_box, shape): x0, y0, x1, y1 = det_box h, w, c = shape @@ -95,9 +100,8 @@ class TableSystem(object): def __init__(self, text_detector=None, text_recognizer=None): self.text_detector = text_detector self.text_recognizer = text_recognizer - self.table_structurer = TableStructurer() - self.match = TableMatch(filter_ocr_result=True) + self.match = TableMatch(**TABLE_CONFIG['table_match']) def __call__(self, img, return_ocr_result_in_table=False, lang='ch'): result = dict() diff --git a/source/model/etl/code/untitled.txt b/source/model/etl/code/untitled.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/source/model/etl/code/utils.py b/source/model/etl/code/utils.py index a5b56cd07..b5e69927c 100644 --- a/source/model/etl/code/utils.py +++ b/source/model/etl/code/utils.py @@ -1,154 +1,46 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- -# Copyright (c) Megvii Inc. All rights reserved. -import os +import os import cv2 +import logging import numpy as np - -# from transformers import StoppingCriteria +from PIL import Image +from io import BytesIO +import boto3 +import base64 +try: + import urllib.request as urllib2 + from urllib.parse import urlparse +except ImportError: + import urllib2 + from urlparse import urlparse __all__ = [ - "preprocess", - "nms", - "multiclass_nms", - "postprocess", - "StoppingCriteriaScores", - "markdown_compatible", + "check_and_read", + "readimg", + "lambda_return" ] - -def preprocess(img, input_size, swap=(2, 0, 1)): - if len(img.shape) == 3: - padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 - else: - padded_img = np.ones(input_size, dtype=np.uint8) * 114 - - r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) - resized_img = cv2.resize( - img, - (int(img.shape[1] * r), int(img.shape[0] * r)), - interpolation=cv2.INTER_LINEAR, - ).astype(np.uint8) - padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img - - padded_img = padded_img.transpose(swap) - padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) - return padded_img, r - - -def nms(boxes, scores, nms_thr): - """Single class NMS implemented in Numpy.""" - x1 = boxes[:, 0] - y1 = boxes[:, 1] - x2 = boxes[:, 2] - y2 = boxes[:, 3] - - areas = (x2 - x1 + 1) * (y2 - y1 + 1) - order = scores.argsort()[::-1] - - keep = [] - while order.size > 0: - i = order[0] - keep.append(i) - xx1 = np.maximum(x1[i], x1[order[1:]]) - yy1 = np.maximum(y1[i], y1[order[1:]]) - xx2 = np.minimum(x2[i], x2[order[1:]]) - yy2 = np.minimum(y2[i], y2[order[1:]]) - - w = np.maximum(0.0, xx2 - xx1 + 1) - h = np.maximum(0.0, yy2 - yy1 + 1) - inter = w * h - ovr = inter / (areas[i] + areas[order[1:]] - inter) - - inds = np.where(ovr <= nms_thr)[0] - order = order[inds + 1] - - return keep - - -def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True): - """Multiclass NMS implemented in Numpy""" - if class_agnostic: - nms_method = multiclass_nms_class_agnostic - else: - nms_method = multiclass_nms_class_aware - return nms_method(boxes, scores, nms_thr, score_thr) - - -def multiclass_nms_class_aware(boxes, scores, nms_thr, score_thr): - """Multiclass NMS implemented in Numpy. Class-aware version.""" - final_dets = [] - num_classes = scores.shape[1] - for cls_ind in range(num_classes): - cls_scores = scores[:, cls_ind] - valid_score_mask = cls_scores > score_thr - if valid_score_mask.sum() == 0: - continue - else: - valid_scores = cls_scores[valid_score_mask] - valid_boxes = boxes[valid_score_mask] - keep = nms(valid_boxes, valid_scores, nms_thr) - if len(keep) > 0: - cls_inds = np.ones((len(keep), 1)) * cls_ind - dets = np.concatenate( - [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 - ) - final_dets.append(dets) - if len(final_dets) == 0: - return None - return np.concatenate(final_dets, 0) - - -def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr): - """Multiclass NMS implemented in Numpy. Class-agnostic version.""" - cls_inds = scores.argmax(1) - cls_scores = scores[np.arange(len(cls_inds)), cls_inds] - - valid_score_mask = cls_scores > score_thr - if valid_score_mask.sum() == 0: - return None - valid_scores = cls_scores[valid_score_mask] - valid_boxes = boxes[valid_score_mask] - valid_cls_inds = cls_inds[valid_score_mask] - keep = nms(valid_boxes, valid_scores, nms_thr) - if keep: - dets = np.concatenate( - [valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]], 1 - ) - return dets - - -def postprocess(outputs, img_size, p6=False): - - grids = [] - expanded_strides = [] - - if not p6: - strides = [8, 16, 32] - else: - strides = [8, 16, 32, 64] - - hsizes = [img_size[0] // stride for stride in strides] - wsizes = [img_size[1] // stride for stride in strides] - - for hsize, wsize, stride in zip(hsizes, wsizes, strides): - xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) - grid = np.stack((xv, yv), 2).reshape(1, -1, 2) - grids.append(grid) - shape = grid.shape[:2] - expanded_strides.append(np.full((*shape, 1), stride)) - - grids = np.concatenate(grids, 1) - expanded_strides = np.concatenate(expanded_strides, 1) - outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides - outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides - - return outputs - - def check_and_read(img_path): - if os.path.basename(img_path)[-3:].lower() == "gif": + """Check and read image file in different formats. + + Supports: + - Common image formats (JPG, PNG, etc.): Returns the image + - GIF: Returns first frame + - PDF: Returns generator of all pages as images + + Args: + img_path: Path to the image file + + Returns: + Tuple of (image_data, is_valid, is_gif) + For PDFs, returns a generator of images + """ + # Get file extension + ext = os.path.basename(img_path)[-3:].lower() + # Handle GIF files + if ext == "gif": gif = cv2.VideoCapture(img_path) ret, frame = gif.read() if not ret: @@ -158,22 +50,86 @@ def check_and_read(img_path): if len(frame.shape) == 2 or frame.shape[-1] == 1: frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) imgvalue = frame[:, :, ::-1] - return imgvalue, True, False - elif os.path.basename(img_path)[-3:].lower() == "pdf": + yield imgvalue + + # Handle PDF files + elif ext == "pdf": import fitz - from PIL import Image - - imgs = [] with fitz.open(img_path) as pdf: for pg in range(0, pdf.page_count): page = pdf[pg] mat = fitz.Matrix(3, 3) pm = page.get_pixmap(matrix=mat, alpha=False) - - # if width or height > 2000 pixels, don't enlarge the image - # if pm.width > 2000 or pm.height > 2000: - # pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) - img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples) img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) yield img + + # Handle common image formats (JPG, PNG, etc.) + else: + img = cv2.imread(img_path) + if img is None: + return None + if len(img.shape) == 2 or img.shape[-1] == 1: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + yield img[:, :, ::-1] + +def readimg(body, keys=None): + """Read images from various sources in a request body. + + Supports: + - HTTP URLs + - S3 URLs + - Base64 encoded images + + Args: + body: Request body containing image data + keys: Keys to look for in the body. If None, uses all keys + + Returns: + Dict mapping keys to numpy arrays of images + """ + if keys is None: + keys = body.keys() + inputs = dict() + for key in keys: + try: + if key.startswith('url'): # url形式 + if body[key].startswith('http'): # http url + image_string = urllib2.urlopen(body[key]).read() + elif body[key].startswith('s3'): # s3 key + o = urlparse(body[key]) + bucket = o.netloc + path = o.path.lstrip('/') + s3 = boto3.resource('s3') + img_obj = s3.Object(bucket, path) + image_string = img_obj.get()['Body'].read() + else: + raise + elif key.startswith('img'): # base64形式 + image_string = base64.b64decode(body[key]) + else: + raise + inputs[key] = np.array(Image.open(BytesIO(image_string)).convert('RGB'))[:, :, :3] + except: + inputs[key] = None + return inputs + +def lambda_return(statusCode, body): + """Create a standardized Lambda function response. + + Args: + statusCode: HTTP status code + body: Response body + + Returns: + Dict containing the Lambda response structure + """ + return { + 'statusCode': statusCode, + 'headers': { + 'Access-Control-Allow-Headers': '*', + 'Access-Control-Allow-Origin': '*', + 'Access-Control-Allow-Methods': '*' + }, + 'body': body + } \ No newline at end of file