From dfc6ee638424d0d03fba76046d57318094410440 Mon Sep 17 00:00:00 2001 From: Nolan Miller Date: Mon, 29 Apr 2024 13:56:18 -0700 Subject: [PATCH] Fixes monkey patching attribute error triggered by cl/629129623. PiperOrigin-RevId: 629177768 --- oryx/core/interpreters/inverse/rules.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/oryx/core/interpreters/inverse/rules.py b/oryx/core/interpreters/inverse/rules.py index 2071fd9..9aade29 100644 --- a/oryx/core/interpreters/inverse/rules.py +++ b/oryx/core/interpreters/inverse/rules.py @@ -258,10 +258,10 @@ def convert_element_type_ildj(incells, outcells, *, new_dtype, **params): jax.scipy.special.logit = custom_inverse(jax.scipy.special.logit) jax.nn.sigmoid = jax.scipy.special.expit jax.nn.softplus = custom_inverse(jax.nn.softplus) -jax.scipy.special.expit.def_inverse_unary(f_inv=jax.scipy.special.logit, - f_ildj=expit_ildj) -jax.scipy.special.logit.def_inverse_unary(f_inv=jax.scipy.special.expit, - f_ildj=logit_ildj) -jax.nn.softplus.def_inverse_unary(f_inv=softplus_inv, - f_ildj=softplus_ildj) - +jax.scipy.special.expit.def_inverse_unary( + f_inv=jax.scipy.special.logit, f_ildj=expit_ildj +) # pytype: disable=attribute-error +jax.scipy.special.logit.def_inverse_unary( + f_inv=jax.scipy.special.expit, f_ildj=logit_ildj +) # pytype: disable=attribute-error +jax.nn.softplus.def_inverse_unary(f_inv=softplus_inv, f_ildj=softplus_ildj) # pytype: disable=attribute-error