Skip to content

Commit 129f736

Browse files
committed
将CNN中的dropout替换为L2正则化
1 parent 90dd55c commit 129f736

File tree

3 files changed

+70
-5
lines changed

3 files changed

+70
-5
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
.DS_Store
2+
pb/
3+
tflite/
24
data/cnews
35
data/thucnews
46
__pycache__

cnn_model.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# coding: utf-8
2+
from functools import partial
23

34
import tensorflow as tf
45

@@ -24,6 +25,8 @@ class TCNNConfig(object):
2425
print_per_batch = 100 # 每多少轮输出一次结果
2526
save_per_batch = 10 # 每多少轮存入tensorboard
2627

28+
scale = 0.01
29+
2730

2831
class TextCNN(object):
2932
"""文本分类,CNN模型"""
@@ -40,6 +43,11 @@ def __init__(self, config):
4043

4144
def cnn(self):
4245
"""CNN模型"""
46+
my_dense_layer = partial(
47+
tf.layers.dense, activation=tf.nn.relu,
48+
# 在这里传入了L2正则化函数,并在函数中传入正则化系数。
49+
kernel_regularizer=tf.contrib.layers.l2_regularizer(self.config.scale)
50+
)
4351
# 词向量映射
4452
with tf.device('/cpu:0'):
4553
embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
@@ -52,10 +60,11 @@ def cnn(self):
5260
gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')
5361

5462
with tf.name_scope("score"):
55-
# 全连接层,后面接dropout以及relu激活
56-
fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')
57-
fc = tf.contrib.layers.dropout(fc, self.keep_prob)
58-
fc = tf.nn.relu(fc)
63+
# 全连接层
64+
fc = my_dense_layer(gmp, self.config.hidden_dim, name='fc1')
65+
# fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')
66+
# fc = tf.contrib.layers.dropout(fc, self.keep_prob)
67+
# fc = tf.nn.relu(fc)
5968

6069
# 分类器
6170
self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
@@ -64,7 +73,9 @@ def cnn(self):
6473
with tf.name_scope("optimize"):
6574
# 损失函数,交叉熵
6675
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
67-
self.loss = tf.reduce_mean(cross_entropy)
76+
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
77+
self.loss = tf.add_n([tf.reduce_mean(cross_entropy)] + reg_losses)
78+
# self.loss = tf.reduce_mean(cross_entropy)
6879
# 优化器
6980
self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)
7081

freeze.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import tensorflow as tf
2+
3+
version = 1.0
4+
output_graph = "./pb/model_{}.pb".format(version)
5+
output_tflite_model = "./tflite/model_{}.tflite".format(version)
6+
7+
8+
def freeze_graph(input_checkpoint):
9+
"""
10+
:param input_checkpoint:
11+
:return:
12+
"""
13+
# checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
14+
# input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
15+
16+
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
17+
output_node_names = "score/fc2/BiasAdd"
18+
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
19+
20+
with tf.Session() as sess:
21+
saver.restore(sess, input_checkpoint) # 恢复图并得到数据
22+
output_graph_def = tf.graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
23+
sess=sess,
24+
input_graph_def=sess.graph_def, # 等于:sess.graph_def
25+
output_node_names=output_node_names.split(",")
26+
) # 如果有多个输出节点,以逗号隔开
27+
28+
with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
29+
f.write(output_graph_def.SerializeToString()) # 序列化输出
30+
31+
32+
def convert_to_tflite():
33+
input_tensors = [
34+
"input_x"
35+
]
36+
output_tensors = [
37+
"score/fc2/BiasAdd"
38+
]
39+
converter = tf.lite.TFLiteConverter.from_frozen_graph(
40+
output_graph,
41+
input_tensors,
42+
output_tensors)
43+
converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
44+
tf.lite.OpsSet.SELECT_TF_OPS]
45+
tflite_model = converter.convert()
46+
open(output_tflite_model, "wb").write(tflite_model)
47+
48+
49+
if __name__ == "__main__":
50+
# freeze_graph("./checkpoints/textcnn/best_validation")
51+
convert_to_tflite()
52+

0 commit comments

Comments
 (0)