7
7
import random
8
8
import unittest
9
9
from torch .testing import make_tensor
10
+ from torch .testing ._internal .common_dtype import (
11
+ all_types_and_complex ,
12
+ )
10
13
from torch .testing ._internal .common_utils import TestCase , run_tests , skipIfRocm , do_test_dtypes , \
11
14
do_test_empty_full , load_tests , TEST_NUMPY , IS_WINDOWS , gradcheck , coalescedonoff , \
12
15
DeterministicGuard
19
22
(SM53OrLater , SM80OrLater , CUDA11OrLater )
20
23
from torch .testing ._internal .common_device_type import \
21
24
(instantiate_device_type_tests , ops , dtypes , dtypesIfCUDA , onlyCPU , onlyCUDA , precisionOverride ,
22
- deviceCountAtLeast )
25
+ deviceCountAtLeast , OpDTypes )
23
26
from torch .testing ._internal .common_methods_invocations import \
24
27
(sparse_unary_ufuncs )
25
28
from torch .testing ._internal .common_dtype import (
@@ -3410,15 +3413,23 @@ def test_cuda_sparse_cpu_dense_add(self):
3410
3413
with self .assertRaisesRegex (RuntimeError , "add: expected 'self' to be a CUDA tensor, but got a CPU tensor" ):
3411
3414
x + sparse_y
3412
3415
3416
+
3417
+ def _sparse_to_dense (tensor ):
3418
+ if tensor .dtype != torch .bool :
3419
+ return tensor .to_dense ()
3420
+
3421
+ # to_dense uses coalesce which isn't implemented for bool
3422
+ return tensor .to (torch .int8 ).to_dense ().to (torch .bool )
3423
+
3424
+
3425
+ _sparse_unary_ops = ops (sparse_unary_ufuncs , dtypes = OpDTypes .supported ,
3426
+ allowed_dtypes = all_types_and_complex ())
3413
3427
class TestSparseUnaryUfuncs (TestCase ):
3414
3428
exact_dtype = True
3415
3429
3416
- @ops (sparse_unary_ufuncs )
3417
- def test_sparse_consistency (self , device , dtype , op ):
3418
- unsupportedTypes = [torch .bfloat16 , torch .float16 ]
3419
- if dtype in unsupportedTypes :
3420
- self .skipTest ('Skipped! Unsupported dtypes for Sparse' )
3421
3430
3431
+ @_sparse_unary_ops
3432
+ def test_sparse_consistency (self , device , dtype , op ):
3422
3433
samples = op .sample_inputs (device , dtype )
3423
3434
3424
3435
if len (samples ) == 0 :
@@ -3428,27 +3439,99 @@ def test_sparse_consistency(self, device, dtype, op):
3428
3439
3429
3440
assert isinstance (sample .input , torch .Tensor )
3430
3441
3431
- expected = op (sample .input )
3442
+ expected = op (sample .input , * sample . args , ** sample . kwargs )
3432
3443
assert torch .is_tensor (expected )
3433
- output = op (sample .input .to_sparse ())
3444
+ output = op (sample .input .to_sparse (), * sample . args , ** sample . kwargs )
3434
3445
assert torch .is_tensor (output )
3435
- self .assertEqual (output . to_dense ( ), expected )
3446
+ self .assertEqual (_sparse_to_dense ( output ), expected )
3436
3447
3437
- @ops (sparse_unary_ufuncs )
3438
- def test_sparse_zero_dims (self , device , dtype , op ):
3439
- # test 0x0 sparse_coo_tensor
3448
+ @_sparse_unary_ops
3449
+ def test_out (self , device , dtype , op ):
3450
+ samples = op .sample_inputs (device , dtype )
3451
+
3452
+ if len (samples ) == 0 :
3453
+ self .skipTest ("Skipped! No sample inputs!" )
3454
+
3455
+ if not op .supports_out :
3456
+ self .skipTest ("Skipped! Out not supported" )
3457
+
3458
+ sample = samples [0 ]
3459
+ sample .input = sample .input .to_sparse ()
3460
+ expect = op (sample .input , * sample .args , ** sample .kwargs )
3461
+
3462
+ out = torch .zeros (sample .input .shape , device = device ,
3463
+ dtype = expect .dtype , layout = torch .sparse_coo )
3464
+ op (sample .input , * sample .args , ** sample .kwargs , out = out )
3465
+ self .assertEqual (out , expect )
3466
+
3467
+ @_sparse_unary_ops
3468
+ def test_inplace (self , device , dtype , op ):
3469
+ samples = op .sample_inputs (device , dtype )
3470
+
3471
+ if len (samples ) == 0 :
3472
+ self .skipTest ("Skipped! No sample inputs!" )
3473
+
3474
+ if op .inplace_variant is None :
3475
+ self .skipTest ("Skipped! Out not supported" )
3440
3476
3441
- unsupportedTypes = [torch .bfloat16 , torch .float16 ]
3442
- if dtype in unsupportedTypes :
3443
- self .skipTest ('Skipped! Unsupported dtypes for Sparse' )
3477
+ sample = samples [0 ]
3478
+ sample .input = sample .input .to_sparse ().coalesce ()
3479
+ expect = op (sample .input , * sample .args , ** sample .kwargs )
3480
+
3481
+ if not torch .can_cast (expect .dtype , dtype ):
3482
+ with self .assertRaisesRegex (RuntimeError , "result type" ):
3483
+ op .inplace_variant (sample .input , * sample .args , ** sample .kwargs )
3484
+ return
3485
+
3486
+ actual = op .inplace_variant (sample .input , * sample .args , ** sample .kwargs )
3487
+ self .assertIs (actual , sample .input )
3488
+ self .assertEqual (actual , expect )
3444
3489
3490
+ @_sparse_unary_ops
3491
+ def test_sparse_zero_dims (self , device , dtype , op ):
3492
+ # test 0x0 sparse_coo_tensor
3445
3493
indices = torch .empty (2 , 0 , dtype = torch .int64 )
3446
3494
values = torch .empty (0 , dtype = dtype )
3447
3495
sparse_0x0 = torch .sparse_coo_tensor (indices , values , (0 , 0 ))
3448
3496
expected = torch .sparse_coo_tensor (indices , op (values ), (0 , 0 ))
3449
3497
actual = op (sparse_0x0 )
3450
3498
self .assertEqual (expected , actual )
3451
3499
3500
+ @_sparse_unary_ops
3501
+ def test_sparse_zeros (self , device , dtype , op ):
3502
+ samples = op .sample_inputs (device , dtype )
3503
+
3504
+ zero_input = torch .zeros ((), device = device , dtype = dtype )
3505
+ sparse_input = torch .zeros ((), dtype = dtype , device = device ,
3506
+ layout = torch .sparse_coo )
3507
+
3508
+ expect = op (zero_input )
3509
+ actual = op (sparse_input )
3510
+ self .assertEqual (expect , _sparse_to_dense (actual ))
3511
+
3512
+ @ops (sparse_unary_ufuncs , dtypes = OpDTypes .supported ,
3513
+ allowed_dtypes = [torch .double , torch .cdouble ])
3514
+ def test_sparse_fn_grad (self , device , dtype , op ):
3515
+ if not op .supports_autograd :
3516
+ self .skipTest ("Skipped! Op doesn't support autograd" )
3517
+
3518
+ for sample in op .sample_inputs (device , dtype ):
3519
+ sparse_input = sample .input .to_sparse ().detach ().requires_grad_ (True )
3520
+
3521
+ def fn (x ):
3522
+ return _sparse_to_dense (
3523
+ op (x , * sample .args , ** sample .kwargs ))
3524
+
3525
+ self .assertTrue (gradcheck (
3526
+ fn ,
3527
+ (sparse_input ,),
3528
+ check_batched_grad = False ,
3529
+ check_grad_dtypes = True ,
3530
+ check_sparse_nnz = True ,
3531
+ nondet_tol = op .gradcheck_nondet_tol ,
3532
+ fast_mode = op .gradcheck_fast_mode ))
3533
+
3534
+
3452
3535
# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
3453
3536
instantiate_device_type_tests (TestSparseUnaryUfuncs , globals (), except_for = 'meta' )
3454
3537
0 commit comments