15
15
16
16
import numpy as np
17
17
import torch
18
+ import torch .distributed as dist
18
19
19
20
from monai .handlers import AveragePrecision
20
21
from monai .transforms import Activations , AsDiscrete
22
+ from tests .utils import DistCall , DistTestCase
21
23
22
24
23
25
class TestHandlerAveragePrecision (unittest .TestCase ):
@@ -44,5 +46,34 @@ def test_compute(self):
44
46
np .testing .assert_allclose (0.8333333 , ap )
45
47
46
48
49
+ class DistributedAveragePrecision (DistTestCase ):
50
+
51
+ @DistCall (nnodes = 1 , nproc_per_node = 2 , node_rank = 0 )
52
+ def test_compute (self ):
53
+ ap_metric = AveragePrecision ()
54
+ act = Activations (softmax = True )
55
+ to_onehot = AsDiscrete (to_onehot = 2 )
56
+
57
+ device = f"cuda:{ dist .get_rank ()} " if torch .cuda .is_available () else "cpu"
58
+ if dist .get_rank () == 0 :
59
+ y_pred = [torch .tensor ([0.1 , 0.9 ], device = device ), torch .tensor ([0.3 , 1.4 ], device = device )]
60
+ y = [torch .tensor ([0 ], device = device ), torch .tensor ([1 ], device = device )]
61
+
62
+ if dist .get_rank () == 1 :
63
+ y_pred = [
64
+ torch .tensor ([0.2 , 0.1 ], device = device ),
65
+ torch .tensor ([0.1 , 0.5 ], device = device ),
66
+ torch .tensor ([0.3 , 0.4 ], device = device ),
67
+ ]
68
+ y = [torch .tensor ([0 ], device = device ), torch .tensor ([1 ], device = device ), torch .tensor ([1 ], device = device )]
69
+
70
+ y_pred = [act (p ) for p in y_pred ]
71
+ y = [to_onehot (y_ ) for y_ in y ]
72
+ ap_metric .update ([y_pred , y ])
73
+
74
+ result = ap_metric .compute ()
75
+ np .testing .assert_allclose (0.7778 , result , rtol = 1e-4 )
76
+
77
+
47
78
if __name__ == "__main__" :
48
79
unittest .main ()
0 commit comments