feat: add support to multi-model and KV cache#206
Closed
vortex-captain wants to merge 24 commits into
Closed
Conversation
added 14 commits
April 1, 2026 13:42
Split T5ForConditionalGeneration into encoder (feature-extraction) and
decoder (text2text-generation) ONNX models with static append-only KV
cache using scatter-write at cache_position.
Key components:
- StaticWriteCache: append-only KV cache layer using torch.scatter,
fully traceable by torch.jit.trace, fixed-shape I/O
- T5EncoderWrapper / T5DecoderWithStaticCache: export wrappers with
from_pretrained() (SAM2 pattern), registered via MODEL_CLASS_MAPPING
- T5EncoderIOConfig / T5DecoderIOConfig: OnnxConfig registrations with
custom DummyInputGenerators for KV cache tensors
- WinMLModelForSeq2SeqLM: inference wrapper composing encoder + decoder
WinMLAutoModel instances, compatible with transformers.pipeline
- get_export_args protocol in HTPExporter for positional arg export
Verified: pipeline('translation_en_to_fr') produces exact match with
PyTorch reference.
Known limitation: OpenVINO EP does not support ScatterElements, requires
CPU EP fallback for decoder inference.
ONNX decoder now outputs only the new token's KV [batch,heads,1,d_kv] instead of the full static buffer. The inference wrapper uses HF's StaticCache (mutated in-place via index_copy_) as the stateful cache that flows through GenerationMixin's generation loop. - Rename T5DecoderWithStaticCache to T5DecoderWrapper - Extract single-token KV via gather at cache_position in export wrapper - Replace custom StaticWriteEncoderDecoderCache with HF StaticCache in WinMLModelForSeq2SeqLM inference wrapper - Cache is initialized once, mutated in-place each step, same object reference returned in Seq2SeqLMOutput
Replace custom StaticWriteCache/StaticWriteEncoderDecoderCache with HF's StaticCache in both the ONNX export wrapper and inference wrapper. The same cache class is now used end-to-end: - Export (t5.py): constructs StaticCache from input KV tensors, wraps in EncoderDecoderCache, extracts new token KV via gather after index_copy_ - Inference (seq2seq.py): StaticCache persists across generation steps, updated in-place via cache.update() (index_copy_ at cache_position) - Handle GenerationMixin wrapping cache in EncoderDecoderCache - Remove models/cache.py (no longer imported) - Update study doc with final design
Three-layer class hierarchy for multi-component ONNX models: - WinMLPipelineModel: base with _SUB_MODEL_CONFIG, from_pretrained that builds each sub-component via WinMLAutoModel - WinMLGenerationModel: encoder-decoder forward with StaticCache, GenerationMixin interface (get_encoder, prepare_inputs, can_generate) - WinMLT5Model: T5-specific sub-model tasks and generation_config Use HF StaticCache for both export (index_copy_ traces correctly) and inference (stateful, mutated in-place via cache.update). ONNX decoder outputs only new token KV [batch,heads,1,d_kv]; inference wrapper writes it back into the cache buffer.
…nputs - Replace _run_encoder + _EncoderProxy with _EncoderWithInputPadding that wraps the raw encoder with auto-padding; assigned as self._encoder so all callsites (forward, get_encoder) use it directly - Extract shared _pad_inputs(source, expected) helper used by both encoder and decoder feed building - forward() takes *, encoder_outputs, past_key_values, input_ids, **model_kwargs — no hardcoded input names - Encoder input names read from ONNX io_config, not assumed - Use encoder_outputs["last_hidden_state"] (explicit key) - Derive _num_kv_layers and _max_dec from ONNX shapes - feeds.setdefault for generated inputs (encoder_hidden_states, decoder_attention_mask, cache_position)
…d docstring
- Rename WinMLGenerationModel to WinMLEncoderDecoderModel
- Add PIPELINE_MODEL_REGISTRY + @register_pipeline_model decorator
- Register WinMLT5Model as ("t5", "translation")
- Add comprehensive module docstring with architecture, KV cache
findings, and design principles for onboarding
added 4 commits
April 7, 2026 17:07
- Extract CapturingStaticCache to shared kv_cache.py (used by T5 and Qwen3) - Add QwenDecoderWrapper with prefill/gen OnnxConfig registrations - Add WinMLDecoderOnlyModel with GenerationMixin, chunked prefill, StaticCache - Register qwen3 build config (dynamo=True, opset 18) - Fix dynamo field missing from _merge_export_config Verified: ONNX logits match HF exactly, generate() produces correct answers
added 6 commits
April 9, 2026 15:13
Replace duplicate T5KVCacheInputGenerator and QwenKVCacheInputGenerator with a single PastKeyValueInputGenerator in kv_cache.py that reads num_layers, num_attention_heads, head_dim, and max_cache_len from NormalizedConfig. Update pipeline-model.md to use WinMLPipelineModel.
Move WinMLEncoderDecoderModel + EncoderDecoderInputGenerator to models/hf/encoder_decoder.py, DecoderOnlyInputGenerator to models/winml/decoder_only.py. Colocate WinMLT5Model in t5.py and WinMLQwen3Model in qwen.py alongside their export configs. Rename seq2seq.py to pipeline_model.py (now only WinMLPipelineModel + registry). Add sub_model_kwargs to from_pretrained for per-component shape_config. Generators accept max_cache_len/seq_len kwargs as overrides from shape_config.
Add Mu2EncoderWrapper, Mu2DecoderWrapper (with CapturingStaticCache), OnnxConfig registrations, and WinMLMu2Model for translation pipeline. Decoder delegates to modeling_mu.py's own decoder (patched with past_key_values + cache_position support). Also forward trust_remote_code through config/build CLI and WinMLAutoModel/WinMLPipelineModel for custom auto_map models. Accept sequence_length kwarg in EncoderDecoderInputGenerator for shape_config overrides.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Multi-model pipeline support for encoder-decoder architectures (T5 translation), with a class hierarchy designed to extend to other multi-component models (SD, vision-language).
Class hierarchy
Key components
@register_pipeline_model("t5", "translation")—winml configgenerates one config per sub-component automaticallymodels/hf/t5.py):T5EncoderWrapper+T5DecoderWrapperwithStaticCache(HF'sindex_copy_traces correctly), registered viaMODEL_CLASS_MAPPING[batch, heads, 1, d_kv]models/winml/seq2seq.py):WinMLEncoderDecoderModelwithStaticCachemutated in-place viacache.update(), compatible withtransformers.pipelineget_export_argsprotocol in HTPExporter for positional arg exportUsage
Verified
winml config --task translationgenerates encoder + decoder configswinml buildsucceeds for both componentstransformers.pipeline("translation")exact match vs PyTorch reference