Skip to content

Commit a481035

Browse files
committed
feat: jang implementation, fix jang nemotron weights, add jang to venvstacks
1 parent 517806f commit a481035

File tree

2 files changed

+86
-110
lines changed

2 files changed

+86
-110
lines changed

omlx/engine/jang.py

Lines changed: 84 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -201,133 +201,107 @@ def _fix_nemotron_h_weights(self) -> None:
201201
202202
JANG v2 stores Nemotron-H weights with different naming and quantized
203203
gate weights that mlx-lm's nemotron_h.py cannot handle directly.
204-
This applies three fixes:
205204
206-
1. Rename switch_mlp.up_proj/down_proj -> switch_mlp.fc1/fc2
207-
2. Dequantize MoE gate weights (stored as quantized uint32, but
208-
mlx-lm expects nn.Linear with float weights)
209-
3. Drop mtp.* keys (multi-token prediction, unused at inference)
205+
The gate weights are nn.Linear in mlx-lm's model skeleton, but JANG
206+
stores them as quantized uint32. When jang-tools loads with strict=False,
207+
the gate.weight (uint32) is loaded but gate.scales and gate.biases are
208+
dropped because nn.Linear doesn't declare them. We must read the
209+
scales/biases from the original safetensors files to dequantize.
210210
"""
211-
import mlx.nn as nn
212-
from mlx.utils import tree_flatten
213-
211+
model_path = Path(self._model_name)
214212
use_bfloat16 = self._needs_bfloat16()
215213
target_dtype = mx.bfloat16 if use_bfloat16 else mx.float16
216214

217-
# Flatten current weights
218-
weights = dict(tree_flatten(self._model.parameters()))
215+
# Read the shard index to find which files contain gate weights
216+
index_path = model_path / "model.safetensors.index.json"
217+
if not index_path.exists():
218+
# Try consolidated format
219+
index_path = model_path / "consolidated.safetensors.index.json"
220+
if not index_path.exists():
221+
logger.warning("Nemotron-H: no safetensors index found, skipping gate fixup")
222+
return
219223

220-
# --- 1. Rename up_proj/down_proj -> fc1/fc2 ---
221-
renames = {
222-
"switch_mlp.up_proj": "switch_mlp.fc1",
223-
"switch_mlp.down_proj": "switch_mlp.fc2",
224-
}
225-
renamed = {}
226-
rename_count = 0
227-
for k, v in weights.items():
228-
new_k = k
229-
for old, new in renames.items():
230-
if old in k:
231-
new_k = k.replace(old, new)
232-
rename_count += 1
233-
break
234-
renamed[new_k] = v
235-
weights = renamed
236-
if rename_count > 0:
237-
logger.info(f"Nemotron-H: renamed {rename_count} switch_mlp weight keys (up_proj->fc1, down_proj->fc2)")
238-
239-
# --- 2. Dequantize gate weights ---
240-
# Collect gate quantization parts: {prefix: {weight, scales, biases}}
241-
gate_parts: dict[str, dict[str, mx.array]] = {}
242-
non_gate_weights = {}
243-
for k, v in weights.items():
244-
if ".gate." in k:
245-
# Extract prefix (everything before .gate.)
246-
prefix = k[:k.index(".gate.") + len(".gate")]
247-
suffix = k[k.index(".gate.") + len(".gate."):]
224+
with open(index_path) as f:
225+
index = json.load(f)
226+
weight_map = index.get("weight_map", {})
227+
228+
# Find all gate weight/scales/biases keys in the safetensors index
229+
# Group by gate prefix (e.g., "backbone.layers.0.mixer.gate")
230+
gate_parts: dict[str, dict[str, str]] = {} # prefix -> {suffix -> shard_file}
231+
for key, shard in weight_map.items():
232+
if ".gate." in key:
233+
prefix = key[:key.index(".gate.") + len(".gate")]
234+
suffix = key[key.index(".gate.") + len(".gate."):]
248235
if prefix not in gate_parts:
249236
gate_parts[prefix] = {}
250-
if suffix == "weight":
251-
gate_parts[prefix]["weight"] = v
252-
elif suffix == "scales":
253-
gate_parts[prefix]["scales"] = v
254-
elif suffix == "biases":
255-
gate_parts[prefix]["biases"] = v
256-
else:
257-
# Other gate sub-keys, keep as-is
258-
non_gate_weights[k] = v
259-
else:
260-
non_gate_weights[k] = v
237+
gate_parts[prefix][suffix] = shard
261238

262-
weights = non_gate_weights
263-
dequant_count = 0
264-
for prefix, parts in gate_parts.items():
265-
gate_weight = parts.get("weight")
266-
scales = parts.get("scales")
267-
biases = parts.get("biases")
239+
if not gate_parts:
240+
logger.info("Nemotron-H: no gate weights found in index, skipping")
241+
return
242+
243+
# Load gate tensors from safetensors and dequantize
244+
dequantized_weights: list[tuple[str, mx.array]] = []
245+
# Cache loaded shards to avoid re-reading
246+
shard_cache: dict[str, dict[str, mx.array]] = {}
268247

269-
if gate_weight is None:
248+
for prefix, parts in gate_parts.items():
249+
if "weight" not in parts:
250+
continue
251+
if "scales" not in parts or "biases" not in parts:
252+
# Gate is not quantized (no scales/biases), skip
270253
continue
271254

272-
if scales is not None and biases is not None:
273-
# Gate is quantized — dequantize by trying bits in order
274-
# Gate is typically 8-bit (CRITICAL tier)
275-
dequantized = None
276-
for bits in [8, 6, 4, 3, 2]:
277-
elem_per_u32 = 32 // bits
278-
real_cols = gate_weight.shape[-1] * elem_per_u32
279-
gs = real_cols // scales.shape[-1]
280-
if gs > 0 and gs * scales.shape[-1] == real_cols:
281-
dequantized = mx.dequantize(
282-
gate_weight, scales, biases, gs, bits
283-
)
284-
dequantized = dequantized.astype(target_dtype)
285-
logger.debug(
286-
f"Nemotron-H: dequantized {prefix}.weight "
287-
f"({bits}-bit, group_size={gs}) -> {dequantized.shape}"
288-
)
289-
break
290-
if dequantized is not None:
291-
weights[f"{prefix}.weight"] = dequantized
292-
dequant_count += 1
293-
else:
294-
# Could not dequantize — keep original parts
295-
logger.warning(
296-
f"Nemotron-H: could not dequantize {prefix}, "
297-
f"keeping original quantized weights"
255+
# Load the required tensors from safetensors
256+
tensors: dict[str, mx.array] = {}
257+
for suffix in ("weight", "scales", "biases"):
258+
full_key = f"{prefix}.{suffix}"
259+
shard_file = parts[suffix]
260+
if shard_file not in shard_cache:
261+
shard_cache[shard_file] = mx.load(str(model_path / shard_file))
262+
tensors[suffix] = shard_cache[shard_file][full_key]
263+
264+
gate_weight = tensors["weight"]
265+
scales = tensors["scales"]
266+
biases = tensors["biases"]
267+
268+
# Dequantize by trying bit widths (gate is typically 8-bit CRITICAL tier)
269+
dequantized = None
270+
for bits in [8, 6, 4, 3, 2]:
271+
elem_per_u32 = 32 // bits
272+
real_cols = gate_weight.shape[-1] * elem_per_u32
273+
gs = real_cols // scales.shape[-1]
274+
if gs > 0 and gs * scales.shape[-1] == real_cols:
275+
dequantized = mx.dequantize(
276+
gate_weight, scales, biases, gs, bits
298277
)
299-
weights[f"{prefix}.weight"] = gate_weight
300-
weights[f"{prefix}.scales"] = scales
301-
weights[f"{prefix}.biases"] = biases
278+
dequantized = dequantized.astype(target_dtype)
279+
logger.info(
280+
f"Nemotron-H: dequantized {prefix}.weight "
281+
f"({bits}-bit, group_size={gs}) "
282+
f"{gate_weight.shape} -> {dequantized.shape}"
283+
)
284+
break
285+
286+
if dequantized is not None:
287+
dequantized_weights.append((f"{prefix}.weight", dequantized))
302288
else:
303-
# Gate is not quantized, keep weight as-is
304-
weights[f"{prefix}.weight"] = gate_weight
289+
logger.warning(
290+
f"Nemotron-H: could not dequantize {prefix}, "
291+
f"weight={gate_weight.shape}, scales={scales.shape}"
292+
)
293+
294+
# Free shard cache
295+
del shard_cache
305296

306-
if dequant_count > 0:
297+
if dequantized_weights:
298+
self._model.load_weights(dequantized_weights, strict=False)
307299
logger.info(
308-
f"Nemotron-H: dequantized {dequant_count} gate weights to {target_dtype}"
300+
f"Nemotron-H: dequantized {len(dequantized_weights)} "
301+
f"gate weights to {target_dtype}"
309302
)
310-
311-
# --- 3. Drop mtp.* keys ---
312-
mtp_count = sum(1 for k in weights if k.startswith("mtp."))
313-
if mtp_count > 0:
314-
weights = {k: v for k, v in weights.items() if not k.startswith("mtp.")}
315-
logger.info(f"Nemotron-H: dropped {mtp_count} mtp.* keys")
316-
317-
# Reload fixed weights into model (strict=False required per JANG guide)
318-
weight_list = list(weights.items())
319-
self._model.load_weights(weight_list, strict=False)
320-
321-
# Clean up stale quantization attributes on gate modules (nn.Linear)
322-
# After dequantization, gate modules may still have scales/biases attrs
323-
# from the original quantized load — remove them so they don't waste memory
324-
for path, module in tree_flatten(
325-
self._model.leaf_modules(), is_leaf=nn.Module.is_module
326-
):
327-
if ".gate" in path and isinstance(module, nn.Linear):
328-
for attr in ("scales", "biases"):
329-
if hasattr(module, attr):
330-
delattr(module, attr)
303+
else:
304+
logger.info("Nemotron-H: no gate weights needed dequantization")
331305

332306
logger.info("Nemotron-H: weight fixup complete")
333307

packaging/venvstacks.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ requirements = [
5050
"Pillow>=9.0.0",
5151
# ModelScope SDK for downloading models from ModelScope Hub
5252
"modelscope>=1.10.0",
53+
# JANG model support (mixed-precision quantization)
54+
"jang[mlx]>=0.1.0",
5355
]
5456
platforms = [
5557
"macosx_arm64",

0 commit comments

Comments
 (0)