Skip to content

Commit f5fa91b

Browse files
peterbell10facebook-github-bot
authored andcommitted
Sparse: Add additional opinfo tests (pytorch#68886)
Summary: Pull Request resolved: pytorch#68886 cc nikitaved pearu cpuhrsch IvanYashchuk Test Plan: Imported from OSS Reviewed By: jbschlosser Differential Revision: D32697933 Pulled By: cpuhrsch fbshipit-source-id: fffdd1bc663cc1bc49abe8cf3680982d1cb497bc
1 parent 3bd7dbf commit f5fa91b

File tree

2 files changed

+109
-16
lines changed

2 files changed

+109
-16
lines changed

test/test_sparse.py

+98-15
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import random
88
import unittest
99
from torch.testing import make_tensor
10+
from torch.testing._internal.common_dtype import (
11+
all_types_and_complex,
12+
)
1013
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
1114
do_test_empty_full, load_tests, TEST_NUMPY, IS_WINDOWS, gradcheck, coalescedonoff, \
1215
DeterministicGuard
@@ -19,7 +22,7 @@
1922
(SM53OrLater, SM80OrLater, CUDA11OrLater)
2023
from torch.testing._internal.common_device_type import \
2124
(instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
22-
deviceCountAtLeast)
25+
deviceCountAtLeast, OpDTypes)
2326
from torch.testing._internal.common_methods_invocations import \
2427
(sparse_unary_ufuncs)
2528
from torch.testing._internal.common_dtype import (
@@ -3410,15 +3413,23 @@ def test_cuda_sparse_cpu_dense_add(self):
34103413
with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"):
34113414
x + sparse_y
34123415

3416+
3417+
def _sparse_to_dense(tensor):
3418+
if tensor.dtype != torch.bool:
3419+
return tensor.to_dense()
3420+
3421+
# to_dense uses coalesce which isn't implemented for bool
3422+
return tensor.to(torch.int8).to_dense().to(torch.bool)
3423+
3424+
3425+
_sparse_unary_ops = ops(sparse_unary_ufuncs, dtypes=OpDTypes.supported,
3426+
allowed_dtypes=all_types_and_complex())
34133427
class TestSparseUnaryUfuncs(TestCase):
34143428
exact_dtype = True
34153429

3416-
@ops(sparse_unary_ufuncs)
3417-
def test_sparse_consistency(self, device, dtype, op):
3418-
unsupportedTypes = [torch.bfloat16, torch.float16]
3419-
if dtype in unsupportedTypes:
3420-
self.skipTest('Skipped! Unsupported dtypes for Sparse')
34213430

3431+
@_sparse_unary_ops
3432+
def test_sparse_consistency(self, device, dtype, op):
34223433
samples = op.sample_inputs(device, dtype)
34233434

34243435
if len(samples) == 0:
@@ -3428,27 +3439,99 @@ def test_sparse_consistency(self, device, dtype, op):
34283439

34293440
assert isinstance(sample.input, torch.Tensor)
34303441

3431-
expected = op(sample.input)
3442+
expected = op(sample.input, *sample.args, **sample.kwargs)
34323443
assert torch.is_tensor(expected)
3433-
output = op(sample.input.to_sparse())
3444+
output = op(sample.input.to_sparse(), *sample.args, **sample.kwargs)
34343445
assert torch.is_tensor(output)
3435-
self.assertEqual(output.to_dense(), expected)
3446+
self.assertEqual(_sparse_to_dense(output), expected)
34363447

3437-
@ops(sparse_unary_ufuncs)
3438-
def test_sparse_zero_dims(self, device, dtype, op):
3439-
# test 0x0 sparse_coo_tensor
3448+
@_sparse_unary_ops
3449+
def test_out(self, device, dtype, op):
3450+
samples = op.sample_inputs(device, dtype)
3451+
3452+
if len(samples) == 0:
3453+
self.skipTest("Skipped! No sample inputs!")
3454+
3455+
if not op.supports_out:
3456+
self.skipTest("Skipped! Out not supported")
3457+
3458+
sample = samples[0]
3459+
sample.input = sample.input.to_sparse()
3460+
expect = op(sample.input, *sample.args, **sample.kwargs)
3461+
3462+
out = torch.zeros(sample.input.shape, device=device,
3463+
dtype=expect.dtype, layout=torch.sparse_coo)
3464+
op(sample.input, *sample.args, **sample.kwargs, out=out)
3465+
self.assertEqual(out, expect)
3466+
3467+
@_sparse_unary_ops
3468+
def test_inplace(self, device, dtype, op):
3469+
samples = op.sample_inputs(device, dtype)
3470+
3471+
if len(samples) == 0:
3472+
self.skipTest("Skipped! No sample inputs!")
3473+
3474+
if op.inplace_variant is None:
3475+
self.skipTest("Skipped! Out not supported")
34403476

3441-
unsupportedTypes = [torch.bfloat16, torch.float16]
3442-
if dtype in unsupportedTypes:
3443-
self.skipTest('Skipped! Unsupported dtypes for Sparse')
3477+
sample = samples[0]
3478+
sample.input = sample.input.to_sparse().coalesce()
3479+
expect = op(sample.input, *sample.args, **sample.kwargs)
3480+
3481+
if not torch.can_cast(expect.dtype, dtype):
3482+
with self.assertRaisesRegex(RuntimeError, "result type"):
3483+
op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
3484+
return
3485+
3486+
actual = op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
3487+
self.assertIs(actual, sample.input)
3488+
self.assertEqual(actual, expect)
34443489

3490+
@_sparse_unary_ops
3491+
def test_sparse_zero_dims(self, device, dtype, op):
3492+
# test 0x0 sparse_coo_tensor
34453493
indices = torch.empty(2, 0, dtype=torch.int64)
34463494
values = torch.empty(0, dtype=dtype)
34473495
sparse_0x0 = torch.sparse_coo_tensor(indices, values, (0, 0))
34483496
expected = torch.sparse_coo_tensor(indices, op(values), (0, 0))
34493497
actual = op(sparse_0x0)
34503498
self.assertEqual(expected, actual)
34513499

3500+
@_sparse_unary_ops
3501+
def test_sparse_zeros(self, device, dtype, op):
3502+
samples = op.sample_inputs(device, dtype)
3503+
3504+
zero_input = torch.zeros((), device=device, dtype=dtype)
3505+
sparse_input = torch.zeros((), dtype=dtype, device=device,
3506+
layout=torch.sparse_coo)
3507+
3508+
expect = op(zero_input)
3509+
actual = op(sparse_input)
3510+
self.assertEqual(expect, _sparse_to_dense(actual))
3511+
3512+
@ops(sparse_unary_ufuncs, dtypes=OpDTypes.supported,
3513+
allowed_dtypes=[torch.double, torch.cdouble])
3514+
def test_sparse_fn_grad(self, device, dtype, op):
3515+
if not op.supports_autograd:
3516+
self.skipTest("Skipped! Op doesn't support autograd")
3517+
3518+
for sample in op.sample_inputs(device, dtype):
3519+
sparse_input = sample.input.to_sparse().detach().requires_grad_(True)
3520+
3521+
def fn(x):
3522+
return _sparse_to_dense(
3523+
op(x, *sample.args, **sample.kwargs))
3524+
3525+
self.assertTrue(gradcheck(
3526+
fn,
3527+
(sparse_input,),
3528+
check_batched_grad=False,
3529+
check_grad_dtypes=True,
3530+
check_sparse_nnz=True,
3531+
nondet_tol=op.gradcheck_nondet_tol,
3532+
fast_mode=op.gradcheck_fast_mode))
3533+
3534+
34523535
# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
34533536
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')
34543537

torch/testing/_internal/common_methods_invocations.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -8071,6 +8071,8 @@ def ref_pairwise_distance(input1, input2):
80718071
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard',
80728072
device_type='cuda', dtypes=[torch.cdouble],
80738073
active_if=IS_WINDOWS),
8074+
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
8075+
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
80748076
)),
80758077
# NOTE: derivative for inplace asinh is not implemented
80768078
UnaryUfuncInfo('asinh',
@@ -8328,6 +8330,8 @@ def ref_pairwise_distance(input1, input2):
83288330
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":118,
83298331
# please report a bug to PyTorch.
83308332
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, )),
8333+
DecorateInfo(unittest.skip("Skipped! conj_physical_ not implemented for sparse"),
8334+
'TestSparseUnaryUfuncs', 'test_inplace'),
83318335
)),
83328336
OpInfo('resolve_conj',
83338337
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
@@ -11295,7 +11299,12 @@ def ref_pairwise_distance(input1, input2):
1129511299
active_if=IS_MACOS),
1129611300
# Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436
1129711301
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard',
11298-
dtypes=[torch.bfloat16])),
11302+
dtypes=[torch.bfloat16]),
11303+
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
11304+
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
11305+
DecorateInfo(unittest.skip("Skipped! sqrt_ not implemented for sparse"),
11306+
'TestSparseUnaryUfuncs', 'test_inplace'),
11307+
),
1129911308
safe_casts_outputs=True,
1130011309
handles_complex_extremals=False),
1130111310
UnaryUfuncInfo('square',
@@ -11391,6 +11400,7 @@ def ref_pairwise_distance(input1, input2):
1139111400
ref=np.isnan,
1139211401
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
1139311402
supports_out=False,
11403+
supports_sparse=True,
1139411404
supports_autograd=False),
1139511405
OpInfo('linalg.solve',
1139611406
aten_name='linalg_solve',

0 commit comments

Comments
 (0)