@@ -896,6 +896,55 @@ def test_linear(self):
896
896
x_int32 = np .random .randint (- 10 , 10 , (10 , 5 )).astype (np .int32 )
897
897
self .assertAllClose (x_int32 , activations .linear (x_int32 ))
898
898
899
+ def test_sparsemax (self ):
900
+ # result check with 1d
901
+ x_1d = np .linspace (1 , 12 , num = 12 )
902
+ expected_result = np .zeros_like (x_1d )
903
+ expected_result [- 1 ] = 1.0
904
+ self .assertAllClose (expected_result , activations .sparsemax (x_1d ))
905
+
906
+ # result check with 2d
907
+ x_2d = np .linspace (1 , 12 , num = 12 ).reshape (- 1 , 2 )
908
+ expected_result = np .zeros_like (x_2d )
909
+ expected_result [:, - 1 ] = 1.0
910
+ self .assertAllClose (expected_result , activations .sparsemax (x_2d ))
911
+
912
+ # result check with 3d
913
+ x_3d = np .linspace (1 , 12 , num = 12 ).reshape (- 1 , 1 , 3 )
914
+ expected_result = np .zeros_like (x_3d )
915
+ expected_result [:, :, - 1 ] = 1.0
916
+ self .assertAllClose (expected_result , activations .sparsemax (x_3d ))
917
+
918
+ # result check with axis=-2 with 2d input
919
+ x_2d = np .linspace (1 , 12 , num = 12 ).reshape (- 1 , 2 )
920
+ expected_result = np .zeros_like (x_2d )
921
+ expected_result [- 1 , :] = 1.0
922
+ self .assertAllClose (
923
+ expected_result , activations .sparsemax (x_2d , axis = - 2 )
924
+ )
925
+
926
+ # result check with axis=-2 with 3d input
927
+ x_3d = np .linspace (1 , 12 , num = 12 ).reshape (- 1 , 1 , 3 )
928
+ expected_result = np .ones_like (x_3d )
929
+ self .assertAllClose (
930
+ expected_result , activations .sparsemax (x_3d , axis = - 2 )
931
+ )
932
+
933
+ # result check with axis=-3 with 3d input
934
+ x_3d = np .linspace (1 , 12 , num = 12 ).reshape (- 1 , 1 , 3 )
935
+ expected_result = np .zeros_like (x_3d )
936
+ expected_result [- 1 , :, :] = 1.0
937
+ self .assertAllClose (
938
+ expected_result , activations .sparsemax (x_3d , axis = - 3 )
939
+ )
940
+
941
+ # result check with axis=-3 with 4d input
942
+ x_4d = np .linspace (1 , 12 , num = 12 ).reshape (- 1 , 1 , 1 , 2 )
943
+ expected_result = np .ones_like (x_4d )
944
+ self .assertAllClose (
945
+ expected_result , activations .sparsemax (x_4d , axis = - 3 )
946
+ )
947
+
899
948
def test_get_method (self ):
900
949
obj = activations .get ("relu" )
901
950
self .assertEqual (obj , activations .relu )
0 commit comments