Skip to content

Commit d6b0308

Browse files
khushi-411pytorchmergebot
authored andcommitted
[primTorch] special: j0, j1, spherical_j0 (pytorch#86049)
Adds prims and refs for special functions (bessel_j0, bessel_j1, spherical_bessel_j0). Thanks! Pull Request resolved: pytorch#86049 Approved by: https://github.com/mruberry
1 parent 8bce2f3 commit d6b0308

File tree

5 files changed

+74
-6
lines changed

5 files changed

+74
-6
lines changed

docs/source/special.rst

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ Functions
1313
-----------------------
1414

1515
.. autofunction:: airy_ai
16+
.. autofunction:: bessel_j0
17+
.. autofunction:: bessel_j1
1618
.. autofunction:: digamma
1719
.. autofunction:: entr
1820
.. autofunction:: erf

test/test_proxy_tensor.py

-3
Original file line numberDiff line numberDiff line change
@@ -1271,8 +1271,6 @@ def f(a, b, c, d, e):
12711271
xfail('slice_scatter', ''), # aten.slice_scatter.default - couldn't find symbolic meta function/decomposition
12721272
xfail('sort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition
12731273
xfail('special.airy_ai', ''), # aten.special_airy_ai.default - couldn't find symbolic meta function/decomposition
1274-
xfail('special.bessel_j0', ''), # aten.special_bessel_j0.default - couldn't find symbolic meta function/decomposition
1275-
xfail('special.bessel_j1', ''), # aten.special_bessel_j1.default - couldn't find symbolic meta function/decomposition
12761274
xfail('special.bessel_y0', ''), # aten.special_bessel_y0.default - couldn't find symbolic meta function/decomposition
12771275
xfail('special.bessel_y1', ''), # aten.special_bessel_y1.default - couldn't find symbolic meta function/decomposition
12781276
xfail('special.chebyshev_polynomial_t', ''), # aten.special_chebyshev_polynomial_t.default - couldn't find symbolic me...
@@ -1291,7 +1289,6 @@ def f(a, b, c, d, e):
12911289
xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/...
12921290
xfail('special.scaled_modified_bessel_k0', ''), # aten.special_scaled_modified_bessel_k0.default - couldn't find symbo...
12931291
xfail('special.scaled_modified_bessel_k1', ''), # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo...
1294-
xfail('special.spherical_bessel_j0', ''), # aten.special_spherical_bessel_j0.default - couldn't find symbolic meta fun...
12951292
xfail('special.xlog1py', ''), # aten.special_xlog1py.default - couldn't find symbolic meta function/decomposition
12961293
xfail('split', ''), # 'torch._C.SymIntNode' and 'int'
12971294
xfail('split', 'list_args'), # aten.size.default - couldn't find symbolic meta function/decomposition

torch/_prims/__init__.py

+24
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
"bessel_i0e",
6262
"bessel_i1",
6363
"bessel_i1e",
64+
"bessel_j0",
65+
"bessel_j1",
6466
"bitwise_not",
6567
"cbrt",
6668
"ceil",
@@ -89,6 +91,7 @@
8991
"signbit",
9092
"sin",
9193
"sinh",
94+
"spherical_bessel_j0",
9295
"sqrt",
9396
"tan",
9497
"tanh",
@@ -497,6 +500,20 @@ def _not_impl(*args, **kwargs):
497500
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
498501
)
499502

503+
bessel_j0 = _make_elementwise_unary_prim(
504+
"bessel_j0",
505+
impl_aten=torch.special.bessel_j0,
506+
doc="",
507+
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
508+
)
509+
510+
bessel_j1 = _make_elementwise_unary_prim(
511+
"bessel_j1",
512+
impl_aten=torch.special.bessel_j1,
513+
doc="",
514+
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
515+
)
516+
500517
bessel_i0 = _make_elementwise_unary_prim(
501518
"bessel_i0",
502519
impl_aten=torch.i0,
@@ -778,6 +795,13 @@ def _fill_aten(a: Tensor, value: NumberType) -> Tensor:
778795
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
779796
)
780797

798+
spherical_bessel_j0 = _make_elementwise_unary_prim(
799+
"spherical_bessel_j0",
800+
impl_aten=torch.special.spherical_bessel_j0,
801+
doc="",
802+
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
803+
)
804+
781805
sqrt = _make_elementwise_unary_prim(
782806
"sqrt",
783807
impl_aten=torch.sqrt,

torch/_refs/special/__init__.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,52 @@
1717

1818

1919
__all__ = [
20+
"bessel_j0",
21+
"bessel_j1",
2022
"i0e",
2123
"i1",
2224
"i1e",
2325
"logit",
2426
"multigammaln",
27+
"spherical_bessel_j0",
2528
"zeta",
2629
]
2730

2831

32+
@_make_elementwise_unary_reference(
33+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
34+
aten_op=torch.ops.aten.special_bessel_j0,
35+
)
36+
def bessel_j0(a: TensorLikeType) -> TensorLikeType:
37+
return prims.bessel_j0(a)
38+
39+
40+
@_make_elementwise_unary_reference(
41+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
42+
aten_op=torch.ops.aten.special_bessel_j1,
43+
)
44+
def bessel_j1(a: TensorLikeType) -> TensorLikeType:
45+
return prims.bessel_j1(a)
46+
47+
2948
@_make_elementwise_unary_reference(
3049
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i0e
3150
)
32-
def i0e(a):
51+
def i0e(a: TensorLikeType) -> TensorLikeType:
3352
return prims.bessel_i0e(a)
3453

3554

3655
@_make_elementwise_unary_reference(
3756
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i1
3857
)
39-
def i1(a):
58+
def i1(a: TensorLikeType) -> TensorLikeType:
4059
return prims.bessel_i1(a)
4160

4261

4362
@_make_elementwise_unary_reference(
4463
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i1e
4564
)
46-
def i1e(a):
65+
def i1e(a: TensorLikeType) -> TensorLikeType:
4766
return prims.bessel_i1e(a)
4867

4968

@@ -74,6 +93,14 @@ def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType:
7493
return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c
7594

7695

96+
@_make_elementwise_unary_reference(
97+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
98+
aten_op=torch.ops.aten.special_spherical_bessel_j0,
99+
)
100+
def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType:
101+
return prims.spherical_bessel_j0(a)
102+
103+
77104
zeta = _make_elementwise_binary_reference(
78105
prims.zeta, # type: ignore[has-type]
79106
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,

torch/testing/_internal/opinfo/definitions/special.py

+18
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,18 @@ def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
645645
#
646646
# Elementwise Unary Special OpInfos
647647
#
648+
ElementwiseUnaryPythonRefInfo(
649+
"_refs.special.bessel_j0",
650+
torch_opinfo_name="special.bessel_j0",
651+
supports_nvfuser=False,
652+
op_db=op_db,
653+
),
654+
ElementwiseUnaryPythonRefInfo(
655+
"_refs.special.bessel_j1",
656+
torch_opinfo_name="special.bessel_j1",
657+
supports_nvfuser=False,
658+
op_db=op_db,
659+
),
648660
ElementwiseUnaryPythonRefInfo(
649661
"_refs.special.i0e",
650662
torch_opinfo_name="special.i0e",
@@ -663,6 +675,12 @@ def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
663675
supports_nvfuser=False,
664676
op_db=op_db,
665677
),
678+
ElementwiseUnaryPythonRefInfo(
679+
"_refs.special.spherical_bessel_j0",
680+
torch_opinfo_name="special.spherical_bessel_j0",
681+
supports_nvfuser=False,
682+
op_db=op_db,
683+
),
666684
#
667685
# Elementwise Binary Special OpInfos
668686
#

0 commit comments

Comments
 (0)