2323
2424def test_config ():
2525 # mcc object
26- mcc1 = MatthewsCorrelationCoefficient (num_classes = 1 )
27- assert mcc1 .num_classes == 1
26+ mcc1 = MatthewsCorrelationCoefficient (num_classes = 2 )
27+ assert mcc1 .num_classes == 2
2828 assert mcc1 .dtype == tf .float32
2929 # check configure
3030 mcc2 = MatthewsCorrelationCoefficient .from_config (mcc1 .get_config ())
31- assert mcc2 .num_classes == 1
31+ assert mcc2 .num_classes == 2
3232 assert mcc2 .dtype == tf .float32
3333
3434
3535def check_results (obj , value ):
3636 np .testing .assert_allclose (value , obj .result ().numpy (), atol = 1e-6 )
3737
3838
39+ def test_binary_classes_sparse ():
40+ gt_label = tf .constant ([[1.0 ], [1.0 ], [1.0 ], [0.0 ]], dtype = tf .float32 )
41+ preds = tf .constant ([[1.0 ], [0.0 ], [1.0 ], [1.0 ]], dtype = tf .float32 )
42+ # Initialize
43+ mcc = MatthewsCorrelationCoefficient (1 )
44+ # Update
45+ mcc .update_state (gt_label , preds )
46+ # Check results
47+ check_results (mcc , [- 0.33333334 ])
48+
49+
3950def test_binary_classes ():
4051 gt_label = tf .constant (
4152 [[0.0 , 1.0 ], [0.0 , 1.0 ], [0.0 , 1.0 ], [1.0 , 0.0 ]], dtype = tf .float32
@@ -91,6 +102,16 @@ def test_multiple_classes():
91102 sklearn_result = sklearn_matthew (gt_label .argmax (axis = 1 ), preds .argmax (axis = 1 ))
92103 check_results (mcc , sklearn_result )
93104
105+ gt_label_sparse = tf .constant (
106+ [[0.0 ], [2.0 ], [0.0 ], [2.0 ], [1.0 ], [1.0 ], [0.0 ], [0.0 ], [2.0 ], [1.0 ]]
107+ )
108+ preds_sparse = tf .constant (
109+ [[2.0 ], [0.0 ], [2.0 ], [2.0 ], [2.0 ], [2.0 ], [2.0 ], [0.0 ], [2.0 ], [2.0 ]]
110+ )
111+ mcc = MatthewsCorrelationCoefficient (3 )
112+ mcc .update_state (gt_label_sparse , preds_sparse )
113+ check_results (mcc , sklearn_result )
114+
94115
95116# Keras model API check
96117def test_keras_model ():
@@ -110,13 +131,9 @@ def test_keras_model():
110131
111132
112133def test_reset_states_graph ():
113- gt_label = tf .constant (
114- [[0.0 , 1.0 ], [0.0 , 1.0 ], [0.0 , 1.0 ], [1.0 , 0.0 ]], dtype = tf .float32
115- )
116- preds = tf .constant (
117- [[0.0 , 1.0 ], [1.0 , 0.0 ], [0.0 , 1.0 ], [0.0 , 1.0 ]], dtype = tf .float32
118- )
119- mcc = MatthewsCorrelationCoefficient (2 )
134+ gt_label = tf .constant ([[1.0 ], [1.0 ], [1.0 ], [0.0 ]], dtype = tf .float32 )
135+ preds = tf .constant ([[1.0 ], [0.0 ], [1.0 ], [1.0 ]], dtype = tf .float32 )
136+ mcc = MatthewsCorrelationCoefficient (1 )
120137 mcc .update_state (gt_label , preds )
121138
122139 @tf .function
0 commit comments