Skip to content

Commit c4400fc

Browse files
janeyx99facebook-github-bot
authored andcommitted
Retire repeat_test_for_types (pytorch#71033)
Summary: Fixes pytorch#69865 cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: pytorch#71033 Reviewed By: mruberry Differential Revision: D33486370 Pulled By: janeyx99 fbshipit-source-id: 71f9383dbc1e00b572f26eb4f04d0a94c6759e35
1 parent e1b84e1 commit c4400fc

File tree

2 files changed

+107
-110
lines changed

2 files changed

+107
-110
lines changed

test/distributed/test_data_parallel.py

Lines changed: 107 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from torch.cuda.amp import autocast
1313
import torch.nn.parallel as dp
1414
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
15-
from torch.testing._internal.common_utils import run_tests, TestCase, repeat_test_for_types, ALL_TENSORTYPES
15+
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, onlyCUDA, skipMeta
16+
from torch.testing._internal.common_utils import run_tests, TestCase
1617
from torch.testing._internal.common_utils import _assertGradAndGradgradChecks, gradcheck
1718
from torch.testing._internal.common_utils import dtype2prec_DONTUSE
1819
from torch.testing._internal.common_utils import sandcastle_skip_if
@@ -434,93 +435,6 @@ def forward(self, *input):
434435
output = dp.data_parallel(Net(), input, gpus)
435436
self.assertEqual(output, fn(input))
436437

437-
@sandcastle_skip_if(not TEST_CUDA, "CUDA unavailable")
438-
@repeat_test_for_types(ALL_TENSORTYPES)
439-
def test_data_parallel_module(self, dtype=torch.float):
440-
l = nn.Linear(10, 5).to("cuda", dtype)
441-
i = torch.randn(20, 10, device="cuda", dtype=dtype)
442-
expected_out = l(i)
443-
net = nn.DataParallel(l)
444-
out = net(i)
445-
self.assertEqual(out.get_device(), 0)
446-
self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
447-
448-
@sandcastle_skip_if(not TEST_CUDA, "CUDA unavailable")
449-
@repeat_test_for_types(ALL_TENSORTYPES)
450-
def test_data_parallel_module_kwargs_only(self, dtype=torch.float):
451-
class Net(nn.Module):
452-
def __init__(self):
453-
super(Net, self).__init__()
454-
self.l = l
455-
456-
def forward(self, input):
457-
return self.l(input)
458-
459-
l = nn.Linear(10, 5).to("cuda", dtype)
460-
i = torch.randn(20, 10, device="cuda", dtype=dtype)
461-
expected_out = l(i)
462-
n = nn.DataParallel(Net())
463-
out = n(input=i)
464-
self.assertEqual(out.get_device(), 0)
465-
self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
466-
467-
@sandcastle_skip_if(not TEST_CUDA, "CUDA unavailable")
468-
@repeat_test_for_types(ALL_TENSORTYPES)
469-
def test_data_parallel_module_kwargs_only_empty_list(self, dtype=torch.float):
470-
class Net(nn.Module):
471-
def __init__(self):
472-
super(Net, self).__init__()
473-
self.l = l
474-
475-
def forward(self, input):
476-
return self.l(input['data'])
477-
478-
l = nn.Linear(10, 5).to("cuda", dtype)
479-
i = torch.randn(20, 10, device="cuda", dtype=dtype)
480-
expected_out = l(i)
481-
n = nn.DataParallel(Net())
482-
out = n(input={'data': i, 'unused': []})
483-
self.assertEqual(out.get_device(), 0)
484-
self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
485-
486-
@sandcastle_skip_if(not TEST_CUDA, "CUDA unavailable")
487-
@repeat_test_for_types(ALL_TENSORTYPES)
488-
def test_data_parallel_module_kwargs_only_empty_dict(self, dtype=torch.float):
489-
class Net(nn.Module):
490-
def __init__(self):
491-
super(Net, self).__init__()
492-
self.l = l
493-
494-
def forward(self, input):
495-
return self.l(input['data'])
496-
497-
l = nn.Linear(10, 5).to("cuda", dtype)
498-
i = torch.randn(20, 10, device="cuda", dtype=dtype)
499-
expected_out = l(i)
500-
n = nn.DataParallel(Net())
501-
out = n(input={'data': i, 'unused': {}})
502-
self.assertEqual(out.get_device(), 0)
503-
self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
504-
505-
@sandcastle_skip_if(not TEST_CUDA, "CUDA unavailable")
506-
@repeat_test_for_types(ALL_TENSORTYPES)
507-
def test_data_parallel_module_kwargs_only_empty_tuple(self, dtype=torch.float):
508-
class Net(nn.Module):
509-
def __init__(self):
510-
super(Net, self).__init__()
511-
self.l = l
512-
513-
def forward(self, input):
514-
return self.l(input['data'])
515-
516-
l = nn.Linear(10, 5).to("cuda", dtype)
517-
i = torch.randn(20, 10, device="cuda", dtype=dtype)
518-
expected_out = l(i)
519-
n = nn.DataParallel(Net())
520-
out = n(input={'data': i, 'unused': ()})
521-
self.assertEqual(out.get_device(), 0)
522-
self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
523-
524438
@sandcastle_skip_if(not TEST_MULTIGPU, "multi-GPU not supported")
525439
def test_data_parallel_module_zero_inputs(self):
526440
class TestModule(nn.Module):
@@ -757,13 +671,13 @@ def test_save_replica_module(self):
757671
@sandcastle_skip_if(not TEST_MULTIGPU, "multi-GPU not supported")
758672
def test_strided_grad_layout(self):
759673
class ConvNet(nn.Module):
760-
def __init__(self, layouts, dtypes):
674+
def __init__(self, layouts, dtype_list):
761675
super(ConvNet, self).__init__()
762-
self.dtypes = dtypes
763-
self.conv0 = torch.nn.Conv2d(8, 16, (2, 2)).to(memory_format=layouts[0], dtype=dtypes[0])
764-
self.conv1 = torch.nn.Conv2d(16, 32, (2, 2)).to(memory_format=layouts[1], dtype=dtypes[1])
765-
self.conv2 = torch.nn.Conv2d(32, 16, (2, 2)).to(memory_format=layouts[2], dtype=dtypes[2])
766-
self.conv3 = torch.nn.Conv2d(16, 8, (2, 2)).to(memory_format=layouts[3], dtype=dtypes[3])
676+
self.dtypes = dtype_list
677+
self.conv0 = torch.nn.Conv2d(8, 16, (2, 2)).to(memory_format=layouts[0], dtype=dtype_list[0])
678+
self.conv1 = torch.nn.Conv2d(16, 32, (2, 2)).to(memory_format=layouts[1], dtype=dtype_list[1])
679+
self.conv2 = torch.nn.Conv2d(32, 16, (2, 2)).to(memory_format=layouts[2], dtype=dtype_list[2])
680+
self.conv3 = torch.nn.Conv2d(16, 8, (2, 2)).to(memory_format=layouts[3], dtype=dtype_list[3])
767681

768682
def forward(self, x):
769683
x = x.to(self.dtypes[0])
@@ -786,10 +700,10 @@ def forward(self, x):
786700
device_ids = list(range(ndevs))
787701

788702
with torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False):
789-
for formats, dtypes in product(layer_formats, layer_dtypes):
703+
for formats, dtype_list in product(layer_formats, layer_dtypes):
790704
model_msg = "formats = {} dtypes = {}".format(formats, dtypes)
791705
try:
792-
m = ConvNet(formats, dtypes).cuda(device="cuda:0")
706+
m = ConvNet(formats, dtype_list).cuda(device="cuda:0")
793707
m_dp = dp.DataParallel(deepcopy(m), device_ids=device_ids)
794708
opt = torch.optim.SGD(m.parameters(), lr=0.1)
795709
opt_dp = torch.optim.SGD(m_dp.parameters(), lr=0.1)
@@ -855,5 +769,102 @@ def forward(self, inp):
855769
model(input)
856770

857771

772+
class TestDataParallelDeviceType(TestCase):
773+
774+
@onlyCUDA
775+
@skipMeta
776+
@dtypes(torch.float, torch.double, torch.half)
777+
def test_data_parallel_module(self, device, dtype):
778+
l = nn.Linear(10, 5).to(device, dtype)
779+
i = torch.randn(20, 10, device=device, dtype=dtype)
780+
expected_out = l(i)
781+
net = nn.DataParallel(l)
782+
out = net(i)
783+
self.assertEqual(out.get_device(), 0)
784+
self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
785+
786+
@onlyCUDA
787+
@skipMeta
788+
@dtypes(torch.float, torch.double, torch.half)
789+
def test_data_parallel_module_kwargs_only(self, device, dtype):
790+
class Net(nn.Module):
791+
def __init__(self):
792+
super(Net, self).__init__()
793+
self.l = l
794+
795+
def forward(self, input):
796+
return self.l(input)
797+
798+
l = nn.Linear(10, 5).to(device, dtype)
799+
i = torch.randn(20, 10, device=device, dtype=dtype)
800+
expected_out = l(i)
801+
n = nn.DataParallel(Net())
802+
out = n(input=i)
803+
self.assertEqual(out.get_device(), 0)
804+
self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
805+
806+
@onlyCUDA
807+
@skipMeta
808+
@dtypes(torch.float, torch.double, torch.half)
809+
def test_data_parallel_module_kwargs_only_empty_list(self, device, dtype):
810+
class Net(nn.Module):
811+
def __init__(self):
812+
super(Net, self).__init__()
813+
self.l = l
814+
815+
def forward(self, input):
816+
return self.l(input['data'])
817+
818+
l = nn.Linear(10, 5).to(device, dtype)
819+
i = torch.randn(20, 10, device=device, dtype=dtype)
820+
expected_out = l(i)
821+
n = nn.DataParallel(Net())
822+
out = n(input={'data': i, 'unused': []})
823+
self.assertEqual(out.get_device(), 0)
824+
self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
825+
826+
@onlyCUDA
827+
@skipMeta
828+
@dtypes(torch.float, torch.double, torch.half)
829+
def test_data_parallel_module_kwargs_only_empty_dict(self, device, dtype):
830+
class Net(nn.Module):
831+
def __init__(self):
832+
super(Net, self).__init__()
833+
self.l = l
834+
835+
def forward(self, input):
836+
return self.l(input['data'])
837+
838+
l = nn.Linear(10, 5).to(device, dtype)
839+
i = torch.randn(20, 10, device=device, dtype=dtype)
840+
expected_out = l(i)
841+
n = nn.DataParallel(Net())
842+
out = n(input={'data': i, 'unused': {}})
843+
self.assertEqual(out.get_device(), 0)
844+
self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
845+
846+
@onlyCUDA
847+
@skipMeta
848+
@dtypes(torch.float, torch.double, torch.half)
849+
def test_data_parallel_module_kwargs_only_empty_tuple(self, device, dtype):
850+
class Net(nn.Module):
851+
def __init__(self):
852+
super(Net, self).__init__()
853+
self.l = l
854+
855+
def forward(self, input):
856+
return self.l(input['data'])
857+
858+
l = nn.Linear(10, 5).to(device, dtype)
859+
i = torch.randn(20, 10, device=device, dtype=dtype)
860+
expected_out = l(i)
861+
n = nn.DataParallel(Net())
862+
out = n(input={'data': i, 'unused': ()})
863+
self.assertEqual(out.get_device(), 0)
864+
self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
865+
866+
867+
instantiate_device_type_tests(TestDataParallelDeviceType, globals())
868+
858869
if __name__ == '__main__':
859870
run_tests()

torch/testing/_internal/common_utils.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -523,20 +523,6 @@ def shell(command, cwd=None, env=None):
523523
return wait_for_process(p)
524524

525525

526-
# Used to run the same test with different tensor types
527-
def repeat_test_for_types(dtypes):
528-
def repeat_helper(f):
529-
@wraps(f)
530-
def call_helper(self, *args):
531-
for dtype in dtypes:
532-
with TestCase.subTest(self, dtype=dtype):
533-
f(self, *args, dtype=dtype)
534-
535-
return call_helper
536-
return repeat_helper
537-
538-
539-
540526
def discover_test_cases_recursively(suite_or_case):
541527
if isinstance(suite_or_case, unittest.TestCase):
542528
return [suite_or_case]

0 commit comments

Comments
 (0)