diff --git a/source/model/etl/code/aikits_utils.py b/source/model/etl/code/aikits_utils.py
deleted file mode 100644
index 5ba1f5dc2..000000000
--- a/source/model/etl/code/aikits_utils.py
+++ /dev/null
@@ -1,50 +0,0 @@
-from io import BytesIO
-import boto3
-import base64
-import numpy as np
-from PIL import Image
-import cv2
-try:
- import urllib.request as urllib2
- from urllib.parse import urlparse
-except ImportError:
- import urllib2
- from urlparse import urlparse
-
-def readimg(body, keys=None):
- if keys is None:
- keys = body.keys()
- inputs = dict()
- for key in keys:
- try:
- if key.startswith('url'): # url形式
- if body[key].startswith('http'): # http url
- image_string = urllib2.urlopen(body[key]).read()
- elif body[key].startswith('s3'): # s3 key
- o = urlparse(body[key])
- bucket = o.netloc
- path = o.path.lstrip('/')
- s3 = boto3.resource('s3')
- img_obj = s3.Object(bucket, path)
- image_string = img_obj.get()['Body'].read()
- else:
- raise
- elif key.startswith('img'): # base64形式
- image_string = base64.b64decode(body[key])
- else:
- raise
- inputs[key] = np.array(Image.open(BytesIO(image_string)).convert('RGB'))[:, :, :3]
- except:
- inputs[key] = None
- return inputs
-
-def lambda_return(statusCode, body):
- return {
- 'statusCode': statusCode,
- 'headers': {
- 'Access-Control-Allow-Headers': '*',
- 'Access-Control-Allow-Origin': '*',
- 'Access-Control-Allow-Methods': '*'
- },
- 'body': body
- }
\ No newline at end of file
diff --git a/source/model/etl/code/config.py b/source/model/etl/code/config.py
new file mode 100644
index 000000000..d155ac9b9
--- /dev/null
+++ b/source/model/etl/code/config.py
@@ -0,0 +1,73 @@
+class ModelConfig:
+ """模型配置管理类"""
+
+ LANG_CONFIGS = {
+ 'ch': {
+ 'det': {
+ 'model': 'det_cn.onnx',
+ 'postprocess': {
+ 'name': 'DBPostProcess',
+ 'thresh': 0.1,
+ 'box_thresh': 0.1,
+ 'max_candidates': 1000,
+ 'unclip_ratio': 1.5,
+ 'use_dilation': False,
+ 'score_mode': 'fast',
+ 'box_type': 'quad'
+ }
+ },
+ 'rec': {
+ 'model': 'rec_ch.onnx',
+ 'dict_path': 'ppocr_keys_v1.txt',
+ 'postprocess': {
+ 'name': 'CTCLabelDecode',
+ 'character_type': 'ch',
+ 'use_space_char': True
+ }
+ }
+ },
+ 'en': {
+ 'det': {
+ 'model': 'det_en.onnx',
+ 'postprocess': {
+ 'name': 'DBPostProcess',
+ 'thresh': 0.1,
+ 'box_thresh': 0.1,
+ 'max_candidates': 1000,
+ 'unclip_ratio': 1.5,
+ 'use_dilation': False,
+ 'score_mode': 'fast',
+ 'box_type': 'quad'
+ }
+ },
+ 'rec': {
+ 'model': 'rec_en.onnx',
+ 'dict_path': 'en_dict.txt',
+ 'postprocess': {
+ 'name': 'CTCLabelDecode',
+ 'use_space_char': True
+ }
+ }
+ }
+ }
+
+ @classmethod
+ def get_model_path(cls, lang, model_type):
+ """获取模型路径"""
+ return os.path.join(
+ os.environ['MODEL_PATH'],
+ cls.LANG_CONFIGS[lang][model_type]['model']
+ )
+
+ @classmethod
+ def get_dict_path(cls, lang):
+ """获取字典路径"""
+ return os.path.join(
+ os.environ['MODEL_PATH'],
+ cls.LANG_CONFIGS[lang]['rec']['dict_path']
+ )
+
+ @classmethod
+ def get_postprocess_config(cls, lang, model_type):
+ """获取后处理配置"""
+ return cls.LANG_CONFIGS[lang][model_type]['postprocess']
\ No newline at end of file
diff --git a/source/model/etl/code/figure_llm.py b/source/model/etl/code/figure_llm.py
index 078b21640..ec117aaba 100644
--- a/source/model/etl/code/figure_llm.py
+++ b/source/model/etl/code/figure_llm.py
@@ -5,15 +5,73 @@
import io
import base64
import json
+import os
+import openai
+
+# Add logger configuration
+logger = logging.getLogger(__name__)
class figureUnderstand():
- def __init__(self):
- self.bedrock_runtime = boto3.client(service_name='bedrock-runtime')
+ def __init__(self, model_provider="bedrock", api_secret_name=None):
+ self.model_provider = model_provider
+ if model_provider == "bedrock":
+ self.bedrock_runtime = boto3.client(service_name='bedrock-runtime')
+ elif model_provider == "openai":
+ self.openai_api_key = self._get_api_key(api_secret_name)
+ if not self.openai_api_key:
+ raise ValueError("Failed to retrieve OpenAI API key from Secrets Manager")
+ openai.api_key = self.openai_api_key
+ # Set OpenAI base URL from environment variable if provided
+ base_url = os.environ.get("OPENAI_API_BASE")
+ if base_url:
+ openai.base_url = base_url
+ else:
+ raise ValueError("Unsupported model provider. Choose 'bedrock' or 'openai'")
+
self.mermaid_prompt = json.load(open('prompt/mermaid.json', 'r'))
- def invoke_llm(self, img, prompt, prefix=""):
+
+ def _get_api_key(self, api_secret_name):
+ """
+ Get the API key from AWS Secrets Manager.
+ Args:
+ api_secret_name (str): The name of the secret in AWS Secrets Manager containing the API key.
+ Returns:
+ str: The API key.
+ """
+ if not api_secret_name:
+ raise ValueError("api_secret_name must be provided when using OpenAI")
+
+ try:
+ secrets_client = boto3.client("secretsmanager")
+ secret_response = secrets_client.get_secret_value(
+ SecretId=api_secret_name
+ )
+ if "SecretString" in secret_response:
+ secret_data = json.loads(secret_response["SecretString"])
+ api_key = secret_data.get("api_key")
+ logger.info(
+ f"Successfully retrieved API key from secret: {api_secret_name}"
+ )
+ return api_key
+ except Exception as e:
+ logger.error(f"Error retrieving secret {api_secret_name}: {str(e)}")
+ raise
+ return None
+
+ def _image_to_base64(self, img):
+ """Convert PIL Image to base64 encoded string"""
image_stream = io.BytesIO()
img.save(image_stream, format="JPEG")
- base64_encoded = base64.b64encode(image_stream.getvalue()).decode('utf-8')
+ return base64.b64encode(image_stream.getvalue()).decode('utf-8')
+
+ def invoke_llm(self, img, prompt, prefix=""):
+ if self.model_provider == "bedrock":
+ return self._invoke_bedrock(img, prompt, prefix, stop)
+ elif self.model_provider == "openai":
+ return self._invoke_openai(img, prompt, prefix, stop)
+
+ def _invoke_bedrock(self, img, prompt, prefix=""):
+ base64_encoded = self._image_to_base64(img)
messages = [
{
"role": "user",
@@ -34,7 +92,7 @@ def invoke_llm(self, img, prompt, prefix=""):
},
{"role": "assistant", "content": prefix},
]
- model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
+ model_id = "anthropic.claude-3-5-sonnet-20241022-v2:0"
body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 4096,
@@ -46,51 +104,67 @@ def invoke_llm(self, img, prompt, prefix=""):
response_body = json.loads(response.get('body').read())
result = prefix + response_body['content'][0]['text'] + stop
return result
+
+ def _invoke_openai(self, img, prompt, prefix=""):
+ base64_encoded = self._image_to_base64(img)
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": prompt
+ },
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/jpeg;base64,{base64_encoded}"
+ }
+ }
+ ]
+ },
+ {
+ "role": "assistant",
+ "content": prefix
+ }
+ ]
+
+ response = openai.chat.completions.create(
+ model="gpt-4-vision-preview",
+ messages=messages,
+ max_tokens=4096,
+ stop=[stop] if stop else None
+ )
+
+ result = prefix + response.choices[0].message.content + stop
+ return result
+
def get_classification(self, img):
with open('prompt/figure_classification.txt') as f:
figure_classification_prompt = f.read()
output = self.invoke_llm(img, figure_classification_prompt)
return output
+
def get_chart(self, img, context, tag):
- prompt = '''您是文档阅读专家。您的任务是将图片中的图表转换成Markdown格式。以下是说明:
-1. 找到图片中的图表。
-2. 仔细观察图表,了解其中包含的结构和数据。
-3. 使用标签中的上下文信息来帮助你更好地理解和描述这张图表。上下文中的{tag}就是指该图表。
-4. 按照以下指南将图表数据转换成 Markdown 表格格式:
- - 使用 | 字符分隔列
- - 使用 --- 行表示标题行
- - 确保表格格式正确并对齐
- - 对于不确定的数字,请根据图片估算。
-5. 仔细检查您的 Markdown 表格是否准确地反映了图表图像中的数据。
-6. 在 xml 标签中仅返回 Markdown,不含其他文本。
-
-
-{context}
-
-请将你的描述写在xml标签之间。
-'''.strip()
- output = self.invoke_llm(img, prompt)
+ with open('prompt/chart.txt') as f:
+ chart_prompt = f.read()
+ output = self.invoke_llm(img, chart_prompt.format(context=context, tag=tag))
return output
- def get_description(self, img, context, tag):
- prompt = '''
-你是一位资深的图像分析专家。你的任务是仔细观察给出的插图,并按照以下步骤进行:
-1. 清晰地描述这张图片中显示的内容细节。如果图片中包含任何文字,请确保在描述中准确无误地包含这些文字。
-2. 使用标签中的上下文信息来帮助你更好地理解和描述这张图片。上下文中的{tag}就是指该插图。
-3. 将你的描述写在标签之间。
-
-{context}
-
-请将你的描述写在xml标签之间。
-'''.strip()
- output = self.invoke_llm(img, prompt.format(context=context, tag=tag))
+ def get_description(self, img, context, tag):
+ with open('prompt/description.txt') as f:
+ description_prompt = f.read()
+ output = self.invoke_llm(img, description_prompt.format(context=context, tag=tag))
return f'![{output}]()'
+
def get_mermaid(self, img, classification):
with open('prompt/mermaid_template.txt') as f:
mermaid_prompt = f.read()
prompt = mermaid_prompt.format(diagram_type=classification, diagram_example=self.mermaid_prompt[classification])
output = self.invoke_llm(img, prompt, prefix='', stop='')
return output
+
def parse_result(self, llm_output, tag):
try:
pattern = fr"<{tag}>(.*?){tag}>"
@@ -98,6 +172,7 @@ def parse_result(self, llm_output, tag):
except:
output = llm_output.replace(f"<{tag}>", '').replace(f"{tag}>", '')
return output
+
def __call__(self, img, context, tag, s3_link):
classification = self.get_classification(img)
classification = self.parse_result(classification, 'output')
diff --git a/source/model/etl/code/gpu_config.py b/source/model/etl/code/gpu_config.py
new file mode 100644
index 000000000..02b1df65e
--- /dev/null
+++ b/source/model/etl/code/gpu_config.py
@@ -0,0 +1,13 @@
+import GPUtil
+
+def get_provider_config():
+ if len(GPUtil.getGPUs()):
+ provider = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "HEURISTIC"}), "CPUExecutionProvider"]
+ rec_batch_num = 6
+ layout_model = 'layout.onnx'
+ else:
+ provider = ["CPUExecutionProvider"]
+ rec_batch_num = 1
+ layout_model = 'layout_s.onnx'
+
+ return provider, rec_batch_num, layout_model
\ No newline at end of file
diff --git a/source/model/etl/code/imaug/__init__.py b/source/model/etl/code/imaug/__init__.py
index 461a24738..419b5c999 100644
--- a/source/model/etl/code/imaug/__init__.py
+++ b/source/model/etl/code/imaug/__init__.py
@@ -4,6 +4,10 @@
from __future__ import unicode_literals
from .operators import *
from .table_ops import *
+from .preprocess import preprocess
+
+__all__ = ["preprocess"]
+
def transform(data, ops=None):
""" transform """
if ops is None:
diff --git a/source/model/etl/code/imaug/preprocess.py b/source/model/etl/code/imaug/preprocess.py
new file mode 100644
index 000000000..1a55e087e
--- /dev/null
+++ b/source/model/etl/code/imaug/preprocess.py
@@ -0,0 +1,35 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+import cv2
+import numpy as np
+
+__all__ = ["preprocess"]
+
+def preprocess(img, input_size, swap=(2, 0, 1)):
+ """Preprocess image for model input.
+
+ Args:
+ img: Input image
+ input_size: Target size (height, width)
+ swap: Channel swap order
+
+ Returns:
+ Preprocessed image and scale ratio
+ """
+ if len(img.shape) == 3:
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
+ else:
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
+
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
+ resized_img = cv2.resize(
+ img,
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
+ interpolation=cv2.INTER_LINEAR,
+ ).astype(np.uint8)
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
+
+ padded_img = padded_img.transpose(swap)
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
+ return padded_img, r
\ No newline at end of file
diff --git a/source/model/etl/code/layout.py b/source/model/etl/code/layout.py
index 7b4f8d091..ab67ac8b5 100644
--- a/source/model/etl/code/layout.py
+++ b/source/model/etl/code/layout.py
@@ -17,33 +17,35 @@
import numpy as np
import time
-from utils import preprocess, multiclass_nms, postprocess
+from imaug import preprocess
+from postprocess import multiclass_nms, postprocess
import onnxruntime
-import GPUtil
-if len(GPUtil.getGPUs()):
- provider = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "HEURISTIC"}), "CPUExecutionProvider"]
- model = 'layout.onnx'
-else:
- provider = ["CPUExecutionProvider"]
- model = 'layout_s.onnx'
+from gpu_config import get_provider_config
+from model_config import LAYOUT_CONFIG
+
+provider, _, layout_model = get_provider_config()
class LayoutPredictor(object):
def __init__(self):
- self.ort_session = onnxruntime.InferenceSession(os.path.join(os.environ['MODEL_PATH'], model), providers=provider)
- #_ = self.ort_session.run(['output'], {'images': np.zeros((1,3,640,640), dtype='float32')})[0]
- self.categorys = ['text', 'title', 'figure', 'table']
+ self.ort_session = onnxruntime.InferenceSession(os.path.join(os.environ['MODEL_PATH'], layout_model), providers=provider)
+ self.categorys = LAYOUT_CONFIG['categories']
+ self.nms_thr = LAYOUT_CONFIG['nms_threshold']
+ self.score_thr = LAYOUT_CONFIG['score_threshold']
+ self.image_size = LAYOUT_CONFIG['image_size']
+ self.aspect_ratio_threshold = LAYOUT_CONFIG['aspect_ratio_threshold']
+
def __call__(self, img):
ori_im = img.copy()
starttime = time.time()
h,w,_ = img.shape
h_ori, w_ori, _ = img.shape
- if max(h_ori, w_ori)/min(h_ori, w_ori)>2:
- s = 640/min(h_ori, w_ori)
+ if max(h_ori, w_ori)/min(h_ori, w_ori) > self.aspect_ratio_threshold:
+ s = self.image_size/min(h_ori, w_ori)
h_new = int((h_ori*s)//32*32)
w_new = int((w_ori*s)//32*32)
h, w = (h_new, w_new)
else:
- h, w = (640, 640)
+ h, w = (self.image_size, self.image_size)
image, ratio = preprocess(img, (h, w))
res = self.ort_session.run(['output'], {'images': image[np.newaxis,:]})[0]
predictions = postprocess(res, (h, w), p6=False)[0]
@@ -58,7 +60,7 @@ def __call__(self, img):
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
boxes_xyxy /= ratio
- dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.15, score_thr=0.3)
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=self.nms_thr, score_thr=self.score_thr)
if dets is None:
return [], time.time() - starttime
scores = dets[:, 4]
diff --git a/source/model/etl/code/main.py b/source/model/etl/code/main.py
index 94c1860c7..23e51beca 100644
--- a/source/model/etl/code/main.py
+++ b/source/model/etl/code/main.py
@@ -6,6 +6,7 @@
import re
import subprocess
from pathlib import Path
+from concurrent.futures import ThreadPoolExecutor, as_completed
from ocr import TextSystem
from table import TableSystem
@@ -182,6 +183,14 @@ def remove_symbols(text):
return cleaned_text
+def process_figure(figure_data, doc, figure_idx):
+ k, v = figure_data
+ image, region_text = v[0], v[1] if v[1] is not None else ''
+ start_pos = doc.index(k)
+ context = doc[max(start_pos-200, 0): min(start_pos+200, len(doc))]
+ return k, figure_understand(image, context, k, s3_link=f'{figure_idx:05d}.jpg')
+
+
def structure_predict(file_path: Path, lang: str, auto_dpi, figure_rec) -> str:
"""
Extracts structured information from images in the given file path and returns a formatted document.
@@ -193,8 +202,6 @@ def structure_predict(file_path: Path, lang: str, auto_dpi, figure_rec) -> str:
str: The formatted document containing the extracted information.
"""
- # img_list, flag_gif, flag_pdf are returned from check_and_read
- #img_list, _, _ = check_and_read(file_path)
all_res = []
for index, img in enumerate(check_and_read(file_path)):
@@ -260,15 +267,37 @@ def structure_predict(file_path: Path, lang: str, auto_dpi, figure_rec) -> str:
doc += "\n\n"
doc = re.sub("\n{2,}", "\n\n", doc.strip())
images = {}
- for figure_idx, (k,v) in enumerate(figure.items()):
- images[f'{figure_idx:05d}.jpg'] = v[0]
- region_text = v[1] if not v[1] is None else ''
- if figure_rec:
- start_pos = doc.index(k)
- context = doc[max(start_pos-200, 0): min(start_pos+200, len(doc))]
- doc = doc.replace(k, figure_understand(v[0], context, k, s3_link=f'{figure_idx:05d}.jpg'))
- else:
+ start_time = time.time()
+ if figure_rec:
+ # Process figures concurrently using ThreadPoolExecutor
+ replacements = {}
+ max_workers = int(os.environ.get('FIGURE_UNDERSTANDING_MAX_WORKERS', 4))
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ future_to_figure = {
+ executor.submit(process_figure, (k,v), doc, figure_idx): (figure_idx, k, v)
+ for figure_idx, (k,v) in enumerate(figure.items())
+ }
+
+ for future in as_completed(future_to_figure):
+ figure_idx = future_to_figure[future][0]
+ images[f'{figure_idx:05d}.jpg'] = future_to_figure[future][2][0]
+ try:
+ k, replacement = future.result()
+ replacements[k] = replacement
+ except Exception as e:
+ logger.error(f"Error processing figure {figure_idx}: {str(e)}")
+
+ # Apply all replacements after concurrent processing
+ for k, replacement in replacements.items():
+ doc = doc.replace(k, replacement)
+ else:
+ for figure_idx, (k,v) in enumerate(figure.items()):
+ images[f'{figure_idx:05d}.jpg'] = v[0]
+ region_text = v[1] if not v[1] is None else ''
doc = doc.replace(k, f"\n\n{figure_idx:05d}.jpg\nocr\n\n{region_text}\n\n\n")
+
+ end_time = time.time()
+ logger.info(f"Figure processing time: {end_time - start_time:.2f} seconds")
doc = re.sub("\n{2,}", "\n\n", doc.strip())
return doc, images
diff --git a/source/model/etl/code/model_config.py b/source/model/etl/code/model_config.py
new file mode 100644
index 000000000..67107685d
--- /dev/null
+++ b/source/model/etl/code/model_config.py
@@ -0,0 +1,78 @@
+MODEL_CONFIGS = {
+ 'ch': {
+ 'det': 'det_cn.onnx',
+ 'rec': 'rec_ch.onnx',
+ 'dict_path': 'ppocr_keys_v1.txt',
+ 'character_type': 'ch',
+ 'use_space_char': True
+ },
+ 'en': {
+ 'det': 'det_en.onnx',
+ 'rec': 'rec_en.onnx',
+ 'dict_path': 'en_dict.txt',
+ 'character_type': 'en',
+ 'use_space_char': True
+ },
+ 'multi': {
+ 'det': 'det_cn.onnx',
+ 'rec': 'rec_multi_large.onnx',
+ 'dict_path': 'keys_en_chs_cht_vi_ja_ko.txt',
+ 'character_type': 'ch',
+ 'use_space_char': True
+ }
+}
+
+LAYOUT_CONFIG = {
+ 'categories': ['text', 'title', 'figure', 'table'],
+ 'nms_threshold': 0.15,
+ 'score_threshold': 0.3,
+ 'image_size': 640, # Default image processing size
+ 'aspect_ratio_threshold': 2 # Max ratio threshold for special handling
+}
+
+TABLE_CONFIG = {
+ 'model': {
+ 'name': 'table_sim.onnx',
+ 'session_options': {
+ 'intra_op_num_threads': 8,
+ 'execution_mode': 'ORT_SEQUENTIAL',
+ 'optimization_level': 'ORT_ENABLE_ALL'
+ }
+ },
+ 'preprocess': [
+ {
+ 'ResizeTableImage': {
+ 'max_len': 488
+ }
+ },
+ {
+ 'NormalizeImage': {
+ 'std': [0.229, 0.224, 0.225],
+ 'mean': [0.485, 0.456, 0.406],
+ 'scale': '1./255.',
+ 'order': 'hwc'
+ }
+ },
+ {
+ 'PaddingTableImage': {
+ 'size': [488, 488]
+ }
+ },
+ {
+ 'ToCHWImage': None
+ },
+ {
+ 'KeepKeys': {
+ 'keep_keys': ['image', 'shape']
+ }
+ }
+ ],
+ 'postprocess': {
+ 'name': 'TableLabelDecode',
+ 'dict_path': 'table_structure_dict_ch.txt',
+ 'merge_no_span_structure': True
+ },
+ 'table_match': {
+ 'filter_ocr_result': True
+ }
+}
\ No newline at end of file
diff --git a/source/model/etl/code/ocr.py b/source/model/etl/code/ocr.py
index 7a8326c88..61d783469 100644
--- a/source/model/etl/code/ocr.py
+++ b/source/model/etl/code/ocr.py
@@ -9,17 +9,14 @@
import cv2
from imaug import create_operators, transform
from postprocess import build_post_process
-import GPUtil
-if len(GPUtil.getGPUs()):
- provider = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "HEURISTIC"}), "CPUExecutionProvider"]
- rec_batch_num = 6
-else:
- provider = ["CPUExecutionProvider"]
- rec_batch_num = 1
+from gpu_config import get_provider_config
+from model_config import MODEL_CONFIGS
+
+provider, rec_batch_num, _ = get_provider_config()
class TextClassifier():
def __init__(self):
- self.weights_path = os.environ['MODEL_PATH'] + 'classifier.onnx'
+ self.weights_path = os.path.join(os.environ['MODEL_PATH'], 'classifier.onnx')
self.cls_image_shape = [3, 48, 192]
self.cls_batch_num = 30
@@ -43,7 +40,6 @@ def resize_norm_img(self, img):
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = np.array(Image.fromarray(img).resize((resized_w, imgH)))
- #resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if self.cls_image_shape[0] == 1:
resized_image = resized_image / 255
@@ -95,15 +91,8 @@ def __call__(self, img_list):
class TextDetector():
def __init__(self, lang):
-
- if lang=='ch':
- modelName = 'det_cn.onnx'
- elif lang=='en':
- modelName = 'det_en.onnx'
- else:
- modelName = 'det_cn.onnx'
-
- self.weights_path = os.environ['MODEL_PATH'] + modelName
+ model_config = MODEL_CONFIGS[lang]
+ self.weights_path = os.path.join(os.environ['MODEL_PATH'], model_config['det'])
self.det_algorithm = 'DB'
self.use_zero_copy_run = False
@@ -209,30 +198,15 @@ def __call__(self, img, scale=None):
class TextRecognizer():
def __init__(self, lang='ch'):
- if lang=='ch':
- modelName = 'rec_ch.onnx'
- postprocess_params = {
- 'name': 'CTCLabelDecode',
- "character_type": 'ch',
- "character_dict_path": os.environ['MODEL_PATH'] + 'ppocr_keys_v1.txt',
- "use_space_char": True
- }
- elif lang=='en':
- modelName = 'rec_en.onnx'
- postprocess_params = {
- 'name': 'CTCLabelDecode',
- "character_dict_path": os.environ['MODEL_PATH'] + 'en_dict.txt',
- "use_space_char": True
- }
- else:
- modelName = 'rec_multi_large.onnx'
- postprocess_params = {
- 'name': 'CTCLabelDecode',
- "character_type": 'ch',
- "character_dict_path": os.environ['MODEL_PATH'] + 'keys_en_chs_cht_vi_ja_ko.txt',
- "use_space_char": True
- }
- self.weights_path = os.environ['MODEL_PATH'] + modelName
+ model_config = MODEL_CONFIGS[lang]
+ self.weights_path = os.path.join(os.environ['MODEL_PATH'], model_config['rec'])
+
+ postprocess_params = {
+ 'name': 'CTCLabelDecode',
+ "character_type": model_config['character_type'],
+ "character_dict_path": os.path.join(os.environ['MODEL_PATH'], model_config['dict_path']),
+ "use_space_char": model_config['use_space_char']
+ }
self.limited_max_width = 1280
self.limited_min_width = 16
@@ -243,7 +217,6 @@ def __init__(self, lang='ch'):
self.use_zero_copy_run = False
self.postprocess_op = build_post_process(postprocess_params)
-
self.ort_session = onnxruntime.InferenceSession(self.weights_path, providers=provider)
def resize_norm_img(self, img, max_wh_ratio):
diff --git a/source/model/etl/code/postprocess/__init__.py b/source/model/etl/code/postprocess/__init__.py
index 935708e11..ff0895548 100644
--- a/source/model/etl/code/postprocess/__init__.py
+++ b/source/model/etl/code/postprocess/__init__.py
@@ -5,7 +5,13 @@
import copy
-__all__ = ['build_post_process']
+from .nms import nms, multiclass_nms, postprocess
+
+__all__ = [
+ "nms",
+ "multiclass_nms",
+ "postprocess"
+]
def build_post_process(config, global_config=None):
diff --git a/source/model/etl/code/postprocess/nms.py b/source/model/etl/code/postprocess/nms.py
new file mode 100644
index 000000000..29b70b7ec
--- /dev/null
+++ b/source/model/etl/code/postprocess/nms.py
@@ -0,0 +1,111 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+import numpy as np
+
+__all__ = ["nms", "multiclass_nms", "postprocess"]
+
+def nms(boxes, scores, nms_thr):
+ """Single class NMS implemented in Numpy."""
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2]
+ y2 = boxes[:, 3]
+
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+ inds = np.where(ovr <= nms_thr)[0]
+ order = order[inds + 1]
+
+ return keep
+
+def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True):
+ """Multiclass NMS implemented in Numpy"""
+ if class_agnostic:
+ nms_method = multiclass_nms_class_agnostic
+ else:
+ nms_method = multiclass_nms_class_aware
+ return nms_method(boxes, scores, nms_thr, score_thr)
+
+def multiclass_nms_class_aware(boxes, scores, nms_thr, score_thr):
+ """Multiclass NMS implemented in Numpy. Class-aware version."""
+ final_dets = []
+ num_classes = scores.shape[1]
+ for cls_ind in range(num_classes):
+ cls_scores = scores[:, cls_ind]
+ valid_score_mask = cls_scores > score_thr
+ if valid_score_mask.sum() == 0:
+ continue
+ else:
+ valid_scores = cls_scores[valid_score_mask]
+ valid_boxes = boxes[valid_score_mask]
+ keep = nms(valid_boxes, valid_scores, nms_thr)
+ if len(keep) > 0:
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
+ dets = np.concatenate(
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
+ )
+ final_dets.append(dets)
+ if len(final_dets) == 0:
+ return None
+ return np.concatenate(final_dets, 0)
+
+def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr):
+ """Multiclass NMS implemented in Numpy. Class-agnostic version."""
+ cls_inds = scores.argmax(1)
+ cls_scores = scores[np.arange(len(cls_inds)), cls_inds]
+
+ valid_score_mask = cls_scores > score_thr
+ if valid_score_mask.sum() == 0:
+ return None
+ valid_scores = cls_scores[valid_score_mask]
+ valid_boxes = boxes[valid_score_mask]
+ valid_cls_inds = cls_inds[valid_score_mask]
+ keep = nms(valid_boxes, valid_scores, nms_thr)
+ if keep:
+ dets = np.concatenate(
+ [valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]], 1
+ )
+ return dets
+
+def postprocess(outputs, img_size, p6=False):
+ """Post-process model outputs."""
+ grids = []
+ expanded_strides = []
+
+ if not p6:
+ strides = [8, 16, 32]
+ else:
+ strides = [8, 16, 32, 64]
+
+ hsizes = [img_size[0] // stride for stride in strides]
+ wsizes = [img_size[1] // stride for stride in strides]
+
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
+ grids.append(grid)
+ shape = grid.shape[:2]
+ expanded_strides.append(np.full((*shape, 1), stride))
+
+ grids = np.concatenate(grids, 1)
+ expanded_strides = np.concatenate(expanded_strides, 1)
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
+
+ return outputs
\ No newline at end of file
diff --git a/source/model/etl/code/prompt/chart.txt b/source/model/etl/code/prompt/chart.txt
new file mode 100644
index 000000000..8de485baf
--- /dev/null
+++ b/source/model/etl/code/prompt/chart.txt
@@ -0,0 +1,16 @@
+您是文档阅读专家。您的任务是将图片中的图表转换成Markdown格式。以下是说明:
+1. 找到图片中的图表。
+2. 仔细观察图表,了解其中包含的结构和数据。
+3. 使用标签中的上下文信息来帮助你更好地理解和描述这张图表。上下文中的{tag}就是指该图表。
+4. 按照以下指南将图表数据转换成 Markdown 表格格式:
+ - 使用 | 字符分隔列
+ - 使用 --- 行表示标题行
+ - 确保表格格式正确并对齐
+ - 对于不确定的数字,请根据图片估算。
+5. 仔细检查您的 Markdown 表格是否准确地反映了图表图像中的数据。
+6. 在 xml 标签中仅返回 Markdown,不含其他文本。
+
+
+{context}
+
+请将你的描述写在xml标签之间。
\ No newline at end of file
diff --git a/source/model/etl/code/prompt/description.txt b/source/model/etl/code/prompt/description.txt
new file mode 100644
index 000000000..22fa1a862
--- /dev/null
+++ b/source/model/etl/code/prompt/description.txt
@@ -0,0 +1,8 @@
+你是一位资深的图像分析专家。你的任务是仔细观察给出的插图,并按照以下步骤进行:
+1. 清晰地描述这张图片中显示的内容细节。如果图片中包含任何文字,请确保在描述中准确无误地包含这些文字。
+2. 使用标签中的上下文信息来帮助你更好地理解和描述这张图片。上下文中的{tag}就是指该插图。
+3. 将你的描述写在标签之间。
+
+{context}
+
+请将你的描述写在xml标签之间。
\ No newline at end of file
diff --git a/source/model/etl/code/prompt/mermaid_template.txt b/source/model/etl/code/prompt/mermaid_template.txt
index eae16c000..fcfc9d1e3 100644
--- a/source/model/etl/code/prompt/mermaid_template.txt
+++ b/source/model/etl/code/prompt/mermaid_template.txt
@@ -19,7 +19,6 @@ Your response should be structured as follows:
[Mermaid chart code representing the workflow]
-
Below are example of {diagram_type} mermaid templates. The example include a detailed description of the workflow along with Mermaid chart codes. Use the example as reference when analyzing the workflow diagram in the image.
diff --git a/source/model/etl/code/requirements.txt b/source/model/etl/code/requirements.txt
index 695227c67..4fd7ee918 100644
--- a/source/model/etl/code/requirements.txt
+++ b/source/model/etl/code/requirements.txt
@@ -8,4 +8,5 @@ PyMuPDF<1.21.0
markdownify
flask
gevent
-GPUtil
\ No newline at end of file
+GPUtil
+openai>=1.0.0
\ No newline at end of file
diff --git a/source/model/etl/code/sm_predictor.py b/source/model/etl/code/sm_predictor.py
index 5cbacb949..00e3bc26d 100644
--- a/source/model/etl/code/sm_predictor.py
+++ b/source/model/etl/code/sm_predictor.py
@@ -1,4 +1,3 @@
-from aikits_utils import lambda_return
from main import process_pdf_pipeline
from gevent import pywsgi
import flask
@@ -9,7 +8,7 @@
def handler(event, context):
if 'body' not in event:
- return lambda_return(400, 'invalid param')
+ return create_response('invalid param', 400)
try:
if isinstance(event['body'], str):
body = json.loads(event['body'])
@@ -17,14 +16,30 @@ def handler(event, context):
body = event['body']
if 's3_bucket' not in body or 'object_key' not in body:
- return lambda_return(400, 'Must specify the `s3_bucket` and `object_key` for the file')
+ return create_response('Must specify the `s3_bucket` and `object_key` for the file', 400)
except:
- return lambda_return(400, 'invalid param')
+ return create_response('invalid param', 400)
output = process_pdf_pipeline(body)
+ return create_response(json.dumps(output), 200)
- return lambda_return(200, json.dumps(output))
+
+def create_response(body, status_code=200):
+ """Create a Flask response with CORS headers.
+
+ Args:
+ body: Response body
+ status_code: HTTP status code (default: 200)
+
+ Returns:
+ Flask Response object
+ """
+ response = flask.make_response(body, status_code)
+ response.headers['Access-Control-Allow-Headers'] = '*'
+ response.headers['Access-Control-Allow-Origin'] = '*'
+ response.headers['Access-Control-Allow-Methods'] = '*'
+ return response
@app.route('/ping', methods=['GET'])
@@ -48,14 +63,10 @@ def transformation():
if flask.request.content_type == 'application/json':
request_body = flask.request.data.decode('utf-8')
body = json.loads(request_body)
- req = handler({'body': body}, None)
- return flask.Response(
- response=req['body'],
- status=req['statusCode'], mimetype='application/json')
+ response = handler({'body': body}, None)
+ return response
else:
- return flask.Response(
- response='Only supports application/json data',
- status=415, mimetype='application/json')
+ return create_response('Only supports application/json data', 415)
server = pywsgi.WSGIServer(('0.0.0.0', 8080), app)
diff --git a/source/model/etl/code/table.py b/source/model/etl/code/table.py
index c6e891088..a7fb0a788 100644
--- a/source/model/etl/code/table.py
+++ b/source/model/etl/code/table.py
@@ -7,12 +7,16 @@
import numpy as np
import os
import onnxruntime as ort
+from model_config import TABLE_CONFIG
-sess_options = ort.SessionOptions()
+def get_session_options():
+ sess_options = ort.SessionOptions()
+ config = TABLE_CONFIG['model']['session_options']
+ sess_options.intra_op_num_threads = config['intra_op_num_threads']
+ sess_options.execution_mode = getattr(ort.ExecutionMode, config['execution_mode'])
+ sess_options.graph_optimization_level = getattr(ort.GraphOptimizationLevel, config['optimization_level'])
+ return sess_options
-sess_options.intra_op_num_threads = 8
-sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
-sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
@@ -33,24 +37,24 @@ def sorted_boxes(dt_boxes):
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
+
class TableStructurer(object):
def __init__(self):
- self.use_onnx = True #args.use_onnx
- pre_process_list = [{'ResizeTableImage': {'max_len': 488}}, {'NormalizeImage': {'std': [0.229, 0.224, 0.225], 'mean': [0.485, 0.456, 0.406], 'scale': '1./255.', 'order': 'hwc'}}, {'PaddingTableImage': {'size': [488, 488]}}, {'ToCHWImage': None}, {'KeepKeys': {'keep_keys': ['image', 'shape']}}]
-
- postprocess_params = {
- 'name': 'TableLabelDecode',
- "character_dict_path": os.environ['MODEL_PATH'] + 'table_structure_dict_ch.txt',
- 'merge_no_span_structure': True
- }
- self.preprocess_op = create_operators(pre_process_list)
+ self.use_onnx = True
+ self.preprocess_op = create_operators(TABLE_CONFIG['preprocess'])
+
+ postprocess_params = dict(TABLE_CONFIG['postprocess'])
+ postprocess_params['character_dict_path'] = os.path.join(os.environ['MODEL_PATH'], postprocess_params.pop('dict_path'))
self.postprocess_op = build_post_process(postprocess_params)
- sess = ort.InferenceSession(os.environ['MODEL_PATH'] + 'table_sim.onnx', providers=['CPUExecutionProvider']) #, sess_options=sess_options, providers=[("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"})]
+ sess = ort.InferenceSession(
+ os.path.join(os.environ['MODEL_PATH'], TABLE_CONFIG['model']['name']),
+ providers=['CPUExecutionProvider'],
+ sess_options=get_session_options()
+ )
_ = sess.run(None, {'x': np.zeros((1, 3, 488, 488), dtype='float32')})
self.predictor, self.input_tensor, self.output_tensors, self.config = sess, sess.get_inputs()[0], None, None
-
def __call__(self, img):
starttime = time.time()
ori_im = img.copy()
@@ -78,6 +82,7 @@ def __call__(self, img):
] + structure_str_list + ['', '