Skip to content

Commit 6895ddb

Browse files
authored
fix: fix bugs in data structure (#402)
- fix bug when an empty list is passed into `TextRegions.from_list` - fix bug when concatenating a list of `LayoutElements` the class id maps is not updated correctly
1 parent 4309e9e commit 6895ddb

File tree

5 files changed

+51
-6
lines changed

5 files changed

+51
-6
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1+
## 0.8.2
2+
3+
* fix: fix bug when an empty list is passed into `TextRegions.from_list` triggers `IndexError`
4+
* fix: fix bug when concatenate a list of `LayoutElements` the class id mapping is no properly
5+
updated
6+
17
## 0.8.1
8+
29
* fix: fix list index out of range error caused by calling LayoutElements.from_list() with empty list
310

411
## 0.8.0

test_unstructured_inference/test_elements.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,31 @@ def test_layoutelements_to_list_and_back(test_layoutelements):
441441
def test_layoutelements_from_list_no_elements():
442442
back = LayoutElements.from_list(elements=[])
443443
assert back.source is None
444+
assert back.element_coords.size == 0
445+
446+
447+
def test_textregions_from_list_no_elements():
448+
back = TextRegions.from_list(regions=[])
449+
assert back.source is None
450+
assert back.element_coords.size == 0
451+
452+
453+
def test_layoutelements_concatenate():
454+
layout1 = LayoutElements(
455+
element_coords=np.array([[0, 0, 1, 1], [1, 1, 2, 2]]),
456+
texts=np.array(["a", "two"]),
457+
source=None,
458+
element_class_ids=np.array([0, 1]),
459+
element_class_id_map={0: "type0", 1: "type1"},
460+
)
461+
layout2 = LayoutElements(
462+
element_coords=np.array([[10, 10, 2, 2], [20, 20, 1, 1]]),
463+
texts=np.array(["three", "4"]),
464+
source=None,
465+
element_class_ids=np.array([0, 1]),
466+
element_class_id_map={0: "type1", 1: "type2"},
467+
)
468+
joint = LayoutElements.concatenate([layout1, layout2])
469+
assert joint.texts.tolist() == ["a", "two", "three", "4"]
470+
assert joint.element_class_ids.tolist() == [0, 1, 1, 2]
471+
assert joint.element_class_id_map == {0: "type0", 1: "type1", 2: "type2"}

unstructured_inference/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.8.1" # pragma: no cover
1+
__version__ = "0.8.2" # pragma: no cover

unstructured_inference/inference/elements.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ def from_list(cls, regions: list):
244244
for region in regions:
245245
coords.append((region.bbox.x1, region.bbox.y1, region.bbox.x2, region.bbox.y2))
246246
texts.append(region.text)
247-
return cls(element_coords=np.array(coords), texts=np.array(texts), source=regions[0].source)
247+
source = regions[0].source if regions else None
248+
return cls(element_coords=np.array(coords), texts=np.array(texts), source=source)
248249

249250
def __len__(self):
250251
return self.element_coords.shape[0]

unstructured_inference/inference/layoutelement.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,22 +74,31 @@ def slice(self, indices) -> LayoutElements:
7474
def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
7575
"""concatenate a sequence of LayoutElements in order as one LayoutElements"""
7676
coords, texts, probs, class_ids, sources = [], [], [], [], []
77-
class_id_map = {}
77+
class_id_reverse_map: dict[str, int] = {}
7878
for group in groups:
7979
coords.append(group.element_coords)
8080
texts.append(group.texts)
8181
probs.append(group.element_probs)
82-
class_ids.append(group.element_class_ids)
8382
if group.source:
8483
sources.append(group.source)
84+
85+
idx = group.element_class_ids.copy()
8586
if group.element_class_id_map:
86-
class_id_map.update(group.element_class_id_map)
87+
for class_id, class_name in group.element_class_id_map.items():
88+
if class_name in class_id_reverse_map:
89+
idx[group.element_class_ids == class_id] = class_id_reverse_map[class_name]
90+
continue
91+
new_id = len(class_id_reverse_map)
92+
class_id_reverse_map[class_name] = new_id
93+
idx[group.element_class_ids == class_id] = new_id
94+
class_ids.append(idx)
95+
8796
return cls(
8897
element_coords=np.concatenate(coords),
8998
texts=np.concatenate(texts),
9099
element_probs=np.concatenate(probs),
91100
element_class_ids=np.concatenate(class_ids),
92-
element_class_id_map=class_id_map,
101+
element_class_id_map={v: k for k, v in class_id_reverse_map.items()},
93102
source=sources[0] if sources else None,
94103
)
95104

0 commit comments

Comments
 (0)