diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index e616302446..f4ac8dba10 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -165,6 +165,8 @@ def tile(self, tensor_in, repeats): Returns: JAX ndarray: The tensor with repeated axes """ + if not isinstance(tensor_in, jnp.ndarray): + tensor_in = jnp.array(tensor_in) return jnp.tile(tensor_in, repeats) def conditional(self, predicate, true_callable, false_callable):