diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index 6b5bf37314c..6f3cf0b7d4c 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -246,3 +246,120 @@ def _to_backend_layout(tensor_layout): partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes) jax_mesh = tensor_layout.device_mesh.backend_mesh return jax.sharding.NamedSharding(jax_mesh, partition_spec) + + +def _distribute_initializer( + init_func=None, mean=0.0, stddev=1.0, seed=None, layout=None +): + """ + Distribution-aware token embedding initializer for JAX backend. + + This function will create a Jax random array and + distribute it according to the current token embedding layout. + + Args: + init_func: A functools.partial-wrapped object that takes the seed + as argument and returns a jax.Array. Must have shape and dtype + already bound via partial. + mean: Mean of distribution (applied to normal/truncated_normal). + stddev: Standard deviation of the distribution. + seed: Random seed for initialization. + layout: TensorLayout for the distributed tensor. + + Returns: + A distributed jax array. + + Raises: + ValueError: If init_func or seed is None. + If init_func.func is not a supported random function. + Supported jax.random func: normal, truncated_normal, uniform + TypeError: If init_func is not a functools.partial object. + """ + import warnings + from functools import partial + + # Create SeedGenerator to ensure backend variable exists + # For future state tracking for distributed keys, add + # attributes for base/split keys and number of devices sharded. + if isinstance(seed, jax.Array): + seed_gen = seed_generator.SeedGenerator(seed=int(seed[0])) + elif isinstance(seed, int): + seed_gen = seed_generator.SeedGenerator(seed=seed) + elif isinstance(seed, seed_generator.SeedGenerator): + seed_gen = seed + else: + raise ValueError( + f"seed must be int, JAX array, or SeedGenerator, got {type(seed)}" + ) + + # Extract the state value as JAX array + jax_seed = seed_gen.state.value + + # Convert to JAX PRNG key format (swap counter and seed value) + jax_compatible_seed = jax.numpy.array( + [jax_seed[1], jax_seed[0]], dtype=jax.numpy.uint32 + ) + + # Validate all required arguments + if init_func is None or init_func.func.__name__ not in [ + "normal", + "truncated_normal", + "uniform", + ]: + raise ValueError( + "init_func cannot be None or " + "Unsupported initializer: {init_func.func.__name__}." + "only JAX-compatible random initializers are supported. " + "Supported jax.random funcs: normal, truncated_normal, uniform" + ) + + # Ensure init_func is a partial + if not isinstance(init_func, partial): + raise TypeError( + f"init_func must be functools.partial object, got {type(init_func)}" + "init_func is a jax.random.* function with shape and " + "dtype bound via partial" + ) + + # Shard based on tensor layout + if layout is None: + warnings.warn( + f"The layout is {layout}, sharding will default to single device" + ) + sharding = None + else: + sharding = _to_backend_layout(layout) + + # JAX PRNG key handling within JIT: + # The key is passed directly to jax.random.* functions which are + # JIT-compatible and functional. JAX automatically ensures different + # random values per shard when out_shardings is specified. + try: + compiled_init = jax.jit( + lambda jax_compatible_seed: init_func(jax_compatible_seed), + out_shardings=sharding, + ) + sample = compiled_init(jax_compatible_seed) + except RuntimeError as e: + warnings.warn( + f"Sharding failed due to: {e}, falling back to single device" + ) + compiled_init = jax.jit( + lambda jax_compatible_seed: init_func(jax_compatible_seed), + out_shardings=None, + ) + sample = compiled_init(jax_compatible_seed) + + # Store the SeedGenerator for state tracking + seed = seed_gen.next() + + # Apply mean/stddev only for distributions where it makes sense + if init_func.func in (jax.random.normal, jax.random.truncated_normal): + return sample * stddev + mean + elif init_func.func == jax.random.uniform: + # Uniform doesn't use mean/stddev - warn + if mean != 0.0 or stddev != 1.0: + warnings.warn( + "mean and stddev are ignored for uniform distribution" + ) + return sample