From c3e83ad3c03c0de8771a2db7c2a66c23a1486f9f Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 26 Nov 2025 14:45:30 -0800 Subject: [PATCH 1/2] feat: Apply RoPE embedding with sequence offsets Signed-off-by: Sudhakar Singh --- .../core/extensions/transformer_engine.py | 52 +++++++++++++------ tests/unit_tests/transformer/test_rope.py | 30 +++++++++++ 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index e95409e08e9..8269b502a74 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -1943,21 +1943,33 @@ def fused_apply_rotary_pos_emb( freqs: torch.Tensor, transpose_output_memory: bool = False, interleaved: bool = False, + start_positions: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Apply rotary positional embedding to input tensor T in `sbhd` format.""" if transpose_output_memory: warnings.warn( "transpose_output_memory is not supported by TE's fused RoPE and will be ignored." ) + + conditional_kwargs = {} + if is_te_min_version("2.10.0.dev0"): + conditional_kwargs["start_positions"] = start_positions + else: + if start_positions is not None: + raise ValueError( + "Only TE >= 2.10.0.dev0 supports offset RoPE application with " + "`start_positions` argument." + ) + if is_te_min_version("2.3.0"): - return apply_rotary_pos_emb( - t, freqs, tensor_format="sbhd", interleaved=interleaved, fused=True - ) + conditional_kwargs["interleaved"] = interleaved else: if interleaved: raise ValueError("Only TE >= 2.3.0 supports interleaved fused RoPE.") - return apply_rotary_pos_emb(t, freqs, tensor_format="sbhd", fused=True) + return apply_rotary_pos_emb( + t, freqs, tensor_format="sbhd", fused=True, **conditional_kwargs + ) def fused_apply_rotary_pos_emb_thd( t: torch.Tensor, @@ -1965,25 +1977,31 @@ def fused_apply_rotary_pos_emb_thd( freqs: torch.Tensor, cp_size: int = 1, cp_rank: int = 0, + start_positions: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Apply rotary positional embedding to input tensor T in `thd` format with CP support. """ + conditional_kwargs = {} + if is_te_min_version("2.10.0.dev0"): + conditional_kwargs["start_positions"] = start_positions + else: + if start_positions is not None: + raise ValueError( + "Only TE >= 2.10.0.dev0 supports offset RoPE application with " + "`start_positions` argument." + ) + if is_te_min_version("1.12.0", check_equality=True): - return apply_rotary_pos_emb( - t, - freqs, - tensor_format="thd", - fused=True, - cu_seqlens=cu_seqlens, - cp_size=cp_size, - cp_rank=cp_rank, - ) + conditional_kwargs["cp_size"] = cp_size + conditional_kwargs["cp_rank"] = cp_rank else: - assert cp_size == 1, "Only TE >= 1.12 supports RoPE fusion for THD format with CP." - return apply_rotary_pos_emb( - t, freqs, tensor_format="thd", fused=True, cu_seqlens=cu_seqlens - ) + if cp_size > 1: + raise ValueError("Only TE >= 1.12.0 supports CP RoPE application for THD format.") + + return apply_rotary_pos_emb( + t, freqs, tensor_format="thd", fused=True, cu_seqlens=cu_seqlens, **conditional_kwargs + ) except ImportError: pass diff --git a/tests/unit_tests/transformer/test_rope.py b/tests/unit_tests/transformer/test_rope.py index e33c6101b91..a6c6fdec995 100644 --- a/tests/unit_tests/transformer/test_rope.py +++ b/tests/unit_tests/transformer/test_rope.py @@ -2,6 +2,8 @@ import pytest import torch +from packaging.version import Version as PkgVersion +from pytest_mock import mocker from megatron.core.models.common.embeddings import apply_rotary_pos_emb from megatron.core.models.common.embeddings.rotary_pos_embedding import ( @@ -10,6 +12,10 @@ ) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.extensions.transformer_engine import ( + fused_apply_rotary_pos_emb, + fused_apply_rotary_pos_emb_thd, +) try: from transformer_engine.pytorch.attention.rope import apply_fused_qkv_rotary_pos_emb @@ -94,6 +100,30 @@ def test_cpu_forward(self): assert output.dtype == torch.float32 assert output.device.type == 'cuda' + @pytest.mark.internal + def test_transformer_engine_version_less_than_2_10(self, mocker): + with pytest.raises(Exception) as exc_info: + mocker.patch("megatron.core.utils.get_te_version", return_value=PkgVersion("2.9")) + t = torch.randn(64, 1, 1, 8) + freqs = torch.randn(64, 1, 1, 8) + fused_apply_rotary_pos_emb(t, freqs, start_positions=torch.tensor([0, 1, 2, 3])) + + assert str(exc_info.value) == ( + "Only TE >= 2.10.0.dev0 supports offset RoPE application with " + "`start_positions` argument." + ) + + with pytest.raises(Exception) as exc_info_thd: + mocker.patch("megatron.core.utils.get_te_version", return_value=PkgVersion("2.9")) + t = torch.randn(64, 1, 8) + freqs = torch.randn(64, 1, 1, 8) + cu_seqlens = torch.tensor([0, 64]) + fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs, start_positions=torch.tensor([0,])) + + assert str(exc_info_thd.value) == ( + "Only TE >= 2.10.0.dev0 supports offset RoPE application with " + "`start_positions` argument." + ) class TestQKVRotaryEmbedding: def setup_method(self): From f0461ab06bf97267d25abf30b73686e31d3969f0 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 26 Nov 2025 15:50:47 -0800 Subject: [PATCH 2/2] apply autoformatter.sh checks Signed-off-by: Sudhakar Singh --- tests/unit_tests/transformer/test_rope.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/unit_tests/transformer/test_rope.py b/tests/unit_tests/transformer/test_rope.py index a6c6fdec995..f188786f6f2 100644 --- a/tests/unit_tests/transformer/test_rope.py +++ b/tests/unit_tests/transformer/test_rope.py @@ -5,6 +5,10 @@ from packaging.version import Version as PkgVersion from pytest_mock import mocker +from megatron.core.extensions.transformer_engine import ( + fused_apply_rotary_pos_emb, + fused_apply_rotary_pos_emb_thd, +) from megatron.core.models.common.embeddings import apply_rotary_pos_emb from megatron.core.models.common.embeddings.rotary_pos_embedding import ( MultimodalRotaryEmbedding, @@ -12,10 +16,6 @@ ) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.extensions.transformer_engine import ( - fused_apply_rotary_pos_emb, - fused_apply_rotary_pos_emb_thd, -) try: from transformer_engine.pytorch.attention.rope import apply_fused_qkv_rotary_pos_emb @@ -118,13 +118,14 @@ def test_transformer_engine_version_less_than_2_10(self, mocker): t = torch.randn(64, 1, 8) freqs = torch.randn(64, 1, 1, 8) cu_seqlens = torch.tensor([0, 64]) - fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs, start_positions=torch.tensor([0,])) + fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs, start_positions=torch.tensor([0])) assert str(exc_info_thd.value) == ( "Only TE >= 2.10.0.dev0 supports offset RoPE application with " "`start_positions` argument." ) + class TestQKVRotaryEmbedding: def setup_method(self): Utils.initialize_model_parallel(1, 1)