diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index bf5b228b6c9..f852da4ffc3 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -2254,21 +2254,33 @@ def fused_apply_rotary_pos_emb( freqs: torch.Tensor, transpose_output_memory: bool = False, interleaved: bool = False, + start_positions: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Apply rotary positional embedding to input tensor T in `sbhd` format.""" if transpose_output_memory: warnings.warn( "transpose_output_memory is not supported by TE's fused RoPE and will be ignored." ) + + conditional_kwargs = {} + if is_te_min_version("2.10.0.dev0"): + conditional_kwargs["start_positions"] = start_positions + else: + if start_positions is not None: + raise ValueError( + "Only TE >= 2.10.0.dev0 supports offset RoPE application with " + "`start_positions` argument." + ) + if is_te_min_version("2.3.0"): - return apply_rotary_pos_emb( - t, freqs, tensor_format="sbhd", interleaved=interleaved, fused=True - ) + conditional_kwargs["interleaved"] = interleaved else: if interleaved: raise ValueError("Only TE >= 2.3.0 supports interleaved fused RoPE.") - return apply_rotary_pos_emb(t, freqs, tensor_format="sbhd", fused=True) + return apply_rotary_pos_emb( + t, freqs, tensor_format="sbhd", fused=True, **conditional_kwargs + ) def fused_apply_rotary_pos_emb_thd( t: torch.Tensor, @@ -2276,25 +2288,31 @@ def fused_apply_rotary_pos_emb_thd( freqs: torch.Tensor, cp_size: int = 1, cp_rank: int = 0, + start_positions: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Apply rotary positional embedding to input tensor T in `thd` format with CP support. """ + conditional_kwargs = {} + if is_te_min_version("2.10.0.dev0"): + conditional_kwargs["start_positions"] = start_positions + else: + if start_positions is not None: + raise ValueError( + "Only TE >= 2.10.0.dev0 supports offset RoPE application with " + "`start_positions` argument." + ) + if is_te_min_version("1.12.0", check_equality=True): - return apply_rotary_pos_emb( - t, - freqs, - tensor_format="thd", - fused=True, - cu_seqlens=cu_seqlens, - cp_size=cp_size, - cp_rank=cp_rank, - ) + conditional_kwargs["cp_size"] = cp_size + conditional_kwargs["cp_rank"] = cp_rank else: - assert cp_size == 1, "Only TE >= 1.12 supports RoPE fusion for THD format with CP." - return apply_rotary_pos_emb( - t, freqs, tensor_format="thd", fused=True, cu_seqlens=cu_seqlens - ) + if cp_size > 1: + raise ValueError("Only TE >= 1.12.0 supports CP RoPE application for THD format.") + + return apply_rotary_pos_emb( + t, freqs, tensor_format="thd", fused=True, cu_seqlens=cu_seqlens, **conditional_kwargs + ) except ImportError: pass diff --git a/tests/unit_tests/transformer/test_rope.py b/tests/unit_tests/transformer/test_rope.py index e33c6101b91..f188786f6f2 100644 --- a/tests/unit_tests/transformer/test_rope.py +++ b/tests/unit_tests/transformer/test_rope.py @@ -2,7 +2,13 @@ import pytest import torch +from packaging.version import Version as PkgVersion +from pytest_mock import mocker +from megatron.core.extensions.transformer_engine import ( + fused_apply_rotary_pos_emb, + fused_apply_rotary_pos_emb_thd, +) from megatron.core.models.common.embeddings import apply_rotary_pos_emb from megatron.core.models.common.embeddings.rotary_pos_embedding import ( MultimodalRotaryEmbedding, @@ -94,6 +100,31 @@ def test_cpu_forward(self): assert output.dtype == torch.float32 assert output.device.type == 'cuda' + @pytest.mark.internal + def test_transformer_engine_version_less_than_2_10(self, mocker): + with pytest.raises(Exception) as exc_info: + mocker.patch("megatron.core.utils.get_te_version", return_value=PkgVersion("2.9")) + t = torch.randn(64, 1, 1, 8) + freqs = torch.randn(64, 1, 1, 8) + fused_apply_rotary_pos_emb(t, freqs, start_positions=torch.tensor([0, 1, 2, 3])) + + assert str(exc_info.value) == ( + "Only TE >= 2.10.0.dev0 supports offset RoPE application with " + "`start_positions` argument." + ) + + with pytest.raises(Exception) as exc_info_thd: + mocker.patch("megatron.core.utils.get_te_version", return_value=PkgVersion("2.9")) + t = torch.randn(64, 1, 8) + freqs = torch.randn(64, 1, 1, 8) + cu_seqlens = torch.tensor([0, 64]) + fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs, start_positions=torch.tensor([0])) + + assert str(exc_info_thd.value) == ( + "Only TE >= 2.10.0.dev0 supports offset RoPE application with " + "`start_positions` argument." + ) + class TestQKVRotaryEmbedding: def setup_method(self):