Skip to content

feat: [MLU] add mlu experimental ops; fix post norm bug#379

Merged
Neuromancer42 merged 4 commits into
masterfrom
gyj/adapt_mlu
Jun 30, 2026
Merged

feat: [MLU] add mlu experimental ops; fix post norm bug#379
Neuromancer42 merged 4 commits into
masterfrom
gyj/adapt_mlu

Conversation

@jessicagao01

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several new experimental operators, including MojoStorePagedKVCacheC8, MojoDequantFromPagedKVCache, MojoRMSNormInplace, MojoGroupRMSNormInplace, and MojoMRoPEInplace, along with their corresponding accuracy tests and platform support updates. Feedback on the changes highlights a potential runtime TypeError in MojoGroupRMSNormInplace when elementwise_affine is disabled, a typo in the extra_repr method name (written as extra_expr), and redundant casting logic in the dequant_from_cache helper of MojoDequantFromPagedKVCache.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread mojo_opset/experimental/operators/normalization.py
Comment thread mojo_opset/experimental/operators/kv_cache.py
Comment thread mojo_opset/experimental/operators/normalization.py Outdated
@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- Critical correctness bug in residual update plus several robustness issues in new operators.

Summary

Adds inplace variants of RMSNorm/GroupRMSNorm/MRoPE, int8 paged KV cache store and dequant operators, expands MRoPE platform support, and wires up xops backend autoload for accuracy tests. Also changes the residual output semantics in MojoResidualAddRMSNormQuant.

Must fix

  • [BLOCKER] Residual semantics changed -- mojo_opset/core/operators/normalization.py:521 -- Changing residual = hidden_state to residual = normed alters the contract of MojoResidualAddRMSNormQuant: the returned residual is now the normalized tensor instead of the pre-norm sum, which breaks the standard residual+RMSNorm pattern used by callers. Revert or justify with a corresponding update to all call sites/tests.
  • [BLOCKER] Dead/incorrect code in dequant helper -- mojo_opset/experimental/operators/kv_cache.py:243-245 -- scale_data_fp32 = scale_data.clone().to(torch.float) is immediately overwritten by scale_data_fp32 = scale_data[..., None, :] (un-casted), so the float upcast is silently dropped and per-block (quant_mode!=0) scale shapes are not handled. Use the fp32 cast and branch on scale rank.
  • [BLOCKER] context_seq_offset cumsum wrong dtype/device assumption -- mojo_opset/experimental/operators/kv_cache.py:255-258 -- torch.cumsum(context_lengths, dim=-1) then context_seq_offset[1:] = cu_seq_offset[:-1] leaves index 0 as the cumulative-sum dtype zero but uses the full cumsum tensor; fine logically, but you ignore that callers may pass context_seq_offset on a different device. Also when batch_size==1 the slice assignment is a no-op (OK) but the cumsum is wasted -- guard it. More importantly, this path uses .item() inside a Python loop per batch, forcing device sync on every iteration on the hot path.
  • [BLOCKER] MojoStorePagedKVCacheC8 ignores cu_q_lens validation -- mojo_opset/experimental/operators/kv_cache.py:148-152 -- When chunk_metadata is None, cu_q_lens is passed through without the assert cu_q_lens is not None check that the sibling MojoStorePagedKVCache performs, and build_paged_kv_chunk_metadata will likely crash with a confusing error. Add the assert.
  • [BLOCKER] Silent dtype mismatch in inplace RMSNorm -- mojo_opset/experimental/operators/normalization.py:130-136 -- hidden_state.copy_(normalized) where normalized is fp32 (from F.rms_norm) into a bf16/fp16 input is fine, but the weight is created with tensor_factory_kwargs dtype while tests copy_(weight.to(torch.float32)); if the parameter dtype is fp16, the fp32 copy is silently downcast and reference vs impl may diverge. Document or enforce weight dtype.

Suggestions

Suggestions (6)
  • [MAJOR] Python-level loop on hot path -- mojo_opset/experimental/operators/kv_cache.py:160-167 -- chunk_metadata.tolist() plus per-chunk indexed assignment and .permute will be slow for large batches; consider a vectorized scatter or at least batched indexing, matching the non-C8 implementation style.
  • [MAJOR] Layering: tests reach into sibling repo -- mojo_opset/tests/accuracy/conftest.py:15-46 -- Hard-coding mojo_opset_gitlab as a sibling path and calling a private _autoload() couples this repo to an external layout; prefer an env var only, and call a public API.
  • [MAJOR] Broad ModuleNotFoundError swallow -- mojo_opset/tests/accuracy/conftest.py:39-44 -- Logging a warning and continuing on missing xops backend hides config errors in CI where the backend is expected. Make it opt-in strict via env.
  • [MAJOR] assert_paged_kv_store_contract order -- mojo_opset/experimental/operators/kv_cache.py:154-158 -- Asserting after possibly building metadata is fine, but the early-return if chunk_metadata.shape[0] == 0 skips the quantization; that is correct only if there is truly nothing to store. Confirm build_paged_kv_chunk_metadata returns an empty tensor (not None) in all decode-with-zero-q paths.
  • [MINOR] extra_expr typo -- mojo_opset/experimental/operators/normalization.py:181 -- Should be extra_repr; current method is never called by nn.Module.__repr__.
  • [MINOR] Chinese comment in shipped code -- mojo_opset/experimental/operators/position_embedding.py:230 -- Replace with English comment for consistency.

Nits

Nits (4)
  • [NIT] Missing blank line before class -- mojo_opset/experimental/operators/kv_cache.py:109 and :165 -- PEP8 expects two blank lines.
  • [NIT] else : spacing -- mojo_opset/experimental/operators/normalization.py:171 -- drop the space before colon.
  • [NIT] print(...) debug residue in test -- mojo_opset/tests/accuracy/operators/test_kv_cache.py:762-767 -- use caplog/logging or remove.
  • [NIT] Unused quant_mode/quant_bit returned from gen_args but never validated against op -- mojo_opset/tests/accuracy/operators/test_kv_cache.py:737-741 -- either thread through or drop.

Notes

  • [CHECK] The MojoResidualAddRMSNormQuant change may be an intentional API fix; if so, please flag downstream model code and update its tests in this PR.
  • [CHECK] MojoDequantFromPagedKVCache declares quant_mode != 0 per-block scales in tests but the forward only handles the per-channel layout (scale[..., None, :]); confirm the non-torch backend path handles both.

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- Critical correctness bug in residual update for MojoResidualAddRMSNormQuant plus a few issues in new operators and test plumbing.

Summary

The PR adds new experimental operators (paged C8 KV store/dequant, inplace RMSNorm/GroupRMSNorm, inplace MRoPE) with accuracy tests, expands MRoPE platform support, and adds an autoload hook for an external xops backend in accuracy tests. It also makes a one-line change to MojoResidualAddRMSNormQuant that alters residual semantics.

Must fix

  • [BLOCKER] Residual semantics changed in RMSNorm+Quant -- mojo_opset/core/operators/normalization.py:521 -- The residual was the pre-norm hidden_state (standard residual stream); changing it to normed makes the returned residual the normalized output, which silently breaks any caller relying on the previous contract. Confirm intent or revert; if intentional, update docstring/tests.
  • [BLOCKER] dequant_from_cache ignores its scale_data argument -- mojo_opset/experimental/operators/kv_cache.py:236-240 -- scale_data_fp32 is first computed from scale_data then reassigned from the outer scale_data[..., None, :], bypassing the .to(float) cast and the cloned local. Works only because dtype promotion happens via multiplication; still wrong/confusing and risks dtype bugs. Use the local scale_data_fp32 consistently.
  • [BLOCKER] MojoDequantFromPagedKVCache crashes when batch_size == 1 -- mojo_opset/experimental/operators/kv_cache.py:248-251 -- context_seq_offset[1:] = cu_seq_offset[:-1] is fine, but cu_seq_offset = torch.cumsum(context_lengths, dim=-1) followed by assignment relies on contiguity assumptions; more importantly, when context_seq_offset is None and only one batch exists this still works, but the transpose(1,0) on a slice and then in-place key_i[...] = ... requires the source key to be contiguous on the head axis -- with context_strided=True test path the source is a transposed view. Verify the in-place write actually propagates through the transpose+slice on all backends, or assign back via the original layout.
  • [BLOCKER] MojoStorePagedKVCacheC8 cannot run in pure decode mode without metadata -- mojo_opset/experimental/operators/kv_cache.py:148-159 -- When chunk_metadata is None, only block_table and context_kv_lens are asserted; cu_q_lens may be None (decode), but is passed through to build_paged_kv_chunk_metadata unchecked. Confirm the helper handles cu_q_lens=None, otherwise add the same assertion as the legacy path or document it.

Suggestions

Suggestions (5)
  • [MAJOR] Layering violation: accuracy conftest loads an external repo from a sibling path -- mojo_opset/tests/accuracy/conftest.py:15-46 -- Hardcoding repo_root.parent / "mojo_opset_gitlab" and mutating sys.path at import time couples tests to a developer-specific layout; gate strictly on env var and log when skipped.
  • [MAJOR] Broad-ish exception swallowed via name check -- mojo_opset/tests/accuracy/conftest.py:38-43 -- Catching ModuleNotFoundError then re-raising only when exc.name differs is fragile; if a transitive import inside mojo_opset_ext_autoload fails with a different module name it will be re-raised, but a missing transitive dep of the autoloader itself will be silently swallowed. Consider catching only top-level absence explicitly.
  • [MAJOR] Calling a private _autoload() -- mojo_opset/tests/accuracy/conftest.py:45 -- Reaching into mojo_opset_ext_autoload._autoload ties tests to private API. Expose a public entry point on that package instead.
  • [MAJOR] MojoGroupRMSNormInplace indexes self.weight[group_id] even when elementwise_affine=False -- mojo_opset/experimental/operators/normalization.py:170-176 -- In that case self.weight is None and indexing will crash. Pass weight=self.weight[group_id] if self.weight is not None else None.
  • [MINOR] extra_expr typo -- mojo_opset/experimental/operators/normalization.py:186 -- Should be extra_repr; current method will never be called by PyTorch.

Nits

Nits (5)
  • [NIT] Chinese comment in source -- mojo_opset/experimental/operators/position_embedding.py:230 -- replace with English for consistency.
  • [NIT] Typos in docstring -- mojo_opset/experimental/operators/kv_cache.py:130 -- per_channal, key_cahce, value_cahce.
  • [NIT] Missing blank line between class definitions -- mojo_opset/experimental/operators/kv_cache.py:109, :163 -- PEP 8 expects two blank lines.
  • [NIT] else : spacing -- mojo_opset/experimental/operators/normalization.py:179 -- drop the space before colon.
  • [NIT] Unused imports / dead args in tests -- mojo_opset/tests/accuracy/operators/test_kv_cache.py:4-5 -- math, random, plus quant_mode/quant_bit unpacked but never used in the operator call.

Notes

  • [CHECK] MojoStorePagedKVCacheC8 uses symmetric int8 quant with round(x/scale).clamp(-128,127); the reference uses the same op, so the test compares same-formula implementations. Confirm there is a separate ground-truth test against an unquantized reference somewhere, otherwise the test only checks backend parity.
  • [CHECK] MojoMRoPE.supported_platforms_list now includes meta_device; verify the base implementation actually runs under meta tensors (no data-dependent control flow on tensor values).

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- Several correctness issues in the new residual-norm semantics, KV-cache dequant helper, and accuracy-test bootstrap need to be addressed before merge.

Summary

Adds new experimental in-place norm / MRoPE / int8 paged KV-cache operators and accuracy tests, broadens MRoPE platform list, and changes the residual returned by MojoResidualAddRMSNormQuant for the post norm path. Also wires an external xops autoloader into the accuracy test conftest.

Must fix

  • [BLOCKER] Residual semantics change is asymmetric and likely wrong for post -- mojo_opset/core/operators/normalization.py:521 -- Reference now returns residual = normed for post, but the sibling MojoResidualAddRMSNorm and other quant variants still return the pre-norm residual. Either align all variants or document/justify; this currently silently changes the contract for one operator only.
  • [BLOCKER] NPU op returns wrong residual for post -- mojo_opset/backends/torch_npu/operators/norm.py:130 -- Returns normed (post-quant-input, fp) when norm_pos != "pre", but the underlying fused kernel produces residual_before_norm regardless; this no longer matches the kernel's actual output and was only "fixed" by relaxing tolerances in the test. Verify the kernel path actually produces the post-norm tensor or revert.
  • [BLOCKER] dequant_from_cache ignores its scale_data argument -- mojo_opset/experimental/operators/kv_cache.py:248-252 -- scale_data_fp32 = scale_data[..., None, :] overwrites the earlier .to(float) conversion and references the outer (un-cast) tensor; the .clone().to(torch.float) line is dead code. Functionally still works for fp inputs, but if scale_data is ever non-float the dequant silently uses the wrong dtype.
  • [BLOCKER] context_seq_offset derivation is wrong when not provided -- mojo_opset/experimental/operators/kv_cache.py:261-264 -- torch.cumsum(context_lengths, dim=-1) then shifting gives offsets that assume contiguous packing with no padding, but the test path explicitly uses context_lens + context_paddings. The fallback will produce incorrect offsets whenever sequences are padded; require the caller to pass it or compute with paddings.
  • [BLOCKER] Accuracy conftest hard-codes a sibling repo path and auto-loads it -- mojo_opset/tests/accuracy/conftest.py:14-46 -- parents[3] / "mojo_opset_gitlab" and silent ModuleNotFoundError swallow is a layering/security concern: tests now implicitly mutate sys.path and import third-party autoload code by default on mlu. Gate behind an opt-in env var (default off) and remove the hard-coded sibling path, or move to a plugin.

Suggestions

Suggestions (6)
  • [MAJOR] Python-loop store of int8 cache is a hot-path perf regression -- mojo_opset/experimental/operators/kv_cache.py:159-166 -- Per-chunk .tolist() + Python for-loop with permute will be slow for many chunks; consider a vectorized scatter as done elsewhere.
  • [MAJOR] Per-iteration Python .item() and list-comp torch.concat in dequant -- mojo_opset/experimental/operators/kv_cache.py:268-285 -- Forces device sync per batch and rebuilds tensors; batch via index_select on block_tables to avoid sync and concat overhead.
  • [MAJOR] MojoMRoPEInplace.forward allocates and concats even in inplace mode -- mojo_opset/experimental/operators/position_embedding.py:218-232 -- torch.cats build full new tensors then copy_; this is not actually in-place and provides no perf benefit over the non-inplace path. Consider writing back into views of orig_query/orig_key.
  • [MAJOR] MojoGroupRMSNormInplace indexes self.weight[group_id] when elementwise_affine=False -- mojo_opset/experimental/operators/normalization.py:166-172 -- self.weight is None in that case; will raise TypeError. Guard the index.
  • [MINOR] extra_expr typo (should be extra_repr) -- mojo_opset/experimental/operators/normalization.py:184 -- Method will never be called by nn.Module.__repr__.
  • [MINOR] Test relaxation hides a real diff -- mojo_opset/tests/accuracy/operators/test_normalization.py:524-525 -- atol/rtol bumped from 1e-3 to 1e-2 for the quantized residual norm; this coincides with the residual-semantics change above and should be justified, not relaxed.

Nits

Nits (5)
  • [NIT] Non-ASCII comment ("根据 self.inplace ...") -- mojo_opset/experimental/operators/position_embedding.py:235 -- repo appears to be English-only.
  • [NIT] Typos in docstring: "per_channal", "key_cahce", "value_cahce" -- mojo_opset/experimental/operators/kv_cache.py:128,143.
  • [NIT] Missing blank line between class definitions -- mojo_opset/experimental/operators/kv_cache.py:109,184.
  • [NIT] else : spacing -- mojo_opset/experimental/operators/normalization.py:177.
  • [NIT] quant_mode/quant_bit are unused by the op under test -- mojo_opset/tests/accuracy/operators/test_kv_cache.py:760-770 -- drop from parametrize or assert they match the op contract.

Notes

  • [CHECK] MojoMRoPE.supported_platforms_list was expanded to include mlu, meta_device, ilu -- confirm backend registrations actually exist for those platforms, otherwise dispatch will fail at runtime instead of being caught by bypass_not_implemented.
  • [CHECK] MojoStorePagedKVCacheC8 does symmetric per-channel quant with round/clamp(-128,127); verify this matches the NPU/MLU kernel's quant convention (some kernels use (-127, 127) or asymmetric).

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- Residual semantics change and a quant-dequant bug make this unsafe to merge as-is.

Summary

This PR adds C8 paged KV cache store/dequant, in-place variants of RMSNorm / GroupRMSNorm / MRoPE, and broadens MRoPE platform support; it also changes the residual returned by MojoResidualAddRMSNormQuant. There are also new test conftest hooks for loading an external xops backend.

Must fix

  • [BLOCKER] Residual return value changed for pre norm position -- mojo_opset/core/operators/normalization.py:521 -- For norm_pos == "pre" the residual returned is now normed instead of the original hidden_state. This silently changes the semantics of every existing caller of MojoResidualAddRMSNormQuant (and is why the tolerance had to be loosened to 1e-2 in test_normalization.py:525). Decide which branch is correct and gate by self.norm_pos consistently in both core and the NPU override (backends/torch_npu/operators/norm.py:130 only handles "pre").
  • [BLOCKER] Dequant uses unscaled tensor -- mojo_opset/experimental/operators/kv_cache.py:236 -- scale_data_fp32 = scale_data[..., None, :] overwrites the just-computed fp32 copy with the original (possibly bf16/fp16) scale_data, so the multiply runs in mixed dtype and silently breaks the fp32 contract. Use scale_data_fp32[..., None, :].
  • [BLOCKER] MojoDequantFromPagedKVCache ignores context_seq_offset padding -- mojo_opset/experimental/operators/kv_cache.py:251-256 -- When context_seq_offset is None, the fallback uses cumsum(context_lengths) but the test generator computes offsets from context_lens + context_paddings. The fallback path therefore disagrees with the documented layout and will write to wrong rows whenever padding is present.
  • [BLOCKER] MojoStorePagedKVCacheC8 does float division by a possibly-int8 scale -- mojo_opset/experimental/operators/kv_cache.py:159-160 -- torch.round(key_states / key_scale) runs in the input's (likely bf16/fp16) dtype, which loses precision and clamps near boundaries. Cast key_states/value_states to float32 before dividing, matching the dequant side.

Suggestions

Suggestions (5)
  • [MAJOR] Per-token Python loop on the hot path -- mojo_opset/experimental/operators/kv_cache.py:166-173 -- The torch reference iterates chunk_metadata.tolist() and indexes per chunk; for long contexts this dominates. Consider a vectorized scatter using index_put_ or chunk_metadata flattened indices, at least for the non-reference path.
  • [MAJOR] Dequant path is O(batch * blocks) Python concat -- mojo_opset/experimental/operators/kv_cache.py:261-268 -- Building key_cache_i via a Python list comprehension and torch.concat per batch element is very slow and triggers large temporaries; use key_cache.index_select along the block dim then reshape.
  • [MAJOR] In-place flag on a reference op is silently ignored when shape changes -- mojo_opset/experimental/operators/position_embedding.py:230-238 -- orig_query.copy_(query) requires identical shape; query was reshaped to (num_tokens, n_qh, head_dim) and then re-flattened, but if head_dim was passed with rope_dim < head_dim and the input was a non-contiguous view, the final view(num_tokens, -1) can fail or produce wrong strides. Add an explicit .reshape_as(orig_query) and assert shape match before copy_.
  • [MAJOR] Test conftest auto-loads code from a sibling repo -- mojo_opset/tests/accuracy/conftest.py:15-46 -- _load_xops_backend_for_accuracy mutates sys.path with repo_root.parent / "mojo_opset_gitlab" and imports mojo_opset_ext_autoload._autoload() (a private function). This is a layering/security concern: tests can pick up arbitrary code from a sibling checkout. At minimum, gate strictly on the env var and do not fall back to a hardcoded sibling path.
  • [MINOR] MojoStorePagedKVCacheC8.__init__ takes no args but uses super().__init__() without **kwargs -- mojo_opset/experimental/operators/kv_cache.py:111-114 -- Inconsistent with the rest of the file (e.g. MojoStorePagedMLAKVCache) and prevents passing device/dtype factory kwargs.

Nits

Nits (5)
  • [NIT] Non-ASCII em-dash in docstrings -- mojo_opset/experimental/operators/kv_cache.py:131-141 -- prefer -- for grep-ability.
  • [NIT] Chinese comment left in source -- mojo_opset/experimental/operators/position_embedding.py:231 -- "根据 self.inplace 决定...".
  • [NIT] Typos: per_channal, key_cahce, value_cahce -- mojo_opset/experimental/operators/kv_cache.py:130,148.
  • [NIT] else : with extra space -- mojo_opset/experimental/operators/normalization.py:175.
  • [NIT] Stray print(...) in test -- mojo_opset/tests/accuracy/operators/test_kv_cache.py:760-764 -- use caplog or remove.

Notes

  • [CHECK] The tolerance loosening in test_residual_add_rmsnorm_quant (test_normalization.py:525) appears to compensate for the residual-semantics change; once the residual question is resolved, the original 1e-3 tolerances should be restored.
  • [CHECK] MojoMRoPE.supported_platforms_list now claims meta_device -- verify there is actually a registered impl for meta_device, otherwise dispatch will fail at runtime rather than at registration.

@Neuromancer42 Neuromancer42 merged commit a658e17 into master Jun 30, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants