Skip to content

Commit 01c12f5

Browse files
author
andy
committed
change to python3
1 parent 74af41f commit 01c12f5

11 files changed

+94
-77
lines changed

README.md

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# LSTM + CTC + Tensorflow Example
22

3-
This is a demo using lstm and ctc to recognize a picture of a series numbers with blanks all at once.
3+
This is a demo using lstm and ctc to recognize a picture of a series numbers with blanks all at once. The code is compatible with Python3.
44

55
For example:given the piture below the model would give result `73791096754314441539`.
66

@@ -9,23 +9,31 @@ For example:given the piture below the model would give result `7379109675431444
99

1010
## Installation
1111
```
12+
# on mac
13+
pip install pillow
1214
pip install opencv-python
1315
brew install cmake
1416
brew tap homebrew/science
1517
brew install opencv
1618
sh ./prepare_train_data.sh
1719
```
20+
21+
```
22+
# on ubuntu
23+
pip intall pillow
24+
pip install opencv-python
25+
pip install tensorflow-gpu
26+
```
27+
28+
1829
The `prepare_train_data.sh` script would download the [SUN database](http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz) and extract the pitures to bgs dir. Then you can run `python gen.py` to generate test and train dir.
1930

2031
When the train and test data set are ready you can start the train process by `nohup python lstm_and_ctc_ocr_train.py `.
2132

2233
## Requirements
2334

24-
- Python 2.7+
35+
- Python 2.7+ / Python 3.5+
2536
- Tensorflow 1.0+
26-
- python_speech_features
27-
- numpy
28-
- scipy
2937

3038
##
3139
## License

common.py

+29-23
Original file line numberDiff line numberDiff line change
@@ -32,50 +32,52 @@
3232
import time
3333

3434
SPACE_INDEX = 0
35-
FIRST_INDEX = ord('0') - 1 # 0 is reserved to space
35+
# FIRST_INDEX = ord('0') - 1 # 0 is reserved to space
36+
FIRST_INDEX = 1 # 0 is reserved to space
3637

3738
SPACE_TOKEN = '<space>'
3839

3940
__all__ = (
4041
'DIGITS',
4142
'sigmoid',
4243
'softmax',
44+
'CHARS'
4345
)
4446

45-
OUTPUT_SHAPE = (64, 256)
4647

4748
DIGITS = "0123456789"
48-
# LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
49+
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
4950

51+
CHARS = list(DIGITS + LETTERS)
5052

51-
CHARS = DIGITS
52-
LENGTH = 16
53-
LENGTHS = [16, 20] # the number of digits varies from LENGTHS[0] to LENGTHS[1] in a image
54-
TEST_SIZE = 200
55-
ADD_BLANK = True # if add a blank between digits
53+
LENGTHS = [6, 6] # the number of digits varies from LENGTHS[0] to LENGTHS[1] in a image
54+
TEST_SIZE = 100
55+
ADD_BLANK = True # if add a blank between digits
5656
LEARNING_RATE_DECAY_FACTOR = 0.9 # The learning rate decay factor
5757
INITIAL_LEARNING_RATE = 1e-3
5858
DECAY_STEPS = 5000
5959

6060
# parameters for bdlstm ctc
6161
BATCH_SIZE = 64
62-
BATCHES = 10
62+
BATCHES = 100
63+
64+
OUTPUT_SHAPE = (BATCH_SIZE, 256)
6365

6466
TRAIN_SIZE = BATCH_SIZE * BATCHES
6567

6668
MOMENTUM = 0.9
67-
REPORT_STEPS = 100
69+
REPORT_STEPS = 1000
6870

6971
# Hyper-parameters
70-
num_epochs = 200
71-
num_hidden = 64
72-
num_layers = 1
72+
num_epochs = 2000
73+
num_hidden = 128
74+
num_layers = 2
7375

7476
# Some configs
7577
# Accounting the 0th indice + space + blank label = 28 characters
7678
# num_classes = ord('9') - ord('0') + 1 + 1 + 1
77-
num_classes = len(DIGITS) + 1 + 1 # 10 digits + blank + ctc blank
78-
print num_classes
79+
num_classes = len(CHARS) + 1 + 1 # 10 digits + blank + ctc blank
80+
print(num_classes)
7981

8082

8183
def softmax(a):
@@ -96,8 +98,8 @@ def sigmoid(a):
9698
def load_data_set(dirname):
9799
fname_list = glob.glob(dirname + "/*.png")
98100
result = dict()
101+
print("loading", dirname)
99102
for fname in sorted(fname_list):
100-
print "loading", fname
101103
im = cv2.imread(fname)[:, :, 0].astype(numpy.float32) / 255.
102104
code = list(fname.split("/")[1].split("_")[1])
103105
index = fname.split("/")[1].split("_")[0]
@@ -108,7 +110,7 @@ def load_data_set(dirname):
108110
def read_data_for_lstm_ctc(dirname, start_index=None, end_index=None):
109111
start = time.time()
110112
fname_list = []
111-
if not data_set.has_key(dirname):
113+
if dirname not in data_set.keys():
112114
load_data_set(dirname)
113115

114116
if start_index is None:
@@ -127,12 +129,16 @@ def read_data_for_lstm_ctc(dirname, start_index=None, end_index=None):
127129
# im = cv2.imread(fname)[:, :, 0].astype(numpy.float32) / 255.
128130
# code = list(fname.split("/")[1].split("_")[1])
129131
im, code = dir_data_set.get(fname)
130-
yield im, numpy.asarray([SPACE_INDEX if x == SPACE_TOKEN else (ord(x) - FIRST_INDEX) for x in list(code)])
132+
yield im, numpy.asarray(
133+
[SPACE_INDEX if x == SPACE_TOKEN else (CHARS.index(x) + FIRST_INDEX) for x in list(code)])
131134
# print("get time ", time.time() - start)
132135

133136

137+
# print numpy.asarray([SPACE_INDEX if x == SPACE_TOKEN else (CHARS.index(x) + FIRST_INDEX) for x in list(code)])
138+
139+
134140
def convert_original_code_train_code(code):
135-
return numpy.asarray([SPACE_INDEX if x == SPACE_TOKEN else (ord(x) - FIRST_INDEX) for x in code])
141+
return numpy.asarray([SPACE_INDEX if x == SPACE_TOKEN else (CHARS.index(x) - FIRST_INDEX) for x in code])
136142

137143

138144
def unzip(b):
@@ -144,9 +150,9 @@ def unzip(b):
144150

145151
if __name__ == '__main__':
146152
train_inputs, train_codes = unzip(list(read_data_for_lstm_ctc("test"))[:2])
147-
print train_inputs.shape
148-
print train_codes
153+
print(train_inputs.shape)
154+
print(train_codes)
149155
print("train_codes", train_codes)
150156
targets = np.asarray(train_codes).flat[:]
151-
print targets
152-
print list(read_data_for_lstm_ctc("test", 0, 10))
157+
print(targets)
158+
print(list(read_data_for_lstm_ctc("test", 0, 10)))

detect.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,6 @@ def detect(test_inputs, test_targets, test_seq_len):
5353

5454
if __name__ == '__main__':
5555
test_inputs, test_targets, test_seq_len = utils.get_data_set('small_test')
56-
print test_inputs[0].shape
57-
print detect(test_inputs, test_targets, test_seq_len)
56+
print(test_inputs[0].shape)
57+
print(detect(test_inputs, test_targets, test_seq_len))
5858
# print_tensors_in_checkpoint_file("model/ocr.model.50", None)

extractbgs.py

-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def members():
8585
if im.shape[0] > 256:
8686
im = cv2.resize(im, (256, 256))
8787
fname = "bgs/{:08}.jpg".format(index)
88-
print fname
8988
rc = cv2.imwrite(fname, im)
9089
if not rc:
9190
raise Exception("Failed to write file {}".format(fname))

gen.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,15 @@
4545

4646
import common
4747
from common import OUTPUT_SHAPE
48-
fonts = ["fonts/Farrington-7B-Qiqi.ttf", "fonts/Arial.ttf", "fonts/times.ttf"]
49-
# fonts = ["fonts/times.ttf"]
48+
49+
# fonts = ["fonts/Farrington-7B-Qiqi.ttf", "fonts/Arial.ttf", "fonts/times.ttf"]
50+
fonts = ["fonts/times.ttf"]
5051
FONT_HEIGHT = 32 # Pixel size to which the chars are resized
5152

52-
CHARS = common.CHARS + " "
5353

5454

55+
CHARS=common.CHARS[:]
56+
CHARS.append(" ")
5557
def make_char_ims(output_height, font):
5658
font_size = output_height * 4
5759
font = ImageFont.truetype(font, font_size)
@@ -166,7 +168,7 @@ def generate_code():
166168
for i in range(length):
167169
if 0 == i % 4 and append_blank:
168170
f = f + blank
169-
f = f + random.choice(common.DIGITS)
171+
f = f + random.choice(common.CHARS)
170172
return f
171173

172174

@@ -283,5 +285,5 @@ def generate_ims(num_images):
283285
im_gen = generate_ims(size.get(dir_name))
284286
for img_idx, (im, c, p) in enumerate(im_gen):
285287
fname = dir_name + "/{:08d}_{}_{}.png".format(img_idx, c, "1" if p else "0")
286-
print '\''+fname+'\','
288+
print('\'' + fname + '\',')
287289
cv2.imwrite(fname, im * 255.)

gen_no_plate_shape_version.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@
5050
# fonts = ["fonts/times.ttf"]
5151
FONT_HEIGHT = 32 # Pixel size to which the chars are resized
5252

53-
CHARS = common.CHARS + " "
54-
53+
CHARS=common.CHARS[:]
54+
CHARS.append(" ")
5555

5656
def make_char_ims(output_height, font):
5757
font_size = output_height * 4
@@ -292,5 +292,5 @@ def generate_ims(num_images):
292292
im_gen = generate_ims(size.get(dir_name))
293293
for img_idx, (im, c, p) in enumerate(im_gen):
294294
fname = dir_name + "/{:08d}_{}_{}.png".format(img_idx, c, "1" if p else "0")
295-
print '\'' + fname + '\','
295+
print('\'' + fname + '\',')
296296
cv2.imwrite(fname, im * 255.)

lstm_and_ctc_ocr_train.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
# Some configs
1818
# Accounting the 0th indice + space + blank label = 28 characters
19-
num_classes = ord('9') - ord('0') + 1 + 1 + 1
19+
# num_classes = ord('9') - ord('0') + 1 + 1 + 1
20+
num_classes = common.num_classes
2021
print("num_classes", num_classes)
2122
# Hyper-parameters
2223
num_epochs = 10000
@@ -59,7 +60,7 @@ def train():
5960
staircase=True)
6061
logits, inputs, targets, seq_len, W, b = model.get_train_model()
6162

62-
loss = tf.nn.ctc_loss( targets, logits, seq_len)
63+
loss = tf.nn.ctc_loss(targets, logits, seq_len)
6364
cost = tf.reduce_mean(loss)
6465

6566
optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
@@ -89,25 +90,27 @@ def do_batch():
8990
if steps > 0 and steps % common.REPORT_STEPS == 0:
9091
do_report()
9192
save_path = saver.save(session, "models/ocr.model", global_step=steps)
92-
# print(save_path)
93+
#print(save_path)
9394
return b_cost, steps
9495

95-
with tf.Session() as session:
96+
config = tf.ConfigProto()
97+
config.gpu_options.allow_growth = True
98+
with tf.Session(config=config) as session:
9699
session.run(init)
97100
saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
98-
for curr_epoch in xrange(num_epochs):
101+
for curr_epoch in range(num_epochs):
99102
# variables = tf.all_variables()
100103
# for i in variables:
101104
# print(i.name)
102105

103106
print("Epoch.......", curr_epoch)
104107
train_cost = train_ler = 0
105-
for batch in xrange(common.BATCHES):
108+
for batch in range(common.BATCHES):
106109
start = time.time()
107110
train_inputs, train_targets, train_seq_len = utils.get_data_set('train', batch * common.BATCH_SIZE,
108111
(batch + 1) * common.BATCH_SIZE)
109112

110-
print("get data time", time.time() - start)
113+
#print("get data time", time.time() - start)
111114
start = time.time()
112115
c, steps = do_batch()
113116
train_cost += c * common.BATCH_SIZE
@@ -116,7 +119,6 @@ def do_batch():
116119

117120
train_cost /= common.TRAIN_SIZE
118121
# train_ler /= common.TRAIN_SIZE
119-
120122
val_feed = {inputs: train_inputs,
121123
targets: train_targets,
122124
seq_len: train_seq_len}

model.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def avg_pool(x, ksize=(2, 2), stride=(2, 2)):
3636
def convolutional_layers():
3737
"""
3838
Get the convolutional layers of the model.
39-
4039
"""
4140

4241
inputs = tf.placeholder(tf.float32, [None, None, common.OUTPUT_SHAPE[0]])
@@ -71,15 +70,20 @@ def convolutional_layers():
7170
features = tf.nn.relu(tf.matmul(conv_layer_flat, W_fc1) + b_fc1)
7271
shape = tf.shape(features)
7372
features = tf.reshape(features, [shape[0], common.OUTPUT_SHAPE[1], 1]) # batchsize * outputshape * 1
74-
return features
73+
return inputs, features
74+
75+
76+
def lstm_cell():
77+
return tf.contrib.rnn.LSTMCell(common.num_hidden)
7578

7679

7780
def get_train_model():
7881
# Has size [batch_size, max_stepsize, num_features], but the
7982
# batch_size and max_stepsize can vary along each step
80-
#features = convolutional_layers()
81-
#print features.get_shape()
82-
inputs = tf.placeholder(tf.float32, [None, None, common.OUTPUT_SHAPE[0]])
83+
inputs, features = convolutional_layers()
84+
# print features.get_shape()
85+
86+
# inputs = tf.placeholder(tf.float32, [None, None, common.OUTPUT_SHAPE[0]])
8387

8488
# Here we use sparse_placeholder that will generate a
8589
# SparseTensor required by ctc_loss op.
@@ -92,16 +96,16 @@ def get_train_model():
9296
# Can be:
9397
# tf.nn.rnn_cell.RNNCell
9498
# tf.nn.rnn_cell.GRUCell
95-
cell = tf.contrib.rnn.core_rnn_cell.LSTMCell(common.num_hidden, state_is_tuple=True)
99+
# cell = tf.contrib.rnn.LSTMCell(common.num_hidden, state_is_tuple=True)
96100

97101
# Stacking rnn cells
98-
stack = tf.contrib.rnn.core_rnn_cell.MultiRNNCell([cell] * common.num_layers,
102+
stack = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(0, common.num_layers)],
99103
state_is_tuple=True)
100104

101105
# The second output is the last state and we will no use that
102-
outputs, _ = tf.nn.dynamic_rnn(cell, inputs, seq_len, dtype=tf.float32)
106+
outputs, _ = tf.nn.dynamic_rnn(stack, features, seq_len, dtype=tf.float32)
103107

104-
shape = tf.shape(inputs)
108+
shape = tf.shape(features)
105109
batch_s, max_timesteps = shape[0], shape[1]
106110

107111
# Reshaping to apply the same weights over the timesteps

models/README

Whitespace-only changes.

test.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
import utils
88

99
__author__ = "andy"
10-
for batch in xrange(common.BATCHES):
11-
train_inputs, train_targets, train_seq_len = utils.get_data_set('train', batch*common.BATCH_SIZE, (batch + 1) * common.BATCH_SIZE)
12-
print batch, train_inputs.shape
13-
# pickle_file = 'test/test.pickle' + str(batch)
14-
# f = open(pickle_file, 'wb')
15-
# pickle.dump(batch_data, f, pickle.HIGHEST_PROTOCOL)
10+
a = ['a','b','c','d']
11+
print (a.index('d'))
12+
13+
#for batch in xrange(common.BATCHES):
14+
# train_inputs, train_targets, train_seq_len = utils.get_data_set('train', batch*common.BATCH_SIZE, (batch + 1) * common.BATCH_SIZE)
15+
# print batch, train_inputs.shape
16+

0 commit comments

Comments
 (0)