Skip to content

feat: add support to multi-model and KV cache#206

Closed
vortex-captain wants to merge 24 commits into
mainfrom
reny/t5
Closed

feat: add support to multi-model and KV cache#206
vortex-captain wants to merge 24 commits into
mainfrom
reny/t5

Conversation

@vortex-captain

@vortex-captain vortex-captain commented Apr 1, 2026

Copy link
Copy Markdown
Contributor

Summary

Multi-model pipeline support for encoder-decoder architectures (T5 translation), with a class hierarchy designed to extend to other multi-component models (SD, vision-language).

Class hierarchy

WinMLPipelineModel(PreTrainedModel)            — multi-component base
  └─ WinMLEncoderDecoderModel(GenerationMixin) — encoder-decoder with StaticCache
       └─ WinMLT5Model                         — T5 tasks + generation config

Key components

  • Pipeline model registry: @register_pipeline_model("t5", "translation")winml config generates one config per sub-component automatically
  • Export wrappers (models/hf/t5.py): T5EncoderWrapper + T5DecoderWrapper with StaticCache (HF's index_copy_ traces correctly), registered via MODEL_CLASS_MAPPING
  • ONNX decoder I/O: static-shape inputs (full KV buffer), outputs only new token's KV [batch, heads, 1, d_kv]
  • Inference wrapper (models/winml/seq2seq.py): WinMLEncoderDecoderModel with StaticCache mutated in-place via cache.update(), compatible with transformers.pipeline
  • get_export_args protocol in HTPExporter for positional arg export

Usage

winml config -m google-t5/t5-small --task translation -o t5.json
  → t5_encoder.json + t5_decoder.json

winml build -c t5_encoder.json -m google-t5/t5-small -o output/encoder
winml build -c t5_decoder.json -m google-t5/t5-small -o output/decoder
model = WinMLT5Model.from_pretrained("google-t5/t5-small")
pipe = pipeline("translation_en_to_fr", model=model, tokenizer=tokenizer)
pipe("Hello, how are you?", num_beams=1)
# → Bonjour, comment êtes-vous ?

Verified

  • winml config --task translation generates encoder + decoder configs
  • winml build succeeds for both components
  • transformers.pipeline("translation") exact match vs PyTorch reference

Yi Ren added 14 commits April 1, 2026 13:42
Split T5ForConditionalGeneration into encoder (feature-extraction) and
decoder (text2text-generation) ONNX models with static append-only KV
cache using scatter-write at cache_position.

Key components:
- StaticWriteCache: append-only KV cache layer using torch.scatter,
  fully traceable by torch.jit.trace, fixed-shape I/O
- T5EncoderWrapper / T5DecoderWithStaticCache: export wrappers with
  from_pretrained() (SAM2 pattern), registered via MODEL_CLASS_MAPPING
- T5EncoderIOConfig / T5DecoderIOConfig: OnnxConfig registrations with
  custom DummyInputGenerators for KV cache tensors
- WinMLModelForSeq2SeqLM: inference wrapper composing encoder + decoder
  WinMLAutoModel instances, compatible with transformers.pipeline
- get_export_args protocol in HTPExporter for positional arg export

Verified: pipeline('translation_en_to_fr') produces exact match with
PyTorch reference.

Known limitation: OpenVINO EP does not support ScatterElements, requires
CPU EP fallback for decoder inference.
ONNX decoder now outputs only the new token's KV [batch,heads,1,d_kv]
instead of the full static buffer. The inference wrapper uses HF's
StaticCache (mutated in-place via index_copy_) as the stateful cache
that flows through GenerationMixin's generation loop.

- Rename T5DecoderWithStaticCache to T5DecoderWrapper
- Extract single-token KV via gather at cache_position in export wrapper
- Replace custom StaticWriteEncoderDecoderCache with HF StaticCache
  in WinMLModelForSeq2SeqLM inference wrapper
- Cache is initialized once, mutated in-place each step, same object
  reference returned in Seq2SeqLMOutput
Replace custom StaticWriteCache/StaticWriteEncoderDecoderCache with HF's
StaticCache in both the ONNX export wrapper and inference wrapper. The
same cache class is now used end-to-end:

- Export (t5.py): constructs StaticCache from input KV tensors, wraps in
  EncoderDecoderCache, extracts new token KV via gather after index_copy_
- Inference (seq2seq.py): StaticCache persists across generation steps,
  updated in-place via cache.update() (index_copy_ at cache_position)
- Handle GenerationMixin wrapping cache in EncoderDecoderCache
- Remove models/cache.py (no longer imported)
- Update study doc with final design
Three-layer class hierarchy for multi-component ONNX models:

- WinMLPipelineModel: base with _SUB_MODEL_CONFIG, from_pretrained
  that builds each sub-component via WinMLAutoModel
- WinMLGenerationModel: encoder-decoder forward with StaticCache,
  GenerationMixin interface (get_encoder, prepare_inputs, can_generate)
- WinMLT5Model: T5-specific sub-model tasks and generation_config

Use HF StaticCache for both export (index_copy_ traces correctly) and
inference (stateful, mutated in-place via cache.update). ONNX decoder
outputs only new token KV [batch,heads,1,d_kv]; inference wrapper
writes it back into the cache buffer.
…nputs

- Replace _run_encoder + _EncoderProxy with _EncoderWithInputPadding
  that wraps the raw encoder with auto-padding; assigned as self._encoder
  so all callsites (forward, get_encoder) use it directly
- Extract shared _pad_inputs(source, expected) helper used by both
  encoder and decoder feed building
- forward() takes *, encoder_outputs, past_key_values, input_ids,
  **model_kwargs — no hardcoded input names
- Encoder input names read from ONNX io_config, not assumed
- Use encoder_outputs["last_hidden_state"] (explicit key)
- Derive _num_kv_layers and _max_dec from ONNX shapes
- feeds.setdefault for generated inputs (encoder_hidden_states,
  decoder_attention_mask, cache_position)
…d docstring

- Rename WinMLGenerationModel to WinMLEncoderDecoderModel
- Add PIPELINE_MODEL_REGISTRY + @register_pipeline_model decorator
- Register WinMLT5Model as ("t5", "translation")
- Add comprehensive module docstring with architecture, KV cache
  findings, and design principles for onboarding
@vortex-captain vortex-captain changed the title reny/t5 tracking Multi-Model Support Design Apr 7, 2026
Yi Ren added 4 commits April 7, 2026 17:07
- Extract CapturingStaticCache to shared kv_cache.py (used by T5 and Qwen3)
- Add QwenDecoderWrapper with prefill/gen OnnxConfig registrations
- Add WinMLDecoderOnlyModel with GenerationMixin, chunked prefill, StaticCache
- Register qwen3 build config (dynamo=True, opset 18)
- Fix dynamo field missing from _merge_export_config

Verified: ONNX logits match HF exactly, generate() produces correct answers
@vortex-captain vortex-captain changed the title Multi-Model Support Design feat: add support to multi-model and KV cache Apr 9, 2026
Yi Ren added 6 commits April 9, 2026 15:13
Replace duplicate T5KVCacheInputGenerator and QwenKVCacheInputGenerator
with a single PastKeyValueInputGenerator in kv_cache.py that reads
num_layers, num_attention_heads, head_dim, and max_cache_len from
NormalizedConfig. Update pipeline-model.md to use WinMLPipelineModel.
Move WinMLEncoderDecoderModel + EncoderDecoderInputGenerator to
models/hf/encoder_decoder.py, DecoderOnlyInputGenerator to
models/winml/decoder_only.py. Colocate WinMLT5Model in t5.py and
WinMLQwen3Model in qwen.py alongside their export configs.

Rename seq2seq.py to pipeline_model.py (now only WinMLPipelineModel
+ registry). Add sub_model_kwargs to from_pretrained for per-component
shape_config. Generators accept max_cache_len/seq_len kwargs as
overrides from shape_config.
Add Mu2EncoderWrapper, Mu2DecoderWrapper (with CapturingStaticCache),
OnnxConfig registrations, and WinMLMu2Model for translation pipeline.
Decoder delegates to modeling_mu.py's own decoder (patched with
past_key_values + cache_position support).

Also forward trust_remote_code through config/build CLI and
WinMLAutoModel/WinMLPipelineModel for custom auto_map models.
Accept sequence_length kwarg in EncoderDecoderInputGenerator for
shape_config overrides.
@tezheng tezheng closed this Apr 13, 2026
@tezheng tezheng deleted the reny/t5 branch April 13, 2026 14:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants