|
17 | 17 |
|
18 | 18 |
|
19 | 19 | __all__ = [
|
| 20 | + "bessel_j0", |
| 21 | + "bessel_j1", |
20 | 22 | "i0e",
|
21 | 23 | "i1",
|
22 | 24 | "i1e",
|
23 | 25 | "logit",
|
24 | 26 | "multigammaln",
|
| 27 | + "spherical_bessel_j0", |
25 | 28 | "zeta",
|
26 | 29 | ]
|
27 | 30 |
|
28 | 31 |
|
| 32 | +@_make_elementwise_unary_reference( |
| 33 | + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
| 34 | + aten_op=torch.ops.aten.special_bessel_j0, |
| 35 | +) |
| 36 | +def bessel_j0(a: TensorLikeType) -> TensorLikeType: |
| 37 | + return prims.bessel_j0(a) |
| 38 | + |
| 39 | + |
| 40 | +@_make_elementwise_unary_reference( |
| 41 | + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
| 42 | + aten_op=torch.ops.aten.special_bessel_j1, |
| 43 | +) |
| 44 | +def bessel_j1(a: TensorLikeType) -> TensorLikeType: |
| 45 | + return prims.bessel_j1(a) |
| 46 | + |
| 47 | + |
29 | 48 | @_make_elementwise_unary_reference(
|
30 | 49 | ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i0e
|
31 | 50 | )
|
32 |
| -def i0e(a): |
| 51 | +def i0e(a: TensorLikeType) -> TensorLikeType: |
33 | 52 | return prims.bessel_i0e(a)
|
34 | 53 |
|
35 | 54 |
|
36 | 55 | @_make_elementwise_unary_reference(
|
37 | 56 | ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i1
|
38 | 57 | )
|
39 |
| -def i1(a): |
| 58 | +def i1(a: TensorLikeType) -> TensorLikeType: |
40 | 59 | return prims.bessel_i1(a)
|
41 | 60 |
|
42 | 61 |
|
43 | 62 | @_make_elementwise_unary_reference(
|
44 | 63 | ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i1e
|
45 | 64 | )
|
46 |
| -def i1e(a): |
| 65 | +def i1e(a: TensorLikeType) -> TensorLikeType: |
47 | 66 | return prims.bessel_i1e(a)
|
48 | 67 |
|
49 | 68 |
|
@@ -74,6 +93,14 @@ def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType:
|
74 | 93 | return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c
|
75 | 94 |
|
76 | 95 |
|
| 96 | +@_make_elementwise_unary_reference( |
| 97 | + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
| 98 | + aten_op=torch.ops.aten.special_spherical_bessel_j0, |
| 99 | +) |
| 100 | +def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType: |
| 101 | + return prims.spherical_bessel_j0(a) |
| 102 | + |
| 103 | + |
77 | 104 | zeta = _make_elementwise_binary_reference(
|
78 | 105 | prims.zeta, # type: ignore[has-type]
|
79 | 106 | type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
0 commit comments