I have a model domain which is a nested dict of arrays. I want to mark some arrays as constant for optimize_kl, so I pass a dictionary with the same structure but booleans instead of arrays as leaves. This raises a TypeError: `primals` and `point_estimates` pytree structre do no match. Interestingly if I have a flat parameter dictionary and pass a tuple of strings for constants, it works fine. A minimal reproduction script is attached below.
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jr
import nifty.re as jft
rng = jr.key(123)
class Model(jft.Model):
def __init__(self):
domain = {
"fruit": jax.ShapeDtypeStruct((10,), jnp.float64),
"vegetables": jax.ShapeDtypeStruct((10,), jnp.float64)
}
super().__init__(domain=domain)
def __call__(self, x):
return x['fruit'] + x['vegetables']
data = jnp.ones((10,))
likelihood = jft.Gaussian(data).amend(Model())
rng, k_i, k_o = jr.split(rng, 3)
samples, state = jft.optimize_kl(
likelihood,
jft.Vector(likelihood.init(k_i)),
# constants=("fruit",), # this works
constants = {"fruit": True, "vegetables": False}, # this fails with the error above
n_total_iterations=1,
n_samples=2,
key=k_o,
sample_mode="linear_resample",
)
I have a model domain which is a nested dict of arrays. I want to mark some arrays as constant for optimize_kl, so I pass a dictionary with the same structure but booleans instead of arrays as leaves. This raises a
TypeError: `primals` and `point_estimates` pytree structre do no match. Interestingly if I have a flat parameter dictionary and pass a tuple of strings forconstants, it works fine. A minimal reproduction script is attached below.