-
Notifications
You must be signed in to change notification settings - Fork 440
Add support for dLLM encoder-decoder models (DiffusionGemma) [tied-weight PTQ export support ] #1707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add support for dLLM encoder-decoder models (DiffusionGemma) [tied-weight PTQ export support ] #1707
Changes from all commits
47d4ab6
60d4ebb
225072b
a0d9b65
e351c0f
d684477
d0a735e
0543907
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,13 +42,24 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: | |
| {E}.gate_proj.weight, {E}.gate_proj.weight_scale, ... | ||
| {E}.up_proj.weight, {E}.up_proj.weight_scale, ... | ||
| {E}.down_proj.weight, {E}.down_proj.weight_scale, ... | ||
|
|
||
| Tied-experts dedup: when multiple fused-expert modules share their 3-D | ||
| source params via HF ``_tied_weights_keys``, the unpacking creates fresh | ||
| per-expert tensors that break the tie. We cache the source ``data_ptr()`` | ||
| at entry and on a later cache hit alias the per-expert ``weight`` / | ||
| ``weight_scale`` / ``weight_scale_2`` back to the prior module so | ||
| downstream dedup catches them. ``input_scale`` is left per-side. | ||
| """ | ||
| from modelopt.torch.export.unified_export_hf import _export_quantized_weight | ||
| from modelopt.torch.quantization.plugins.huggingface import _get_fused_expert_intermediate_dim | ||
|
|
||
| n = module.num_experts | ||
| expert_dim = _get_fused_expert_intermediate_dim(module) | ||
|
|
||
| # Capture source tensor identities BEFORE unpacking (the source | ||
| # attrs are deleted at the end of this function). | ||
| _source_key = (module.gate_up_proj.data_ptr(), module.down_proj.data_ptr()) | ||
|
|
||
| # 1. Shared input quantizers — one per projection type, shared across all experts. | ||
| gate_up_input_q = module.gate_up_proj_input_quantizer | ||
| down_input_q = module.down_proj_input_quantizer | ||
|
|
@@ -178,6 +189,46 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: | |
| if hasattr(module, attr): | ||
| delattr(module, attr) | ||
|
|
||
| # 5. Tied-experts dedup: if this module's source params have been seen | ||
| # before, alias the bit-identical per-expert buffers (weight, | ||
| # weight_scale, weight_scale_2, input_scale) to the previously-unpacked | ||
| # module. input_scale is safe to alias because sync_tied_input_amax | ||
| # runs earlier in _export_transformers_checkpoint and max-merges the | ||
| # shared input_quantizer amaxes across tied fused-experts modules, so | ||
| # both sides now derive bit-identical input_scale values. | ||
| _cache = _export_fused_experts.__dict__.setdefault("_tied_unpacked_cache", {}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 High (companion to the |
||
| _prior = _cache.get(_source_key) | ||
| if _prior is not None and _prior is not module: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Medium — second tied module is fully unpacked, then discarded. On a cache hit the later tied module has already run the entire unpack/pack path above; this block then aliases all of it back to |
||
| for _idx in range(n): | ||
| _cur_expert = getattr(module, str(_idx), None) | ||
| _prior_expert = getattr(_prior, str(_idx), None) | ||
| if _cur_expert is None or _prior_expert is None: | ||
| continue | ||
| for _proj_name in ("gate_proj", "up_proj", "down_proj"): | ||
| _cur_proj = getattr(_cur_expert, _proj_name, None) | ||
| _prior_proj = getattr(_prior_expert, _proj_name, None) | ||
| if _cur_proj is None or _prior_proj is None: | ||
| continue | ||
| # Alias the weight (Parameter) so both sides reference the | ||
| # same nn.Parameter → same data_ptr() → existing dedup | ||
| # in postprocess_state_dict will drop the duplicate. | ||
| if hasattr(_prior_proj, "weight"): | ||
| _cur_proj.weight = _prior_proj.weight | ||
| # Alias the bit-identical scale buffers (including | ||
| # input_scale, made safe by sync_tied_input_amax pre-export | ||
| # merging). Re-register to ensure data_ptr() matches the | ||
| # prior side's tensor. | ||
| for _attr in ("weight_scale", "weight_scale_2", "input_scale"): | ||
| if not hasattr(_prior_proj, _attr): | ||
| continue | ||
| if _attr in _cur_proj._buffers: | ||
| del _cur_proj._buffers[_attr] | ||
| elif hasattr(_cur_proj, _attr): | ||
| delattr(_cur_proj, _attr) | ||
| _cur_proj.register_buffer(_attr, getattr(_prior_proj, _attr)) | ||
| else: | ||
| _cache[_source_key] = module | ||
|
|
||
|
|
||
| def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None): | ||
| """Collect expert_token_count from all quantized MoE layers and save as an HTML table. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring contradicts implementation for
input_scalealiasing.Line 51 states
input_scale is left per-side, but line 221 explicitly aliasesinput_scalealong withweight_scaleandweight_scale_2. The implementation comment at lines 195-198 correctly explains thatinput_scaleIS aliased becausesync_tied_input_amaxruns earlier.📝 Suggested docstring fix
Tied-experts dedup: when multiple fused-expert modules share their 3-D source params via HF ``_tied_weights_keys``, the unpacking creates fresh per-expert tensors that break the tie. We cache the source ``data_ptr()`` at entry and on a later cache hit alias the per-expert ``weight`` / - ``weight_scale`` / ``weight_scale_2`` back to the prior module so - downstream dedup catches them. ``input_scale`` is left per-side. + ``weight_scale`` / ``weight_scale_2`` / ``input_scale`` back to the prior + module so downstream dedup catches them. ``input_scale`` aliasing is safe + because ``sync_tied_input_amax`` runs earlier and max-merges the shared + input_quantizer amaxes, so both sides derive bit-identical values.🤖 Prompt for AI Agents