12
12
from torch .cuda .amp import autocast
13
13
import torch .nn .parallel as dp
14
14
from torch .testing ._internal .common_cuda import TEST_MULTIGPU , TEST_CUDA
15
- from torch .testing ._internal .common_utils import run_tests , TestCase , repeat_test_for_types , ALL_TENSORTYPES
15
+ from torch .testing ._internal .common_device_type import instantiate_device_type_tests , dtypes , onlyCUDA , skipMeta
16
+ from torch .testing ._internal .common_utils import run_tests , TestCase
16
17
from torch .testing ._internal .common_utils import _assertGradAndGradgradChecks , gradcheck
17
18
from torch .testing ._internal .common_utils import dtype2prec_DONTUSE
18
19
from torch .testing ._internal .common_utils import sandcastle_skip_if
@@ -434,93 +435,6 @@ def forward(self, *input):
434
435
output = dp .data_parallel (Net (), input , gpus )
435
436
self .assertEqual (output , fn (input ))
436
437
437
- @sandcastle_skip_if (not TEST_CUDA , "CUDA unavailable" )
438
- @repeat_test_for_types (ALL_TENSORTYPES )
439
- def test_data_parallel_module (self , dtype = torch .float ):
440
- l = nn .Linear (10 , 5 ).to ("cuda" , dtype )
441
- i = torch .randn (20 , 10 , device = "cuda" , dtype = dtype )
442
- expected_out = l (i )
443
- net = nn .DataParallel (l )
444
- out = net (i )
445
- self .assertEqual (out .get_device (), 0 )
446
- self .assertEqual (out , expected_out , atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
447
-
448
- @sandcastle_skip_if (not TEST_CUDA , "CUDA unavailable" )
449
- @repeat_test_for_types (ALL_TENSORTYPES )
450
- def test_data_parallel_module_kwargs_only (self , dtype = torch .float ):
451
- class Net (nn .Module ):
452
- def __init__ (self ):
453
- super (Net , self ).__init__ ()
454
- self .l = l
455
-
456
- def forward (self , input ):
457
- return self .l (input )
458
-
459
- l = nn .Linear (10 , 5 ).to ("cuda" , dtype )
460
- i = torch .randn (20 , 10 , device = "cuda" , dtype = dtype )
461
- expected_out = l (i )
462
- n = nn .DataParallel (Net ())
463
- out = n (input = i )
464
- self .assertEqual (out .get_device (), 0 )
465
- self .assertEqual (out , expected_out , atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
466
-
467
- @sandcastle_skip_if (not TEST_CUDA , "CUDA unavailable" )
468
- @repeat_test_for_types (ALL_TENSORTYPES )
469
- def test_data_parallel_module_kwargs_only_empty_list (self , dtype = torch .float ):
470
- class Net (nn .Module ):
471
- def __init__ (self ):
472
- super (Net , self ).__init__ ()
473
- self .l = l
474
-
475
- def forward (self , input ):
476
- return self .l (input ['data' ])
477
-
478
- l = nn .Linear (10 , 5 ).to ("cuda" , dtype )
479
- i = torch .randn (20 , 10 , device = "cuda" , dtype = dtype )
480
- expected_out = l (i )
481
- n = nn .DataParallel (Net ())
482
- out = n (input = {'data' : i , 'unused' : []})
483
- self .assertEqual (out .get_device (), 0 )
484
- self .assertEqual (out , expected_out , atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
485
-
486
- @sandcastle_skip_if (not TEST_CUDA , "CUDA unavailable" )
487
- @repeat_test_for_types (ALL_TENSORTYPES )
488
- def test_data_parallel_module_kwargs_only_empty_dict (self , dtype = torch .float ):
489
- class Net (nn .Module ):
490
- def __init__ (self ):
491
- super (Net , self ).__init__ ()
492
- self .l = l
493
-
494
- def forward (self , input ):
495
- return self .l (input ['data' ])
496
-
497
- l = nn .Linear (10 , 5 ).to ("cuda" , dtype )
498
- i = torch .randn (20 , 10 , device = "cuda" , dtype = dtype )
499
- expected_out = l (i )
500
- n = nn .DataParallel (Net ())
501
- out = n (input = {'data' : i , 'unused' : {}})
502
- self .assertEqual (out .get_device (), 0 )
503
- self .assertEqual (out , expected_out , atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
504
-
505
- @sandcastle_skip_if (not TEST_CUDA , "CUDA unavailable" )
506
- @repeat_test_for_types (ALL_TENSORTYPES )
507
- def test_data_parallel_module_kwargs_only_empty_tuple (self , dtype = torch .float ):
508
- class Net (nn .Module ):
509
- def __init__ (self ):
510
- super (Net , self ).__init__ ()
511
- self .l = l
512
-
513
- def forward (self , input ):
514
- return self .l (input ['data' ])
515
-
516
- l = nn .Linear (10 , 5 ).to ("cuda" , dtype )
517
- i = torch .randn (20 , 10 , device = "cuda" , dtype = dtype )
518
- expected_out = l (i )
519
- n = nn .DataParallel (Net ())
520
- out = n (input = {'data' : i , 'unused' : ()})
521
- self .assertEqual (out .get_device (), 0 )
522
- self .assertEqual (out , expected_out , atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
523
-
524
438
@sandcastle_skip_if (not TEST_MULTIGPU , "multi-GPU not supported" )
525
439
def test_data_parallel_module_zero_inputs (self ):
526
440
class TestModule (nn .Module ):
@@ -757,13 +671,13 @@ def test_save_replica_module(self):
757
671
@sandcastle_skip_if (not TEST_MULTIGPU , "multi-GPU not supported" )
758
672
def test_strided_grad_layout (self ):
759
673
class ConvNet (nn .Module ):
760
- def __init__ (self , layouts , dtypes ):
674
+ def __init__ (self , layouts , dtype_list ):
761
675
super (ConvNet , self ).__init__ ()
762
- self .dtypes = dtypes
763
- self .conv0 = torch .nn .Conv2d (8 , 16 , (2 , 2 )).to (memory_format = layouts [0 ], dtype = dtypes [0 ])
764
- self .conv1 = torch .nn .Conv2d (16 , 32 , (2 , 2 )).to (memory_format = layouts [1 ], dtype = dtypes [1 ])
765
- self .conv2 = torch .nn .Conv2d (32 , 16 , (2 , 2 )).to (memory_format = layouts [2 ], dtype = dtypes [2 ])
766
- self .conv3 = torch .nn .Conv2d (16 , 8 , (2 , 2 )).to (memory_format = layouts [3 ], dtype = dtypes [3 ])
676
+ self .dtypes = dtype_list
677
+ self .conv0 = torch .nn .Conv2d (8 , 16 , (2 , 2 )).to (memory_format = layouts [0 ], dtype = dtype_list [0 ])
678
+ self .conv1 = torch .nn .Conv2d (16 , 32 , (2 , 2 )).to (memory_format = layouts [1 ], dtype = dtype_list [1 ])
679
+ self .conv2 = torch .nn .Conv2d (32 , 16 , (2 , 2 )).to (memory_format = layouts [2 ], dtype = dtype_list [2 ])
680
+ self .conv3 = torch .nn .Conv2d (16 , 8 , (2 , 2 )).to (memory_format = layouts [3 ], dtype = dtype_list [3 ])
767
681
768
682
def forward (self , x ):
769
683
x = x .to (self .dtypes [0 ])
@@ -786,10 +700,10 @@ def forward(self, x):
786
700
device_ids = list (range (ndevs ))
787
701
788
702
with torch .backends .cudnn .flags (enabled = True , deterministic = True , benchmark = False ):
789
- for formats , dtypes in product (layer_formats , layer_dtypes ):
703
+ for formats , dtype_list in product (layer_formats , layer_dtypes ):
790
704
model_msg = "formats = {} dtypes = {}" .format (formats , dtypes )
791
705
try :
792
- m = ConvNet (formats , dtypes ).cuda (device = "cuda:0" )
706
+ m = ConvNet (formats , dtype_list ).cuda (device = "cuda:0" )
793
707
m_dp = dp .DataParallel (deepcopy (m ), device_ids = device_ids )
794
708
opt = torch .optim .SGD (m .parameters (), lr = 0.1 )
795
709
opt_dp = torch .optim .SGD (m_dp .parameters (), lr = 0.1 )
@@ -855,5 +769,102 @@ def forward(self, inp):
855
769
model (input )
856
770
857
771
772
+ class TestDataParallelDeviceType (TestCase ):
773
+
774
+ @onlyCUDA
775
+ @skipMeta
776
+ @dtypes (torch .float , torch .double , torch .half )
777
+ def test_data_parallel_module (self , device , dtype ):
778
+ l = nn .Linear (10 , 5 ).to (device , dtype )
779
+ i = torch .randn (20 , 10 , device = device , dtype = dtype )
780
+ expected_out = l (i )
781
+ net = nn .DataParallel (l )
782
+ out = net (i )
783
+ self .assertEqual (out .get_device (), 0 )
784
+ self .assertEqual (out , expected_out , atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
785
+
786
+ @onlyCUDA
787
+ @skipMeta
788
+ @dtypes (torch .float , torch .double , torch .half )
789
+ def test_data_parallel_module_kwargs_only (self , device , dtype ):
790
+ class Net (nn .Module ):
791
+ def __init__ (self ):
792
+ super (Net , self ).__init__ ()
793
+ self .l = l
794
+
795
+ def forward (self , input ):
796
+ return self .l (input )
797
+
798
+ l = nn .Linear (10 , 5 ).to (device , dtype )
799
+ i = torch .randn (20 , 10 , device = device , dtype = dtype )
800
+ expected_out = l (i )
801
+ n = nn .DataParallel (Net ())
802
+ out = n (input = i )
803
+ self .assertEqual (out .get_device (), 0 )
804
+ self .assertEqual (out , expected_out , atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
805
+
806
+ @onlyCUDA
807
+ @skipMeta
808
+ @dtypes (torch .float , torch .double , torch .half )
809
+ def test_data_parallel_module_kwargs_only_empty_list (self , device , dtype ):
810
+ class Net (nn .Module ):
811
+ def __init__ (self ):
812
+ super (Net , self ).__init__ ()
813
+ self .l = l
814
+
815
+ def forward (self , input ):
816
+ return self .l (input ['data' ])
817
+
818
+ l = nn .Linear (10 , 5 ).to (device , dtype )
819
+ i = torch .randn (20 , 10 , device = device , dtype = dtype )
820
+ expected_out = l (i )
821
+ n = nn .DataParallel (Net ())
822
+ out = n (input = {'data' : i , 'unused' : []})
823
+ self .assertEqual (out .get_device (), 0 )
824
+ self .assertEqual (out , expected_out , atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
825
+
826
+ @onlyCUDA
827
+ @skipMeta
828
+ @dtypes (torch .float , torch .double , torch .half )
829
+ def test_data_parallel_module_kwargs_only_empty_dict (self , device , dtype ):
830
+ class Net (nn .Module ):
831
+ def __init__ (self ):
832
+ super (Net , self ).__init__ ()
833
+ self .l = l
834
+
835
+ def forward (self , input ):
836
+ return self .l (input ['data' ])
837
+
838
+ l = nn .Linear (10 , 5 ).to (device , dtype )
839
+ i = torch .randn (20 , 10 , device = device , dtype = dtype )
840
+ expected_out = l (i )
841
+ n = nn .DataParallel (Net ())
842
+ out = n (input = {'data' : i , 'unused' : {}})
843
+ self .assertEqual (out .get_device (), 0 )
844
+ self .assertEqual (out , expected_out , atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
845
+
846
+ @onlyCUDA
847
+ @skipMeta
848
+ @dtypes (torch .float , torch .double , torch .half )
849
+ def test_data_parallel_module_kwargs_only_empty_tuple (self , device , dtype ):
850
+ class Net (nn .Module ):
851
+ def __init__ (self ):
852
+ super (Net , self ).__init__ ()
853
+ self .l = l
854
+
855
+ def forward (self , input ):
856
+ return self .l (input ['data' ])
857
+
858
+ l = nn .Linear (10 , 5 ).to (device , dtype )
859
+ i = torch .randn (20 , 10 , device = device , dtype = dtype )
860
+ expected_out = l (i )
861
+ n = nn .DataParallel (Net ())
862
+ out = n (input = {'data' : i , 'unused' : ()})
863
+ self .assertEqual (out .get_device (), 0 )
864
+ self .assertEqual (out , expected_out , atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
865
+
866
+
867
+ instantiate_device_type_tests (TestDataParallelDeviceType , globals ())
868
+
858
869
if __name__ == '__main__' :
859
870
run_tests ()
0 commit comments