@@ -5450,7 +5450,18 @@ def test_kernel_image(self, dtype, device):
54505450 def test_kernel_video (self ):
54515451 check_kernel (F .equalize_image , make_video ())
54525452
5453- @pytest .mark .parametrize ("make_input" , [make_image_tensor , make_image_pil , make_image , make_video ])
5453+ @pytest .mark .parametrize (
5454+ "make_input" ,
5455+ [
5456+ make_image_tensor ,
5457+ make_image_pil ,
5458+ make_image ,
5459+ make_video ,
5460+ pytest .param (
5461+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
5462+ ),
5463+ ],
5464+ )
54545465 def test_functional (self , make_input ):
54555466 check_functional (F .equalize , make_input ())
54565467
@@ -5461,33 +5472,71 @@ def test_functional(self, make_input):
54615472 (F ._color ._equalize_image_pil , PIL .Image .Image ),
54625473 (F .equalize_image , tv_tensors .Image ),
54635474 (F .equalize_video , tv_tensors .Video ),
5475+ pytest .param (
5476+ F ._color ._equalize_cvcuda ,
5477+ "cvcuda.Tensor" ,
5478+ marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" ),
5479+ ),
54645480 ],
54655481 )
54665482 def test_functional_signature (self , kernel , input_type ):
5483+ if input_type == "cvcuda.Tensor" :
5484+ input_type = _import_cvcuda ().Tensor
54675485 check_functional_kernel_signature_match (F .equalize , kernel = kernel , input_type = input_type )
54685486
54695487 @pytest .mark .parametrize (
54705488 "make_input" ,
5471- [make_image_tensor , make_image_pil , make_image , make_video ],
5489+ [
5490+ make_image_tensor ,
5491+ make_image_pil ,
5492+ make_image ,
5493+ make_video ,
5494+ pytest .param (
5495+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
5496+ ),
5497+ ],
54725498 )
54735499 def test_transform (self , make_input ):
54745500 check_transform (transforms .RandomEqualize (p = 1 ), make_input ())
54755501
54765502 @pytest .mark .parametrize (("low" , "high" ), [(0 , 64 ), (64 , 192 ), (192 , 256 ), (0 , 1 ), (127 , 128 ), (255 , 256 )])
5503+ @pytest .mark .parametrize (
5504+ "tensor_type" ,
5505+ [
5506+ torch .Tensor ,
5507+ pytest .param (
5508+ "cvcuda.Tensor" , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
5509+ ),
5510+ ],
5511+ )
54775512 @pytest .mark .parametrize ("fn" , [F .equalize , transform_cls_to_functional (transforms .RandomEqualize , p = 1 )])
5478- def test_image_correctness (self , low , high , fn ):
5513+ def test_image_correctness (self , low , high , tensor_type , fn ):
54795514 # We are not using the default `make_image` here since that uniformly samples the values over the whole value
54805515 # range. Since the whole point of F.equalize is to transform an arbitrary distribution of values into a uniform
54815516 # one over the full range, the information gain is low if we already provide something really close to the
54825517 # expected value.
5483- image = tv_tensors .Image (
5484- torch .testing .make_tensor ((3 , 117 , 253 ), dtype = torch .uint8 , device = "cpu" , low = low , high = high )
5485- )
5518+ shape = (3 , 117 , 253 )
5519+ if tensor_type == "cvcuda.Tensor" :
5520+ shape = (1 , * shape )
5521+ image = tv_tensors .Image (torch .testing .make_tensor (shape , dtype = torch .uint8 , device = "cpu" , low = low , high = high ))
5522+
5523+ if tensor_type == "cvcuda.Tensor" :
5524+ image = F .to_cvcuda_tensor (image )
54865525
54875526 actual = fn (image )
5527+
5528+ if tensor_type == "cvcuda.Tensor" :
5529+ actual = F .cvcuda_to_tensor (actual ).to (device = "cpu" )
5530+ actual = actual .squeeze (0 )
5531+ image = F .cvcuda_to_tensor (image )
5532+ image = image .squeeze (0 )
5533+
54885534 expected = F .to_image (F .equalize (F .to_pil_image (image )))
54895535
5490- assert_equal (actual , expected )
5536+ if tensor_type == "cvcuda.Tensor" :
5537+ torch .testing .assert_close (actual , expected , rtol = 1e-10 , atol = 1 )
5538+ else :
5539+ assert_equal (actual , expected )
54915540
54925541
54935542class TestUniformTemporalSubsample :
0 commit comments