@@ -93,8 +93,8 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
9393 out_dtype = rv .type .dtype
9494 out_size = rv .type .shape
9595
96- if op .ndim_supp > 0 :
97- out_size = node .outputs [ 1 ] .type .shape [: - op . ndim_supp ]
96+ batch_ndim = op .batch_ndim ( node )
97+ out_size = node .default_output () .type .shape [:batch_ndim ]
9898
9999 # If one dimension has unknown size, either the size is determined
100100 # by a `Shape` operator in which case JAX will compile, or it is
@@ -106,18 +106,18 @@ def sample_fn(rng, size, dtype, *parameters):
106106 # PyTensor uses empty size to represent size = None
107107 if jax .numpy .asarray (size ).shape == (0 ,):
108108 size = None
109- return jax_sample_fn (op )(rng , size , out_dtype , * parameters )
109+ return jax_sample_fn (op , node = node )(rng , size , out_dtype , * parameters )
110110
111111 else :
112112
113113 def sample_fn (rng , size , dtype , * parameters ):
114- return jax_sample_fn (op )(rng , out_size , out_dtype , * parameters )
114+ return jax_sample_fn (op , node = node )(rng , out_size , out_dtype , * parameters )
115115
116116 return sample_fn
117117
118118
119119@singledispatch
120- def jax_sample_fn (op ):
120+ def jax_sample_fn (op , node ):
121121 name = op .name
122122 raise NotImplementedError (
123123 f"No JAX implementation for the given distribution: { name } "
@@ -128,7 +128,7 @@ def jax_sample_fn(op):
128128@jax_sample_fn .register (ptr .DirichletRV )
129129@jax_sample_fn .register (ptr .PoissonRV )
130130@jax_sample_fn .register (ptr .MvNormalRV )
131- def jax_sample_fn_generic (op ):
131+ def jax_sample_fn_generic (op , node ):
132132 """Generic JAX implementation of random variables."""
133133 name = op .name
134134 jax_op = getattr (jax .random , name )
@@ -149,7 +149,7 @@ def sample_fn(rng, size, dtype, *parameters):
149149@jax_sample_fn .register (ptr .LogisticRV )
150150@jax_sample_fn .register (ptr .NormalRV )
151151@jax_sample_fn .register (ptr .StandardNormalRV )
152- def jax_sample_fn_loc_scale (op ):
152+ def jax_sample_fn_loc_scale (op , node ):
153153 """JAX implementation of random variables in the loc-scale families.
154154
155155 JAX only implements the standard version of random variables in the
@@ -174,7 +174,7 @@ def sample_fn(rng, size, dtype, *parameters):
174174
175175
176176@jax_sample_fn .register (ptr .BernoulliRV )
177- def jax_sample_fn_bernoulli (op ):
177+ def jax_sample_fn_bernoulli (op , node ):
178178 """JAX implementation of `BernoulliRV`."""
179179
180180 # We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX
@@ -189,7 +189,7 @@ def sample_fn(rng, size, dtype, p):
189189
190190
191191@jax_sample_fn .register (ptr .CategoricalRV )
192- def jax_sample_fn_categorical (op ):
192+ def jax_sample_fn_categorical (op , node ):
193193 """JAX implementation of `CategoricalRV`."""
194194
195195 # We need a separate dispatch because Categorical expects logits in JAX
@@ -208,7 +208,7 @@ def sample_fn(rng, size, dtype, p):
208208@jax_sample_fn .register (ptr .RandIntRV )
209209@jax_sample_fn .register (ptr .IntegersRV )
210210@jax_sample_fn .register (ptr .UniformRV )
211- def jax_sample_fn_uniform (op ):
211+ def jax_sample_fn_uniform (op , node ):
212212 """JAX implementation of random variables with uniform density.
213213
214214 We need to pass the arguments as keyword arguments since the order
@@ -236,7 +236,7 @@ def sample_fn(rng, size, dtype, *parameters):
236236
237237@jax_sample_fn .register (ptr .ParetoRV )
238238@jax_sample_fn .register (ptr .GammaRV )
239- def jax_sample_fn_shape_scale (op ):
239+ def jax_sample_fn_shape_scale (op , node ):
240240 """JAX implementation of random variables in the shape-scale family.
241241
242242 JAX only implements the standard version of random variables in the
@@ -259,7 +259,7 @@ def sample_fn(rng, size, dtype, shape, scale):
259259
260260
261261@jax_sample_fn .register (ptr .ExponentialRV )
262- def jax_sample_fn_exponential (op ):
262+ def jax_sample_fn_exponential (op , node ):
263263 """JAX implementation of `ExponentialRV`."""
264264
265265 def sample_fn (rng , size , dtype , scale ):
@@ -275,7 +275,7 @@ def sample_fn(rng, size, dtype, scale):
275275
276276
277277@jax_sample_fn .register (ptr .StudentTRV )
278- def jax_sample_fn_t (op ):
278+ def jax_sample_fn_t (op , node ):
279279 """JAX implementation of `StudentTRV`."""
280280
281281 def sample_fn (rng , size , dtype , df , loc , scale ):
@@ -290,38 +290,119 @@ def sample_fn(rng, size, dtype, df, loc, scale):
290290 return sample_fn
291291
292292
293- @jax_sample_fn .register (ptr .ChoiceRV )
294- def jax_funcify_choice (op ):
293+ @jax_sample_fn .register (ptr .ChoiceWithoutReplacement )
294+ def jax_funcify_choice (op , node ):
295295 """JAX implementation of `ChoiceRV`."""
296296
297+ batch_ndim = op .batch_ndim (node )
298+ a , * p , core_shape = node .inputs [3 :]
299+ a_core_ndim , * p_core_ndim , _ = op .ndims_params
300+
301+ if batch_ndim and a_core_ndim == 0 :
302+ raise NotImplementedError (
303+ "Batch dimensions are not supported for 0d arrays. "
304+ "A default JAX rewrite should have materialized the implicit arange"
305+ )
306+
307+ a_batch_ndim = a .type .ndim - a_core_ndim
308+ if op .has_p_param :
309+ [p ] = p
310+ [p_core_ndim ] = p_core_ndim
311+ p_batch_ndim = p .type .ndim - p_core_ndim
312+
297313 def sample_fn (rng , size , dtype , * parameters ):
298314 rng_key = rng ["jax_state" ]
299315 rng_key , sampling_key = jax .random .split (rng_key , 2 )
300- (a , p , replace ) = parameters
301- smpl_value = jax .random .choice (sampling_key , a , size , replace , p )
316+
317+ if op .has_p_param :
318+ a , p , core_shape = parameters
319+ else :
320+ a , core_shape = parameters
321+ p = None
322+ core_shape = tuple (np .asarray (core_shape ))
323+
324+ if batch_ndim == 0 :
325+ sample = jax .random .choice (
326+ sampling_key , a , shape = core_shape , replace = False , p = p
327+ )
328+
329+ else :
330+ if size is None :
331+ if p is None :
332+ size = a .shape [:a_batch_ndim ]
333+ else :
334+ size = jax .numpy .broadcast_shapes (
335+ a .shape [:a_batch_ndim ],
336+ p .shape [:p_batch_ndim ],
337+ )
338+
339+ a = jax .numpy .broadcast_to (a , size + a .shape [a_batch_ndim :])
340+ if p is not None :
341+ p = jax .numpy .broadcast_to (p , size + p .shape [p_batch_ndim :])
342+
343+ batch_sampling_keys = jax .random .split (sampling_key , np .prod (size ))
344+
345+ # Ravel the batch dimensions because vmap only works along a single axis
346+ raveled_batch_a = a .reshape ((- 1 ,) + a .shape [batch_ndim :])
347+ if p is None :
348+ raveled_sample = jax .vmap (
349+ lambda key , a : jax .random .choice (
350+ key , a , shape = core_shape , replace = False , p = None
351+ )
352+ )(batch_sampling_keys , raveled_batch_a )
353+ else :
354+ raveled_batch_p = p .reshape ((- 1 ,) + p .shape [batch_ndim :])
355+ raveled_sample = jax .vmap (
356+ lambda key , a , p : jax .random .choice (
357+ key , a , shape = core_shape , replace = False , p = p
358+ )
359+ )(batch_sampling_keys , raveled_batch_a , raveled_batch_p )
360+
361+ # Reshape the batch dimensions
362+ sample = raveled_sample .reshape (size + raveled_sample .shape [1 :])
363+
302364 rng ["jax_state" ] = rng_key
303- return (rng , smpl_value )
365+ return (rng , sample )
304366
305367 return sample_fn
306368
307369
308370@jax_sample_fn .register (ptr .PermutationRV )
309- def jax_sample_fn_permutation (op ):
371+ def jax_sample_fn_permutation (op , node ):
310372 """JAX implementation of `PermutationRV`."""
311373
374+ batch_ndim = op .batch_ndim (node )
375+ x_batch_ndim = node .inputs [- 1 ].type .ndim - op .ndims_params [0 ]
376+
312377 def sample_fn (rng , size , dtype , * parameters ):
313378 rng_key = rng ["jax_state" ]
314379 rng_key , sampling_key = jax .random .split (rng_key , 2 )
315380 (x ,) = parameters
316- sample = jax .random .permutation (sampling_key , x )
381+ if batch_ndim :
382+ # jax.random.permutation has no concept of batch dims
383+ x_core_shape = x .shape [x_batch_ndim :]
384+ if size is None :
385+ size = x .shape [:x_batch_ndim ]
386+ else :
387+ x = jax .numpy .broadcast_to (x , size + x_core_shape )
388+
389+ batch_sampling_keys = jax .random .split (sampling_key , np .prod (size ))
390+ raveled_batch_x = x .reshape ((- 1 ,) + x .shape [batch_ndim :])
391+ raveled_sample = jax .vmap (lambda key , x : jax .random .permutation (key , x ))(
392+ batch_sampling_keys , raveled_batch_x
393+ )
394+ sample = raveled_sample .reshape (size + raveled_sample .shape [1 :])
395+ else :
396+ sample = jax .random .permutation (sampling_key , x )
397+
317398 rng ["jax_state" ] = rng_key
318399 return (rng , sample )
319400
320401 return sample_fn
321402
322403
323404@jax_sample_fn .register (ptr .BinomialRV )
324- def jax_sample_fn_binomial (op ):
405+ def jax_sample_fn_binomial (op , node ):
325406 if not numpyro_available :
326407 raise NotImplementedError (
327408 f"No JAX implementation for the given distribution: { op .name } . "
@@ -344,7 +425,7 @@ def sample_fn(rng, size, dtype, n, p):
344425
345426
346427@jax_sample_fn .register (ptr .MultinomialRV )
347- def jax_sample_fn_multinomial (op ):
428+ def jax_sample_fn_multinomial (op , node ):
348429 if not numpyro_available :
349430 raise NotImplementedError (
350431 f"No JAX implementation for the given distribution: { op .name } . "
@@ -367,7 +448,7 @@ def sample_fn(rng, size, dtype, n, p):
367448
368449
369450@jax_sample_fn .register (ptr .VonMisesRV )
370- def jax_sample_fn_vonmises (op ):
451+ def jax_sample_fn_vonmises (op , node ):
371452 if not numpyro_available :
372453 raise NotImplementedError (
373454 f"No JAX implementation for the given distribution: { op .name } . "
0 commit comments