Skip to content

Commit 94d9f0e

Browse files
committed
Added new data
1 parent f9a4dee commit 94d9f0e

File tree

5 files changed

+361
-0
lines changed

5 files changed

+361
-0
lines changed

STS_Research/configs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import os
2+
from datetime import datetime
3+
4+
from mltu.configs import BaseModelConfigs
5+
6+
class ModelConfigs(BaseModelConfigs):
7+
def __init__(self):
8+
super().__init__()
9+
self.model_path = os.path.join("Models/04_sentence_recognition", datetime.strftime(datetime.now(), "%Y%m%d%H%M"))
10+
self.vocab = ""
11+
self.height = 96
12+
self.width = 1408
13+
self.max_text_length = 0
14+
self.batch_size = 32
15+
self.learning_rate = 0.0005
16+
self.train_epochs = 1000
17+
self.train_workers = 20

STS_Research/inferenceModel.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import cv2
2+
import typing
3+
import numpy as np
4+
5+
from mltu.inferenceModel import OnnxInferenceModel
6+
from mltu.utils.text_utils import ctc_decoder, get_cer, get_wer
7+
from mltu.transformers import ImageResizer
8+
9+
class ImageToWordModel(OnnxInferenceModel):
10+
def __init__(self, char_list: typing.Union[str, list], *args, **kwargs):
11+
super().__init__(*args, **kwargs)
12+
self.char_list = char_list
13+
14+
def predict(self, image: np.ndarray):
15+
image = ImageResizer.resize_maintaining_aspect_ratio(image, *self.input_shapes[0][1:3][::-1])
16+
17+
image_pred = np.expand_dims(image, axis=0).astype(np.float32)
18+
19+
preds = self.model.run(self.output_names, {self.input_names[0]: image_pred})[0]
20+
21+
text = ctc_decoder(preds, self.char_list)[0]
22+
23+
return text
24+
25+
if __name__ == "__main__":
26+
import pandas as pd
27+
from tqdm import tqdm
28+
from mltu.configs import BaseModelConfigs
29+
30+
configs = BaseModelConfigs.load("Models/04_sentence_recognition/202301131202/configs.yaml")
31+
32+
model = ImageToWordModel(model_path=configs.model_path, char_list=configs.vocab)
33+
34+
df = pd.read_csv("Models/04_sentence_recognition/202301131202/val.csv").values.tolist()
35+
36+
accum_cer, accum_wer = [], []
37+
for image_path, label in tqdm(df):
38+
image = cv2.imread(image_path.replace("\\", "/"))
39+
40+
prediction_text = model.predict(image)
41+
42+
cer = get_cer(prediction_text, label)
43+
wer = get_wer(prediction_text, label)
44+
print("Image: ", image_path)
45+
print("Label:", label)
46+
print("Prediction: ", prediction_text)
47+
print(f"CER: {cer}; WER: {wer}")
48+
49+
accum_cer.append(cer)
50+
accum_wer.append(wer)
51+
52+
cv2.imshow(prediction_text, image)
53+
cv2.waitKey(0)
54+
cv2.destroyAllWindows()
55+
56+
print(f"Average CER: {np.average(accum_cer)}, Average WER: {np.average(accum_wer)}")

STS_Research/login.html

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
2+
3+
<!DOCTYPE html>
4+
<html lang="en">
5+
6+
<head>
7+
<title>Kaggle: Your Home for Data Science</title>
8+
<meta charset="utf-8" />
9+
<meta name="robots" content="index, follow" />
10+
<meta name="description" content="Kaggle is the world&#x2019;s largest data science community with powerful tools and resources to help you achieve your data science goals." />
11+
<meta name="turbolinks-cache-control" content="no-cache" />
12+
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=5.0, minimum-scale=1.0">
13+
<meta name="theme-color" content="#008ABC" />
14+
<script nonce="EiRBuHfyGAC/q&#x2B;OmJ6g0zQ==" type="text/javascript">
15+
window["pageRequestStartTime"] = 1716135025964;
16+
window["pageRequestEndTime"] = 1716135025968;
17+
window["initialPageLoadStartTime"] = new Date().getTime();
18+
</script>
19+
<script nonce="EiRBuHfyGAC/q&#x2B;OmJ6g0zQ==" id="gsi-client" src="https://accounts.google.com/gsi/client" async defer></script>
20+
<script nonce="EiRBuHfyGAC/q&#x2B;OmJ6g0zQ==">window.KAGGLE_JUPYTERLAB_PATH = "/static/assets/jupyterlab/jupyterlab-index-9369516b66d9c0a10de7.html";</script>
21+
<link rel="preconnect" href="https://www.google-analytics.com" crossorigin="anonymous" /><link rel="preconnect" href="https://stats.g.doubleclick.net" /><link rel="preconnect" href="https://storage.googleapis.com" /><link rel="preconnect" href="https://apis.google.com" />
22+
<link href="/static/images/favicon.ico" rel="shortcut icon" type="image/x-icon" />
23+
<link rel="manifest" href="/static/json/manifest.json" crossorigin="use-credentials">
24+
25+
26+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin />
27+
28+
<link href="https://fonts.googleapis.com/css?family=Inter:400,400i,500,500i,600,600i,700,700i&display=swap"
29+
rel="preload" as="style" />
30+
<link href="https://fonts.googleapis.com/css2?family=Google+Symbols:[email protected]&display=block"
31+
rel="preload" as="style" />
32+
<link href="https://fonts.googleapis.com/css?family=Inter:400,400i,500,500i,600,600i,700,700i&display=swap"
33+
rel="stylesheet" media="print" id="async-google-font-1" />
34+
<link href="https://fonts.googleapis.com/css2?family=Google+Symbols:[email protected]&display=block"
35+
rel="stylesheet" media="print" id="async-google-font-2" />
36+
<script nonce="EiRBuHfyGAC/q&#x2B;OmJ6g0zQ==" type="text/javascript">
37+
const styleSheetIds = ["async-google-font-1", "async-google-font-2"];
38+
styleSheetIds.forEach(function (id) {
39+
document.getElementById(id).addEventListener("load", function() {
40+
this.media = "all";
41+
});
42+
});
43+
</script>
44+
45+
<script nonce="EiRBuHfyGAC/q&#x2B;OmJ6g0zQ==" src="https://www.google.com/recaptcha/enterprise.js?render=6LcW02cpAAAAAJlaJemsQQEwAiTEYB4aR6FYE_rD&waf=session" async defer></script>
46+
<style>.grecaptcha-badge { visibility: hidden; }</style>
47+
48+
<link rel="stylesheet" type="text/css" href="/static/assets/vendor.css?v=dne" />
49+
<link rel="stylesheet" type="text/css" href="/static/assets/app.css?v=62d595175a876550f3e6" />
50+
51+
52+
53+
54+
<script nonce="EiRBuHfyGAC/q&#x2B;OmJ6g0zQ==">
55+
try{(function(a,s,y,n,c,h,i,d,e){d=s.createElement("style");
56+
d.appendChild(s.createTextNode(""));s.head.appendChild(d);d=d.sheet;
57+
y=y.map(x => d.insertRule(x + "{ opacity: 0 !important }"));
58+
h.start=1*new Date;h.end=i=function(){y.forEach(x => x<d.cssRules.length ? d.deleteRule(x) : {})};
59+
(a[n]=a[n]||[]).hide=h;setTimeout(function(){i();h.end=null},c);h.timeout=c;
60+
})(window,document,['.site-header-react__nav'],'dataLayer',2000,{'GTM-52LNT9S':true});}catch(ex){}
61+
</script>
62+
<script nonce="EiRBuHfyGAC/q&#x2B;OmJ6g0zQ==">
63+
window.dataLayer = window.dataLayer || [];
64+
function gtag() { dataLayer.push(arguments); }
65+
gtag('js', new Date());
66+
gtag('config', 'G-T7QHS60L4Q', {
67+
'optimize_id': 'GTM-52LNT9S',
68+
'displayFeaturesTask': null,
69+
'send_page_view': false,
70+
'content_group1': 'Account'
71+
});
72+
</script>
73+
<script nonce="EiRBuHfyGAC/q&#x2B;OmJ6g0zQ==" async src="https://www.googletagmanager.com/gtag/js?id=G-T7QHS60L4Q"></script>
74+
75+
76+
77+
78+
<meta name="twitter:site" content="@Kaggle" />
79+
80+
81+
82+
83+
84+
85+
86+
87+
88+
89+
<script nonce="EiRBuHfyGAC/q&#x2B;OmJ6g0zQ==">window['useKaggleAnalytics'] = true;</script>
90+
91+
<script id="gapi-target" nonce="EiRBuHfyGAC/q&#x2B;OmJ6g0zQ==" src="https://apis.google.com/js/api.js" defer
92+
async></script>
93+
<script nonce="EiRBuHfyGAC/q+OmJ6g0zQ==" src="/static/assets/runtime.js?v=96c24c08a967efb5ee0c" data-turbolinks-track="reload"></script>
94+
<script nonce="EiRBuHfyGAC/q+OmJ6g0zQ==" src="/static/assets/vendor.js?v=38b5f1af3c791be446c0" data-turbolinks-track="reload"></script>
95+
<script nonce="EiRBuHfyGAC/q+OmJ6g0zQ==" src="/static/assets/app.js?v=0cb512098feb97aac492" data-turbolinks-track="reload"></script>
96+
<script nonce="EiRBuHfyGAC/q&#x2B;OmJ6g0zQ==" type="text/javascript">
97+
window.kaggleStackdriverConfig = {
98+
key: 'AIzaSyA4eNqUdRRskJsCZWVz-qL655Xa5JEMreE',
99+
projectId: 'kaggle-161607',
100+
service: 'web-fe',
101+
version: 'ci',
102+
userId: '0'
103+
}
104+
</script>
105+
</head>
106+
107+
<body data-turbolinks="false">
108+
<main>
109+
110+
111+
112+
113+
114+
115+
<div id="site-container"></div>
116+
117+
<div id="site-body" class="hide">
118+
119+
</div>
120+
121+
122+
123+
124+
</main>
125+
</body>
126+
127+
</html>

STS_Research/model.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from keras import layers
2+
from keras.models import Model
3+
4+
from mltu.tensorflow.model_utils import residual_block
5+
6+
7+
def train_model(input_dim, output_dim, activation="leaky_relu", dropout=0.2):
8+
9+
inputs = layers.Input(shape=input_dim, name="input")
10+
11+
# normalize images here instead in preprocessing step
12+
input = layers.Lambda(lambda x: x / 255)(inputs)
13+
14+
x1 = residual_block(input, 32, activation=activation, skip_conv=True, strides=1, dropout=dropout)
15+
16+
x2 = residual_block(x1, 32, activation=activation, skip_conv=True, strides=2, dropout=dropout)
17+
x3 = residual_block(x2, 32, activation=activation, skip_conv=False, strides=1, dropout=dropout)
18+
19+
x4 = residual_block(x3, 64, activation=activation, skip_conv=True, strides=2, dropout=dropout)
20+
x5 = residual_block(x4, 64, activation=activation, skip_conv=False, strides=1, dropout=dropout)
21+
22+
x6 = residual_block(x5, 128, activation=activation, skip_conv=True, strides=2, dropout=dropout)
23+
x7 = residual_block(x6, 128, activation=activation, skip_conv=True, strides=1, dropout=dropout)
24+
25+
x8 = residual_block(x7, 128, activation=activation, skip_conv=True, strides=2, dropout=dropout)
26+
x9 = residual_block(x8, 128, activation=activation, skip_conv=False, strides=1, dropout=dropout)
27+
28+
squeezed = layers.Reshape((x9.shape[-3] * x9.shape[-2], x9.shape[-1]))(x9)
29+
30+
blstm = layers.Bidirectional(layers.LSTM(256, return_sequences=True))(squeezed)
31+
blstm = layers.Dropout(dropout)(blstm)
32+
33+
blstm = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(blstm)
34+
blstm = layers.Dropout(dropout)(blstm)
35+
36+
output = layers.Dense(output_dim + 1, activation="softmax", name="output")(blstm)
37+
38+
model = Model(inputs=inputs, outputs=output)
39+
return model

STS_Research/train.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import tensorflow as tf
2+
try: [tf.config.experimental.set_memory_growth(gpu, True) for gpu in tf.config.experimental.list_physical_devices("GPU")]
3+
except: pass
4+
5+
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
6+
7+
from mltu.preprocessors import ImageReader
8+
from mltu.transformers import ImageResizer, LabelIndexer, LabelPadding, ImageShowCV2
9+
from mltu.augmentors import RandomBrightness, RandomRotate, RandomErodeDilate, RandomSharpen
10+
from mltu.annotations.images import CVImage
11+
12+
from mltu.tensorflow.dataProvider import DataProvider
13+
from mltu.tensorflow.losses import CTCloss
14+
from mltu.tensorflow.callbacks import Model2onnx, TrainLogger
15+
from mltu.tensorflow.metrics import CERMetric, WERMetric
16+
17+
from model import train_model
18+
from configs import ModelConfigs
19+
20+
import os
21+
from tqdm import tqdm
22+
23+
# Must download and extract datasets manually from https://fki.tic.heia-fr.ch/databases/download-the-iam-handwriting-database to Datasets\IAM_Sentences
24+
sentences_txt_path = os.path.join("Datasets", "IAM_Sentences", "ascii", "sentences.txt")
25+
sentences_folder_path = os.path.join("Datasets", "IAM_Sentences", "sentences")
26+
27+
dataset, vocab, max_len = [], set(), 0
28+
words = open(sentences_txt_path, "r").readlines()
29+
for line in tqdm(words):
30+
if line.startswith("#"):
31+
continue
32+
33+
line_split = line.split(" ")
34+
if line_split[2] == "err":
35+
continue
36+
37+
folder1 = line_split[0][:3]
38+
folder2 = "-".join(line_split[0].split("-")[:2])
39+
file_name = line_split[0] + ".png"
40+
label = line_split[-1].rstrip("\n")
41+
42+
# replace "|" with " " in label
43+
label = label.replace("|", " ")
44+
45+
rel_path = os.path.join(sentences_folder_path, folder1, folder2, file_name)
46+
if not os.path.exists(rel_path):
47+
print(f"File not found: {rel_path}")
48+
continue
49+
50+
dataset.append([rel_path, label])
51+
vocab.update(list(label))
52+
max_len = max(max_len, len(label))
53+
54+
# Create a ModelConfigs object to store model configurations
55+
configs = ModelConfigs()
56+
57+
# Save vocab and maximum text length to configs
58+
configs.vocab = "".join(vocab)
59+
configs.max_text_length = max_len
60+
configs.save()
61+
62+
# Create a data provider for the dataset
63+
data_provider = DataProvider(
64+
dataset=dataset,
65+
skip_validation=True,
66+
batch_size=configs.batch_size,
67+
data_preprocessors=[ImageReader(CVImage)],
68+
transformers=[
69+
ImageResizer(configs.width, configs.height, keep_aspect_ratio=True),
70+
LabelIndexer(configs.vocab),
71+
LabelPadding(max_word_length=configs.max_text_length, padding_value=len(configs.vocab)),
72+
],
73+
)
74+
75+
# Split the dataset into training and validation sets
76+
train_data_provider, val_data_provider = data_provider.split(split = 0.9)
77+
78+
# Augment training data with random brightness, rotation and erode/dilate
79+
train_data_provider.augmentors = [
80+
RandomBrightness(),
81+
RandomErodeDilate(),
82+
RandomSharpen(),
83+
]
84+
85+
# Creating TensorFlow model architecture
86+
model = train_model(
87+
input_dim = (configs.height, configs.width, 3),
88+
output_dim = len(configs.vocab),
89+
)
90+
91+
# Compile the model and print summary
92+
model.compile(
93+
optimizer=tf.keras.optimizers.Adam(learning_rate=configs.learning_rate),
94+
loss=CTCloss(),
95+
metrics=[
96+
CERMetric(vocabulary=configs.vocab),
97+
WERMetric(vocabulary=configs.vocab)
98+
],
99+
run_eagerly=False
100+
)
101+
model.summary(line_length=110)
102+
103+
# Define callbacks
104+
earlystopper = EarlyStopping(monitor="val_CER", patience=20, verbose=1, mode="min")
105+
checkpoint = ModelCheckpoint(f"{configs.model_path}/model.h5", monitor="val_CER", verbose=1, save_best_only=True, mode="min")
106+
trainLogger = TrainLogger(configs.model_path)
107+
tb_callback = TensorBoard(f"{configs.model_path}/logs", update_freq=1)
108+
reduceLROnPlat = ReduceLROnPlateau(monitor="val_CER", factor=0.9, min_delta=1e-10, patience=5, verbose=1, mode="auto")
109+
model2onnx = Model2onnx(f"{configs.model_path}/model.h5")
110+
111+
# Train the model
112+
model.fit(
113+
train_data_provider,
114+
validation_data=val_data_provider,
115+
epochs=configs.train_epochs,
116+
callbacks=[earlystopper, checkpoint, trainLogger, reduceLROnPlat, tb_callback, model2onnx],
117+
workers=configs.train_workers
118+
)
119+
120+
# Save training and validation datasets as csv files
121+
train_data_provider.to_csv(os.path.join(configs.model_path, "train.csv"))
122+
val_data_provider.to_csv(os.path.join(configs.model_path, "val.csv"))

0 commit comments

Comments
 (0)