Skip to content

Add Marian encoder-decoder support (static cache)#382

Merged
vortex-captain merged 12 commits into
mainfrom
reny/marian_bart
Apr 28, 2026
Merged

Add Marian encoder-decoder support (static cache)#382
vortex-captain merged 12 commits into
mainfrom
reny/marian_bart

Conversation

@vortex-captain

@vortex-captain vortex-captain commented Apr 23, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Marian (Helsinki-NLP/opus-mt-*) encoder-decoder support: MarianEncoderWrapper / MarianDecoderWrapper using WinMLStaticCache (multi-dim index_put_ → ScatterND). The decoder's OnnxConfig exposes cache_position as the position input; under static cache, stock HF's position_ids=cache_position already indexes the sin/cos table correctly, so no embedding patch is applied at runtime. The PATCHING_SPECS infra (_patched_marian_sinusoidal_forward / _build_marian_patching_specs) is kept in the file with its assignment commented, for reference and future sliding-window re-export. Registered composite: ("marian", "translation")WinMLMarianModel.

  • BART (facebook/bart-large-cnn, sshleifer/distilbart-cnn-12-6) encoder-decoder support: BartEncoderWrapper / BartDecoderWrapper using WinMLStaticCache. Same static-cache approach as Marian — BartLearnedPositionalEmbedding.forward correctly receives position_ids=cache_position from stock HF and applies its +self.offset shift, so no embedding patch is applied at runtime. The _patched_bart_learned_forward patch spec remains in the file (commented out on the decoder config) for future sliding-window re-export. Registered composite: ("bart", "summarization")WinMLBartModel.

kv_cache changes

  • WinMLStaticCache.update uses multi-dim index_put_ with full (batch, head, pos) coord tuples so ONNX export emits ScatterND instead of ScatterElements.
  • WinMLCache.num_layers now reads len(self.layers) after super().__init__ (HF's StaticCache.__init__ builds self.layers from config.get_text_config(decoder=True).num_hidden_layers, i.e. the decoder's layer count). Previously config.num_hidden_layers on the outer BartConfig maps to encoder_layers, which fails for asymmetric distilled variants like distilbart-cnn-12-6 (encoder_layers=12, decoder_layers=6). No-op for symmetric models.
  • PastKeyValueInputGenerator reads normalized_config.head_dim unconditionally (no try/except fallback). MarianConfig / BartConfig have no native head_dim attr, so _MarianDecoderNormalizedConfig / _BartDecoderNormalizedConfig — real NormalizedConfig subclasses with UPPERCASE class-level attribute mappings (same pattern as stock NormalizedTextConfig) — expose head_dim as a computed property.

Verified

Model Test Result
Helsinki-NLP/opus-mt-fr-en 3 fr→en translations 3/3 exact MATCH vs HF PyTorch MarianMTModel
facebook/bart-large-cnn 2 CNN-style summarizations 2/2 PREFIX_MATCH vs HF PyTorch BartForConditionalGeneration
sshleifer/distilbart-cnn-12-6 (asymmetric: 12 enc / 6 dec) 2 CNN-style summarizations 2/2 PREFIX_MATCH vs HF PyTorch

PREFIX_MATCH means WinML ONNX and HF PyTorch agree bit-identically for the full shared prefix (e.g. 53 tokens for BART article 1, 40 for article 2) and then diverge at a single EOS-vs-continuation argmax flip near the natural sentence boundary — the expected fingerprint of fp32 drift from fused ops (gelu_fusion, matmul_add_fusion).

Commits

  1. Add Marian (Helsinki-NLP/opus-mt) encoder-decoder support
  2. Add BART (facebook/bart-large-cnn) encoder-decoder support
  3. Fix WinMLCache.num_layers for asymmetric encoder-decoder models
  4. Derive head_dim in per-model NormalizedConfig instead of kv_cache fallback

Yi Ren added 4 commits April 23, 2026 14:33
MarianEncoderWrapper / MarianDecoderWrapper with WinMLStaticCache
(multi-dim index_put_ -> ScatterND). The decoder OnnxConfig exposes
cache_position as the position input -- under static cache, stock HF's
position_ids=cache_position already indexes the sin/cos table correctly,
so no embedding patch is needed at runtime. The PATCHING_SPECS infra
(_patched_marian_sinusoidal_forward / _build_marian_patching_specs) is
kept in the file with the assignment commented out, for reference and
for future sliding-window re-export. Registered composite model:
("marian", "translation") -> WinMLMarianModel.

kv_cache:
- WinMLStaticCache.update now uses multi-dim index_put_ with full
  (batch, head, pos) coord tuples so ONNX export emits ScatterND
  instead of ScatterElements.
- PastKeyValueInputGenerator falls back to hidden_size / num_heads
  when NormalizedConfig has no explicit head_dim attr.

Verified via temp/test_marian_e2e.py: fr->en, all three test sentences
exact-match vs HF PyTorch (MarianMTModel) reference.
BartEncoderWrapper / BartDecoderWrapper with WinMLStaticCache. The
decoder OnnxConfig exposes cache_position as the position input --
under static cache, stock HF's position_ids=cache_position is applied
with the +self.offset shift inside BartLearnedPositionalEmbedding.forward,
so the learned-position lookup is correct without any patch.

The PATCHING_SPECS infra (_patched_bart_learned_forward /
_build_bart_patching_specs) is kept in the file with its assignment
commented, for reference and for future sliding-window re-export.
Registered composite model: ("bart", "summarization") -> WinMLBartModel.

BartConfig has no head_dim attr; PastKeyValueInputGenerator derives it
from hidden_size / num_attention_heads (covered by the fallback added
in the preceding Marian commit).

Verified via temp/test_bart_e2e.py on facebook/bart-large-cnn: two
CNN-style articles, WinML output is a bit-identical prefix of the HF
PyTorch (BartForConditionalGeneration) reference. The remaining EOS
argmax flip at the natural sentence boundary is the expected fp32
drift from fused ops (gelu_fusion, matmul_add_fusion).
Previously read ``config.num_hidden_layers`` from the outer PretrainedConfig,
which for BART-family encoder-decoder models maps to ``encoder_layers``. On
symmetric models (bart-large-cnn, Marian opus-mt-*, T5) encoder_layers ==
decoder_layers so this was correct by coincidence. For asymmetric distilled
variants like sshleifer/distilbart-cnn-12-6 (encoder_layers=12,
decoder_layers=6), it gave 12 while HF StaticCache.__init__ allocated only
6 layer buffers via config.get_text_config(decoder=True).num_hidden_layers.
WinMLCache.reset() and update_all_layers() then iterated past the end of
self.layers and raised IndexError.

Replace config.num_hidden_layers with len(self.layers) in __init__, after
super().__init__() has already built the correct per-decoder-layer list.
Works for both symmetric and asymmetric models; no behavior change for any
previously-passing model.

Verified via temp/test_distilbart_e2e.py: sshleifer/distilbart-cnn-12-6
summarization PREFIX_MATCH vs HF PyTorch reference on two CNN articles.
…lback

MarianConfig / BartConfig have no native ``head_dim`` attr, so
PastKeyValueInputGenerator used to fall back to
``hidden_size // num_attention_heads`` in a try/except. Move that
derivation to per-model NormalizedConfig subclasses
(_MarianDecoderNormalizedConfig, _BartDecoderNormalizedConfig) exposing
``head_dim`` as a property. kv_cache can now read
``normalized_config.head_dim`` unconditionally.

Switch both OnnxConfigs from ``NormalizedConfig.with_args(...)`` (a
functools.partial) to real NormalizedConfig subclasses with UPPERCASE
class-level attribute mappings (same pattern as NormalizedTextConfig),
which lets us attach the ``head_dim`` property.

Verified:
- temp/test_marian_e2e.py: 3/3 translations exact MATCH vs HF PyTorch
- temp/test_bart_e2e.py:   2/2 PREFIX_MATCH vs HF PyTorch
- temp/test_distilbart_e2e.py: 2/2 PREFIX_MATCH vs HF PyTorch
@vortex-captain vortex-captain marked this pull request as ready for review April 23, 2026 09:55
@vortex-captain vortex-captain requested a review from a team as a code owner April 23, 2026 09:55
@xieofxie

Copy link
Copy Markdown
Contributor

do we need tests? although no one will review them (?)

Comment thread src/winml/modelkit/models/hf/bart.py Outdated
Comment thread src/winml/modelkit/models/hf/bart.py
Comment thread src/winml/modelkit/models/hf/bart.py
Comment thread src/winml/modelkit/models/hf/bart.py

@zhenchaoni zhenchaoni left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the bart.py

Comment thread src/winml/modelkit/models/hf/bart.py
Yi Ren added 4 commits April 27, 2026 15:47
The comment described Slice+Concat eviction, which only applies to
WinMLSlidingWindowCache. Both BartDecoderWrapper and MarianDecoderWrapper
currently use WinMLStaticCache (the sliding-window line is commented out),
so the comment was misleading.
Tier 1 (tests/unit/export/test_io.py): WinMLStaticCache.update writes
new KV at the requested cache_position via multi-dim index_put_;
WinMLCache.num_layers follows the decoder layer count for asymmetric
encoder-decoder configs (distilbart-cnn-12-6 regression); Marian/BART
decoder dummy inputs use the right KV shape and decoder-layer count.

Tier 2 (tests/unit/models/{marian,bart}/test_onnx_config.py): TasksManager
registration, NormalizedConfig attribute mapping (head_dim derived,
num_layers=decoder_layers), DUMMY_INPUT_GENERATOR_CLASSES, encoder/decoder
inputs/outputs, and WinMLMarianModel/WinMLBartModel composite registration
plus get_cache_class().
@vortex-captain

Copy link
Copy Markdown
Contributor Author

do we need tests? although no one will review them (?)

UTs added

Yi Ren added 3 commits April 28, 2026 10:44
The Qwen3 composite is registered for ("qwen3", "text-generation"), but
WinMLQwen3Model._SUB_MODEL_CONFIG also used "text-generation" as the
decoder_gen sub-task. When the composite built that sub-component, the
auto-dispatch in WinMLAutoModel.from_pretrained re-resolved the same
(qwen3, text-generation) registry key and recursed back into
WinMLCompositeModel.from_pretrained until the call stack overflowed.

Rename the decoder_gen sub-task to "text2text-generation" so the leaf
build no longer collides with the composite key. The user-facing
composite registration "text-generation" is unchanged. MODEL_CLASS_MAPPING
routes (qwen3, text2text-generation) to QwenDecoderWrapper (loads via
AutoModelForCausalLM), so Optimum's default seq2seq loader for that task
is bypassed.

Verified end-to-end with scripts/test_qwen.py on QNN/NPU: prefill and gen
sub-models build, model.generate() returns expected output.
Required by Helsinki-NLP/opus-mt-* tokenizers (Marian translation models).
@zhenchaoni

Copy link
Copy Markdown
Member

Overall is good to me. I leveraged AI to do a check as well. I think both AI comments can be solved by refactoring the code to use the common decoder_wrapper. You can consider doing it in a separate PR.

1. Heavy structural duplication between bart.py and marian.py

These two files are ~90% structurally identical:

Component bart.py marian.py Difference
*EncoderWrapper class name + outer model class
*DecoderWrapper.forward line 281-352 line 319-392 bit-identical body, just renamed class
*DecoderIOConfig.inputs/outputs line 432-456 line 472-496 bit-identical
_*DecoderNormalizedConfig line 388-403 line 428-443 bit-identical field map
_patched_*_forward line 142-184 line 179-214 same shape (different embedding type)
_build_*_patching_specs line 187-200 line 217-233 bit-identical pattern

This is independent of any other PR — even between just bart and marian, a shared _BartFamilyDecoderWrapperBase (or even just _BartFamilyDecoderIOConfigBase for the inputs/outputs body) would cut ~250-300 LOC of duplicated code. Two parallel files mean every future bug fix has to be applied twice.

Suggestion: factor at least the IOConfig inputs/outputs body and the NormalizedConfig field map into a shared base in either bart.py (and marian.py imports from there) or a new _bart_family.py. The decoder wrapper's forward body is harder to share without a true base class, but the IOConfig body and the NormalizedConfig are easy wins.


2. Positional args[N] indexing couples the wrapper to the IOConfig's input order — fragile

In [bart.py:293-313]:

decoder_input_ids = args[0]
encoder_hidden_states = args[1]
attention_mask = args[2]
decoder_attention_mask = args[3]
cache_position = args[4]
kv_start = 5

max_cache_len = args[kv_start].size(2)
...
for i in range(self.num_layers):
    self_attn_cache.layers[i].keys = args[kv_start + i * 2]
    self_attn_cache.layers[i].values = args[kv_start + i * 2 + 1]

The contract is implicit: "the IOConfig's inputs property declares names in this exact order; forward receives them positionally in the same order; we read args[N] by hardcoded index." Nothing in the wrapper checks that the IOConfig matches this expectation — and nothing in the IOConfig knows that the wrapper depends on its order. If anyone reorders or inserts an input in the IOConfig's inputs dict (e.g., adds head_mask), forward silently misreads and produces nonsense ONNX.

Suggestion: in forward, build a name-keyed dict from the IOConfig:

inputs = dict(zip(self.onnx_config.inputs.keys(), args, strict=True))
decoder_input_ids = inputs["decoder_input_ids"]
# ...
for i in range(self.num_layers):
    self_attn_cache.layers[i].keys = inputs[f"past_{i}_key"]

This makes the wrapper explicit about which names it depends on, fails loudly with strict=True if the count mismatches, and is robust to IOConfig reorderings.


@vortex-captain vortex-captain enabled auto-merge (squash) April 28, 2026 07:40
@vortex-captain vortex-captain merged commit d14c5aa into main Apr 28, 2026
9 checks passed
@vortex-captain vortex-captain deleted the reny/marian_bart branch April 28, 2026 07:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants