From ac871f186e5e8649fdba1f68a24456857f199af6 Mon Sep 17 00:00:00 2001 From: praveenhosdrug123 Date: Fri, 17 Oct 2025 18:28:20 +0530 Subject: [PATCH 1/2] Fix OOM Issue --- keras/src/backend/jax/distribution_lib.py | 77 +++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index 6b5bf37314c0..e94c1bf43417 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -246,3 +246,80 @@ def _to_backend_layout(tensor_layout): partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes) jax_mesh = tensor_layout.device_mesh.backend_mesh return jax.sharding.NamedSharding(jax_mesh, partition_spec) + + +def _distribute_initializer( + init_func=None, mean=0.0, stddev=1.0, seed=None, layout=None +): + """ + Distribution-aware token embedding initializer for JAX backend. + + This function will create a Jax random array and + distribute it according to the current token embedding layout. + + Args: + init_func: A functools.partial-wrapped object that takes the seed + as argument and returns a jax.Array. Must have shape and dtype + already bound via partial. + mean: Mean of distribution (applied to normal/truncated_normal). + stddev: Standard deviation of the distribution. + seed: Random seed for initialization. + layout: TensorLayout for the distributed tensor. + + Returns: + A distributed jax array. + + Raises: + ValueError: If init_func or seed is None. + If init_func.func is not a supported random function. + TypeError: If init_func is not a functools.partial object. + """ + import warnings + from functools import partial + + # Validate all required arguments + if seed is None: + raise ValueError("seed cannot be None. Use keras.random.SeedGenerator.") + + if init_func is None: + raise ValueError( + "init_func cannot be None. Shape and dtype info are required." + ) + + # Ensure init_func is a partial + if not isinstance(init_func, partial): + raise TypeError( + f"init_func must be functools.partial object, got {type(init_func)}" + ) + + # Shard based on tensor layout + if layout is None: + warnings.warn( + f"The layout is {layout}, sharding will default to single device" + ) + sharding = None + else: + sharding = _to_backend_layout(layout) + + # The init_func has static arguments baked in as per initializer. + compiled_init = jax.jit( + lambda seed: init_func(seed), out_shardings=sharding + ) + + sample = compiled_init(seed) + + # Apply mean/stddev only for distributions where it makes sense + if init_func.func in (jax.random.normal, jax.random.truncated_normal): + return sample * stddev + mean + elif init_func.func == jax.random.uniform: + # Uniform doesn't use mean/stddev - warn + if mean != 0.0 or stddev != 1.0: + warnings.warn( + "mean and stddev are ignored for uniform distribution" + ) + return sample + else: + raise ValueError( + f"Unsupported initializer: {init_func.func.__name__}. " + f"Supported: normal, truncated_normal, uniform" + ) From b36b051845532b2db846e7032ccd336209535ce8 Mon Sep 17 00:00:00 2001 From: praveenhosdrug123 Date: Thu, 23 Oct 2025 11:56:02 +0530 Subject: [PATCH 2/2] Address review feedback: improve error messages and add PRNG key handling comments --- keras/src/backend/jax/distribution_lib.py | 70 ++++++++++++++++++----- 1 file changed, 55 insertions(+), 15 deletions(-) diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index e94c1bf43417..6f3cf0b7d4c4 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -272,24 +272,53 @@ def _distribute_initializer( Raises: ValueError: If init_func or seed is None. If init_func.func is not a supported random function. + Supported jax.random func: normal, truncated_normal, uniform TypeError: If init_func is not a functools.partial object. """ import warnings from functools import partial - # Validate all required arguments - if seed is None: - raise ValueError("seed cannot be None. Use keras.random.SeedGenerator.") + # Create SeedGenerator to ensure backend variable exists + # For future state tracking for distributed keys, add + # attributes for base/split keys and number of devices sharded. + if isinstance(seed, jax.Array): + seed_gen = seed_generator.SeedGenerator(seed=int(seed[0])) + elif isinstance(seed, int): + seed_gen = seed_generator.SeedGenerator(seed=seed) + elif isinstance(seed, seed_generator.SeedGenerator): + seed_gen = seed + else: + raise ValueError( + f"seed must be int, JAX array, or SeedGenerator, got {type(seed)}" + ) - if init_func is None: + # Extract the state value as JAX array + jax_seed = seed_gen.state.value + + # Convert to JAX PRNG key format (swap counter and seed value) + jax_compatible_seed = jax.numpy.array( + [jax_seed[1], jax_seed[0]], dtype=jax.numpy.uint32 + ) + + # Validate all required arguments + if init_func is None or init_func.func.__name__ not in [ + "normal", + "truncated_normal", + "uniform", + ]: raise ValueError( - "init_func cannot be None. Shape and dtype info are required." + "init_func cannot be None or " + "Unsupported initializer: {init_func.func.__name__}." + "only JAX-compatible random initializers are supported. " + "Supported jax.random funcs: normal, truncated_normal, uniform" ) # Ensure init_func is a partial if not isinstance(init_func, partial): raise TypeError( f"init_func must be functools.partial object, got {type(init_func)}" + "init_func is a jax.random.* function with shape and " + "dtype bound via partial" ) # Shard based on tensor layout @@ -301,12 +330,28 @@ def _distribute_initializer( else: sharding = _to_backend_layout(layout) - # The init_func has static arguments baked in as per initializer. - compiled_init = jax.jit( - lambda seed: init_func(seed), out_shardings=sharding - ) + # JAX PRNG key handling within JIT: + # The key is passed directly to jax.random.* functions which are + # JIT-compatible and functional. JAX automatically ensures different + # random values per shard when out_shardings is specified. + try: + compiled_init = jax.jit( + lambda jax_compatible_seed: init_func(jax_compatible_seed), + out_shardings=sharding, + ) + sample = compiled_init(jax_compatible_seed) + except RuntimeError as e: + warnings.warn( + f"Sharding failed due to: {e}, falling back to single device" + ) + compiled_init = jax.jit( + lambda jax_compatible_seed: init_func(jax_compatible_seed), + out_shardings=None, + ) + sample = compiled_init(jax_compatible_seed) - sample = compiled_init(seed) + # Store the SeedGenerator for state tracking + seed = seed_gen.next() # Apply mean/stddev only for distributions where it makes sense if init_func.func in (jax.random.normal, jax.random.truncated_normal): @@ -318,8 +363,3 @@ def _distribute_initializer( "mean and stddev are ignored for uniform distribution" ) return sample - else: - raise ValueError( - f"Unsupported initializer: {init_func.func.__name__}. " - f"Supported: normal, truncated_normal, uniform" - )