Skip to content

Commit 06b903c

Browse files
authored
Update cost.py
fix some typo, including: line 25: tf.cost.cross_entropy -> tl.cost.cross_entropy line 44: `x = ` -> `x = output`, `z = targets` -> `z = target`, logistic loss -> binary cross entropy loss line 81: mse = tf.reduce_sum(tf.squared_difference(output, target), reduction_indices = 1) -> mse = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(output, target), reduction_indices = 1)) (# The original one is 'square error'; using variable name 'mse' was not so proper :D) line 227: that apply L1 regularization -> that apply Li regularization (# In fact, there is L2 regularization.)
1 parent c5a18ca commit 06b903c

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

tensorlayer/cost.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def cross_entropy(output, target, name="cross_entropy_loss"):
2222
2323
Examples
2424
--------
25-
>>> ce = tf.cost.cross_entropy(y_logits, y_target_logits)
25+
>>> ce = tl.cost.cross_entropy(y_logits, y_target_logits)
2626
2727
References
2828
-----------
@@ -41,7 +41,7 @@ def cross_entropy(output, target, name="cross_entropy_loss"):
4141
def binary_cross_entropy(output, target, epsilon=1e-8, name='bce_loss'):
4242
"""Computes binary cross entropy given `output`.
4343
44-
For brevity, let `x = `, `z = targets`. The logistic loss is
44+
For brevity, let `x = output`, `z = target`. The binary cross entropy loss is
4545
4646
loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i]))
4747
@@ -78,8 +78,9 @@ def mean_squared_error(output, target):
7878
A distribution with shape: [batch_size, n_feature].
7979
"""
8080
with tf.name_scope("mean_squared_error_loss"):
81-
mse = tf.reduce_sum(tf.squared_difference(output, target), reduction_indices = 1)
82-
return tf.reduce_mean(mse)
81+
mse = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(output, target),
82+
reduction_indices = 1))
83+
return mse
8384

8485

8586

@@ -223,7 +224,7 @@ def li_regularizer(scale):
223224
224225
Returns
225226
--------
226-
A function with signature `li(weights, name=None)` that apply L1 regularization.
227+
A function with signature `li(weights, name=None)` that apply Li regularization.
227228
228229
Raises
229230
------

0 commit comments

Comments
 (0)