diff --git a/oryx/core/interpreters/inverse/rules.py b/oryx/core/interpreters/inverse/rules.py index 2071fd9..9aade29 100644 --- a/oryx/core/interpreters/inverse/rules.py +++ b/oryx/core/interpreters/inverse/rules.py @@ -258,10 +258,10 @@ def convert_element_type_ildj(incells, outcells, *, new_dtype, **params): jax.scipy.special.logit = custom_inverse(jax.scipy.special.logit) jax.nn.sigmoid = jax.scipy.special.expit jax.nn.softplus = custom_inverse(jax.nn.softplus) -jax.scipy.special.expit.def_inverse_unary(f_inv=jax.scipy.special.logit, - f_ildj=expit_ildj) -jax.scipy.special.logit.def_inverse_unary(f_inv=jax.scipy.special.expit, - f_ildj=logit_ildj) -jax.nn.softplus.def_inverse_unary(f_inv=softplus_inv, - f_ildj=softplus_ildj) - +jax.scipy.special.expit.def_inverse_unary( + f_inv=jax.scipy.special.logit, f_ildj=expit_ildj +) # pytype: disable=attribute-error +jax.scipy.special.logit.def_inverse_unary( + f_inv=jax.scipy.special.expit, f_ildj=logit_ildj +) # pytype: disable=attribute-error +jax.nn.softplus.def_inverse_unary(f_inv=softplus_inv, f_ildj=softplus_ildj) # pytype: disable=attribute-error