Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xiaoting #554

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 0 additions & 50 deletions source/model/etl/code/aikits_utils.py

This file was deleted.

73 changes: 73 additions & 0 deletions source/model/etl/code/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
class ModelConfig:
"""模型配置管理类"""

LANG_CONFIGS = {
'ch': {
'det': {
'model': 'det_cn.onnx',
'postprocess': {
'name': 'DBPostProcess',
'thresh': 0.1,
'box_thresh': 0.1,
'max_candidates': 1000,
'unclip_ratio': 1.5,
'use_dilation': False,
'score_mode': 'fast',
'box_type': 'quad'
}
},
'rec': {
'model': 'rec_ch.onnx',
'dict_path': 'ppocr_keys_v1.txt',
'postprocess': {
'name': 'CTCLabelDecode',
'character_type': 'ch',
'use_space_char': True
}
}
},
'en': {
'det': {
'model': 'det_en.onnx',
'postprocess': {
'name': 'DBPostProcess',
'thresh': 0.1,
'box_thresh': 0.1,
'max_candidates': 1000,
'unclip_ratio': 1.5,
'use_dilation': False,
'score_mode': 'fast',
'box_type': 'quad'
}
},
'rec': {
'model': 'rec_en.onnx',
'dict_path': 'en_dict.txt',
'postprocess': {
'name': 'CTCLabelDecode',
'use_space_char': True
}
}
}
}

@classmethod
def get_model_path(cls, lang, model_type):
"""获取模型路径"""
return os.path.join(
os.environ['MODEL_PATH'],
cls.LANG_CONFIGS[lang][model_type]['model']
)

@classmethod
def get_dict_path(cls, lang):
"""获取字典路径"""
return os.path.join(
os.environ['MODEL_PATH'],
cls.LANG_CONFIGS[lang]['rec']['dict_path']
)

@classmethod
def get_postprocess_config(cls, lang, model_type):
"""获取后处理配置"""
return cls.LANG_CONFIGS[lang][model_type]['postprocess']
145 changes: 110 additions & 35 deletions source/model/etl/code/figure_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,73 @@
import io
import base64
import json
import os
import openai

# Add logger configuration
logger = logging.getLogger(__name__)

class figureUnderstand():
def __init__(self):
self.bedrock_runtime = boto3.client(service_name='bedrock-runtime')
def __init__(self, model_provider="bedrock", api_secret_name=None):
self.model_provider = model_provider
if model_provider == "bedrock":
self.bedrock_runtime = boto3.client(service_name='bedrock-runtime')
elif model_provider == "openai":
self.openai_api_key = self._get_api_key(api_secret_name)
if not self.openai_api_key:
raise ValueError("Failed to retrieve OpenAI API key from Secrets Manager")
openai.api_key = self.openai_api_key
# Set OpenAI base URL from environment variable if provided
base_url = os.environ.get("OPENAI_API_BASE")
if base_url:
openai.base_url = base_url
else:
raise ValueError("Unsupported model provider. Choose 'bedrock' or 'openai'")

self.mermaid_prompt = json.load(open('prompt/mermaid.json', 'r'))
def invoke_llm(self, img, prompt, prefix="<output>", stop="</output>"):

def _get_api_key(self, api_secret_name):
"""
Get the API key from AWS Secrets Manager.
Args:
api_secret_name (str): The name of the secret in AWS Secrets Manager containing the API key.
Returns:
str: The API key.
"""
if not api_secret_name:
raise ValueError("api_secret_name must be provided when using OpenAI")

try:
secrets_client = boto3.client("secretsmanager")
secret_response = secrets_client.get_secret_value(
SecretId=api_secret_name
)
if "SecretString" in secret_response:
secret_data = json.loads(secret_response["SecretString"])
api_key = secret_data.get("api_key")
logger.info(
f"Successfully retrieved API key from secret: {api_secret_name}"
)
return api_key
except Exception as e:
logger.error(f"Error retrieving secret {api_secret_name}: {str(e)}")
raise
return None

def _image_to_base64(self, img):
"""Convert PIL Image to base64 encoded string"""
image_stream = io.BytesIO()
img.save(image_stream, format="JPEG")
base64_encoded = base64.b64encode(image_stream.getvalue()).decode('utf-8')
return base64.b64encode(image_stream.getvalue()).decode('utf-8')

def invoke_llm(self, img, prompt, prefix="<output>", stop="</output>"):
if self.model_provider == "bedrock":
return self._invoke_bedrock(img, prompt, prefix, stop)
elif self.model_provider == "openai":
return self._invoke_openai(img, prompt, prefix, stop)

def _invoke_bedrock(self, img, prompt, prefix="<output>", stop="</output>"):
base64_encoded = self._image_to_base64(img)
messages = [
{
"role": "user",
Expand All @@ -34,7 +92,7 @@ def invoke_llm(self, img, prompt, prefix="<output>", stop="</output>"):
},
{"role": "assistant", "content": prefix},
]
model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
model_id = "anthropic.claude-3-5-sonnet-20241022-v2:0"
body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 4096,
Expand All @@ -46,58 +104,75 @@ def invoke_llm(self, img, prompt, prefix="<output>", stop="</output>"):
response_body = json.loads(response.get('body').read())
result = prefix + response_body['content'][0]['text'] + stop
return result

def _invoke_openai(self, img, prompt, prefix="<output>", stop="</output>"):
base64_encoded = self._image_to_base64(img)

messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_encoded}"
}
}
]
},
{
"role": "assistant",
"content": prefix
}
]

response = openai.chat.completions.create(
model="gpt-4-vision-preview",
messages=messages,
max_tokens=4096,
stop=[stop] if stop else None
)

result = prefix + response.choices[0].message.content + stop
return result

def get_classification(self, img):
with open('prompt/figure_classification.txt') as f:
figure_classification_prompt = f.read()
output = self.invoke_llm(img, figure_classification_prompt)
return output

def get_chart(self, img, context, tag):
prompt = '''您是文档阅读专家。您的任务是将图片中的图表转换成Markdown格式。以下是说明:
1. 找到图片中的图表。
2. 仔细观察图表,了解其中包含的结构和数据。
3. 使用<doc></doc>标签中的上下文信息来帮助你更好地理解和描述这张图表。上下文中的{tag}就是指该图表。
4. 按照以下指南将图表数据转换成 Markdown 表格格式:
- 使用 | 字符分隔列
- 使用 --- 行表示标题行
- 确保表格格式正确并对齐
- 对于不确定的数字,请根据图片估算。
5. 仔细检查您的 Markdown 表格是否准确地反映了图表图像中的数据。
6. 在 <output></output>xml 标签中仅返回 Markdown,不含其他文本。

<doc>
{context}
</doc>
请将你的描述写在<output></output>xml标签之间。
'''.strip()
output = self.invoke_llm(img, prompt)
with open('prompt/chart.txt') as f:
chart_prompt = f.read()
output = self.invoke_llm(img, chart_prompt.format(context=context, tag=tag))
return output
def get_description(self, img, context, tag):
prompt = '''
你是一位资深的图像分析专家。你的任务是仔细观察给出的插图,并按照以下步骤进行:
1. 清晰地描述这张图片中显示的内容细节。如果图片中包含任何文字,请确保在描述中准确无误地包含这些文字。
2. 使用<doc></doc>标签中的上下文信息来帮助你更好地理解和描述这张图片。上下文中的{tag}就是指该插图。
3. 将你的描述写在<output></output>标签之间。
<doc>
{context}
</doc>
请将你的描述写在<output></output>xml标签之间。
'''.strip()

output = self.invoke_llm(img, prompt.format(context=context, tag=tag))
def get_description(self, img, context, tag):
with open('prompt/description.txt') as f:
description_prompt = f.read()
output = self.invoke_llm(img, description_prompt.format(context=context, tag=tag))
return f'![{output}]()'

def get_mermaid(self, img, classification):
with open('prompt/mermaid_template.txt') as f:
mermaid_prompt = f.read()
prompt = mermaid_prompt.format(diagram_type=classification, diagram_example=self.mermaid_prompt[classification])
output = self.invoke_llm(img, prompt, prefix='<description>', stop='</mermaid>')
return output

def parse_result(self, llm_output, tag):
try:
pattern = fr"<{tag}>(.*?)</{tag}>"
output = re.findall(pattern, llm_output, re.DOTALL)[0].strip()
except:
output = llm_output.replace(f"<{tag}>", '').replace(f"</{tag}>", '')
return output

def __call__(self, img, context, tag, s3_link):
classification = self.get_classification(img)
classification = self.parse_result(classification, 'output')
Expand Down
13 changes: 13 additions & 0 deletions source/model/etl/code/gpu_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import GPUtil

def get_provider_config():
if len(GPUtil.getGPUs()):
provider = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "HEURISTIC"}), "CPUExecutionProvider"]
rec_batch_num = 6
layout_model = 'layout.onnx'
else:
provider = ["CPUExecutionProvider"]
rec_batch_num = 1
layout_model = 'layout_s.onnx'

return provider, rec_batch_num, layout_model
4 changes: 4 additions & 0 deletions source/model/etl/code/imaug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from __future__ import unicode_literals
from .operators import *
from .table_ops import *
from .preprocess import preprocess

__all__ = ["preprocess"]

def transform(data, ops=None):
""" transform """
if ops is None:
Expand Down
35 changes: 35 additions & 0 deletions source/model/etl/code/imaug/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

import cv2
import numpy as np

__all__ = ["preprocess"]

def preprocess(img, input_size, swap=(2, 0, 1)):
"""Preprocess image for model input.

Args:
img: Input image
input_size: Target size (height, width)
swap: Channel swap order

Returns:
Preprocessed image and scale ratio
"""
if len(img.shape) == 3:
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
else:
padded_img = np.ones(input_size, dtype=np.uint8) * 114

r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR,
).astype(np.uint8)
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img

padded_img = padded_img.transpose(swap)
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
return padded_img, r
Loading
Loading