Skip to content

Commit 2b16585

Browse files
author
Marc Ahlers
committed
refactor: add typings and splitted crohme datamodule code into multiple files in preparation to support HME100K
1 parent 8f5e44b commit 2b16585

13 files changed

+281
-239
lines changed

comer/datamodule/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from .datamodule import Batch, CROHMEDatamodule
2-
from .vocab import vocab
1+
from comer.datamodule.crohme.datamodule import Batch, CROHMEDatamodule
2+
from comer.datamodule.crohme.vocab import vocab
33

44
vocab_size = len(vocab)
55

comer/datamodule/crohme/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .entry import DataEntry, extract_data_entries
2+
from .batch import Batch, BatchTuple, build_batches_from_entries, build_dataset
3+
from .dataset import CROHMEDataset
4+
from .datamodule import CROHMEDatamodule

comer/datamodule/crohme/batch.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from dataclasses import dataclass
2+
from typing import List, Tuple, Callable
3+
from zipfile import ZipFile
4+
5+
import numpy as np
6+
from torch import FloatTensor, LongTensor
7+
8+
from comer.datamodule.crohme import DataEntry, extract_data_entries
9+
10+
11+
@dataclass
12+
class Batch:
13+
img_bases: List[str] # [b,]
14+
imgs: FloatTensor # [b, 1, H, W]
15+
mask: LongTensor # [b, H, W]
16+
indices: List[List[int]] # [b, l]
17+
18+
def __len__(self) -> int:
19+
return len(self.img_bases)
20+
21+
def to(self, device) -> "Batch":
22+
return Batch(
23+
img_bases=self.img_bases,
24+
imgs=self.imgs.to(device),
25+
mask=self.mask.to(device),
26+
indices=self.indices,
27+
)
28+
29+
30+
# A BatchTuple represents a single batch which contains 3 lists of equal length (batch-len)
31+
# [file_names, images, labels]
32+
BatchTuple = Tuple[List[str], List[np.ndarray], List[List[str]]]
33+
34+
# change according to your GPU memory
35+
MAX_SIZE = 32e4
36+
37+
38+
def build_batches_from_entries(
39+
data: List[DataEntry],
40+
batch_size: int,
41+
batch_imagesize: int = MAX_SIZE,
42+
maxlen: int = 200,
43+
max_imagesize: int = MAX_SIZE,
44+
) -> List[BatchTuple]:
45+
curr_fname_batch: List[str] = []
46+
curr_feature_batch: List[np.ndarray] = []
47+
curr_label_batch: List[List[str]] = []
48+
49+
total_fname_batches: List[List[str]] = []
50+
total_feature_batches: List[List[np.ndarray]] = []
51+
total_label_batches: List[List[List[str]]] = []
52+
53+
biggest_image_size = 0
54+
get_entry_image_pixels: Callable[[DataEntry], int] = lambda x: x.image.size[0] * x.image.size[1]
55+
data.sort(key=get_entry_image_pixels)
56+
57+
i = 0
58+
for entry in data:
59+
size = get_entry_image_pixels(entry)
60+
image_arr = np.array(entry.image)
61+
if size > biggest_image_size:
62+
biggest_image_size = size
63+
batch_image_size = biggest_image_size * (i + 1)
64+
if len(entry.label) > maxlen:
65+
print("label", i, "length bigger than", maxlen, "ignore")
66+
elif size > max_imagesize:
67+
print(
68+
f"image: {entry.file_name} size: {image_arr.shape[0]} x {image_arr.shape[1]} = {size} bigger than {max_imagesize}, ignore"
69+
)
70+
else:
71+
if batch_image_size > batch_imagesize or i == batch_size:
72+
# a batch is full, add it to the "batch"-list and reset the current batch with the new entry.
73+
total_fname_batches.append(curr_fname_batch)
74+
total_feature_batches.append(curr_feature_batch)
75+
total_label_batches.append(curr_label_batch)
76+
# reset current batch
77+
i = 0
78+
biggest_image_size = size
79+
curr_fname_batch = []
80+
curr_feature_batch = []
81+
curr_label_batch = []
82+
# add the entry to the current batch
83+
curr_fname_batch.append(entry.file_name)
84+
curr_feature_batch.append(image_arr)
85+
curr_label_batch.append(entry.label)
86+
i += 1
87+
88+
# add last batch if it isn't empty
89+
if len(curr_fname_batch) > 0:
90+
total_fname_batches.append(curr_fname_batch)
91+
total_feature_batches.append(curr_feature_batch)
92+
total_label_batches.append(curr_label_batch)
93+
print("total ", len(total_feature_batches), "batch data loaded")
94+
return list(
95+
# Zips batches into a 3-Tuple Tuple[ List[str] , List[np.ndarray], List[List[str]] ]
96+
# Per batch: file_names, images , labels
97+
zip(
98+
total_fname_batches,
99+
total_feature_batches,
100+
total_label_batches
101+
)
102+
)
103+
104+
105+
def build_dataset(
106+
archive: ZipFile,
107+
folder: str,
108+
batch_size: int
109+
) -> List[BatchTuple]:
110+
return build_batches_from_entries(extract_data_entries(archive, folder), batch_size)

comer/datamodule/crohme/datamodule.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import os
2+
from typing import List, Optional
3+
from zipfile import ZipFile
4+
5+
import pytorch_lightning as pl
6+
import torch
7+
from comer.datamodule.crohme.dataset import CROHMEDataset
8+
from torch.utils.data.dataloader import DataLoader
9+
10+
from comer.datamodule.crohme import Batch, build_dataset, BatchTuple
11+
from comer.datamodule.crohme.vocab import vocab
12+
13+
14+
# Used to transform a Lighting-Batch into some other form (here, our custom Batch)
15+
def collate_fn(batch: List[BatchTuple]) -> Batch:
16+
assert len(batch) == 1
17+
batch = batch[0]
18+
fnames = batch[0]
19+
images_x = batch[1]
20+
seqs_y = [vocab.words2indices(x) for x in batch[2]]
21+
22+
heights_x = [s.size(1) for s in images_x]
23+
widths_x = [s.size(2) for s in images_x]
24+
25+
n_samples = len(heights_x)
26+
max_height_x = max(heights_x)
27+
max_width_x = max(widths_x)
28+
29+
x = torch.zeros(n_samples, 1, max_height_x, max_width_x)
30+
x_mask = torch.ones(n_samples, max_height_x, max_width_x, dtype=torch.bool)
31+
for idx, s_x in enumerate(images_x):
32+
x[idx, :, : heights_x[idx], : widths_x[idx]] = s_x
33+
x_mask[idx, : heights_x[idx], : widths_x[idx]] = 0
34+
35+
# return fnames, x, x_mask, seqs_y
36+
return Batch(fnames, x, x_mask, seqs_y)
37+
38+
39+
class CROHMEDatamodule(pl.LightningDataModule):
40+
def __init__(
41+
self,
42+
zipfile_path: str = f"{os.path.dirname(os.path.realpath(__file__))}/../../data.zip",
43+
test_year: str = "2014",
44+
train_batch_size: int = 8,
45+
eval_batch_size: int = 4,
46+
num_workers: int = 5,
47+
scale_aug: bool = False,
48+
) -> None:
49+
super().__init__()
50+
assert isinstance(test_year, str)
51+
self.zipfile_path = zipfile_path
52+
self.test_year = test_year
53+
self.train_batch_size = train_batch_size
54+
self.eval_batch_size = eval_batch_size
55+
self.num_workers = num_workers
56+
self.scale_aug = scale_aug
57+
58+
print(f"Load data from: {self.zipfile_path}")
59+
60+
def setup(self, stage: Optional[str] = None) -> None:
61+
with ZipFile(self.zipfile_path) as archive:
62+
if stage == "fit" or stage is None:
63+
self.train_dataset = CROHMEDataset(
64+
build_dataset(archive, "train", self.train_batch_size),
65+
True,
66+
self.scale_aug,
67+
)
68+
self.val_dataset = CROHMEDataset(
69+
build_dataset(archive, self.test_year, self.eval_batch_size),
70+
False,
71+
self.scale_aug,
72+
)
73+
if stage == "test" or stage is None:
74+
self.test_dataset = CROHMEDataset(
75+
build_dataset(archive, self.test_year, self.eval_batch_size),
76+
False,
77+
self.scale_aug,
78+
)
79+
80+
def train_dataloader(self):
81+
return DataLoader(
82+
self.train_dataset,
83+
shuffle=True,
84+
num_workers=self.num_workers,
85+
collate_fn=collate_fn,
86+
)
87+
88+
def val_dataloader(self):
89+
return DataLoader(
90+
self.val_dataset,
91+
shuffle=False,
92+
num_workers=self.num_workers,
93+
collate_fn=collate_fn,
94+
)
95+
96+
def test_dataloader(self):
97+
return DataLoader(
98+
self.test_dataset,
99+
shuffle=False,
100+
num_workers=self.num_workers,
101+
collate_fn=collate_fn,
102+
)

comer/datamodule/dataset.py renamed to comer/datamodule/crohme/dataset.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from typing import List
2+
13
import torchvision.transforms as tr
24
from torch.utils.data.dataset import Dataset
35

4-
from .transforms import ScaleAugmentation, ScaleToLimitRange
6+
from comer.datamodule.crohme import BatchTuple
7+
from comer.datamodule.utils.transforms import ScaleAugmentation, ScaleToLimitRange
58

69
K_MIN = 0.7
710
K_MAX = 1.4
@@ -13,7 +16,9 @@
1316

1417

1518
class CROHMEDataset(Dataset):
16-
def __init__(self, ds, is_train: bool, scale_aug: bool) -> None:
19+
ds: List[BatchTuple]
20+
21+
def __init__(self, ds: List[BatchTuple], is_train: bool, scale_aug: bool) -> None:
1722
super().__init__()
1823
self.ds = ds
1924

@@ -28,11 +33,11 @@ def __init__(self, ds, is_train: bool, scale_aug: bool) -> None:
2833
self.transform = tr.Compose(trans_list)
2934

3035
def __getitem__(self, idx):
31-
fname, img, caption = self.ds[idx]
36+
file_names, images, labels = self.ds[idx]
3237

33-
img = [self.transform(im) for im in img]
38+
images = [self.transform(im) for im in images]
3439

35-
return fname, img, caption
40+
return file_names, images, labels
3641

3742
def __len__(self):
3843
return len(self.ds)

comer/datamodule/crohme/entry.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from dataclasses import dataclass
2+
from typing import TypedDict, List
3+
from zipfile import ZipFile
4+
5+
from PIL import Image
6+
7+
8+
@dataclass
9+
class DataEntry:
10+
file_name: str
11+
image: Image
12+
label: List[str]
13+
14+
15+
def extract_data_entries(archive: ZipFile, dir_name: str) -> List[DataEntry]:
16+
"""Extract all data need for a dataset from zip archive
17+
18+
Args:
19+
archive (ZipFile):
20+
dir_name (str): dir name in archive zip (eg: train, test_2014......)
21+
22+
Returns:
23+
Data: list of tuple of image and formula
24+
"""
25+
with archive.open(f"data/{dir_name}/caption.txt", "r") as f:
26+
captions = f.readlines()
27+
data: List[DataEntry] = []
28+
for line in captions:
29+
tmp: List[str] = line.decode().strip().split()
30+
file_name: str = tmp[0]
31+
label: List[str] = tmp[1:]
32+
with archive.open(f"data/{dir_name}/img/{file_name}.bmp", "r") as f:
33+
# move image to memory immediately, avoid lazy loading, which will lead to None pointer error in loading
34+
img: Image.Image = Image.open(f).copy()
35+
data.append(DataEntry(file_name, img, label))
36+
37+
print(f"Extract data from: {dir_name}, with data size: {len(data)}")
38+
39+
return data
File renamed without changes.

0 commit comments

Comments
 (0)