1
1
# coding: utf-8
2
+ from functools import partial
2
3
3
4
import tensorflow as tf
4
5
@@ -24,6 +25,8 @@ class TCNNConfig(object):
24
25
print_per_batch = 100 # 每多少轮输出一次结果
25
26
save_per_batch = 10 # 每多少轮存入tensorboard
26
27
28
+ scale = 0.01
29
+
27
30
28
31
class TextCNN (object ):
29
32
"""文本分类,CNN模型"""
@@ -40,6 +43,11 @@ def __init__(self, config):
40
43
41
44
def cnn (self ):
42
45
"""CNN模型"""
46
+ my_dense_layer = partial (
47
+ tf .layers .dense , activation = tf .nn .relu ,
48
+ # 在这里传入了L2正则化函数,并在函数中传入正则化系数。
49
+ kernel_regularizer = tf .contrib .layers .l2_regularizer (self .config .scale )
50
+ )
43
51
# 词向量映射
44
52
with tf .device ('/cpu:0' ):
45
53
embedding = tf .get_variable ('embedding' , [self .config .vocab_size , self .config .embedding_dim ])
@@ -52,10 +60,11 @@ def cnn(self):
52
60
gmp = tf .reduce_max (conv , reduction_indices = [1 ], name = 'gmp' )
53
61
54
62
with tf .name_scope ("score" ):
55
- # 全连接层,后面接dropout以及relu激活
56
- fc = tf .layers .dense (gmp , self .config .hidden_dim , name = 'fc1' )
57
- fc = tf .contrib .layers .dropout (fc , self .keep_prob )
58
- fc = tf .nn .relu (fc )
63
+ # 全连接层
64
+ fc = my_dense_layer (gmp , self .config .hidden_dim , name = 'fc1' )
65
+ # fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')
66
+ # fc = tf.contrib.layers.dropout(fc, self.keep_prob)
67
+ # fc = tf.nn.relu(fc)
59
68
60
69
# 分类器
61
70
self .logits = tf .layers .dense (fc , self .config .num_classes , name = 'fc2' )
@@ -64,7 +73,9 @@ def cnn(self):
64
73
with tf .name_scope ("optimize" ):
65
74
# 损失函数,交叉熵
66
75
cross_entropy = tf .nn .softmax_cross_entropy_with_logits (logits = self .logits , labels = self .input_y )
67
- self .loss = tf .reduce_mean (cross_entropy )
76
+ reg_losses = tf .get_collection (tf .GraphKeys .REGULARIZATION_LOSSES )
77
+ self .loss = tf .add_n ([tf .reduce_mean (cross_entropy )] + reg_losses )
78
+ # self.loss = tf.reduce_mean(cross_entropy)
68
79
# 优化器
69
80
self .optim = tf .train .AdamOptimizer (learning_rate = self .config .learning_rate ).minimize (self .loss )
70
81
0 commit comments