Skip to content

Adding EffOCR support within the layout parser #189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 0 additions & 133 deletions README.md

This file was deleted.

70 changes: 0 additions & 70 deletions installation.md

This file was deleted.

Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
is_effdet_available,
is_pytesseract_available,
is_gcv_available,
is_effocr_available,
)

_import_structure = {
@@ -51,6 +52,7 @@
"is_paddle_available",
"is_pytesseract_available",
"is_gcv_available",
"is_effocr_available",
"requires_backends"
],
"tools": [
@@ -80,6 +82,9 @@
if is_gcv_available():
_import_structure["ocr.gcv_agent"] = ["GCVAgent", "GCVFeatureType"]

if is_effocr_available():
_import_structure["ocr.effocr_agent"] = ["EffOCRAgent", "EffOCRFeatureType"]

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
19 changes: 19 additions & 0 deletions src/layoutparser/file_utils.py → src/effocr-layout/file_utils.py
Original file line number Diff line number Diff line change
@@ -88,6 +88,14 @@
except ModuleNotFoundError:
_gcv_available = False

try:
_effocr_available = importlib.util.find_spec("onnxruntime") is not None \
and importlib.util.find_spec("onnx") is not None \
and importlib.util.find_spec("faiss") is not None
except ModuleNotFoundError:
_effocr_available = False



def is_torch_available():
return _torch_available
@@ -121,6 +129,9 @@ def is_pytesseract_available():
def is_gcv_available():
return _gcv_available

def is_effocr_available():
return _effocr_available


PYTORCH_IMPORT_ERROR = """
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
@@ -154,6 +165,13 @@ def is_gcv_available():
`pip install google-cloud-vision==1`
"""

EFFOCR_IMPORT_ERROR = """
{0} requires the onnxruntime, onnx and faiss libraries but at least one was not found in your environment. You can install it with pip:
`pip install onnxruntime onnx faiss`
Note that `faiss` can be installed with eiter the CPU or GPU version, but the GPU version requires CUDA. See
https://github.com/facebookresearch/faiss/blob/main/INSTALL.md for more details.
"""

BACKENDS_MAPPING = dict(
[
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
@@ -162,6 +180,7 @@ def is_gcv_available():
("effdet", (is_effdet_available, EFFDET_IMPORT_ERROR)),
("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
("google-cloud-vision", (is_gcv_available, GCV_IMPORT_ERROR)),
("effocr", (is_effocr_available, EFFOCR_IMPORT_ERROR))
]
)

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -13,4 +13,5 @@
# limitations under the License.

from .gcv_agent import GCVAgent, GCVFeatureType
from .tesseract_agent import TesseractAgent, TesseractFeatureType
from .tesseract_agent import TesseractAgent, TesseractFeatureType
from .effocr_agent import EffOCRAgent, EffOCRFeatureType
File renamed without changes.
3 changes: 3 additions & 0 deletions src/effocr-layout/ocr/effocr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .engines import EffLineDetector, EffLocalizer, EffRecognizer
from .utils import create_paired_transform, create_paired_transform_word, letterbox, non_max_suppression
from .infer_transcripton import run_effocr_word
3 changes: 3 additions & 0 deletions src/effocr-layout/ocr/effocr/engines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .localizer_engine import EffLocalizer
from .recognizer_engine import EffRecognizer
from .line_det_engine import EffLineDetector
251 changes: 251 additions & 0 deletions src/effocr-layout/ocr/effocr/engines/line_det_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import os
import sys
# import mmcv
import torch
import numpy as np
import onnxruntime as ort
import torchvision
from torchvision.ops import nms
import cv2
import onnx
from math import floor, ceil

from .ops import non_max_suppression as yolov8_nms
from .ops import get_onnx_input_name
from ..utils import letterbox, non_max_suppression

DEFAULT_MEAN = np.array([123.675, 116.28, 103.53], dtype=np.float32)
DEFAULT_STD = np.array([58.395, 57.12, 57.375], dtype=np.float32)

class EffLineDetector:
"""
Class for running the EffOCR line detection model. Essentially a wrapper for the onnxruntime
inference session based on the model, wit some additional postprocessing, especially regarding splitting and
recombining especailly tall layout regions
"""

def __init__(self, model_path, iou_thresh = 0.15, conf_thresh = 0.20,
num_cores = None, providers=None, input_shape = (640, 640), model_backend='yolo',
min_seg_ratio = 2, visualize = None):
"""Instantiates the object, including setting up the wrapped ONNX InferenceSession
Args:
model_path (str): Path to ONNX model that will be used
iou_thresh (float, optional): IOU filter for line detection NMS. Defaults to 0.15.
conf_thresh (float, optional): Confidence filter for line detection NMS. Defaults to 0.20.
num_cores (_type_, optional): Number of cores to use during inference. Defaults to None, meaning no intra op thread limit.
providers (_type_, optional): Any particular ONNX providers to use. Defaults to None, meaning results of ort.get_available_providers() will be used.
input_shape (tuple, optional): Shape of input images. Defaults to (640, 640).
model_backend (str, optional): Original model backend being used. Defaults to 'yolo'. Options are mmdetection, detectron2, yolo, yolov8.
"""


# Set up and instantiate a ort InfernceSession
sess_options = ort.SessionOptions()
if num_cores is not None:
sess_options.intra_op_num_threads = num_cores

if providers is None:
providers = ort.get_available_providers()

self._eng_net = ort.InferenceSession(
model_path,
sess_options,
providers=providers,
)

# Load in the model as a standard ONNX model to get the input shape and name
base_model = onnx.load(model_path)
self._input_name = get_onnx_input_name(base_model)
self._model_input_shape = self._eng_net.get_inputs()[0].shape

# Rest of the params
self._iou_thresh = iou_thresh
self._conf_thresh = conf_thresh

if isinstance(self._model_input_shape[-1], int) and isinstance(self._model_input_shape[-2], int):
self._input_shape = (self._model_input_shape[-2], self._model_input_shape[-1])
else:
self._input_shape = input_shape
self._model_backend = model_backend
self.min_seg_ratio = min_seg_ratio # Ratio that determines at what point the model will split a region into two



def __call__(self, imgs, visualize = None):
"""Wraps the run method, allowing the object to be called directly
Args:
imgs (list or str or np.ndarray): List of image paths, list of images as np.ndarrays, or single image path, or single image as np.ndarray
Returns:
_type_: _description_
"""
return self.run(imgs, visualize = visualize)

def run(self, imgs, visualize = None):
orig_img = imgs.copy()
if isinstance(imgs, list):
if all(isinstance(img, str) for img in imgs):
imgs = [self.load_line_img(img, self._input_shape, backend=self._model_backend) for img in imgs]
elif all(isinstance(img, np.ndarray) for img in imgs):
imgs = [self.get_crops_from_layout_image(img) for img in imgs]
imgs = [self.format_line_img(img, self._input_shape, backend=self._model_backend) for img in imgs]
else:
raise ValueError('Invalid combination if input types in Line Detection list! Must be all str or all np.ndarray')
elif isinstance(imgs, str):
imgs = [self.load_line_img(imgs, self._input_shape, backend=self._model_backend)]
elif isinstance(imgs, np.ndarray):
imgs = self.get_crops_from_layout_image(imgs)
orig_shapes = [img.shape for img in imgs]
imgs = [self.format_line_img(img, self._input_shape, backend=self._model_backend) for img in imgs]
else:
raise ValueError('Input type {} is not implemented'.format(type(imgs)))

results = [self._eng_net.run(None, {self._input_name: img}) for img in imgs]
return self._postprocess(results, imgs, orig_shapes, orig_img, viz_lines = visualize)

def _postprocess(self, results, imgs, orig_shapes, orig_img, viz_lines = None):
#YOLO NMS is carried out now, other backends will filter by bbox confidence score later
if self._model_backend == 'yolo':
preds = [torch.from_numpy(pred[0]) for pred in results]
preds = [non_max_suppression(pred, conf_thres = self._conf_thresh, iou_thres=self._iou_thresh, max_det=100)[0] for pred in preds]

elif self._model_backend == 'yolov8':
preds = [torch.from_numpy(pred[0]) for pred in results]
preds = [yolov8_nms(pred, conf_thres = self._conf_thresh, iou_thres=self._iou_thresh, max_det=100)[0] for pred in preds]

elif self._model_backend == 'detectron2' or self._model_backend == 'mmdetection':
return results

preds = self.adjust_line_preds(preds, imgs, orig_shapes)
final_preds = self.readjust_line_predictions(preds, imgs[0].shape[1])

line_crops, line_coords = [], []
for i, line_proj_crop in enumerate(final_preds):
x0, y0, x1, y1 = map(round, line_proj_crop)
line_crop = orig_img[y0:y1, x0:x1]
if line_crop.shape[0] == 0 or line_crop.shape[1] == 0:
continue

# Line crops becomes a list of tuples (bbox_id, line_crop [the image itself], line_proj_crop [the coordinates of the line in the layout image])
line_crops.append(np.array(line_crop).astype(np.float32))
line_coords.append((y0, x0, y1, x1))

# If asked to visualize the line detections, draw a rectangle representing each line crop on the original image
if viz_lines is not None:
cv2.rectangle(orig_img, (x0, y0), (x1, y1), (255, 0, 0), 2)

# If asked to visualize, output the image with the line detections drawn on it
if viz_lines is not None:
cv2.imwrite(viz_lines, orig_img)

return line_crops, line_coords


def adjust_line_preds(self, preds, imgs, orig_shapes):
adjusted_preds = []

for pred, shape in zip(preds, orig_shapes):
line_predictions = pred[pred[:, 1].sort()[1]]
line_bboxes, line_confs, line_labels = line_predictions[:, :4], line_predictions[:, -2], line_predictions[:, -1]

im_width, im_height = shape[1], shape[0]
if im_width > im_height:
h_ratio = (im_height / im_width) * 640
h_trans = 640 * ((1 - (im_height / im_width)) / 2)
else:
h_trans = 0
h_ratio = 640

line_proj_crops = []
for line_bbox in line_bboxes:
x0, y0, x1, y1 = torch.round(line_bbox)
x0, y0, x1, y1 = 0, int(floor((y0.item() - h_trans) * im_height / h_ratio)), \
im_width, int(ceil((y1.item() - h_trans) * im_height / h_ratio))

line_proj_crops.append((x0, y0, x1, y1))

adjusted_preds.append((line_proj_crops, line_confs, line_labels))

return adjusted_preds

def readjust_line_predictions(self, line_preds, orig_img_width):
y0 = 0
dif = int(orig_img_width * 1.5)
all_preds, final_preds = [], []
for j in range(len(line_preds)):
preds, probs, labels = line_preds[j]
for i, pred in enumerate(preds):
all_preds.append((pred[0], pred[1] + y0, pred[2], pred[3] + y0, probs[i]))
y0 += dif

all_preds = torch.tensor(all_preds)
if all_preds.dim() > 1:
keep_preds = nms(all_preds[:, :4], all_preds[:, -1], iou_threshold=0.15)
filtered_preds = all_preds[keep_preds, :4]
filtered_preds = filtered_preds[filtered_preds[:, 1].sort()[1]]
for pred in filtered_preds:
x0, y0, x1, y1 = torch.round(pred)
x0, y0, x1, y1 = x0.item(), y0.item(), x1.item(), y1.item()
final_preds.append((x0, y0, x1, y1))
return final_preds
else:
return []

def format_line_img(self, img, input_shape, backend='yolo'):
if backend == 'yolo' or backend == 'yolov8':
im = letterbox(img, input_shape, stride=32, auto=False)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
im = im.astype(np.float32) / 255.0 # 0 - 255 to 0.0 - 1.0
if im.ndim == 3:
im = np.expand_dims(im, 0)

elif backend == 'detectron2':
im = letterbox(img, input_shape, stride=32, auto=False)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
im = im.astype(np.float32)

elif backend == 'mmdetection':
im = mmcv.imrescale(img, (input_shape[0], input_shape[1]))
im = mmcv.impad(im, shape = input_shape, pad_val=0)
im = mmcv.imnormalize(im, DEFAULT_MEAN, DEFAULT_STD, to_rgb=True)
im = im.transpose(2, 0, 1)
if im.ndim == 3:
im = np.expand_dims(im, 0)


else:
raise NotImplementedError('Backend {} is not implemented'.format(backend))

return im

def load_line_img(self, input_path, input_shape, backend='yolo'):
if backend == 'yolo' or backend == 'yolov8' or backend == 'detectron2':
im0 = cv2.imread(input_path)
im0 = self.get_crops_from_layout_image(im0)
return [self.format_line_img(im, input_shape, backend=backend) for im in im0]
elif backend == 'mmdetection':
one_img = mmcv.imread(input_path)
one_img = self.get_crops_from_layout_image(one_img)
return [self.format_line_img(one_im, input_shape, backend=backend) for one_im in one_img]
else:
raise NotImplementedError('Backend {} is not implemented'.format(backend))

def get_crops_from_layout_image(self, image):
im_width, im_height = image.shape[0], image.shape[1]
if im_height <= im_width * self.min_seg_ratio:
return [image]
else:
y0 = 0
y1 = im_width * self.min_seg_ratio
crops = []
while y1 <= im_height:
crops.append(image.crop((0, y0, im_width, y1)))
y0 += int(im_width * self.min_seg_ratio * 0.75) # .75 factor ensures there is overlap between crops
y1 += int(im_width * self.min_seg_ration * 0.75)

crops.append(image.crop((0, y0, im_width, im_height)))
return crops
320 changes: 320 additions & 0 deletions src/effocr-layout/ocr/effocr/engines/localizer_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
import os
import sys
# import mmcv
import torch
import numpy as np
import onnxruntime as ort
import torchvision
import cv2
import onnx

from .ops import non_max_suppression as yolov8_nms

DEFAULT_MEAN = np.array([123.675, 116.28, 103.53], dtype=np.float32)
DEFAULT_STD = np.array([58.395, 57.12, 57.375], dtype=np.float32)

class EffLocalizer:

def __init__(self, model_path, iou_thresh = 0.01, conf_thresh = 0.30, vertical = False,
num_cores = None, providers=None, input_shape = (640, 640), model_backend='yolo'):
sess_options = ort.SessionOptions()
if num_cores is not None:
sess_options.intra_op_num_threads = num_cores

if providers is None:
providers = ort.get_available_providers()

self._eng_net = ort.InferenceSession(
model_path,
sess_options,
providers=providers,
)

base_model = onnx.load(model_path)
self._input_name = EffLocalizer.get_onnx_input_name(base_model)
self._model_input_shape = self._eng_net.get_inputs()[0].shape
self._iou_thresh = iou_thresh
self._conf_thresh = conf_thresh
self._vertical = vertical

if isinstance(self._model_input_shape[-1], int) and isinstance(self._model_input_shape[-2], int):
self._input_shape = (self._model_input_shape[-2], self._model_input_shape[-1])
else:
self._input_shape = input_shape
self._model_backend = model_backend



def __call__(self, imgs):
return self.run(imgs)

def run(self, imgs):
if isinstance(imgs, list):
if isinstance(imgs[0], str):
imgs = [EffLocalizer.load_localizer_img(img, self._input_shape, backend=self._model_backend) for img in imgs]
else:
imgs = [EffLocalizer.format_localizer_img(img, self._input_shape, backend=self._model_backend) for img in imgs]
elif isinstance(imgs, str):
imgs = [EffLocalizer.load_localizer_img(imgs, self._input_shape, backend=self._model_backend)]
elif isinstance(imgs, np.ndarray):
imgs = [EffLocalizer.format_localizer_img(imgs, self._input_shape, backend=self._model_backend)]
else:
raise NotImplementedError('Input type {} is not implemented'.format(type(imgs)))

results = [self._eng_net.run(None, {self._input_name: img}) for img in imgs]
return self._postprocess(results)

def _postprocess(self, results):
#YOLO NMS is carried out now, other backends will filter by bbox confidence score later
if self._model_backend == 'yolo':

preds = [torch.from_numpy(pred[0]) for pred in results]
preds = [self.non_max_suppression(pred, conf_thres = self._conf_thresh, iou_thres=self._iou_thresh, max_det=1000)[0] for pred in preds]
return preds

elif self._model_backend == 'yolov8':
preds = [torch.from_numpy(pred[0]) for pred in results]
preds = [yolov8_nms(pred, conf_thres = self._conf_thresh, iou_thres=self._iou_thresh, max_det=50)[0] for pred in preds]
return preds

elif self._model_backend == 'detectron2' or self._model_backend == 'mmdetection':
return results

@staticmethod
def get_onnx_input_name(model):
input_all = [node.name for node in model.graph.input]
input_initializer = [node.name for node in model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
return net_feed_input[0]

@staticmethod
def format_localizer_img(img, input_shape, backend='yolo'):
if backend == 'yolo' or backend == 'yolov8':
im = EffLocalizer.letterbox(img, input_shape, stride=32, auto=False)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
im = im.astype(np.float32) / 255.0 # 0 - 255 to 0.0 - 1.0
if im.ndim == 3:
im = np.expand_dims(im, 0)
return im
elif backend == 'detectron2':
im = EffLocalizer.letterbox(img, input_shape, stride=32, auto=False)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
im = im.astype(np.float32)
return im
elif backend == 'mmdetection':
one_img = mmcv.imrescale(img, (input_shape[0], input_shape[1]))
one_img = mmcv.impad(one_img, shape = input_shape, pad_val=0)
one_img = mmcv.imnormalize(one_img, DEFAULT_MEAN, DEFAULT_STD, to_rgb=True)
one_img = one_img.transpose(2, 0, 1)
if one_img.ndim == 3:
one_img = np.expand_dims(one_img, 0)

return one_img
else:
raise NotImplementedError('Backend {} is not implemented'.format(backend))

@staticmethod
def load_localizer_img(input_path, input_shape, backend='yolo'):
if backend == 'yolo' or backend == 'yolov8':
im0 = cv2.imread(input_path)
im = EffLocalizer.letterbox(im0, input_shape, stride=32, auto=False)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
im = im.astype(np.float32) / 255.0 # 0 - 255 to 0.0 - 1.0
if im.ndim == 3:
im = np.expand_dims(im, 0)
return im
elif backend == 'detectron2':
im0 = cv2.imread(input_path)
im = EffLocalizer.letterbox(im0, input_shape, stride=32, auto=False)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
im = im.astype(np.float32)
return im
elif backend == 'mmdetection':
one_img = mmcv.imread(input_path)
one_img = mmcv.imrescale(one_img, (input_shape[0], input_shape[1]))
one_img = mmcv.impad(one_img, shape = input_shape, pad_val=0)
one_img = mmcv.imnormalize(one_img, DEFAULT_MEAN, DEFAULT_STD, to_rgb=True)
one_img = one_img.transpose(2, 0, 1)
if one_img.ndim == 3:
one_img = np.expand_dims(one_img, 0)

return one_img
else:
raise NotImplementedError('Backend {} is not implemented'.format(backend))


@staticmethod
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
# Resize and pad image while meeting stride-multiple constraints
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)

# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better val mAP)
r = min(r, 1.0)

# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
elif scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios

dw /= 2 # divide padding into 2 sides
dh /= 2

if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return im, ratio, (dw, dh)

@staticmethod
def xywh2xyxy(x):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
return y

@staticmethod
def box_iou(box1, box2, eps=1e-7):
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
box1 (Tensor[N, 4])
box2 (Tensor[M, 4])
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""

# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)

# IoU = inter / (area1 + area2 - inter)
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)

@staticmethod
def non_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nm=0, ):

if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output

device = prediction.device
mps = 'mps' in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()
bs = prediction.shape[0] # batch size
nc = prediction.shape[2] - nm - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates

# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

# Settings
# min_wh = 2 # (pixels) minimum box width and height
max_wh = 7680 # (pixels) maximum box width and height
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 0.5 + 0.05 * bs # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS

mi = 5 + nc # mask start index
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence

# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
v[:, :4] = lb[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)

# If none remain process next image
if not x.shape[0]:
continue

# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf

# Box/Mask
box = EffLocalizer.xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
mask = x[:, mi:] # zero columns if no masks

# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = x[:, 5:mi].max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]

# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
else:
x = x[x[:, 4].argsort(descending=True)] # sort by confidence

# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = EffLocalizer.box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy

output[xi] = x[i]
if mps:
output[xi] = output[xi].to(device)

return output
723 changes: 723 additions & 0 deletions src/effocr-layout/ocr/effocr/engines/ops.py

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions src/effocr-layout/ocr/effocr/engines/recognizer_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import sys
import torch
import onnxruntime as ort
import numpy as np


class EffRecognizer:

def __init__(self, model, transform = None, num_cores = None, providers=None, char=True):

sess_options = ort.SessionOptions()
if num_cores is not None:
sess_options.intra_op_num_threads = num_cores

if providers is None:
providers = ort.get_available_providers()

self.transform = transform
# null_input = torch.zeros((3, 224, 224)) if char else torch.zeros((1, 224, 224))
self._eng_net = ort.InferenceSession(
model,
sess_options,
providers=providers,
)

def __call__(self, imgs):
return self.run(imgs)

def run(self, imgs):
trans_imgs = []
for img in imgs:
try:
trans_imgs.append(self.transform(img.astype(np.uint8))[0])
except Exception as e:
trans_imgs.append(torch.zeros((3, 224, 224)))

onnx_input = torch.nn.functional.pad(torch.stack(trans_imgs), (0, 0, 0, 0, 0, 0, 0, 64 - len(imgs))).numpy()

return self._eng_net.run(None, {'imgs': onnx_input})

Empty file.
527 changes: 527 additions & 0 deletions src/effocr-layout/ocr/effocr/infer_transcripton.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/effocr-layout/ocr/effocr/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .dataset_utils import create_paired_transform, create_paired_transform_word
from .image_utils import letterbox, non_max_suppression
79 changes: 79 additions & 0 deletions src/effocr-layout/ocr/effocr/utils/dataset_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
from torchvision import transforms as T
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from PIL import Image
import time
from timeit import default_timer as timer


def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]

class MedianPadWord:
"""This padding preserves the aspect ratio of the image. It also pads the image with the median value of the border pixels.
Note how it also centres the ROI in the padded image."""
def __init__(self, override=None,aspect_cutoff=0):
self.override = override
self.aspect_cutoff=aspect_cutoff
def __call__(self, image):
##Convert to RGB
image = image.convert("RGB") if isinstance(image, Image.Image) else image
image = Image.fromarray(image) if isinstance(image, np.ndarray) else image
max_side = max(image.size)
aspect_ratio = image.size[0] / image.size[1]
if aspect_ratio<self.aspect_cutoff:
pad_x, pad_y = [int(0.75*max_side) for _ in image.size]
else:
pad_x, pad_y = [max_side - s for s in image.size]
padding = (round((10+pad_x)/2), round((5+pad_y)/2), round((10+pad_x)/2), round((5+pad_y)/2)) ##Added some extra to avoid info on the long edge

imgarray = np.array(image)
h, w , c= imgarray.shape
rightb, leftb = imgarray[:,w-1,:], imgarray[:,0,:]
topb, bottomb = imgarray[0,:,:], imgarray[h-1,:,:]
bordervals = np.concatenate([rightb, leftb, topb, bottomb], axis=0)
medval = tuple([int(v) for v in np.median(bordervals, axis=0)])
return T.Pad(padding, fill=medval if self.override is None else self.override)(image)

class MedianPad:

def __init__(self, fill):
self.fill = fill

def __call__(self, imgarray):
max_side = max(imgarray.shape)
pad_y, pad_x, _ = [max_side - s for s in imgarray.shape]
padding = (0, 0, pad_x, pad_y)
pil_im = Image.fromarray(imgarray)
return T.Pad(padding, fill=self.fill)(pil_im)


def timerhelper(s, x):
print(s, timer())
time.sleep(1)
return x


def create_paired_transform(lang, size=224):
return T.Compose([
MedianPad(fill=(255,255,255)),
T.ToTensor(),
T.Resize((size, size)),
T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
lambda x: x.unsqueeze(0)
])

def create_paired_transform_word(lang, size=224,aspect_cutoff=0):
return T.Compose([
# SquarePad(),
MedianPadWord(aspect_cutoff=aspect_cutoff),
# T.Resize(size=(224,224)),
# patch_resize,
T.ToTensor(),
T.Resize((size, size)),
T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
lambda x: x.unsqueeze(0)
# featx_transform,
])
176 changes: 176 additions & 0 deletions src/effocr-layout/ocr/effocr/utils/image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import cv2
import numpy as np

import torch
import torchvision

def get_onnx_input_name(model):
input_all = [node.name for node in model.graph.input]
input_initializer = [node.name for node in model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
return net_feed_input[0]

def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
# Resize and pad image while meeting stride-multiple constraints
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)

# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better val mAP)
r = min(r, 1.0)

# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
elif scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios

dw /= 2 # divide padding into 2 sides
dh /= 2

if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return im, ratio, (dw, dh)

def xywh2xyxy(x):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
return y

def box_iou(box1, box2, eps=1e-7):
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
box1 (Tensor[N, 4])
box2 (Tensor[M, 4])
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""

# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)

# IoU = inter / (area1 + area2 - inter)
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)

def non_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nm=0, ):

if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output

device = prediction.device
mps = 'mps' in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()
bs = prediction.shape[0] # batch size
nc = prediction.shape[2] - nm - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates

# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

# Settings
# min_wh = 2 # (pixels) minimum box width and height
max_wh = 7680 # (pixels) maximum box width and height
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 0.5 + 0.05 * bs # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS

mi = 5 + nc # mask start index
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence

# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
v[:, :4] = lb[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)

# If none remain process next image
if not x.shape[0]:
continue

# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf

# Box/Mask
box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
mask = x[:, mi:] # zero columns if no masks

# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = x[:, 5:mi].max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
else:
x = x[x[:, 4].argsort(descending=True)] # sort by confidence

# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy

output[xi] = x[i]
if mps:
output[xi] = output[xi].to(device)

return output

264 changes: 264 additions & 0 deletions src/effocr-layout/ocr/effocr_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
import io
import os
import json
import warnings

import numpy as np
from cv2 import imencode
import multiprocessing
import faiss
from huggingface_hub import hf_hub_download
import joblib

from .base import BaseOCRAgent, BaseOCRElementType
from .effocr import EffLocalizer, EffRecognizer, EffLineDetector, \
run_effocr_word, create_paired_transform, create_paired_transform_word

EFFOCR_DEFAULT_CONFIG = {
"line_model": "",
"line_backend": "yolov8",
"line_input_shape": (640, 640),
"localizer_model": "",
"localizer_backend": "yolov8",
"localizer_input_shape": (640, 640),
"word_recognizer_model": "./src/layoutparser/models/effocr/word_recognizer/enc.onnx",
"word_index": "./src/layoutparser/models/effocr/word_recognizer/word_index.index",
"word_ref": "./src/layoutparser/models/effocr/word_recognizer/word_ref.txt",
"char_recognizer_model": "./src/layoutparser/models/effocr/char_recognizer/enc.onnx",
"char_index": "./src/layoutparser/models/effocr/char_recognizer/char_index.index",
"char_ref": "./src/layoutparser/models/effocr/char_recognizer/char_ref.txt",
"localizer_iou_thresh": 0.10,
"localizer_conf_thresh": 0.20,
"line_iou_thresh": 0.05,
"line_conf_thresh": 0.50,
"word_dist_thresh": 0.90,
"lang": "en",
}

HUGGINGFACE_MODEL_MAP = {
'line_model': 'line.onnx',
'localizer_model': 'localizer.onnx',
'word_recognizer_model': 'word_recognizer/enc.onnx',
'char_recognizer_model': 'char_recognizer/enc.onnx',
'word_index': 'word_recognizer/word_index.index',
'word_ref': 'word_recognizer/word_ref.txt',
'char_index': 'char_recognizer/char_index.index',
'char_ref': 'char_recognizer/char_ref.txt'
}

HUGGINGFACE_REPO_NAME = 'dell-research-harvard/effocr_en'

class EffOCRFeatureType(BaseOCRElementType):
"""
The element types from EffOCR
"""

PAGE = 0
PARA = 1
LINE = 2
WORD = 3
CHAR = 4

@property
def attr_name(self):
name_cvt = {
EffOCRFeatureType.BLOCK: "blocks",
EffOCRFeatureType.PARA: "paragraphs",
EffOCRFeatureType.LINE: "lines",
EffOCRFeatureType.WORD: "words",
}
return name_cvt[self]

@property
def child_level(self):
child_cvt = {
EffOCRFeatureType.BLOCK: EffOCRFeatureType.PARA,
EffOCRFeatureType.PARA: EffOCRFeatureType.LINE,
EffOCRFeatureType.LINE: EffOCRFeatureType.WORD,
EffOCRFeatureType.WORD: None,
}
return child_cvt[self]



class EffOCRAgent(BaseOCRAgent):
"""EffOCR Inference -- Implements method described in https://scholar.harvard.edu/sites/scholar.harvard.edu/files/dell/files/effocr.pdf
Note:
TODO: Fill in with info once implemented
"""

# TODO: Fill in with package dependencies
DEPENDENCIES = ["effocr"]

def __init__(self, languages="eng", **kwargs):
"""Create a EffOCR Agent.
Args:
languages (:obj:`list` or :obj:`str`, optional):
You can specify the language code(s) of the documents to detect to determine the
language EffOCR uses when transcribing the document. As of 7/24, the only option is
English, but Japanese EffOCR will be implemented soon.
Defaults to 'eng'.
"""
if languages != 'eng':
raise NotImplementedError("EffOCR only supports English at this time.")

self.lang = languages if isinstance(languages, str) else "+".join(languages)

self.config = EFFOCR_DEFAULT_CONFIG
for key, value in kwargs.items():
if key in self.config.keys():
self.config[key] = value
else:
warnings.warn(f"Unknown config parameter {key} for {self.__class__.__name__}. Ignoring it.")

self._check_and_download_models()
self._check_and_download_indices()
self._load_models()
self._load_indices()
print(self.config)

def _check_and_download_models(self):
'''
Checks if all of line, localizer, word recognizer, and char recognizer are downloaded,
then downloads them if they are not.
'''

model_keys = ['line_model', 'localizer_model', 'word_recognizer_model', 'char_recognizer_model']
for key in model_keys:
if not os.path.exists(self.config[key]) or not self.config[key].endswith('.onnx'):
self.config[key] = hf_hub_download(HUGGINGFACE_REPO_NAME, HUGGINGFACE_MODEL_MAP[key])
# TODO: replace FileNotFoundError with download code

def _check_and_download_indices(self):
'''
Checks if the word and character recognizers' indices and refernece files are downloaded,
then downloads them if they are not.
'''

index_keys = ['word_index', 'char_index']
ref_keys = ['word_ref', 'char_ref']

for key in index_keys:
if not os.path.exists(self.config[key]):
self.config[key] = hf_hub_download(HUGGINGFACE_REPO_NAME, HUGGINGFACE_MODEL_MAP[key])

for key in ref_keys:
if not os.path.exists(self.config[key]):
self.config[key] = hf_hub_download(HUGGINGFACE_REPO_NAME, HUGGINGFACE_MODEL_MAP[key])

def _load_models(self):
'''
Function to instantiate each of the line model,
localizer model, word recognizer model, and char recognizer model.
'''

self.localizer_engine = EffLocalizer(
self.config['localizer_model'],
iou_thresh = self.config['localizer_iou_thresh'],
conf_thresh = self.config['localizer_conf_thresh'],
vertical = False if self.config['lang'] == "en" else True,
num_cores = multiprocessing.cpu_count(),
model_backend = self.config['localizer_backend'],
input_shape = self.config['localizer_input_shape']
)

# TODO: Fix imports for paired_transforms
char_transform = create_paired_transform(lang='en')
word_transform = create_paired_transform_word(lang='en')

self.word_recognizer_engine = EffRecognizer(
model = self.config['word_recognizer_model'],
transform = char_transform,
num_cores=multiprocessing.cpu_count(),
)

self.char_recognizer_engine = EffRecognizer(
model = self.config['char_recognizer_model'],
transform = char_transform,
num_cores=multiprocessing.cpu_count(),
)

self.line_detector_engine = EffLineDetector(
self.config['line_model'],
iou_thresh = self.config['line_iou_thresh'],
conf_thresh = self.config['line_conf_thresh'],
num_cores = multiprocessing.cpu_count(),
model_backend = self.config['line_backend'],
input_shape = self.config['line_input_shape']
)

def _load_indices(self):
'''
Function to instantiate the faiss indices for each of the word and character recognizers.
Indicies are responsible for storing base vectors for each word/character and performing
similarity search on unknown symbols.
'''

# char index
self.char_index = faiss.read_index(self.config['char_index'])
with open(self.config['char_ref']) as ref_file:
self.candidate_chars = ref_file.read().split()

# word index
self.word_index = faiss.read_index(self.config['word_index'])
with open(self.config['word_ref']) as ref_file:
self.candidate_words = ref_file.read().split()

def _detect(self, image, viz_lines_path=None):
'''
Function to detect text in an image using EffOCR.
Each of the two main parts, line detection and line transcription, are abstrated out here
'''

# Line Detection
line_crops, line_coords = self.line_detector_engine(image)

# Line Transcription
text_results = run_effocr_word(line_crops, self.localizer_engine, self.word_recognizer_engine, self.char_recognizer_engine, self.candidate_chars,
self.candidate_words, self.config['lang'], self.word_index, self.char_index, num_streams=multiprocessing.cpu_count(), vertical=False,
localizer_output = None, conf_thres=self.config['localizer_conf_thresh'], recognizer_thresh = self.config['word_dist_thresh'],
bbox_output = False, punc_padding = 0, insert_paragraph_breaks = True)

return text_results

def detect(self, image, return_response=False, return_only_text=True, agg_output_level=None, viz_lines_path = None):
"""Send the input image for OCR by the EffOCR agent.
Args:
image (:obj:`np.ndarray` or :obj:`str`):
The input image array or the name of the image file
return_response (:obj:`bool`, optional):
Whether directly return the effocr output.
Defaults to `False`.
return_only_text (:obj:`bool`, optional):
Whether return only the texts in the OCR results.
Defaults to `False`.
agg_output_level (:obj:`~EffOCRFeatureType`, optional):
When set, aggregate the EffOCR output with respect to the
specified aggregation level. Defaults to `None`.
Returns:
:obj:`dict` or :obj:`str`:
The OCR results in the specified format.
"""

res = self._detect(image, viz_lines_path = viz_lines_path)

if return_response:
return res

if return_only_text:
return res["text"]

if agg_output_level is not None:
return self.gather_data(res, agg_output_level)

return res["text"]


if __name__ == '__main__':
agent = EffOCRAgent()
img_path = r'C:\Users\bryan\Documents\NBER\layout-parser\tests\fixtures\ocr\test_effocr_image.jpg'
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Binary file added tests/fixtures/ocr/line_dets.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/fixtures/ocr/test_effocr_image.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
30 changes: 29 additions & 1 deletion tests/test_ocr.py
Original file line number Diff line number Diff line change
@@ -17,10 +17,13 @@
GCVFeatureType,
TesseractAgent,
TesseractFeatureType,
EffOCRAgent,
EffOCRFeatureType,
)
import json, cv2, os

image = cv2.imread("tests/fixtures/ocr/test_gcv_image.jpg")
effocr_image = cv2.imread("tests/fixtures/ocr/test_effocr_image.jpg")


def test_gcv_agent(test_detect=False):
@@ -76,4 +79,29 @@ def test_tesseract(test_detect=False):
assert r2 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.BLOCK)
assert r3 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.PARA)
assert r4 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.LINE)
assert r5 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.WORD)
assert r5 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.WORD)

'''
Test the EffOCRAgent, which implements EffOCR -- https://scholar.harvard.edu/sites/scholar.harvard.edu/files/dell/files/effocr.pdf
'''
def test_effocr(test_detect=True):
ocr_agent = EffOCRAgent()

# res = ocr_agent.load_response("tests/fixtures/ocr/test_effocr_response.json")
# r0 = ocr_agent.gather_text_annotations(res)
# r1 = ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.BLOCK)
# r2 = ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.PARA)
# r3 = ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.LINE)
# r4 = ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.WORD)
# r5 = ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.CHAR)

if test_detect:
res = ocr_agent.detect(effocr_image, return_response=True)
assert "The tug boat Alice" in res[0]
assert False
# assert r0 == res["text"]
# assert r1 == ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.BLOCK)
# assert r2 == ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.PARA)
# assert r3 == ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.LINE)
# assert r4 == ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.WORD)
# assert r5 == ocr_agent.gather_data(res, agg_level=EffOCRFeatureType.CHAR)