@@ -201,133 +201,107 @@ def _fix_nemotron_h_weights(self) -> None:
201201
202202 JANG v2 stores Nemotron-H weights with different naming and quantized
203203 gate weights that mlx-lm's nemotron_h.py cannot handle directly.
204- This applies three fixes:
205204
206- 1. Rename switch_mlp.up_proj/down_proj -> switch_mlp.fc1/fc2
207- 2. Dequantize MoE gate weights (stored as quantized uint32, but
208- mlx-lm expects nn.Linear with float weights)
209- 3. Drop mtp.* keys (multi-token prediction, unused at inference)
205+ The gate weights are nn.Linear in mlx-lm's model skeleton, but JANG
206+ stores them as quantized uint32. When jang-tools loads with strict=False,
207+ the gate.weight (uint32) is loaded but gate.scales and gate.biases are
208+ dropped because nn.Linear doesn't declare them. We must read the
209+ scales/biases from the original safetensors files to dequantize.
210210 """
211- import mlx .nn as nn
212- from mlx .utils import tree_flatten
213-
211+ model_path = Path (self ._model_name )
214212 use_bfloat16 = self ._needs_bfloat16 ()
215213 target_dtype = mx .bfloat16 if use_bfloat16 else mx .float16
216214
217- # Flatten current weights
218- weights = dict (tree_flatten (self ._model .parameters ()))
215+ # Read the shard index to find which files contain gate weights
216+ index_path = model_path / "model.safetensors.index.json"
217+ if not index_path .exists ():
218+ # Try consolidated format
219+ index_path = model_path / "consolidated.safetensors.index.json"
220+ if not index_path .exists ():
221+ logger .warning ("Nemotron-H: no safetensors index found, skipping gate fixup" )
222+ return
219223
220- # --- 1. Rename up_proj/down_proj -> fc1/fc2 ---
221- renames = {
222- "switch_mlp.up_proj" : "switch_mlp.fc1" ,
223- "switch_mlp.down_proj" : "switch_mlp.fc2" ,
224- }
225- renamed = {}
226- rename_count = 0
227- for k , v in weights .items ():
228- new_k = k
229- for old , new in renames .items ():
230- if old in k :
231- new_k = k .replace (old , new )
232- rename_count += 1
233- break
234- renamed [new_k ] = v
235- weights = renamed
236- if rename_count > 0 :
237- logger .info (f"Nemotron-H: renamed { rename_count } switch_mlp weight keys (up_proj->fc1, down_proj->fc2)" )
238-
239- # --- 2. Dequantize gate weights ---
240- # Collect gate quantization parts: {prefix: {weight, scales, biases}}
241- gate_parts : dict [str , dict [str , mx .array ]] = {}
242- non_gate_weights = {}
243- for k , v in weights .items ():
244- if ".gate." in k :
245- # Extract prefix (everything before .gate.)
246- prefix = k [:k .index (".gate." ) + len (".gate" )]
247- suffix = k [k .index (".gate." ) + len (".gate." ):]
224+ with open (index_path ) as f :
225+ index = json .load (f )
226+ weight_map = index .get ("weight_map" , {})
227+
228+ # Find all gate weight/scales/biases keys in the safetensors index
229+ # Group by gate prefix (e.g., "backbone.layers.0.mixer.gate")
230+ gate_parts : dict [str , dict [str , str ]] = {} # prefix -> {suffix -> shard_file}
231+ for key , shard in weight_map .items ():
232+ if ".gate." in key :
233+ prefix = key [:key .index (".gate." ) + len (".gate" )]
234+ suffix = key [key .index (".gate." ) + len (".gate." ):]
248235 if prefix not in gate_parts :
249236 gate_parts [prefix ] = {}
250- if suffix == "weight" :
251- gate_parts [prefix ]["weight" ] = v
252- elif suffix == "scales" :
253- gate_parts [prefix ]["scales" ] = v
254- elif suffix == "biases" :
255- gate_parts [prefix ]["biases" ] = v
256- else :
257- # Other gate sub-keys, keep as-is
258- non_gate_weights [k ] = v
259- else :
260- non_gate_weights [k ] = v
237+ gate_parts [prefix ][suffix ] = shard
261238
262- weights = non_gate_weights
263- dequant_count = 0
264- for prefix , parts in gate_parts .items ():
265- gate_weight = parts .get ("weight" )
266- scales = parts .get ("scales" )
267- biases = parts .get ("biases" )
239+ if not gate_parts :
240+ logger .info ("Nemotron-H: no gate weights found in index, skipping" )
241+ return
242+
243+ # Load gate tensors from safetensors and dequantize
244+ dequantized_weights : list [tuple [str , mx .array ]] = []
245+ # Cache loaded shards to avoid re-reading
246+ shard_cache : dict [str , dict [str , mx .array ]] = {}
268247
269- if gate_weight is None :
248+ for prefix , parts in gate_parts .items ():
249+ if "weight" not in parts :
250+ continue
251+ if "scales" not in parts or "biases" not in parts :
252+ # Gate is not quantized (no scales/biases), skip
270253 continue
271254
272- if scales is not None and biases is not None :
273- # Gate is quantized — dequantize by trying bits in order
274- # Gate is typically 8-bit (CRITICAL tier)
275- dequantized = None
276- for bits in [8 , 6 , 4 , 3 , 2 ]:
277- elem_per_u32 = 32 // bits
278- real_cols = gate_weight .shape [- 1 ] * elem_per_u32
279- gs = real_cols // scales .shape [- 1 ]
280- if gs > 0 and gs * scales .shape [- 1 ] == real_cols :
281- dequantized = mx .dequantize (
282- gate_weight , scales , biases , gs , bits
283- )
284- dequantized = dequantized .astype (target_dtype )
285- logger .debug (
286- f"Nemotron-H: dequantized { prefix } .weight "
287- f"({ bits } -bit, group_size={ gs } ) -> { dequantized .shape } "
288- )
289- break
290- if dequantized is not None :
291- weights [f"{ prefix } .weight" ] = dequantized
292- dequant_count += 1
293- else :
294- # Could not dequantize — keep original parts
295- logger .warning (
296- f"Nemotron-H: could not dequantize { prefix } , "
297- f"keeping original quantized weights"
255+ # Load the required tensors from safetensors
256+ tensors : dict [str , mx .array ] = {}
257+ for suffix in ("weight" , "scales" , "biases" ):
258+ full_key = f"{ prefix } .{ suffix } "
259+ shard_file = parts [suffix ]
260+ if shard_file not in shard_cache :
261+ shard_cache [shard_file ] = mx .load (str (model_path / shard_file ))
262+ tensors [suffix ] = shard_cache [shard_file ][full_key ]
263+
264+ gate_weight = tensors ["weight" ]
265+ scales = tensors ["scales" ]
266+ biases = tensors ["biases" ]
267+
268+ # Dequantize by trying bit widths (gate is typically 8-bit CRITICAL tier)
269+ dequantized = None
270+ for bits in [8 , 6 , 4 , 3 , 2 ]:
271+ elem_per_u32 = 32 // bits
272+ real_cols = gate_weight .shape [- 1 ] * elem_per_u32
273+ gs = real_cols // scales .shape [- 1 ]
274+ if gs > 0 and gs * scales .shape [- 1 ] == real_cols :
275+ dequantized = mx .dequantize (
276+ gate_weight , scales , biases , gs , bits
298277 )
299- weights [f"{ prefix } .weight" ] = gate_weight
300- weights [f"{ prefix } .scales" ] = scales
301- weights [f"{ prefix } .biases" ] = biases
278+ dequantized = dequantized .astype (target_dtype )
279+ logger .info (
280+ f"Nemotron-H: dequantized { prefix } .weight "
281+ f"({ bits } -bit, group_size={ gs } ) "
282+ f"{ gate_weight .shape } -> { dequantized .shape } "
283+ )
284+ break
285+
286+ if dequantized is not None :
287+ dequantized_weights .append ((f"{ prefix } .weight" , dequantized ))
302288 else :
303- # Gate is not quantized, keep weight as-is
304- weights [f"{ prefix } .weight" ] = gate_weight
289+ logger .warning (
290+ f"Nemotron-H: could not dequantize { prefix } , "
291+ f"weight={ gate_weight .shape } , scales={ scales .shape } "
292+ )
293+
294+ # Free shard cache
295+ del shard_cache
305296
306- if dequant_count > 0 :
297+ if dequantized_weights :
298+ self ._model .load_weights (dequantized_weights , strict = False )
307299 logger .info (
308- f"Nemotron-H: dequantized { dequant_count } gate weights to { target_dtype } "
300+ f"Nemotron-H: dequantized { len (dequantized_weights )} "
301+ f"gate weights to { target_dtype } "
309302 )
310-
311- # --- 3. Drop mtp.* keys ---
312- mtp_count = sum (1 for k in weights if k .startswith ("mtp." ))
313- if mtp_count > 0 :
314- weights = {k : v for k , v in weights .items () if not k .startswith ("mtp." )}
315- logger .info (f"Nemotron-H: dropped { mtp_count } mtp.* keys" )
316-
317- # Reload fixed weights into model (strict=False required per JANG guide)
318- weight_list = list (weights .items ())
319- self ._model .load_weights (weight_list , strict = False )
320-
321- # Clean up stale quantization attributes on gate modules (nn.Linear)
322- # After dequantization, gate modules may still have scales/biases attrs
323- # from the original quantized load — remove them so they don't waste memory
324- for path , module in tree_flatten (
325- self ._model .leaf_modules (), is_leaf = nn .Module .is_module
326- ):
327- if ".gate" in path and isinstance (module , nn .Linear ):
328- for attr in ("scales" , "biases" ):
329- if hasattr (module , attr ):
330- delattr (module , attr )
303+ else :
304+ logger .info ("Nemotron-H: no gate weights needed dequantization" )
331305
332306 logger .info ("Nemotron-H: weight fixup complete" )
333307
0 commit comments