Skip to content

Commit b160de7

Browse files
committed
add SlimNetsLayer and Inception V3 example / Merge TF-Slim into TensorLayer
1 parent 7594891 commit b160de7

File tree

5 files changed

+217
-3
lines changed

5 files changed

+217
-3
lines changed

docs/modules/layers.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ In addition, if you want to update the parameters of previous 2 layers at the sa
270270
FlattenLayer
271271
ConcatLayer
272272
ReshapeLayer
273+
SlimNetsLayer
273274
MultiplexerLayer
274275
EmbeddingAttentionSeq2seqWrapper
275276
flatten_reshape
@@ -357,10 +358,20 @@ so to implement 1D CNN, you can use Reshape layer as follow.
357358

358359
.. autoclass:: Conv2dLayer
359360

361+
2D Deconvolutional layer
362+
^^^^^^^^^^^^^^^^^^^^^^^^^^
363+
364+
.. autoclass:: DeConv2dLayer
365+
366+
360367
3D Convolutional layer
361368
^^^^^^^^^^^^^^^^^^^^^^^
362369

363370
.. autoclass:: Conv3dLayer
371+
372+
3D Deconvolutional layer
373+
^^^^^^^^^^^^^^^^^^^^^^^^^^
374+
364375
.. autoclass:: DeConv3dLayer
365376

366377
Pooling layer
@@ -397,6 +408,13 @@ Reshape layer
397408

398409
.. autoclass:: ReshapeLayer
399410

411+
Merge TF-Slim
412+
^^^^^^^^^^^^^^^
413+
414+
Yes ! TF-Slim models can be merged into TensorLayer, all Google's Pre-trained model can be used easily ,
415+
see `Slim-model <https://github.com/tensorflow/models/tree/master/slim#Install>`_ .
416+
417+
.. autoclass:: SlimNetsLayer
400418

401419
Flow control layer
402420
----------------------

docs/user/example.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Computer Vision
1919
- Convolutional Network (CIFAR-10). A Convolutional neural network implementation for classifying CIFAR-10 dataset, see ``tutorial_cifar10.py`` and ``tutorial_cifar10_tfrecord.py``on `GitHub`_.
2020
- VGG 16 (ImageNet). A Convolutional neural network implementation for classifying ImageNet dataset, see ``tutorial_vgg16.py`` on `GitHub`_.
2121
- VGG 19 (ImageNet). A Convolutional neural network implementation for classifying ImageNet dataset, see ``tutorial_vgg19.py`` on `GitHub`_.
22+
- InceptionV3 (ImageNet). A Convolutional neural network implementation for classifying ImageNet dataset, see ``tutorial_inceptionV3_tfslim.py`` on `GitHub`_.
2223

2324

2425
Natural Language Processing
@@ -36,6 +37,13 @@ Reinforcement Learning
3637
- Deep Reinforcement Learning - Pong Game. Teach a machine to play Pong games, see ``tutorial_atari_pong.py`` on `GitHub`_.
3738

3839

40+
Special Examples
41+
=================
42+
43+
- Merge TF-Slim into TensorLayer. ``tutorial_inceptionV3_tfslim.py`` on `GitHub`_.
44+
- MultiplexerLayer. ``tutorial_mnist_multiplexer.py`` on `GitHub`_.
45+
46+
3947
..
4048
Applications
4149
=============

tensorlayer/layers.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,6 +1629,58 @@ def __init__(
16291629
self.all_drop = dict(layer.all_drop)
16301630
self.all_layers.extend( [self.outputs] )
16311631

1632+
## TF-Slim layer
1633+
class SlimNetsLayer(Layer):
1634+
"""
1635+
The :class:`SlimNetsLayer` class can be used to merge all TF-Slim nets into
1636+
TensorLayer. Model can be found in `slim-model <https://github.com/tensorflow/models/tree/master/slim#Install>`_ , more about slim
1637+
see `slim-git <https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim>`_ .
1638+
1639+
Parameters
1640+
----------
1641+
layer : a list of :class:`Layer` instances
1642+
The `Layer` class feeding into this layer.
1643+
slim_layer : a slim network function
1644+
The network you want to stack onto, end with ``return net, end_points``.
1645+
name : a string or None
1646+
An optional name to attach to this layer.
1647+
1648+
Note
1649+
-----
1650+
The due to TF-Slim stores the layers as dictionary, the ``all_layers`` in this
1651+
network is not in order ! Fortunately, the ``all_params`` are in order.
1652+
1653+
"""
1654+
def __init__(
1655+
self,
1656+
layer = None,
1657+
slim_layer = None,
1658+
slim_args = {},
1659+
name ='slim_layer',
1660+
):
1661+
Layer.__init__(self, name=name)
1662+
self.inputs = layer.outputs
1663+
print(" tensorlayer:Instantiate SlimNetsLayer %s: %s" % (self.name, slim_layer.__name__))
1664+
1665+
with tf.variable_scope(name) as vs:
1666+
net, end_points = slim_layer(self.inputs, **slim_args)
1667+
slim_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)
1668+
1669+
self.outputs = net
1670+
1671+
slim_layers = []
1672+
for v in end_points.values():
1673+
tf.contrib.layers.summaries.summarize_activation(v)
1674+
slim_layers.append(v)
1675+
1676+
self.all_layers = list(layer.all_layers)
1677+
self.all_params = list(layer.all_params)
1678+
self.all_drop = dict(layer.all_drop)
1679+
1680+
self.all_layers.extend( slim_layers )
1681+
self.all_params.extend( slim_variables )
1682+
1683+
16321684
## Flow control layer
16331685
class MultiplexerLayer(Layer):
16341686
"""

tensorlayer/ops.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,14 @@ def get_site_packages_directory():
156156
"""Print and return the site-packages directory?
157157
"""
158158
import site
159-
loc = site.getsitepackages()
160-
print(loc)
161-
return loc
159+
try:
160+
loc = site.getsitepackages()
161+
print(" tl.ops : site-packages in ", loc)
162+
return loc
163+
except:
164+
p = ' tl.ops : You are using virtual environment'
165+
print(p)
166+
return p
162167

163168

164169

tutorial_inceptionV3_tfslim.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf8 -*-
3+
4+
5+
import tensorflow as tf
6+
import tensorlayer as tl
7+
slim = tf.contrib.slim
8+
from tensorflow.contrib.slim.python.slim.nets.alexnet import alexnet_v2
9+
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_base, inception_v3, inception_v3_arg_scope
10+
# from tensorflow.contrib.slim.python.slim.nets.resnet_v2 import resnet_v2_152
11+
# from tensorflow.contrib.slim.python.slim.nets.vgg import vgg_16
12+
import skimage
13+
import skimage.io
14+
import skimage.transform
15+
import time
16+
from data.imagenet_classes import *
17+
import numpy as np
18+
"""
19+
You will learn:
20+
1. What is TF-Slim ?
21+
1. How to combine TensorLayer and TF-Slim ?
22+
23+
Introduction of Slim : https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim
24+
Slim Pre-trained Models : https://github.com/tensorflow/models/tree/master/slim
25+
26+
With the help of SlimNetsLayer, all Slim Model can be combined into TensorLayer.
27+
All models in the following link, end with `return net, end_points`` are available.
28+
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim/python/slim/nets
29+
30+
31+
Bugs
32+
-----
33+
tf.variable_scope :
34+
https://groups.google.com/a/tensorflow.org/forum/#!topic/discuss/RoxrU3UnbFA
35+
load inception_v3 for prediction:
36+
http://stackoverflow.com/questions/39357454/restore-checkpoint-in-tensorflow-tensor-name-not-found
37+
"""
38+
def load_image(path):
39+
# load image
40+
img = skimage.io.imread(path)
41+
img = img / 255.0
42+
assert (0 <= img).all() and (img <= 1.0).all()
43+
# print "Original Image Shape: ", img.shape
44+
# we crop image from center
45+
short_edge = min(img.shape[:2])
46+
yy = int((img.shape[0] - short_edge) / 2)
47+
xx = int((img.shape[1] - short_edge) / 2)
48+
crop_img = img[yy: yy + short_edge, xx: xx + short_edge]
49+
# resize to 224, 224
50+
resized_img = skimage.transform.resize(crop_img, (299, 299))
51+
return resized_img
52+
53+
54+
def print_prob(prob):
55+
synset = class_names
56+
# print prob
57+
pred = np.argsort(prob)[::-1]
58+
# Get top1 label
59+
top1 = synset[pred[0]]
60+
print("Top1: ", top1, prob[pred[0]])
61+
# Get top5 label
62+
top5 = [(synset[pred[i]], prob[pred[i]]) for i in range(5)]
63+
print("Top5: ", top5)
64+
return top1
65+
66+
67+
## Alexnet_v2 / All Slim nets can be merged into TensorLayer
68+
# x = tf.placeholder(tf.float32, shape=[None, 299, 299, 3])
69+
# net_in = tl.layers.InputLayer(x, name='input_layer')
70+
# network = tl.layers.SlimNetsLayer(layer=net_in, slim_layer=alexnet_v2,
71+
# slim_args= {
72+
# 'num_classes' : 1000,
73+
# 'is_training' : True,
74+
# 'dropout_keep_prob' : 0.5,
75+
# 'spatial_squeeze' : True,
76+
# 'scope' : 'alexnet_v2'
77+
# }
78+
# )
79+
# sess = tf.InteractiveSession()
80+
# sess.run(tf.initialize_all_variables())
81+
# network.print_params()
82+
# exit()
83+
84+
# InceptionV3
85+
x = tf.placeholder(tf.float32, shape=[None, 299, 299, 3])
86+
net_in = tl.layers.InputLayer(x, name='input_layer') # DH
87+
with slim.arg_scope(inception_v3_arg_scope()):
88+
# logits, end_points = inception_v3(X, num_classes=1001,
89+
# is_training=False)
90+
network = tl.layers.SlimNetsLayer(layer=net_in, slim_layer=inception_v3,
91+
slim_args= {
92+
'num_classes' : 1001,
93+
'is_training' : False,
94+
# 'dropout_keep_prob' : 0.8, # for training
95+
# 'min_depth' : 16,
96+
# 'depth_multiplier' : 1.0,
97+
# 'prediction_fn' : slim.softmax,
98+
# 'spatial_squeeze' : True,
99+
# 'reuse' : None,
100+
# 'scope' : 'InceptionV3'
101+
},
102+
name=''
103+
)
104+
saver = tf.train.Saver()
105+
106+
sess = tf.InteractiveSession()
107+
sess.run(tf.initialize_all_variables())
108+
109+
# with tf.Session() as sess:
110+
saver.restore(sess, "inception_v3.ckpt") # download from https://github.com/tensorflow/models/tree/master/slim#Install
111+
print("Model Restored")
112+
network.print_params(False)
113+
114+
115+
from scipy.misc import imread, imresize
116+
y = network.outputs
117+
probs = tf.nn.softmax(y)
118+
img1 = load_image("data/puzzle.jpeg")
119+
img1 = img1.reshape((1, 299, 299, 3))
120+
121+
start_time = time.time()
122+
prob = sess.run(probs, feed_dict= {x : img1})
123+
print("End time : %.5ss" % (time.time() - start_time))
124+
print_prob(prob[0][1:]) # Note : as it have 1001 outputs, the 1st output is nothing
125+
126+
127+
128+
129+
130+
131+
#

0 commit comments

Comments
 (0)