Skip to content

Commit a3a368d

Browse files
Addition of Sparsemax activation (keras-team#20558)
* add: sprsemax ops * add: sparsemax api references to inits * add: sparsemax tests * edit: changes after test * edit: test case * rename: function in numpy * add: pointers to rest inits * edit: docstrings * change: x to logits in docstring
1 parent 75522e4 commit a3a368d

File tree

16 files changed

+217
-1
lines changed

16 files changed

+217
-1
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ dist/**
1818
examples/**/*.jpg
1919
.python-version
2020
.coverage
21-
*coverage.xml
21+
*coverage.xml
22+
.ruff_cache

keras/api/_tf_keras/keras/activations/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from keras.src.activations.activations import softplus
3434
from keras.src.activations.activations import softsign
3535
from keras.src.activations.activations import sparse_plus
36+
from keras.src.activations.activations import sparsemax
3637
from keras.src.activations.activations import squareplus
3738
from keras.src.activations.activations import tanh
3839
from keras.src.activations.activations import tanh_shrink

keras/api/_tf_keras/keras/ops/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
from keras.src.ops.nn import softsign
101101
from keras.src.ops.nn import sparse_categorical_crossentropy
102102
from keras.src.ops.nn import sparse_plus
103+
from keras.src.ops.nn import sparsemax
103104
from keras.src.ops.nn import squareplus
104105
from keras.src.ops.nn import tanh_shrink
105106
from keras.src.ops.numpy import abs

keras/api/_tf_keras/keras/ops/nn/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,6 @@
4545
from keras.src.ops.nn import softsign
4646
from keras.src.ops.nn import sparse_categorical_crossentropy
4747
from keras.src.ops.nn import sparse_plus
48+
from keras.src.ops.nn import sparsemax
4849
from keras.src.ops.nn import squareplus
4950
from keras.src.ops.nn import tanh_shrink

keras/api/activations/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from keras.src.activations.activations import softplus
3434
from keras.src.activations.activations import softsign
3535
from keras.src.activations.activations import sparse_plus
36+
from keras.src.activations.activations import sparsemax
3637
from keras.src.activations.activations import squareplus
3738
from keras.src.activations.activations import tanh
3839
from keras.src.activations.activations import tanh_shrink

keras/api/ops/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
from keras.src.ops.nn import softsign
101101
from keras.src.ops.nn import sparse_categorical_crossentropy
102102
from keras.src.ops.nn import sparse_plus
103+
from keras.src.ops.nn import sparsemax
103104
from keras.src.ops.nn import squareplus
104105
from keras.src.ops.nn import tanh_shrink
105106
from keras.src.ops.numpy import abs

keras/api/ops/nn/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,6 @@
4545
from keras.src.ops.nn import softsign
4646
from keras.src.ops.nn import sparse_categorical_crossentropy
4747
from keras.src.ops.nn import sparse_plus
48+
from keras.src.ops.nn import sparsemax
4849
from keras.src.ops.nn import squareplus
4950
from keras.src.ops.nn import tanh_shrink

keras/src/activations/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from keras.src.activations.activations import softplus
2525
from keras.src.activations.activations import softsign
2626
from keras.src.activations.activations import sparse_plus
27+
from keras.src.activations.activations import sparsemax
2728
from keras.src.activations.activations import squareplus
2829
from keras.src.activations.activations import tanh
2930
from keras.src.activations.activations import tanh_shrink
@@ -59,6 +60,7 @@
5960
mish,
6061
log_softmax,
6162
log_sigmoid,
63+
sparsemax,
6264
}
6365

6466
ALL_OBJECTS_DICT = {fn.__name__: fn for fn in ALL_OBJECTS}

keras/src/activations/activations.py

+25
Original file line numberDiff line numberDiff line change
@@ -617,3 +617,28 @@ def log_softmax(x, axis=-1):
617617
axis: Integer, axis along which the softmax is applied.
618618
"""
619619
return ops.log_softmax(x, axis=axis)
620+
621+
622+
@keras_export(["keras.activations.sparsemax"])
623+
def sparsemax(x, axis=-1):
624+
"""Sparsemax activation function.
625+
626+
For each batch `i`, and class `j`,
627+
sparsemax activation function is defined as:
628+
629+
`sparsemax(x)[i, j] = max(x[i, j] - τ(x[i, :]), 0).`
630+
631+
Args:
632+
x: Input tensor.
633+
axis: `int`, axis along which the sparsemax operation is applied.
634+
635+
Returns:
636+
A tensor, output of sparsemax transformation. Has the same type and
637+
shape as `x`.
638+
639+
Reference:
640+
641+
- [Martins et.al., 2016](https://arxiv.org/abs/1602.02068)
642+
"""
643+
x = backend.convert_to_tensor(x)
644+
return ops.sparsemax(x, axis)

keras/src/activations/activations_test.py

+49
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,55 @@ def test_linear(self):
896896
x_int32 = np.random.randint(-10, 10, (10, 5)).astype(np.int32)
897897
self.assertAllClose(x_int32, activations.linear(x_int32))
898898

899+
def test_sparsemax(self):
900+
# result check with 1d
901+
x_1d = np.linspace(1, 12, num=12)
902+
expected_result = np.zeros_like(x_1d)
903+
expected_result[-1] = 1.0
904+
self.assertAllClose(expected_result, activations.sparsemax(x_1d))
905+
906+
# result check with 2d
907+
x_2d = np.linspace(1, 12, num=12).reshape(-1, 2)
908+
expected_result = np.zeros_like(x_2d)
909+
expected_result[:, -1] = 1.0
910+
self.assertAllClose(expected_result, activations.sparsemax(x_2d))
911+
912+
# result check with 3d
913+
x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)
914+
expected_result = np.zeros_like(x_3d)
915+
expected_result[:, :, -1] = 1.0
916+
self.assertAllClose(expected_result, activations.sparsemax(x_3d))
917+
918+
# result check with axis=-2 with 2d input
919+
x_2d = np.linspace(1, 12, num=12).reshape(-1, 2)
920+
expected_result = np.zeros_like(x_2d)
921+
expected_result[-1, :] = 1.0
922+
self.assertAllClose(
923+
expected_result, activations.sparsemax(x_2d, axis=-2)
924+
)
925+
926+
# result check with axis=-2 with 3d input
927+
x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)
928+
expected_result = np.ones_like(x_3d)
929+
self.assertAllClose(
930+
expected_result, activations.sparsemax(x_3d, axis=-2)
931+
)
932+
933+
# result check with axis=-3 with 3d input
934+
x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)
935+
expected_result = np.zeros_like(x_3d)
936+
expected_result[-1, :, :] = 1.0
937+
self.assertAllClose(
938+
expected_result, activations.sparsemax(x_3d, axis=-3)
939+
)
940+
941+
# result check with axis=-3 with 4d input
942+
x_4d = np.linspace(1, 12, num=12).reshape(-1, 1, 1, 2)
943+
expected_result = np.ones_like(x_4d)
944+
self.assertAllClose(
945+
expected_result, activations.sparsemax(x_4d, axis=-3)
946+
)
947+
899948
def test_get_method(self):
900949
obj = activations.get("relu")
901950
self.assertEqual(obj, activations.relu)

keras/src/backend/jax/nn.py

+18
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,24 @@ def log_softmax(x, axis=-1):
142142
return jnn.log_softmax(x, axis=axis)
143143

144144

145+
def sparsemax(logits, axis=-1):
146+
# Sort logits along the specified axis in descending order
147+
logits = convert_to_tensor(logits)
148+
logits_sorted = -1.0 * jnp.sort(logits * -1.0, axis=axis)
149+
logits_cumsum = jnp.cumsum(logits_sorted, axis=axis) # find cumulative sum
150+
r = jnp.arange(1, logits.shape[axis] + 1) # Determine the sparsity
151+
r_shape = [1] * logits.ndim
152+
r_shape[axis] = -1 # Broadcast to match the target axis
153+
r = r.reshape(r_shape)
154+
support = logits_sorted - (logits_cumsum - 1) / r > 0
155+
# Find the threshold
156+
k = jnp.sum(support, axis=axis, keepdims=True)
157+
logits_cumsum_safe = jnp.where(support, logits_cumsum, 0.0)
158+
tau = (jnp.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k
159+
output = jnp.maximum(logits - tau, 0.0)
160+
return output
161+
162+
145163
def _convert_to_spatial_operand(
146164
x,
147165
num_spatial_dims,

keras/src/backend/numpy/nn.py

+18
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,24 @@ def log_softmax(x, axis=None):
191191
return x - max_x - logsumexp
192192

193193

194+
def sparsemax(logits, axis=-1):
195+
# Sort logits along the specified axis in descending order
196+
logits = convert_to_tensor(logits)
197+
logits_sorted = -1.0 * np.sort(-1.0 * logits, axis=axis)
198+
logits_cumsum = np.cumsum(logits_sorted, axis=axis)
199+
r = np.arange(1, logits.shape[axis] + 1)
200+
r_shape = [1] * logits.ndim
201+
r_shape[axis] = -1 # Broadcast to match the target axis
202+
r = r.reshape(r_shape)
203+
support = logits_sorted - (logits_cumsum - 1) / r > 0
204+
# Find the threshold
205+
k = np.sum(support, axis=axis, keepdims=True)
206+
logits_cumsum_safe = np.where(support, logits_cumsum, 0.0)
207+
tau = (np.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k
208+
output = np.maximum(logits - tau, 0.0)
209+
return output
210+
211+
194212
def _convert_to_spatial_operand(
195213
x,
196214
num_spatial_dims,

keras/src/backend/tensorflow/nn.py

+18
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,24 @@ def log_softmax(x, axis=-1):
151151
return tf.nn.log_softmax(x, axis=axis)
152152

153153

154+
def sparsemax(logits, axis=-1):
155+
# Sort logits along the specified axis in descending order
156+
logits = convert_to_tensor(logits)
157+
logits_sorted = tf.sort(logits, direction="DESCENDING", axis=axis)
158+
logits_cumsum = tf.cumsum(logits_sorted, axis=axis)
159+
r = tf.range(1, tf.shape(logits)[axis] + 1, dtype=logits.dtype)
160+
r_shape = [1] * len(logits.shape)
161+
r_shape[axis] = -1 # Broadcast to match the target axis
162+
r = tf.reshape(r, r_shape) # Reshape for broadcasting
163+
support = logits_sorted - (logits_cumsum - 1) / r > 0
164+
# Find the threshold
165+
logits_cumsum_safe = tf.where(support, logits_cumsum, 0.0)
166+
k = tf.reduce_sum(tf.cast(support, logits.dtype), axis=axis, keepdims=True)
167+
tau = (tf.reduce_sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k
168+
output = tf.maximum(logits - tau, 0.0)
169+
return output
170+
171+
154172
def _transpose_spatial_inputs(inputs):
155173
num_spatial_dims = len(inputs.shape) - 2
156174
# Tensorflow pooling does not support `channels_first` format, so

keras/src/backend/torch/nn.py

+22
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,28 @@ def log_softmax(x, axis=-1):
174174
return cast(output, dtype)
175175

176176

177+
def sparsemax(logits, axis=-1):
178+
# Sort logits along the specified axis in descending order
179+
logits = convert_to_tensor(logits)
180+
logits_sorted, _ = torch.sort(logits, dim=axis, descending=True)
181+
logits_cumsum = torch.cumsum(logits_sorted, dim=axis)
182+
r = torch.arange(
183+
1, logits.size(axis) + 1, device=logits.device, dtype=logits.dtype
184+
)
185+
r_shape = [1] * logits.ndim
186+
r_shape[axis] = -1 # Broadcast to match the target axis
187+
r = r.view(r_shape)
188+
support = logits_sorted - (logits_cumsum - 1) / r > 0
189+
# Find the threshold
190+
k = torch.sum(support, dim=axis, keepdim=True)
191+
logits_cumsum_safe = torch.where(
192+
support, logits_cumsum, torch.tensor(0.0, device=logits.device)
193+
)
194+
tau = (torch.sum(logits_cumsum_safe, dim=axis, keepdim=True) - 1) / k
195+
output = torch.clamp(logits - tau, min=0.0)
196+
return output
197+
198+
177199
def _compute_padding_length(
178200
input_length, kernel_length, stride, dilation_rate=1
179201
):

keras/src/ops/nn.py

+42
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,48 @@ def log_softmax(x, axis=-1):
951951
return backend.nn.log_softmax(x, axis=axis)
952952

953953

954+
class Sparsemax(Operation):
955+
def __init__(self, axis=-1):
956+
super().__init__()
957+
self.axis = axis
958+
959+
def call(self, x):
960+
return backend.nn.sparsemax(x, axis=self.axis)
961+
962+
def compute_output_spec(self, x):
963+
return KerasTensor(x.shape, dtype=x.dtype)
964+
965+
966+
@keras_export(["keras.ops.sparsemax", "keras.ops.nn.sparsemax"])
967+
def sparsemax(x, axis=-1):
968+
"""Sparsemax activation function.
969+
970+
For each batch `i`, and class `j`,
971+
sparsemax activation function is defined as:
972+
973+
`sparsemax(x)[i, j] = max(x[i, j] - τ(x[i, :]), 0).`
974+
975+
Args:
976+
x: Input tensor.
977+
axis: `int`, axis along which the sparsemax operation is applied.
978+
979+
Returns:
980+
A tensor, output of sparsemax transformation. Has the same type and
981+
shape as `x`.
982+
983+
Example:
984+
985+
>>> x = np.array([-1., 0., 1.])
986+
>>> x_sparsemax = keras.ops.sparsemax(x)
987+
>>> print(x_sparsemax)
988+
array([0., 0., 1.], shape=(3,), dtype=float64)
989+
990+
"""
991+
if any_symbolic_tensors((x,)):
992+
return Sparsemax(axis).symbolic_call(x)
993+
return backend.nn.sparsemax(x, axis=axis)
994+
995+
954996
class MaxPool(Operation):
955997
def __init__(
956998
self,

keras/src/ops/nn_test.py

+15
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ def test_log_softmax(self):
200200
self.assertEqual(knn.log_softmax(x, axis=1).shape, (None, 2, 3))
201201
self.assertEqual(knn.log_softmax(x, axis=-1).shape, (None, 2, 3))
202202

203+
def test_sparsemax(self):
204+
x = KerasTensor([None, 2, 3])
205+
self.assertEqual(knn.sparsemax(x).shape, (None, 2, 3))
206+
203207
def test_max_pool(self):
204208
data_format = backend.config.image_data_format()
205209
if data_format == "channels_last":
@@ -861,6 +865,10 @@ def test_log_softmax(self):
861865
self.assertEqual(knn.log_softmax(x, axis=1).shape, (1, 2, 3))
862866
self.assertEqual(knn.log_softmax(x, axis=-1).shape, (1, 2, 3))
863867

868+
def test_sparsemax(self):
869+
x = KerasTensor([1, 2, 3])
870+
self.assertEqual(knn.sparsemax(x).shape, (1, 2, 3))
871+
864872
def test_max_pool(self):
865873
data_format = backend.config.image_data_format()
866874
if data_format == "channels_last":
@@ -1487,6 +1495,13 @@ def test_log_softmax_correctness_with_axis_tuple(self):
14871495
)
14881496
self.assertAllClose(normalized_sum_by_axis, 1.0)
14891497

1498+
def test_sparsemax(self):
1499+
x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32)
1500+
self.assertAllClose(
1501+
knn.sparsemax(x),
1502+
[0.0, 0.0, 0.0, 0.0, 1.0],
1503+
)
1504+
14901505
def test_max_pool(self):
14911506
data_format = backend.config.image_data_format()
14921507
# Test 1D max pooling.

0 commit comments

Comments
 (0)