Skip to content

Commit fcaa672

Browse files
seanpmorganAakashKumarNainfailure-to-thrive
authored
Patch 0.8.2 (#1079)
* Fix kappa (#1047) * fix kappa * add more tests and rename regression variable * add cross_entropy test for binary class model (cherry picked from commit 83df531) * Typo in type #1067 (#1069) (cherry picked from commit 99352d0) * Bump to 0.8.2 * Add CI to release branches Co-authored-by: Aakash Kumar Nain <[email protected]> Co-authored-by: failure-to-thrive <[email protected]>
1 parent a68be5c commit fcaa672

File tree

5 files changed

+191
-71
lines changed

5 files changed

+191
-71
lines changed

.github/workflows/ci_test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ on:
44
push:
55
branches:
66
- master
7+
- r*
78
pull_request:
89
branches:
910
- master
11+
- r*
1012

1113
env:
1214
BAZEL_VERSION: 1.1.0

tensorflow_addons/layers/normalizations.py

Lines changed: 60 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tensorflow_addons.utils import types
2323

2424

25-
@tf.keras.utils.register_keras_serializable(package='Addons')
25+
@tf.keras.utils.register_keras_serializable(package="Addons")
2626
class GroupNormalization(tf.keras.layers.Layer):
2727
"""Group normalization layer.
2828
@@ -71,19 +71,21 @@ class GroupNormalization(tf.keras.layers.Layer):
7171
"""
7272

7373
@typechecked
74-
def __init__(self,
75-
groups: int = 2,
76-
axis: int = -1,
77-
epsilon: int = 1e-3,
78-
center: bool = True,
79-
scale: bool = True,
80-
beta_initializer: types.Initializer = 'zeros',
81-
gamma_initializer: types.Initializer = 'ones',
82-
beta_regularizer: types.Regularizer = None,
83-
gamma_regularizer: types.Regularizer = None,
84-
beta_constraint: types.Constraint = None,
85-
gamma_constraint: types.Constraint = None,
86-
**kwargs):
74+
def __init__(
75+
self,
76+
groups: int = 2,
77+
axis: int = -1,
78+
epsilon: float = 1e-3,
79+
center: bool = True,
80+
scale: bool = True,
81+
beta_initializer: types.Initializer = "zeros",
82+
gamma_initializer: types.Initializer = "ones",
83+
beta_regularizer: types.Regularizer = None,
84+
gamma_regularizer: types.Regularizer = None,
85+
beta_constraint: types.Constraint = None,
86+
gamma_constraint: types.Constraint = None,
87+
**kwargs
88+
):
8789
super().__init__(**kwargs)
8890
self.supports_masking = True
8991
self.groups = groups
@@ -117,39 +119,32 @@ def call(self, inputs):
117119
tensor_input_shape = tf.shape(inputs)
118120

119121
reshaped_inputs, group_shape = self._reshape_into_groups(
120-
inputs, input_shape, tensor_input_shape)
122+
inputs, input_shape, tensor_input_shape
123+
)
121124

122-
normalized_inputs = self._apply_normalization(reshaped_inputs,
123-
input_shape)
125+
normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape)
124126

125127
outputs = tf.reshape(normalized_inputs, tensor_input_shape)
126128

127129
return outputs
128130

129131
def get_config(self):
130132
config = {
131-
'groups':
132-
self.groups,
133-
'axis':
134-
self.axis,
135-
'epsilon':
136-
self.epsilon,
137-
'center':
138-
self.center,
139-
'scale':
140-
self.scale,
141-
'beta_initializer':
142-
tf.keras.initializers.serialize(self.beta_initializer),
143-
'gamma_initializer':
144-
tf.keras.initializers.serialize(self.gamma_initializer),
145-
'beta_regularizer':
146-
tf.keras.regularizers.serialize(self.beta_regularizer),
147-
'gamma_regularizer':
148-
tf.keras.regularizers.serialize(self.gamma_regularizer),
149-
'beta_constraint':
150-
tf.keras.constraints.serialize(self.beta_constraint),
151-
'gamma_constraint':
152-
tf.keras.constraints.serialize(self.gamma_constraint)
133+
"groups": self.groups,
134+
"axis": self.axis,
135+
"epsilon": self.epsilon,
136+
"center": self.center,
137+
"scale": self.scale,
138+
"beta_initializer": tf.keras.initializers.serialize(self.beta_initializer),
139+
"gamma_initializer": tf.keras.initializers.serialize(
140+
self.gamma_initializer
141+
),
142+
"beta_regularizer": tf.keras.regularizers.serialize(self.beta_regularizer),
143+
"gamma_regularizer": tf.keras.regularizers.serialize(
144+
self.gamma_regularizer
145+
),
146+
"beta_constraint": tf.keras.constraints.serialize(self.beta_constraint),
147+
"gamma_constraint": tf.keras.constraints.serialize(self.gamma_constraint),
153148
}
154149
base_config = super().get_config()
155150
return {**base_config, **config}
@@ -174,7 +169,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
174169
group_reduction_axes.pop(axis)
175170

176171
mean, variance = tf.nn.moments(
177-
reshaped_inputs, group_reduction_axes, keepdims=True)
172+
reshaped_inputs, group_reduction_axes, keepdims=True
173+
)
178174

179175
gamma, beta = self._get_reshaped_weights(input_shape)
180176
normalized_inputs = tf.nn.batch_normalization(
@@ -183,7 +179,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
183179
variance=variance,
184180
scale=gamma,
185181
offset=beta,
186-
variance_epsilon=self.epsilon)
182+
variance_epsilon=self.epsilon,
183+
)
187184
return normalized_inputs
188185

189186
def _get_reshaped_weights(self, input_shape):
@@ -200,10 +197,11 @@ def _get_reshaped_weights(self, input_shape):
200197
def _check_if_input_shape_is_none(self, input_shape):
201198
dim = input_shape[self.axis]
202199
if dim is None:
203-
raise ValueError('Axis ' + str(self.axis) + ' of '
204-
'input tensor should have a defined dimension '
205-
'but the layer received an input with shape ' +
206-
str(input_shape) + '.')
200+
raise ValueError(
201+
"Axis " + str(self.axis) + " of "
202+
"input tensor should have a defined dimension "
203+
"but the layer received an input with shape " + str(input_shape) + "."
204+
)
207205

208206
def _set_number_of_groups_for_instance_norm(self, input_shape):
209207
dim = input_shape[self.axis]
@@ -216,26 +214,30 @@ def _check_size_of_dimensions(self, input_shape):
216214
dim = input_shape[self.axis]
217215
if dim < self.groups:
218216
raise ValueError(
219-
'Number of groups (' + str(self.groups) + ') cannot be '
220-
'more than the number of channels (' + str(dim) + ').')
217+
"Number of groups (" + str(self.groups) + ") cannot be "
218+
"more than the number of channels (" + str(dim) + ")."
219+
)
221220

222221
if dim % self.groups != 0:
223222
raise ValueError(
224-
'Number of groups (' + str(self.groups) + ') must be a '
225-
'multiple of the number of channels (' + str(dim) + ').')
223+
"Number of groups (" + str(self.groups) + ") must be a "
224+
"multiple of the number of channels (" + str(dim) + ")."
225+
)
226226

227227
def _check_axis(self):
228228

229229
if self.axis == 0:
230230
raise ValueError(
231231
"You are trying to normalize your batch axis. Do you want to "
232-
"use tf.layer.batch_normalization instead")
232+
"use tf.layer.batch_normalization instead"
233+
)
233234

234235
def _create_input_spec(self, input_shape):
235236

236237
dim = input_shape[self.axis]
237238
self.input_spec = tf.keras.layers.InputSpec(
238-
ndim=len(input_shape), axes={self.axis: dim})
239+
ndim=len(input_shape), axes={self.axis: dim}
240+
)
239241

240242
def _add_gamma_weight(self, input_shape):
241243

@@ -245,10 +247,11 @@ def _add_gamma_weight(self, input_shape):
245247
if self.scale:
246248
self.gamma = self.add_weight(
247249
shape=shape,
248-
name='gamma',
250+
name="gamma",
249251
initializer=self.gamma_initializer,
250252
regularizer=self.gamma_regularizer,
251-
constraint=self.gamma_constraint)
253+
constraint=self.gamma_constraint,
254+
)
252255
else:
253256
self.gamma = None
254257

@@ -260,10 +263,11 @@ def _add_beta_weight(self, input_shape):
260263
if self.center:
261264
self.beta = self.add_weight(
262265
shape=shape,
263-
name='beta',
266+
name="beta",
264267
initializer=self.beta_initializer,
265268
regularizer=self.beta_regularizer,
266-
constraint=self.beta_constraint)
269+
constraint=self.beta_constraint,
270+
)
267271
else:
268272
self.beta = None
269273

@@ -274,7 +278,7 @@ def _create_broadcast_shape(self, input_shape):
274278
return broadcast_shape
275279

276280

277-
@tf.keras.utils.register_keras_serializable(package='Addons')
281+
@tf.keras.utils.register_keras_serializable(package="Addons")
278282
class InstanceNormalization(GroupNormalization):
279283
"""Instance normalization layer.
280284

tensorflow_addons/metrics/cohens_kappa.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,26 @@ def __init__(self,
6767
num_classes: FloatTensorLike,
6868
name: str = 'cohen_kappa',
6969
weightage: Optional[str] = None,
70+
sparse_labels: bool = False,
71+
regression: bool = False,
7072
dtype: AcceptableDTypes = None,
7173
**kwargs):
7274
"""Creates a `CohenKappa` instance.
7375
7476
Args:
7577
num_classes: Number of unique classes in your dataset.
76-
name: (Optional) String name of the metric instance.
77-
weightage: (Optional) Weighting to be considered for calculating
78+
weightage: (optional) Weighting to be considered for calculating
7879
kappa statistics. A valid value is one of
79-
[None, 'linear', 'quadratic']. Defaults to `None`.
80-
dtype: (Optional) Data type of the metric result.
81-
Defaults to `None`.
80+
[None, 'linear', 'quadratic']. Defaults to `None`
81+
sparse_lables: (bool) Valid only for multi-class scenario.
82+
If True, ground truth labels are expected tp be integers
83+
and not one-hot encoded
84+
regression: (bool) If set, that means the problem is being treated
85+
as a regression problem where you are regressing the predictions.
86+
**Note:** If you are regressing for the values, the the output layer
87+
should contain a single unit.
88+
name: (optional) String name of the metric instance
89+
dtype: (optional) Data type of the metric result. Defaults to `None`
8290
8391
Raises:
8492
ValueError: If the value passed for `weightage` is invalid
@@ -89,8 +97,18 @@ def __init__(self,
8997
if weightage not in (None, 'linear', 'quadratic'):
9098
raise ValueError("Unknown kappa weighting type.")
9199

100+
if num_classes == 2:
101+
self._update = self._update_binary_class_model
102+
elif num_classes > 2:
103+
self._update = self._update_multi_class_model
104+
else:
105+
raise ValueError("""Number of classes must be
106+
greater than or euqal to two""")
107+
92108
self.weightage = weightage
93109
self.num_classes = num_classes
110+
self.regression = regression
111+
self.sparse_labels = sparse_labels
94112
self.conf_mtx = self.add_weight(
95113
'conf_mtx',
96114
shape=(self.num_classes, self.num_classes),
@@ -114,22 +132,42 @@ def update_state(self, y_true, y_pred, sample_weight=None):
114132
Returns:
115133
Update op.
116134
"""
135+
return self._update(y_true, y_pred, sample_weight)
136+
137+
def _update_binary_class_model(self, y_true, y_pred, sample_weight=None):
117138
y_true = tf.cast(y_true, dtype=tf.int64)
118-
y_pred = tf.cast(y_pred, dtype=tf.int64)
139+
y_pred = tf.cast(y_pred, dtype=tf.float32)
140+
y_pred = tf.cast(y_pred > 0.5, dtype=tf.int64)
141+
return self._update_confusion_matrix(y_true, y_pred, sample_weight)
142+
143+
def _update_multi_class_model(self, y_true, y_pred, sample_weight=None):
144+
if not self.sparse_labels:
145+
y_true = tf.cast(tf.argmax(y_true, axis=-1), dtype=tf.int64)
146+
else:
147+
y_true = tf.cast(y_true, dtype=tf.int64)
148+
149+
if tf.rank(y_pred) > 1:
150+
if not self.regression:
151+
y_pred = tf.cast(tf.argmax(y_pred, axis=-1), dtype=tf.int64)
152+
else:
153+
y_pred = tf.math.round(tf.math.abs(y_pred))
154+
y_pred = tf.cast(y_pred, dtype=tf.int64)
155+
else:
156+
y_pred = tf.cast(y_pred, dtype=tf.int64)
157+
158+
return self._update_confusion_matrix(y_true, y_pred, sample_weight)
119159

120-
if y_true.shape != y_pred.shape:
121-
raise ValueError(
122-
"Number of samples in `y_true` and `y_pred` are different")
160+
def _update_confusion_matrix(self, y_true, y_pred, sample_weight):
161+
y_true = tf.squeeze(y_true)
162+
y_pred = tf.squeeze(y_pred)
123163

124-
# compute the new values of the confusion matrix
125164
new_conf_mtx = tf.math.confusion_matrix(
126165
labels=y_true,
127166
predictions=y_pred,
128167
num_classes=self.num_classes,
129168
weights=sample_weight,
130169
dtype=tf.float32)
131170

132-
# update the values in the original confusion matrix
133171
return self.conf_mtx.assign_add(new_conf_mtx)
134172

135173
def result(self):
@@ -179,6 +217,8 @@ def get_config(self):
179217
config = {
180218
"num_classes": self.num_classes,
181219
"weightage": self.weightage,
220+
"sparse_labels": self.sparse_labels,
221+
"regression": self.regression
182222
}
183223
base_config = super().get_config()
184224
return {**base_config, **config}

0 commit comments

Comments
 (0)