diff --git a/mojo_opset/backends/torch_npu/operators/norm.py b/mojo_opset/backends/torch_npu/operators/norm.py index e88946df..98e74486 100644 --- a/mojo_opset/backends/torch_npu/operators/norm.py +++ b/mojo_opset/backends/torch_npu/operators/norm.py @@ -127,7 +127,7 @@ def forward( self.variance_epsilon, ) quantized, scale = _dynamic_quant(normed, self.quant_dtype, smooth_scale) - return quantized, residual_before_norm, scale + return quantized, residual_before_norm if self.norm_pos == "pre" else normed, scale class TorchNpuResidualAddLayerNormQuant(MojoResidualAddLayerNormQuant, default_priority=0): diff --git a/mojo_opset/core/operators/normalization.py b/mojo_opset/core/operators/normalization.py index 3bc366df..09e1556d 100644 --- a/mojo_opset/core/operators/normalization.py +++ b/mojo_opset/core/operators/normalization.py @@ -518,7 +518,7 @@ def forward( weight=self.weight, eps=self.variance_epsilon, ) - residual = hidden_state + residual = normed normed_fp = _apply_optional_smooth_scale(normed, smooth_scale) scale = normed_fp.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) / self.q_max diff --git a/mojo_opset/core/operators/position_embedding.py b/mojo_opset/core/operators/position_embedding.py index 9dbc30fc..a933d33a 100644 --- a/mojo_opset/core/operators/position_embedding.py +++ b/mojo_opset/core/operators/position_embedding.py @@ -184,7 +184,7 @@ class MojoMRoPE(MojoOperator): Reference: https://qwenlm.github.io/blog/qwen2-vl/ """ - supported_platforms_list = ["npu"] + supported_platforms_list = ["npu", "mlu", "meta_device", "ilu"] def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/mojo_opset/experimental/__init__.py b/mojo_opset/experimental/__init__.py index c47a24d2..9780d24a 100755 --- a/mojo_opset/experimental/__init__.py +++ b/mojo_opset/experimental/__init__.py @@ -25,12 +25,17 @@ from .operators.indexer import MojoIndexer from .operators.indexer import MojoLightningIndexer from .operators.kv_cache import MojoStorePagedMLAKVCache +from .operators.kv_cache import MojoStorePagedKVCacheC8 +from .operators.kv_cache import MojoDequantFromPagedKVCache from .operators.moe import MojoFusedSwiGLUMoEScaleDynamicQuantize from .operators.moe import MojoMoEInitRoutingDynamicQuant from .operators.normalization import MojoChannelRMSNorm from .operators.normalization import MojoGroupLayerNorm +from .operators.normalization import MojoRMSNormInplace +from .operators.normalization import MojoGroupRMSNormInplace from .operators.position_embedding import MojoGridRoPE from .operators.position_embedding import MojoRelativeEmbedding +from .operators.position_embedding import MojoMRoPEInplace from .operators.store_lowrank import MojoStoreLowrank __all__ = [ @@ -55,12 +60,17 @@ "MojoFusedAttnOutputGate", "MojoPagedPrefillSageGQA", "MojoStorePagedMLAKVCache", + "MojoStorePagedKVCacheC8", + "MojoDequantFromPagedKVCache", "MojoMoEInitRoutingDynamicQuant", "MojoFusedSwiGLUMoEScaleDynamicQuantize", "MojoGroupLayerNorm", "MojoChannelRMSNorm", + "MojoRMSNormInplace", + "MojoGroupRMSNormInplace", "MojoRelativeEmbedding", "MojoGridRoPE", "MojoStoreLowrank", "MojoIndexer", + "MojoMRoPEInplace", ] diff --git a/mojo_opset/experimental/operators/__init__.py b/mojo_opset/experimental/operators/__init__.py index d0c747c3..0d3adebd 100644 --- a/mojo_opset/experimental/operators/__init__.py +++ b/mojo_opset/experimental/operators/__init__.py @@ -16,12 +16,17 @@ from .indexer import MojoIndexer from .indexer import MojoLightningIndexer from .kv_cache import MojoStorePagedMLAKVCache +from .kv_cache import MojoStorePagedKVCacheC8 +from .kv_cache import MojoDequantFromPagedKVCache from .moe import MojoFusedSwiGLUMoEScaleDynamicQuantize from .moe import MojoMoEInitRoutingDynamicQuant from .normalization import MojoChannelRMSNorm from .normalization import MojoGroupLayerNorm +from .normalization import MojoRMSNormInplace +from .normalization import MojoGroupRMSNormInplace from .position_embedding import MojoGridRoPE from .position_embedding import MojoRelativeEmbedding +from .position_embedding import MojoMRoPEInplace from .store_lowrank import MojoStoreLowrank __all__ = [ @@ -42,12 +47,17 @@ "MojoPagedPrefillSWAWithKVDequant", "MojoPagedDecodeSWAWithKVDequant", "MojoStorePagedMLAKVCache", + "MojoStorePagedKVCacheC8", + "MojoDequantFromPagedKVCache", "MojoMoEInitRoutingDynamicQuant", "MojoFusedSwiGLUMoEScaleDynamicQuantize", "MojoGroupLayerNorm", "MojoChannelRMSNorm", "MojoRelativeEmbedding", + "MojoRMSNormInplace", + "MojoGroupRMSNormInplace", "MojoGridRoPE", "MojoStoreLowrank", "MojoQuantBatchGemmReduceSum", + "MojoMRoPEInplace", ] diff --git a/mojo_opset/experimental/operators/kv_cache.py b/mojo_opset/experimental/operators/kv_cache.py index 1066c3dc..ee33fd28 100644 --- a/mojo_opset/experimental/operators/kv_cache.py +++ b/mojo_opset/experimental/operators/kv_cache.py @@ -1,8 +1,12 @@ -from typing import Tuple +from typing import Tuple, Optional import torch -from mojo_opset.core.operators.kv_cache import assert_paged_kv_layout_contract +from mojo_opset.core.operators.kv_cache import ( + assert_paged_kv_layout_contract, + assert_paged_kv_store_contract, + build_paged_kv_chunk_metadata +) from mojo_opset.core.operator import MojoOperator @@ -102,7 +106,166 @@ def forward( return compressed_kv_cache, k_pe_cache +class MojoStorePagedKVCacheC8(MojoOperator): + def __init__( + self, + ): + super().__init__() + + def forward( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + block_table: Optional[torch.Tensor] = None, + cu_q_lens: Optional[torch.Tensor] = None, + context_kv_lens: Optional[torch.Tensor] = None, + *, + chunk_metadata: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Copy new K/V tokens into a paged KV cache with int8 per_channal quant. + + Args: + key_states (torch.Tensor): Shape (token_num, kv_head_num, head_dim) — new key tokens. + value_states (torch.Tensor): Shape (token_num, kv_head_num, head_dim) — new value tokens. + key_cache (torch.Tensor): Shape (total_phys_blocks, kv_heads, block_size, head_dim) — key cache. + value_cache (torch.Tensor): Shape (total_phys_blocks, kv_heads, block_size, head_dim) — value cache. + key_scale (torch.Tensor): Shape (kv_head_num, head_dim) — key scale. + value_scale (torch.Tensor): Shape (kv_head_num, head_dim) — value scale. + block_table (torch.Tensor | None): Legacy logical-to-physical block mapping. + cu_q_lens (torch.Tensor | None): Legacy cumulative query lengths. ``None`` indicates decode mode. + context_kv_lens (torch.Tensor | None): Legacy KV lengths before storing current tokens. + chunk_metadata (torch.Tensor | None): Optimized precomputed store plan with shape ``(num_chunks, 4)`` + and per-row ``(src_token_start, dst_block_id, dst_block_offset, chunk_len)``. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Updated `(key_cahce, value_cahce)` after in-place writes. + """ + assert len(key_states.shape) == 3 and len(value_states.shape) == 3 and key_states.shape == value_states.shape, ( + "key/value states must be (token_num, kv_head_num, head_dim), please check." + ) + if chunk_metadata is None: + assert block_table is not None, "block_table is required when chunk_metadata is not provided." + assert context_kv_lens is not None, "context_kv_lens is required when chunk_metadata is not provided." + chunk_metadata = build_paged_kv_chunk_metadata( + block_table, + cu_q_lens, + context_kv_lens, + key_cache.shape[2], + ) + else: + assert block_table is None and cu_q_lens is None and context_kv_lens is None, ( + "chunk_metadata path should not be mixed with block_table/cu_q_lens/context_kv_lens." + ) + + assert key_scale is not None and value_scale is not None + assert_paged_kv_store_contract(chunk_metadata) + + if chunk_metadata.shape[0] == 0: + return key_cache, value_cache + + key_q = torch.round(key_states / key_scale).clamp(-128, 127).to(torch.int8) + value_q = torch.round(value_states / value_scale).clamp(-128, 127).to(torch.int8) + + for src_token_start, dst_block_id, dst_block_offset, chunk_len in chunk_metadata.tolist(): + src_end = src_token_start + chunk_len + dst_end = dst_block_offset + chunk_len + key_cache[dst_block_id, :, dst_block_offset:dst_end, :] = key_q[src_token_start:src_end].permute( + 1, 0, 2 + ) + value_cache[dst_block_id, :, dst_block_offset:dst_end, :] = value_q[src_token_start:src_end].permute( + 1, 0, 2 + ) + + return key_cache, value_cache + +class MojoDequantFromPagedKVCache(MojoOperator): + def __init__(self): + super().__init__() + + def forward( + self, + *, + key: torch.Tensor, + value: Optional[torch.Tensor] = None, + key_cache: torch.Tensor, + key_cache_scale: torch.Tensor, + value_cache: Optional[torch.Tensor] = None, + value_cache_scale: Optional[torch.Tensor] = None, + context_lengths: Optional[torch.Tensor] = None, + max_context_len: int, + context_seq_offset: Optional[torch.Tensor] = None, + block_tables: torch.Tensor, + ): + r""" + Copy and dequantize from Transformer int8 paged K/V cache to linear K/V states. + + Args: + key (torch.Tensor): Shape (total_seq_len, head_num, head_size) — key states. + value (torch.Tensor | None): Shape (total_seq_len, head_num, head_size) — value states. + key_cache (torch.Tensor): Shape (block_num, head_num, block_size, head_size) — key cache. + value_cache (torch.Tensor | None): Shape (block_num, head_num, block_size, head_size) — value cache. + key_cache_scale (torch.Tensor): Shape (head_num, head_size) — key cache scale. + value_cache_scale (torch.Tensor | None): Shape (head_num, head_size) — value scale. + context_lengths (torch.Tensor): Shape (batch_size,) — Valid sequence length for each batch sample. + max_context_len (int): Scalar int value — Maximum valid sequence length across all batches during context prefill phase. + context_seq_offset (torch.Tensor | None): Shape (batch_size,) — Cumulative sequence offset for each batch to guarantee non-overlapping sequence storage. + block_tables (torch.Tensor | None): Shape (batch_size, max_block_num) — Logical-to-physical block mapping table for paged cache. + + Returns: + None: All writes are performed in-place on key and value tensors. + """ + def dequant_from_cache(quant_data: torch.Tensor, scale_data: torch.Tensor): + quant_data_fp32 = quant_data.clone().to(torch.float) + scale_data_fp32 = scale_data.clone().to(torch.float) + scale_data_fp32 = scale_data[..., None, :] + dequant_data_fp32 = quant_data_fp32 * scale_data_fp32 + return dequant_data_fp32 + + batch_size = context_lengths.size(0) + if context_seq_offset is None: + cu_seq_offset = torch.cumsum(context_lengths, dim=-1) + context_seq_offset = torch.zeros_like(cu_seq_offset) + context_seq_offset[1:] = cu_seq_offset[:-1] + + total_seqlen = 0 + block_size = key_cache.size(2) + for i in range(batch_size): + context_len = context_lengths[i].item() + seq_begin = context_seq_offset[i].item() + seq_end = seq_begin + context_len + total_seqlen += context_len + full_block_num = context_len // block_size + rem_token_num = context_len % block_size + + # dequant key from cache + key_i = key[seq_begin:seq_end].transpose(1, 0) + key_cache_i = torch.concat( + [key_cache[block_tables[i, j], ...] for j in range(full_block_num)] + + ([key_cache[block_tables[i, full_block_num], :, :rem_token_num, :]] if rem_token_num > 0 else []), + dim=-2, + ) + dequant_key_i = dequant_from_cache(key_cache_i, key_cache_scale) + key_i[...] = dequant_key_i.to(key_i.dtype) + + # dequant value from cache + if not (value_cache is None or value is None or value_cache_scale is None): + value_i = value[seq_begin:seq_end].transpose(1, 0) + value_cache_i = torch.concat( + [value_cache[block_tables[i, j], ...] for j in range(full_block_num)] + + ([value_cache[block_tables[i, full_block_num], :, :rem_token_num, :]] if rem_token_num > 0 else []), + dim=-2, + ) + dequant_value_i = dequant_from_cache(value_cache_i, value_cache_scale) + value_i[...] = dequant_value_i.to(value_i.dtype) + return key, value __all__ = [ "MojoStorePagedMLAKVCache", + "MojoStorePagedKVCacheC8", + "MojoDequantFromPagedKVCache", ] diff --git a/mojo_opset/experimental/operators/normalization.py b/mojo_opset/experimental/operators/normalization.py index 2f44e8ed..b2ba36e2 100644 --- a/mojo_opset/experimental/operators/normalization.py +++ b/mojo_opset/experimental/operators/normalization.py @@ -92,7 +92,98 @@ def extra_repr(self) -> str: ) +class MojoRMSNormInplace(MojoOperator): + def __init__( + self, + norm_size: int, + eps: float = 1e-5, + inplace: bool = False, + **kwargs, + ): + """ + Initialize RMSNorm patch parameters. + + Args: + norm_size (int): Size of 1-D affine scale vector. + eps (float, default=1e-5): Epsilon added for numerical stability; must be > 0. + inplace (bool, default=False): Whether to perform RMSNorm in-place on the input tensor. + **kwargs: The keyword arguments of torch.empty, such as device, dtype and so on to create the weight and bias. + """ + super().__init__(**kwargs) + self.norm_size = norm_size + self.weight = torch.nn.Parameter(torch.empty(norm_size, **self.tensor_factory_kwargs)) + self.variance_epsilon = eps + self.inplace = inplace + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + """ + Apply RMSNorm over the last dimension of the input. + + Args: + hidden_state (torch.Tensor): Input tensor whose last dimension is the hidden size + (e.g., shape (B, T, D) or (..., D)). The normalization is performed across D. + + Returns: + torch.Tensor: Tensor of the same shape and dtype as `hidden_state`, normalized + over the last dimension. + """ + normalized = F.rms_norm( + hidden_state, + [hidden_state.shape[-1]], + weight=self.weight, + eps=self.variance_epsilon, + ) + if self.inplace: + hidden_state.copy_(normalized) + return hidden_state + else: + return normalized + + def extra_repr(self) -> str: + return f"{self.norm_size=}, {self.variance_epsilon=}".replace("self.", "") + +class MojoGroupRMSNormInplace(MojoOperator): + def __init__(self, num_groups, norm_size, eps, elementwise_affine=True, inplace=False, **kwargs): + super().__init__(**kwargs) + self.num_groups = num_groups + self.norm_size = norm_size + self.elementwise_affine = elementwise_affine + if elementwise_affine: + self.weight = torch.nn.Parameter(torch.empty((num_groups, norm_size), **self.tensor_factory_kwargs)) + else: + self.weight = None + self.variance_epsilon = eps + self.inplace = inplace + + def forward(self, input_groups): + # Note: input_groups is a list of tensors, each tensor has compatible shapes for norm + + output_groups = [] + for group_id in range(self.num_groups): + # Compute normalized result (new tensor) + normalized = F.rms_norm( + input_groups[group_id], + (self.norm_size,), + weight=self.weight[group_id], + eps=self.variance_epsilon, + ) + # Copy the normalized values back into the original tensor (in-place) + if self.inplace: + input_groups[group_id].copy_(normalized) + else : + output_groups.append(normalized) + + if self.inplace: + return input_groups + else: + return output_groups + + def extra_repr(self) -> str: + return f"{self.num_groups=}, {self.norm_size=}, {self.variance_epsilon=}, {self.elementwise_affine=}".replace("self.", "") + __all__ = [ "MojoGroupLayerNorm", "MojoChannelRMSNorm", + "MojoRMSNormInplace", + "MojoGroupRMSNormInplace", ] diff --git a/mojo_opset/experimental/operators/position_embedding.py b/mojo_opset/experimental/operators/position_embedding.py index a5f7dea4..4450e31e 100644 --- a/mojo_opset/experimental/operators/position_embedding.py +++ b/mojo_opset/experimental/operators/position_embedding.py @@ -1,5 +1,5 @@ import math -from typing import List +from typing import Optional, Tuple, List import torch @@ -118,7 +118,125 @@ def forward( return y.type_as(x) +class MojoMRoPEInplace(MojoOperator): + """Multimodal Rotary Position Embedding (MRoPE) for Qwen2-VL. + + Applies 3D rotary position embedding over temporal (T), height (H), and width (W) + dimensions to query and key tensors. Supports both interleaved and non-interleaved modes. + + Reference: https://qwenlm.github.io/blog/qwen2-vl/ + """ + + supported_platforms_list = ["npu", "mlu", "meta_device", "ilu"] + + def __init__(self, inplace: bool = False, **kwargs): + """ + Args: + + inplace (bool, default=False): Whether to perform MRoPE in-place on the input tensor. + """ + super().__init__(**kwargs) + self.inplace = inplace + + def extra_repr(self) -> str: + return "" + + @staticmethod + def _rotate_half(hidden_states: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + hidden_size = hidden_states.shape[-1] + hidden_states_half = hidden_size // 2 + left = hidden_states[..., :hidden_states_half] + right = hidden_states[..., hidden_states_half:] + return torch.cat((-right, left), dim=-1) + + @staticmethod + def _apply_interleaved_mrope( + cos_table: torch.Tensor, + sin_table: torch.Tensor, + mrope_section: List[int], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply interleaved MRoPE pattern to cos/sin tables.""" + cos_interleaved = cos_table[0].clone() + cos_interleaved[..., 1 : mrope_section[1] * 3 : 3] = cos_table[1, ..., 1 : mrope_section[1] * 3 : 3] + cos_interleaved[..., 2 : mrope_section[2] * 3 : 3] = cos_table[2, ..., 2 : mrope_section[2] * 3 : 3] + + sin_interleaved = sin_table[0].clone() + sin_interleaved[..., 1 : mrope_section[1] * 3 : 3] = sin_table[1, ..., 1 : mrope_section[1] * 3 : 3] + sin_interleaved[..., 2 : mrope_section[2] * 3 : 3] = sin_table[2, ..., 2 : mrope_section[2] * 3 : 3] + + return cos_interleaved, sin_interleaved + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + cos_table: torch.Tensor, + sin_table: torch.Tensor, + mrope_section: List[int], + is_interleaved: bool = False, + head_dim: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + orig_query = query + orig_key = key + + num_tokens, n_qh_head_dim = query.shape + num_tokens_k, n_kh_head_dim = key.shape + + rope_dim = sum(mrope_section) * 2 + half_rope_dim = rope_dim // 2 + + if head_dim is None: + head_dim = rope_dim + + n_qh = n_qh_head_dim // head_dim + n_kh = n_kh_head_dim // head_dim + + query = query.view(num_tokens, n_qh, head_dim) + key = key.view(num_tokens_k, n_kh, head_dim) + + query_rot, query_pass = query.split([rope_dim, head_dim - rope_dim], dim=-1) + key_rot, key_pass = key.split([rope_dim, head_dim - rope_dim], dim=-1) + + if cos_table.dim() == 3: + if is_interleaved: + cos_table, sin_table = self._apply_interleaved_mrope(cos_table, sin_table, mrope_section) + else: + cos_table = torch.cat([m[i] for i, m in enumerate(cos_table.split(mrope_section, dim=-1))], dim=-1) + sin_table = torch.cat([m[i] for i, m in enumerate(sin_table.split(mrope_section, dim=-1))], dim=-1) + + cos_table = cos_table.view(num_tokens, half_rope_dim) + sin_table = sin_table.view(num_tokens, half_rope_dim) + + query_rot_half1 = query_rot[..., :half_rope_dim] + query_rot_half2 = query_rot[..., half_rope_dim:] + key_rot_half1 = key_rot[..., :half_rope_dim] + key_rot_half2 = key_rot[..., half_rope_dim:] + + cos_expanded = cos_table.unsqueeze(1) + sin_expanded = sin_table.unsqueeze(1) + + query_rot_new_half1 = query_rot_half1 * cos_expanded - query_rot_half2 * sin_expanded + query_rot_new_half2 = query_rot_half2 * cos_expanded + query_rot_half1 * sin_expanded + key_rot_new_half1 = key_rot_half1 * cos_expanded - key_rot_half2 * sin_expanded + key_rot_new_half2 = key_rot_half2 * cos_expanded + key_rot_half1 * sin_expanded + + query_rot = torch.cat([query_rot_new_half1, query_rot_new_half2], dim=-1) + key_rot = torch.cat([key_rot_new_half1, key_rot_new_half2], dim=-1) + + query = torch.cat([query_rot, query_pass], dim=-1).view(num_tokens, -1) + key = torch.cat([key_rot, key_pass], dim=-1).view(num_tokens_k, -1) + + # 根据 self.inplace 决定是否原地修改并返回原始张量 + if self.inplace: + orig_query.copy_(query) # 将计算结果复制回原始输入 + orig_key.copy_(key) + return orig_query, orig_key + else: + return query, key + __all__ = [ "MojoRelativeEmbedding", "MojoGridRoPE", + "MojoMRoPEInplace", ] diff --git a/mojo_opset/tests/accuracy/conftest.py b/mojo_opset/tests/accuracy/conftest.py index 00dee2cc..bc688445 100644 --- a/mojo_opset/tests/accuracy/conftest.py +++ b/mojo_opset/tests/accuracy/conftest.py @@ -1,5 +1,7 @@ import logging import os +import sys +from pathlib import Path import pytest import torch @@ -10,6 +12,43 @@ from mojo_opset.tests.utils import resolve_backend_for_accuracy_test +def _candidate_ext_roots(): + env_root = os.environ.get("MOJO_OPSET_EXT_PATH") + if env_root: + yield Path(env_root) + + repo_root = Path(__file__).resolve().parents[3] + yield repo_root.parent / "mojo_opset_gitlab" + + +def _load_xops_backend_for_accuracy(): + if get_platform() != "mlu": + return + + if os.environ.get("MOJO_OPSET_ACCURACY_LOAD_XOPS", "1") != "1": + return + + for ext_root in _candidate_ext_roots(): + if (ext_root / "mojo_opset_ext_autoload.py").is_file(): + ext_root_str = str(ext_root) + if ext_root_str not in sys.path: + sys.path.insert(0, ext_root_str) + break + + try: + import mojo_opset_ext_autoload + except ModuleNotFoundError as exc: + if exc.name != "mojo_opset_ext_autoload": + raise + logging.warning("mojo_opset_ext_autoload is not available, xops backend will not be loaded.") + return + + mojo_opset_ext_autoload._autoload() + + +_load_xops_backend_for_accuracy() + + @pytest.fixture(scope="session", autouse=True) def setup_session_device(request): platform = get_platform() diff --git a/mojo_opset/tests/accuracy/operators/test_kv_cache.py b/mojo_opset/tests/accuracy/operators/test_kv_cache.py index f0a4f5b0..8bc8621b 100644 --- a/mojo_opset/tests/accuracy/operators/test_kv_cache.py +++ b/mojo_opset/tests/accuracy/operators/test_kv_cache.py @@ -1,8 +1,12 @@ import pytest import torch +import math +import random from mojo_opset import MojoStorePagedKVCache from mojo_opset.experimental import MojoStorePagedMLAKVCache +from mojo_opset.experimental import MojoStorePagedKVCacheC8 +from mojo_opset.experimental import MojoDequantFromPagedKVCache from mojo_opset.tests.utils import assert_close from mojo_opset.tests.utils import auto_switch_platform from mojo_opset.tests.utils import bypass_not_implemented @@ -12,6 +16,15 @@ from mojo_opset.core.operators.kv_cache import build_paged_kv_chunk_metadata +def _assert_int8_cache_close(result, ref, mismatch_ratio=1e-5): + assert result.shape == ref.shape + assert result.dtype == ref.dtype + + diff = torch.abs(result.cpu().to(torch.int16) - ref.cpu().to(torch.int16)) + assert diff.max().item() <= 1 + assert torch.count_nonzero(diff).item() / diff.numel() <= mismatch_ratio + + def _build_store_paged_kv_case( batch_size, kv_heads, @@ -191,6 +204,122 @@ def test_store_paged_kv(batch_size, kv_heads, head_dim, block_size, context_kv_l assert_close(v_cache, v_cache_ref) +@pytest.mark.parametrize( + "batch_size, kv_heads, head_dim, block_size, context_kv_lens_val, q_lens_val", + [ + (2, 2, 128, 128, [0, 0], [130, 33]), + (2, 2, 128, 128, [32, 35], [1, 1]), + (2, 2, 128, 128, [15, 40], [788, 126]), + (2, 2, 128, 256, [15, 40], [788, 126]), + (2, 2, 128, 512, [255, 511], [300, 257]), + (2, 2, 128, 1024, [511, 1023], [600, 513]), + (2, 2, 128, 2048, [1023, 2047], [900, 1025]), + (1, 1, 128, 128, [0], [5]), + (1, 1, 128, 128, [5], [1]), + (1, 1, 128, 512, [510], [3]), + (1, 1, 128, 1024, [1022], [2]), + (1, 1, 128, 2048, [2046], [2]), + (3, 2, 128, 128, [32, -1, 35], [1, 1, 1]), + (3, 2, 128, 128, [0, -1, 5], [4, 0, 2]), + (3, 2, 128, 512, [510, -1, 700], [4, 1, 300]), + (3, 2, 128, 1024, [1020, -1, 1530], [8, 1, 520]), + (3, 2, 128, 2048, [2040, -1, 3000], [16, 1, 900]), + (8, 2, 128, 128, [224, 542, 34, 41, 54, 57, 65, 0], [432, 84, 977, 93, 23, 89, 31, 555]), + (8, 2, 128, 128, [772, 974, 3232, 43, 77, 7633, 888, 1], [1, 1, 1, 1, 1, 1, 1, 1]), + ( + 8, + 2, + 128, + 512, + [224, 542, 34, 41, 54, 57, 65, 0], + [432, 84, 977, 93, 23, 89, 31, 555], + ), + ( + 8, + 2, + 128, + 1024, + [900, 1500, 34, 41, 54, 57, 65, 0], + [700, 600, 977, 93, 23, 89, 31, 555], + ), + ( + 8, + 2, + 128, + 2048, + [1800, 2500, 34, 41, 54, 57, 65, 0], + [900, 1200, 977, 93, 23, 89, 31, 555], + ), + ( + 8, + 2, + 128, + 512, + [772, 974, 3232, 43, 77, 7633, 888, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + ), + ( + 8, + 2, + 128, + 1024, + [1023, 1024, 3232, 43, 77, 7633, 888, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + ), + ( + 8, + 2, + 128, + 2048, + [2047, 2048, 3232, 43, 77, 7633, 888, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + ), + ], +) +@auto_switch_platform() +@bypass_not_implemented +def test_store_paged_kv_c8(batch_size, kv_heads, head_dim, block_size, context_kv_lens_val, q_lens_val): + case = _build_store_paged_kv_case( + batch_size, + kv_heads, + head_dim, + block_size, + context_kv_lens_val, + q_lens_val, + device=get_torch_device(), + ) + cache_scale = torch.randn((2, kv_heads, head_dim), dtype=torch.float, device=get_torch_device()) + key_scale = cache_scale[0] + value_scale = cache_scale[1] + + store_paged_kv_c8_ref = MojoStorePagedKVCacheC8._registry.get("torch")() + store_paged_kv_c8 = MojoStorePagedKVCacheC8() + if type(store_paged_kv_c8_ref) is type(store_paged_kv_c8): + raise NotImplementedError("both operands resolve to the same implementation, skipping comparison.") + + k_cache_ref, v_cache_ref = store_paged_kv_c8_ref( + case["key_states"], + case["value_states"], + case["k_cache"].clone().to(torch.int8), + case["v_cache"].clone().to(torch.int8), + key_scale, + value_scale, + chunk_metadata=case["chunk_metadata"], + ) + k_cache, v_cache = store_paged_kv_c8( + case["key_states"], + case["value_states"], + case["k_cache"].clone().to(torch.int8), + case["v_cache"].clone().to(torch.int8), + key_scale, + value_scale, + chunk_metadata=case["chunk_metadata"], + ) + + _assert_int8_cache_close(k_cache, k_cache_ref) + _assert_int8_cache_close(v_cache, v_cache_ref) + + @pytest.mark.parametrize( "batch_size, kv_heads, head_dim, block_size, context_kv_lens_val, q_lens_val", [ @@ -509,3 +638,188 @@ def test_store_paged_mla_kv( assert_close(ckv_cache, ckv_cache_ref) assert_close(kpe_cache, kpe_cache_ref) + + +def gen_args( + batch_size, + max_context_len, + head_num_q, + head_num_kv, + cache_mem_len, + head_size, + group_size, + block_size, + use_seq_offset, + dtype, + quant_mode, + quant_bit, + pad_head_size=0, + has_value=True, + context_strided=False, +): + # Preprocess arguments + assert cache_mem_len >= max_context_len, "cache_mem_len should greater then or equal to max_context_len." + assert head_size % group_size == 0, "head_size should be a multiply of groupwise." + device = get_torch_device() + total_heads = head_num_q + head_num_kv * 2 + max_seq_offset = cache_mem_len - max_context_len + max_block_num = int(math.ceil(max_context_len / block_size)) + total_blocks = int(math.ceil(cache_mem_len / block_size)) * batch_size // 4 * 4 + block_tables = random.sample(range(0, total_blocks), batch_size * max_block_num) + block_tables = torch.tensor(block_tables, dtype=torch.int32, device=device).view(batch_size, max_block_num) + # Generates key and cache from context + context_lens = torch.randint(size=[batch_size], low=1, high=max_context_len + 1, dtype=torch.int32, device=device) + if use_seq_offset: + context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset, dtype=torch.int32, device=device) + else: + context_paddings = torch.zeros_like(context_lens) + cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1) + total_seqlen = cu_context_lens[-1] + context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device=device) + context_seq_offset[1:] = cu_context_lens[:-1] + if context_strided: + context = torch.randn([total_heads, total_seqlen, max(pad_head_size, head_size)], dtype=dtype, device=device) + context = context.transpose(0, 1) + else: + context = torch.randn([total_seqlen, total_heads, max(pad_head_size, head_size)], dtype=dtype, device=device) + key = context[..., head_num_q : head_num_q + head_num_kv, :head_size] + value = None + dim = 2 if has_value else 1 + cache = ( + torch.randint( + size=(dim, total_blocks // 4, head_num_kv, block_size, head_size), + low=-128, + high=127, + dtype=torch.int32, + device=device, + ) + .view(torch.int8) + .view(dim, total_blocks, head_num_kv, block_size, head_size) + ) + + # Generates key_cache_scale and value_cache_scale + if quant_mode == 0: # quant_mode == 0 is per channel + cache_scale = torch.randn((dim, head_num_kv, head_size), dtype=torch.float, device=device) + else: # quant_mode != 1 (== 1 for extend) is per head + cache_scale = torch.randn((dim, total_blocks, head_num_kv, block_size), dtype=torch.float, device=device) + key_cache = cache[0] + value_cache = None + key_cache_scale = cache_scale[0] + value_cache_scale = None + # Prepare arguments + if has_value: + value = context[..., head_num_q + head_num_kv : head_num_q + 2 * head_num_kv, :head_size] + value_cache = cache[1] + value_cache_scale = cache_scale[1] + args = [key, value, key_cache, value_cache, key_cache_scale, value_cache_scale] + args += [context_lens, max_context_len, context_seq_offset if use_seq_offset else None, block_tables] + args += [quant_mode, quant_bit] + return args + +@pytest.mark.ci +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("head_num_q", [4]) +@pytest.mark.parametrize("max_context_len", [512]) +@pytest.mark.parametrize("head_num_kv", [32]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("pad_head_size", [64]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("cache_mem_len", [768, 1024]) +@pytest.mark.parametrize("quant_mode", [0]) +@pytest.mark.parametrize("quant_bit", [8]) +@pytest.mark.parametrize("use_seq_offset", [False, True]) +@pytest.mark.parametrize("has_value", [True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("context_strided", [False, True]) +@auto_switch_platform() +@bypass_not_implemented +def test_dequant_from_paged_cache( + batch_size, + head_num_q, + max_context_len, + head_num_kv, + head_size, + pad_head_size, + block_size, + cache_mem_len, + quant_mode, + quant_bit, + use_seq_offset, + has_value, + dtype, + context_strided, +): + print( + "batch_size={}, head_num={}, head_size={}, max_context_len={}, quant_mode={}, " + "quant_bit={}, dtype={},context_strided={} testing...".format( + batch_size, head_num_kv, head_size, max_context_len, quant_mode, quant_bit, dtype, context_strided + ) + ) + + args = gen_args( + batch_size, + max_context_len, + head_num_q, + head_num_kv, + cache_mem_len, + head_size, + head_size, + block_size, + use_seq_offset, + dtype, + quant_mode, + quant_bit, + pad_head_size, + has_value, + context_strided, + ) + ( + key, + value, + key_cache, + value_cache, + key_cache_scale, + value_cache_scale, + context_lengths, + max_context_len, + context_seq_offset, + block_tables, + quant_mode, + quant_bit, + ) = args + key_ref = key.clone() + value_ref = value.clone() + + dequant_from_paged_kv_cache_ref = MojoDequantFromPagedKVCache._registry.get("torch")() + dequant_from_paged_kv_cache = MojoDequantFromPagedKVCache() + if type(dequant_from_paged_kv_cache_ref) is type(dequant_from_paged_kv_cache): + raise NotImplementedError("both operands resolve to the same implementation, skipping comparison.") + + key, value = dequant_from_paged_kv_cache( + key=key, + value=value, + key_cache=key_cache, + key_cache_scale=key_cache_scale, + value_cache=value_cache, + value_cache_scale=value_cache_scale, + context_lengths=context_lengths, + max_context_len=max_context_len, + context_seq_offset=context_seq_offset, + block_tables=block_tables, + ) + + key_ref, value_ref = dequant_from_paged_kv_cache_ref( + key=key_ref, + value=value_ref, + key_cache=key_cache, + key_cache_scale=key_cache_scale, + value_cache=value_cache, + value_cache_scale=value_cache_scale, + context_lengths=context_lengths, + max_context_len=max_context_len, + context_seq_offset=context_seq_offset, + block_tables=block_tables, + ) + + assert_close(key, key_ref) + assert_close(value, value_ref) diff --git a/mojo_opset/tests/accuracy/operators/test_normalization.py b/mojo_opset/tests/accuracy/operators/test_normalization.py index d44fd192..455cfb85 100644 --- a/mojo_opset/tests/accuracy/operators/test_normalization.py +++ b/mojo_opset/tests/accuracy/operators/test_normalization.py @@ -16,6 +16,8 @@ from mojo_opset import MojoRMSNormQuant from mojo_opset.experimental import MojoChannelRMSNorm from mojo_opset.experimental import MojoGroupLayerNorm +from mojo_opset.experimental import MojoRMSNormInplace +from mojo_opset.experimental import MojoGroupRMSNormInplace torch.manual_seed(43) @@ -61,6 +63,51 @@ def test_rmsnorm(shape, dtype, eps): rmsnorm.forward_diff_with(rmsnorm_ref, x, atol=atol, rtol=rtol) +@pytest.mark.parametrize( + "shape", + [ + (32, 1024), + (64, 8192), + (57, 7338), + (2, 256), + (7762, 18688), + ], +) +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("eps", [1e-5]) +@pytest.mark.parametrize("inplace", [True]) +@bypass_not_implemented +def test_rmsnorm_inplace(shape, dtype, eps, inplace): + x = torch.randn(size=shape, dtype=dtype) + x_ref = x.clone() + weight = torch.randn(size=(shape[-1],), dtype=dtype) + rmsnorm = MojoRMSNormInplace(eps=eps, norm_size=shape[-1], inplace=inplace, device=x.device, dtype=x.dtype) + + rmsnorm_ref = ( + MojoRMSNormInplace._registry.get("torch")( + eps=eps, + norm_size=weight.size(0), + inplace=inplace, + ) + .to(x.device) + .to(weight.dtype) + ) + + with torch.no_grad(): + rmsnorm.weight.copy_(weight.to(torch.float32)) + rmsnorm_ref.weight.copy_(weight.to(torch.float32)) + + rmsnorm(x) + rmsnorm_ref(x_ref) + + if x.dtype == torch.float32: + atol, rtol = 1e-5, 1e-6 + else: + atol, rtol = 2e-1, 2e-2 + + torch.testing.assert_close(x, x_ref, atol=atol, rtol=rtol) + + @pytest.mark.parametrize( "shape", [ @@ -129,7 +176,7 @@ def test_grouprmsnorm(bsz, group_dims, hidden_size, dtype, eps): rmsnorm_ref = ( MojoGroupRMSNorm._registry.get("torch")( - num_groups = len(group_dims), + num_groups = len(group_dims), eps=eps, norm_size=hidden_size, ) @@ -148,6 +195,60 @@ def test_grouprmsnorm(bsz, group_dims, hidden_size, dtype, eps): rmsnorm.forward_diff_with(rmsnorm_ref, x_groups, atol=atol, rtol=rtol) + +@pytest.mark.parametrize( + "bsz,group_dims,hidden_size", + [ + (1024, (16, 4), 96), + (798, (16, 4, 8, 2), 128), + (8000, (48, 8, 16, 4), 128), + (17, (3, 5), 128), + (33, (2, 7, 1), 128), + (65, (4, 4, 4, 4), 128), + (129, (1, 3, 5, 7), 128), + (257, (6, 2), 192), + (513, (8, 8, 8), 256), + (1025, (12, 6, 3, 1), 128), + (2049, (5, 9, 7, 3), 64), + ], +) +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("eps", [1e-5]) +@pytest.mark.parametrize("inplace", [True]) +@bypass_not_implemented +def test_grouprmsnorm_inplace(bsz, group_dims, hidden_size, dtype, eps, inplace): + x = torch.randn(size=(bsz, sum(group_dims), hidden_size), dtype=dtype) + x_ref = x.clone() + x_groups = torch.split(x, group_dims, dim=1) + weight = torch.randn(size=(len(group_dims), hidden_size), dtype=dtype) + rmsnorm = MojoGroupRMSNormInplace(num_groups = len(group_dims), eps=eps, norm_size=hidden_size, inplace=inplace, device=x.device, dtype=x.dtype) + + rmsnorm_ref = ( + MojoGroupRMSNormInplace._registry.get("torch")( + num_groups = len(group_dims), + eps=eps, + norm_size=hidden_size, + inplace=inplace, + ) + .to(x.device) + .to(weight.dtype) + ) + + with torch.no_grad(): + rmsnorm.weight.copy_(weight.to(torch.float32)) + rmsnorm_ref.weight.copy_(weight.to(torch.float32)) + + rmsnorm(x) + rmsnorm_ref(x_ref) + + if x.dtype == torch.float32: + atol, rtol = 1e-5, 1e-6 + else: + atol, rtol = 3e-2, 6e-3 + + torch.testing.assert_close(x, x_ref, atol=atol, rtol=rtol) + + @pytest.mark.parametrize( "bsz,group_dims,hidden_size", [ @@ -421,8 +522,8 @@ def test_residual_add_rmsnorm_quant(shape, dtype, norm_pos): op_ref, x, residual, - atol=(2, 1e-3, 1e-3), - rtol=(0, 1e-3, 1e-3), + atol=(2, 1e-2, 1e-2), + rtol=(0, 1e-2, 1e-2), ) diff --git a/mojo_opset/tests/accuracy/operators/test_position_embedding.py b/mojo_opset/tests/accuracy/operators/test_position_embedding.py index 56a834f6..ea0aa9f5 100644 --- a/mojo_opset/tests/accuracy/operators/test_position_embedding.py +++ b/mojo_opset/tests/accuracy/operators/test_position_embedding.py @@ -8,6 +8,7 @@ from mojo_opset import MojoVisionRotaryEmbedding2D from mojo_opset.experimental import MojoGridRoPE from mojo_opset.experimental import MojoRelativeEmbedding +from mojo_opset.experimental import MojoMRoPEInplace from mojo_opset.tests.utils import bypass_not_implemented from mojo_opset.utils.platform import get_torch_device @@ -27,19 +28,19 @@ } -def compute_cos_sin_cache(head_dim, rotary_dim, max_position, base=10000.0): +def compute_cos_sin_cache(head_dim, rotary_dim, max_position, device, base=10000.0): inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim // 2, 2, dtype=torch.float32) / rotary_dim)) t = torch.arange(max_position, dtype=inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, inv_freq) freqs = freqs.repeat_interleave(2, dim=-1) - return freqs.cos(), freqs.sin() + return freqs.cos().to(device), freqs.sin().to(device) def prepare_mrope_test_inputs(num_tokens, n_qh, n_kh, head_dim, mrope_section, device, dtype=torch.float32): rotary_dim = sum(mrope_section) * 2 positions = torch.randint(0, 1000, (3, num_tokens), device=device, dtype=torch.long) - cos_cache, sin_cache = compute_cos_sin_cache(head_dim, rotary_dim, 4000, base=10000.0) + cos_cache, sin_cache = compute_cos_sin_cache(head_dim, rotary_dim, 4000, device, base=10000.0) half_rotary_dim = rotary_dim // 2 cos_3d = torch.zeros(3, num_tokens, half_rotary_dim, device=device, dtype=torch.float32) @@ -290,6 +291,50 @@ def test_mrope_qwen_models( mrope.forward_diff_with(mrope_ref, query, key, cos_table, sin_table, mrope_section_out, is_interleaved, head_dim=head_dim) +@pytest.mark.parametrize("num_tokens", [1, 32, 128]) +@pytest.mark.parametrize("n_qh, n_kh, mrope_section, is_interleaved, model_name", [ + (28, 4, [16, 24, 24], False, "Qwen2-VL-7B"), + (40, 8, [16, 24, 24], False, "Qwen2.5-VL-32B"), + (16, 8, [24, 20, 20], True, "Qwen3-VL-2B"), + (32, 8, [24, 20, 20], True, "Qwen3-VL-8B"), +]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("inplace", [True]) +@bypass_not_implemented +def test_mrope_qwen_models_inplace( + num_tokens, + n_qh, + n_kh, + head_dim, + mrope_section, + is_interleaved, + model_name, + dtype, + inplace, +): + device = get_torch_device() + query, key, cos_table, sin_table, mrope_section_out = prepare_mrope_test_inputs( + num_tokens, n_qh, n_kh, head_dim, mrope_section, device, dtype=dtype + ) + query_ref = query.clone() + key_ref = key.clone() + + mrope = MojoMRoPEInplace(inplace=inplace) + mrope_ref = MojoMRoPEInplace._registry.get("torch")(inplace=inplace) + + mrope(query, key, cos_table, sin_table, mrope_section_out, is_interleaved, head_dim=head_dim) + mrope_ref(query_ref, key_ref, cos_table, sin_table, mrope_section_out, is_interleaved, head_dim=head_dim) + + if dtype == torch.float32: + atol, rtol = 1e-5, 1e-6 + else: + atol, rtol = 3e-2, 6e-3 + + torch.testing.assert_close(query, query_ref, atol=atol, rtol=rtol) + torch.testing.assert_close(key, key_ref, atol=atol, rtol=rtol) + + @pytest.mark.parametrize("num_tokens", [16, 64]) @pytest.mark.parametrize("n_qh, n_kh", [ (32, 4), @@ -323,6 +368,53 @@ def test_mrope_partial_rotation( mrope.forward_diff_with(mrope_ref, query, key, cos_table, sin_table, mrope_section_out, is_interleaved, head_dim=head_dim) +@pytest.mark.parametrize("num_tokens", [16, 64]) +@pytest.mark.parametrize("n_qh, n_kh", [ + (32, 4), + (64, 8), +]) +@pytest.mark.parametrize("head_dim, mrope_section, description", [ + (128, [8, 12, 12], "partial_rotation_50pct"), + (128, [12, 18, 18], "partial_rotation_75pct"), + (96, [8, 12, 12], "small_head_full_rotation"), +]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("is_interleaved", [False]) +@pytest.mark.parametrize("inplace", [True]) +@bypass_not_implemented +def test_mrope_partial_rotation_inplace( + num_tokens, + n_qh, + n_kh, + head_dim, + mrope_section, + description, + dtype, + is_interleaved, + inplace, +): + device = get_torch_device() + query, key, cos_table, sin_table, mrope_section_out = prepare_mrope_test_inputs( + num_tokens, n_qh, n_kh, head_dim, mrope_section, device, dtype=dtype + ) + query_ref = query.clone() + key_ref = key.clone() + + mrope = MojoMRoPEInplace(inplace=inplace) + mrope_ref = MojoMRoPEInplace._registry.get("torch")(inplace=inplace) + + mrope(query, key, cos_table, sin_table, mrope_section_out, is_interleaved, head_dim=head_dim) + mrope_ref(query_ref, key_ref, cos_table, sin_table, mrope_section_out, is_interleaved, head_dim=head_dim) + + if dtype == torch.float32: + atol, rtol = 1e-5, 1e-6 + else: + atol, rtol = 3e-2, 6e-3 + + torch.testing.assert_close(query, query_ref, atol=atol, rtol=rtol) + torch.testing.assert_close(key, key_ref, atol=atol, rtol=rtol) + + @pytest.mark.parametrize("num_buckets", [32, 64]) @pytest.mark.parametrize("num_heads", [8, 16]) @pytest.mark.parametrize("bidirectional", [True, False])