Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mojo_opset/backends/torch_npu/operators/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def forward(
self.variance_epsilon,
)
quantized, scale = _dynamic_quant(normed, self.quant_dtype, smooth_scale)
return quantized, residual_before_norm, scale
return quantized, residual_before_norm if self.norm_pos == "pre" else normed, scale


class TorchNpuResidualAddLayerNormQuant(MojoResidualAddLayerNormQuant, default_priority=0):
Expand Down
2 changes: 1 addition & 1 deletion mojo_opset/core/operators/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def forward(
weight=self.weight,
eps=self.variance_epsilon,
)
residual = hidden_state
residual = normed

normed_fp = _apply_optional_smooth_scale(normed, smooth_scale)
scale = normed_fp.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) / self.q_max
Expand Down
2 changes: 1 addition & 1 deletion mojo_opset/core/operators/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class MojoMRoPE(MojoOperator):
Reference: https://qwenlm.github.io/blog/qwen2-vl/
"""

supported_platforms_list = ["npu"]
supported_platforms_list = ["npu", "mlu", "meta_device", "ilu"]

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand Down
10 changes: 10 additions & 0 deletions mojo_opset/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@
from .operators.indexer import MojoIndexer
from .operators.indexer import MojoLightningIndexer
from .operators.kv_cache import MojoStorePagedMLAKVCache
from .operators.kv_cache import MojoStorePagedKVCacheC8
from .operators.kv_cache import MojoDequantFromPagedKVCache
from .operators.moe import MojoFusedSwiGLUMoEScaleDynamicQuantize
from .operators.moe import MojoMoEInitRoutingDynamicQuant
from .operators.normalization import MojoChannelRMSNorm
from .operators.normalization import MojoGroupLayerNorm
from .operators.normalization import MojoRMSNormInplace
from .operators.normalization import MojoGroupRMSNormInplace
from .operators.position_embedding import MojoGridRoPE
from .operators.position_embedding import MojoRelativeEmbedding
from .operators.position_embedding import MojoMRoPEInplace
from .operators.store_lowrank import MojoStoreLowrank

__all__ = [
Expand All @@ -55,12 +60,17 @@
"MojoFusedAttnOutputGate",
"MojoPagedPrefillSageGQA",
"MojoStorePagedMLAKVCache",
"MojoStorePagedKVCacheC8",
"MojoDequantFromPagedKVCache",
"MojoMoEInitRoutingDynamicQuant",
"MojoFusedSwiGLUMoEScaleDynamicQuantize",
"MojoGroupLayerNorm",
"MojoChannelRMSNorm",
"MojoRMSNormInplace",
"MojoGroupRMSNormInplace",
"MojoRelativeEmbedding",
"MojoGridRoPE",
"MojoStoreLowrank",
"MojoIndexer",
"MojoMRoPEInplace",
]
10 changes: 10 additions & 0 deletions mojo_opset/experimental/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@
from .indexer import MojoIndexer
from .indexer import MojoLightningIndexer
from .kv_cache import MojoStorePagedMLAKVCache
from .kv_cache import MojoStorePagedKVCacheC8
from .kv_cache import MojoDequantFromPagedKVCache
from .moe import MojoFusedSwiGLUMoEScaleDynamicQuantize
from .moe import MojoMoEInitRoutingDynamicQuant
from .normalization import MojoChannelRMSNorm
from .normalization import MojoGroupLayerNorm
from .normalization import MojoRMSNormInplace
from .normalization import MojoGroupRMSNormInplace
from .position_embedding import MojoGridRoPE
from .position_embedding import MojoRelativeEmbedding
from .position_embedding import MojoMRoPEInplace
from .store_lowrank import MojoStoreLowrank

__all__ = [
Expand All @@ -42,12 +47,17 @@
"MojoPagedPrefillSWAWithKVDequant",
"MojoPagedDecodeSWAWithKVDequant",
"MojoStorePagedMLAKVCache",
"MojoStorePagedKVCacheC8",
"MojoDequantFromPagedKVCache",
"MojoMoEInitRoutingDynamicQuant",
"MojoFusedSwiGLUMoEScaleDynamicQuantize",
"MojoGroupLayerNorm",
"MojoChannelRMSNorm",
"MojoRelativeEmbedding",
"MojoRMSNormInplace",
"MojoGroupRMSNormInplace",
"MojoGridRoPE",
"MojoStoreLowrank",
"MojoQuantBatchGemmReduceSum",
"MojoMRoPEInplace",
]
167 changes: 165 additions & 2 deletions mojo_opset/experimental/operators/kv_cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Tuple
from typing import Tuple, Optional

import torch

from mojo_opset.core.operators.kv_cache import assert_paged_kv_layout_contract
from mojo_opset.core.operators.kv_cache import (
assert_paged_kv_layout_contract,
assert_paged_kv_store_contract,
build_paged_kv_chunk_metadata
)
from mojo_opset.core.operator import MojoOperator


Expand Down Expand Up @@ -102,7 +106,166 @@ def forward(

return compressed_kv_cache, k_pe_cache

class MojoStorePagedKVCacheC8(MojoOperator):
def __init__(
self,
):
super().__init__()

def forward(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
key_scale: torch.Tensor,
value_scale: torch.Tensor,
block_table: Optional[torch.Tensor] = None,
cu_q_lens: Optional[torch.Tensor] = None,
context_kv_lens: Optional[torch.Tensor] = None,
*,
chunk_metadata: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Copy new K/V tokens into a paged KV cache with int8 per_channal quant.

Args:
key_states (torch.Tensor): Shape (token_num, kv_head_num, head_dim) — new key tokens.
value_states (torch.Tensor): Shape (token_num, kv_head_num, head_dim) — new value tokens.
key_cache (torch.Tensor): Shape (total_phys_blocks, kv_heads, block_size, head_dim) — key cache.
value_cache (torch.Tensor): Shape (total_phys_blocks, kv_heads, block_size, head_dim) — value cache.
key_scale (torch.Tensor): Shape (kv_head_num, head_dim) — key scale.
value_scale (torch.Tensor): Shape (kv_head_num, head_dim) — value scale.
block_table (torch.Tensor | None): Legacy logical-to-physical block mapping.
cu_q_lens (torch.Tensor | None): Legacy cumulative query lengths. ``None`` indicates decode mode.
context_kv_lens (torch.Tensor | None): Legacy KV lengths before storing current tokens.
chunk_metadata (torch.Tensor | None): Optimized precomputed store plan with shape ``(num_chunks, 4)``
and per-row ``(src_token_start, dst_block_id, dst_block_offset, chunk_len)``.

Returns:
Tuple[torch.Tensor, torch.Tensor]: Updated `(key_cahce, value_cahce)` after in-place writes.
"""
assert len(key_states.shape) == 3 and len(value_states.shape) == 3 and key_states.shape == value_states.shape, (
"key/value states must be (token_num, kv_head_num, head_dim), please check."
)
if chunk_metadata is None:
assert block_table is not None, "block_table is required when chunk_metadata is not provided."
assert context_kv_lens is not None, "context_kv_lens is required when chunk_metadata is not provided."
chunk_metadata = build_paged_kv_chunk_metadata(
block_table,
cu_q_lens,
context_kv_lens,
key_cache.shape[2],
)
else:
assert block_table is None and cu_q_lens is None and context_kv_lens is None, (
"chunk_metadata path should not be mixed with block_table/cu_q_lens/context_kv_lens."
)

assert key_scale is not None and value_scale is not None
assert_paged_kv_store_contract(chunk_metadata)

if chunk_metadata.shape[0] == 0:
return key_cache, value_cache

key_q = torch.round(key_states / key_scale).clamp(-128, 127).to(torch.int8)
value_q = torch.round(value_states / value_scale).clamp(-128, 127).to(torch.int8)

for src_token_start, dst_block_id, dst_block_offset, chunk_len in chunk_metadata.tolist():
src_end = src_token_start + chunk_len
dst_end = dst_block_offset + chunk_len
key_cache[dst_block_id, :, dst_block_offset:dst_end, :] = key_q[src_token_start:src_end].permute(
1, 0, 2
)
value_cache[dst_block_id, :, dst_block_offset:dst_end, :] = value_q[src_token_start:src_end].permute(
1, 0, 2
)

return key_cache, value_cache

class MojoDequantFromPagedKVCache(MojoOperator):
def __init__(self):
super().__init__()

def forward(
self,
*,
key: torch.Tensor,
value: Optional[torch.Tensor] = None,
key_cache: torch.Tensor,
key_cache_scale: torch.Tensor,
value_cache: Optional[torch.Tensor] = None,
value_cache_scale: Optional[torch.Tensor] = None,
context_lengths: Optional[torch.Tensor] = None,
max_context_len: int,
context_seq_offset: Optional[torch.Tensor] = None,
block_tables: torch.Tensor,
):
r"""
Copy and dequantize from Transformer int8 paged K/V cache to linear K/V states.

Args:
key (torch.Tensor): Shape (total_seq_len, head_num, head_size) — key states.
value (torch.Tensor | None): Shape (total_seq_len, head_num, head_size) — value states.
key_cache (torch.Tensor): Shape (block_num, head_num, block_size, head_size) — key cache.
value_cache (torch.Tensor | None): Shape (block_num, head_num, block_size, head_size) — value cache.
key_cache_scale (torch.Tensor): Shape (head_num, head_size) — key cache scale.
value_cache_scale (torch.Tensor | None): Shape (head_num, head_size) — value scale.
context_lengths (torch.Tensor): Shape (batch_size,) — Valid sequence length for each batch sample.
max_context_len (int): Scalar int value — Maximum valid sequence length across all batches during context prefill phase.
context_seq_offset (torch.Tensor | None): Shape (batch_size,) — Cumulative sequence offset for each batch to guarantee non-overlapping sequence storage.
block_tables (torch.Tensor | None): Shape (batch_size, max_block_num) — Logical-to-physical block mapping table for paged cache.

Returns:
None: All writes are performed in-place on key and value tensors.
"""
def dequant_from_cache(quant_data: torch.Tensor, scale_data: torch.Tensor):
quant_data_fp32 = quant_data.clone().to(torch.float)
scale_data_fp32 = scale_data.clone().to(torch.float)
scale_data_fp32 = scale_data[..., None, :]
dequant_data_fp32 = quant_data_fp32 * scale_data_fp32
return dequant_data_fp32
Comment thread
Neuromancer42 marked this conversation as resolved.

batch_size = context_lengths.size(0)
if context_seq_offset is None:
cu_seq_offset = torch.cumsum(context_lengths, dim=-1)
context_seq_offset = torch.zeros_like(cu_seq_offset)
context_seq_offset[1:] = cu_seq_offset[:-1]

total_seqlen = 0
block_size = key_cache.size(2)
for i in range(batch_size):
context_len = context_lengths[i].item()
seq_begin = context_seq_offset[i].item()
seq_end = seq_begin + context_len
total_seqlen += context_len
full_block_num = context_len // block_size
rem_token_num = context_len % block_size

# dequant key from cache
key_i = key[seq_begin:seq_end].transpose(1, 0)
key_cache_i = torch.concat(
[key_cache[block_tables[i, j], ...] for j in range(full_block_num)]
+ ([key_cache[block_tables[i, full_block_num], :, :rem_token_num, :]] if rem_token_num > 0 else []),
dim=-2,
)
dequant_key_i = dequant_from_cache(key_cache_i, key_cache_scale)
key_i[...] = dequant_key_i.to(key_i.dtype)

# dequant value from cache
if not (value_cache is None or value is None or value_cache_scale is None):
value_i = value[seq_begin:seq_end].transpose(1, 0)
value_cache_i = torch.concat(
[value_cache[block_tables[i, j], ...] for j in range(full_block_num)]
+ ([value_cache[block_tables[i, full_block_num], :, :rem_token_num, :]] if rem_token_num > 0 else []),
dim=-2,
)
dequant_value_i = dequant_from_cache(value_cache_i, value_cache_scale)
value_i[...] = dequant_value_i.to(value_i.dtype)
return key, value

__all__ = [
"MojoStorePagedMLAKVCache",
"MojoStorePagedKVCacheC8",
"MojoDequantFromPagedKVCache",
]
91 changes: 91 additions & 0 deletions mojo_opset/experimental/operators/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,98 @@ def extra_repr(self) -> str:
)


class MojoRMSNormInplace(MojoOperator):
def __init__(
self,
norm_size: int,
eps: float = 1e-5,
inplace: bool = False,
**kwargs,
):
"""
Initialize RMSNorm patch parameters.

Args:
norm_size (int): Size of 1-D affine scale vector.
eps (float, default=1e-5): Epsilon added for numerical stability; must be > 0.
inplace (bool, default=False): Whether to perform RMSNorm in-place on the input tensor.
**kwargs: The keyword arguments of torch.empty, such as device, dtype and so on to create the weight and bias.
"""
super().__init__(**kwargs)
self.norm_size = norm_size
self.weight = torch.nn.Parameter(torch.empty(norm_size, **self.tensor_factory_kwargs))
self.variance_epsilon = eps
self.inplace = inplace

def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
"""
Apply RMSNorm over the last dimension of the input.

Args:
hidden_state (torch.Tensor): Input tensor whose last dimension is the hidden size
(e.g., shape (B, T, D) or (..., D)). The normalization is performed across D.

Returns:
torch.Tensor: Tensor of the same shape and dtype as `hidden_state`, normalized
over the last dimension.
"""
normalized = F.rms_norm(
hidden_state,
[hidden_state.shape[-1]],
weight=self.weight,
eps=self.variance_epsilon,
)
if self.inplace:
hidden_state.copy_(normalized)
return hidden_state
else:
return normalized

def extra_repr(self) -> str:
return f"{self.norm_size=}, {self.variance_epsilon=}".replace("self.", "")

class MojoGroupRMSNormInplace(MojoOperator):
def __init__(self, num_groups, norm_size, eps, elementwise_affine=True, inplace=False, **kwargs):
super().__init__(**kwargs)
self.num_groups = num_groups
self.norm_size = norm_size
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.weight = torch.nn.Parameter(torch.empty((num_groups, norm_size), **self.tensor_factory_kwargs))
else:
self.weight = None
self.variance_epsilon = eps
self.inplace = inplace

def forward(self, input_groups):
# Note: input_groups is a list of tensors, each tensor has compatible shapes for norm

output_groups = []
for group_id in range(self.num_groups):
# Compute normalized result (new tensor)
normalized = F.rms_norm(
input_groups[group_id],
(self.norm_size,),
weight=self.weight[group_id],
eps=self.variance_epsilon,
)
Comment thread
Neuromancer42 marked this conversation as resolved.
# Copy the normalized values back into the original tensor (in-place)
if self.inplace:
input_groups[group_id].copy_(normalized)
else :
output_groups.append(normalized)

if self.inplace:
return input_groups
else:
return output_groups

def extra_repr(self) -> str:
return f"{self.num_groups=}, {self.norm_size=}, {self.variance_epsilon=}, {self.elementwise_affine=}".replace("self.", "")

__all__ = [
"MojoGroupLayerNorm",
"MojoChannelRMSNorm",
"MojoRMSNormInplace",
"MojoGroupRMSNormInplace",
]
Loading
Loading