Add Marian encoder-decoder support (static cache)#382
Conversation
MarianEncoderWrapper / MarianDecoderWrapper with WinMLStaticCache
(multi-dim index_put_ -> ScatterND). The decoder OnnxConfig exposes
cache_position as the position input -- under static cache, stock HF's
position_ids=cache_position already indexes the sin/cos table correctly,
so no embedding patch is needed at runtime. The PATCHING_SPECS infra
(_patched_marian_sinusoidal_forward / _build_marian_patching_specs) is
kept in the file with the assignment commented out, for reference and
for future sliding-window re-export. Registered composite model:
("marian", "translation") -> WinMLMarianModel.
kv_cache:
- WinMLStaticCache.update now uses multi-dim index_put_ with full
(batch, head, pos) coord tuples so ONNX export emits ScatterND
instead of ScatterElements.
- PastKeyValueInputGenerator falls back to hidden_size / num_heads
when NormalizedConfig has no explicit head_dim attr.
Verified via temp/test_marian_e2e.py: fr->en, all three test sentences
exact-match vs HF PyTorch (MarianMTModel) reference.
BartEncoderWrapper / BartDecoderWrapper with WinMLStaticCache. The
decoder OnnxConfig exposes cache_position as the position input --
under static cache, stock HF's position_ids=cache_position is applied
with the +self.offset shift inside BartLearnedPositionalEmbedding.forward,
so the learned-position lookup is correct without any patch.
The PATCHING_SPECS infra (_patched_bart_learned_forward /
_build_bart_patching_specs) is kept in the file with its assignment
commented, for reference and for future sliding-window re-export.
Registered composite model: ("bart", "summarization") -> WinMLBartModel.
BartConfig has no head_dim attr; PastKeyValueInputGenerator derives it
from hidden_size / num_attention_heads (covered by the fallback added
in the preceding Marian commit).
Verified via temp/test_bart_e2e.py on facebook/bart-large-cnn: two
CNN-style articles, WinML output is a bit-identical prefix of the HF
PyTorch (BartForConditionalGeneration) reference. The remaining EOS
argmax flip at the natural sentence boundary is the expected fp32
drift from fused ops (gelu_fusion, matmul_add_fusion).
Previously read ``config.num_hidden_layers`` from the outer PretrainedConfig, which for BART-family encoder-decoder models maps to ``encoder_layers``. On symmetric models (bart-large-cnn, Marian opus-mt-*, T5) encoder_layers == decoder_layers so this was correct by coincidence. For asymmetric distilled variants like sshleifer/distilbart-cnn-12-6 (encoder_layers=12, decoder_layers=6), it gave 12 while HF StaticCache.__init__ allocated only 6 layer buffers via config.get_text_config(decoder=True).num_hidden_layers. WinMLCache.reset() and update_all_layers() then iterated past the end of self.layers and raised IndexError. Replace config.num_hidden_layers with len(self.layers) in __init__, after super().__init__() has already built the correct per-decoder-layer list. Works for both symmetric and asymmetric models; no behavior change for any previously-passing model. Verified via temp/test_distilbart_e2e.py: sshleifer/distilbart-cnn-12-6 summarization PREFIX_MATCH vs HF PyTorch reference on two CNN articles.
…lback MarianConfig / BartConfig have no native ``head_dim`` attr, so PastKeyValueInputGenerator used to fall back to ``hidden_size // num_attention_heads`` in a try/except. Move that derivation to per-model NormalizedConfig subclasses (_MarianDecoderNormalizedConfig, _BartDecoderNormalizedConfig) exposing ``head_dim`` as a property. kv_cache can now read ``normalized_config.head_dim`` unconditionally. Switch both OnnxConfigs from ``NormalizedConfig.with_args(...)`` (a functools.partial) to real NormalizedConfig subclasses with UPPERCASE class-level attribute mappings (same pattern as NormalizedTextConfig), which lets us attach the ``head_dim`` property. Verified: - temp/test_marian_e2e.py: 3/3 translations exact MATCH vs HF PyTorch - temp/test_bart_e2e.py: 2/2 PREFIX_MATCH vs HF PyTorch - temp/test_distilbart_e2e.py: 2/2 PREFIX_MATCH vs HF PyTorch
|
do we need tests? although no one will review them (?) |
The comment described Slice+Concat eviction, which only applies to WinMLSlidingWindowCache. Both BartDecoderWrapper and MarianDecoderWrapper currently use WinMLStaticCache (the sliding-window line is commented out), so the comment was misleading.
Tier 1 (tests/unit/export/test_io.py): WinMLStaticCache.update writes
new KV at the requested cache_position via multi-dim index_put_;
WinMLCache.num_layers follows the decoder layer count for asymmetric
encoder-decoder configs (distilbart-cnn-12-6 regression); Marian/BART
decoder dummy inputs use the right KV shape and decoder-layer count.
Tier 2 (tests/unit/models/{marian,bart}/test_onnx_config.py): TasksManager
registration, NormalizedConfig attribute mapping (head_dim derived,
num_layers=decoder_layers), DUMMY_INPUT_GENERATOR_CLASSES, encoder/decoder
inputs/outputs, and WinMLMarianModel/WinMLBartModel composite registration
plus get_cache_class().
UTs added |
The Qwen3 composite is registered for ("qwen3", "text-generation"), but
WinMLQwen3Model._SUB_MODEL_CONFIG also used "text-generation" as the
decoder_gen sub-task. When the composite built that sub-component, the
auto-dispatch in WinMLAutoModel.from_pretrained re-resolved the same
(qwen3, text-generation) registry key and recursed back into
WinMLCompositeModel.from_pretrained until the call stack overflowed.
Rename the decoder_gen sub-task to "text2text-generation" so the leaf
build no longer collides with the composite key. The user-facing
composite registration "text-generation" is unchanged. MODEL_CLASS_MAPPING
routes (qwen3, text2text-generation) to QwenDecoderWrapper (loads via
AutoModelForCausalLM), so Optimum's default seq2seq loader for that task
is bypassed.
Verified end-to-end with scripts/test_qwen.py on QNN/NPU: prefill and gen
sub-models build, model.generate() returns expected output.
Required by Helsinki-NLP/opus-mt-* tokenizers (Marian translation models).
|
Overall is good to me. I leveraged AI to do a check as well. I think both AI comments can be solved by refactoring the code to use the common 1. Heavy structural duplication between bart.py and marian.pyThese two files are ~90% structurally identical:
This is independent of any other PR — even between just bart and marian, a shared Suggestion: factor at least the IOConfig 2. Positional
|
Summary
Marian (
Helsinki-NLP/opus-mt-*) encoder-decoder support:MarianEncoderWrapper/MarianDecoderWrapperusingWinMLStaticCache(multi-dimindex_put_→ ScatterND). The decoder'sOnnxConfigexposescache_positionas the position input; under static cache, stock HF'sposition_ids=cache_positionalready indexes the sin/cos table correctly, so no embedding patch is applied at runtime. ThePATCHING_SPECSinfra (_patched_marian_sinusoidal_forward/_build_marian_patching_specs) is kept in the file with its assignment commented, for reference and future sliding-window re-export. Registered composite:("marian", "translation")→WinMLMarianModel.BART (
facebook/bart-large-cnn,sshleifer/distilbart-cnn-12-6) encoder-decoder support:BartEncoderWrapper/BartDecoderWrapperusingWinMLStaticCache. Same static-cache approach as Marian —BartLearnedPositionalEmbedding.forwardcorrectly receivesposition_ids=cache_positionfrom stock HF and applies its+self.offsetshift, so no embedding patch is applied at runtime. The_patched_bart_learned_forwardpatch spec remains in the file (commented out on the decoder config) for future sliding-window re-export. Registered composite:("bart", "summarization")→WinMLBartModel.kv_cachechangesWinMLStaticCache.updateuses multi-dimindex_put_with full(batch, head, pos)coord tuples so ONNX export emits ScatterND instead of ScatterElements.WinMLCache.num_layersnow readslen(self.layers)aftersuper().__init__(HF'sStaticCache.__init__buildsself.layersfromconfig.get_text_config(decoder=True).num_hidden_layers, i.e. the decoder's layer count). Previouslyconfig.num_hidden_layerson the outerBartConfigmaps toencoder_layers, which fails for asymmetric distilled variants likedistilbart-cnn-12-6(encoder_layers=12, decoder_layers=6). No-op for symmetric models.PastKeyValueInputGeneratorreadsnormalized_config.head_dimunconditionally (no try/except fallback).MarianConfig/BartConfighave no nativehead_dimattr, so_MarianDecoderNormalizedConfig/_BartDecoderNormalizedConfig— realNormalizedConfigsubclasses with UPPERCASE class-level attribute mappings (same pattern as stockNormalizedTextConfig) — exposehead_dimas a computed property.Verified
Helsinki-NLP/opus-mt-fr-enMarianMTModelfacebook/bart-large-cnnBartForConditionalGenerationsshleifer/distilbart-cnn-12-6(asymmetric: 12 enc / 6 dec)PREFIX_MATCHmeans WinML ONNX and HF PyTorch agree bit-identically for the full shared prefix (e.g. 53 tokens for BART article 1, 40 for article 2) and then diverge at a single EOS-vs-continuation argmax flip near the natural sentence boundary — the expected fingerprint of fp32 drift from fused ops (gelu_fusion,matmul_add_fusion).Commits