Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,47 +2254,65 @@ def fused_apply_rotary_pos_emb(
freqs: torch.Tensor,
transpose_output_memory: bool = False,
interleaved: bool = False,
start_positions: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T in `sbhd` format."""
if transpose_output_memory:
warnings.warn(
"transpose_output_memory is not supported by TE's fused RoPE and will be ignored."
)

conditional_kwargs = {}
if is_te_min_version("2.10.0.dev0"):
conditional_kwargs["start_positions"] = start_positions
else:
if start_positions is not None:
raise ValueError(
"Only TE >= 2.10.0.dev0 supports offset RoPE application with "
"`start_positions` argument."
)

if is_te_min_version("2.3.0"):
return apply_rotary_pos_emb(
t, freqs, tensor_format="sbhd", interleaved=interleaved, fused=True
)
conditional_kwargs["interleaved"] = interleaved
else:
if interleaved:
raise ValueError("Only TE >= 2.3.0 supports interleaved fused RoPE.")

return apply_rotary_pos_emb(t, freqs, tensor_format="sbhd", fused=True)
return apply_rotary_pos_emb(
t, freqs, tensor_format="sbhd", fused=True, **conditional_kwargs
)

def fused_apply_rotary_pos_emb_thd(
t: torch.Tensor,
cu_seqlens: torch.Tensor,
freqs: torch.Tensor,
cp_size: int = 1,
cp_rank: int = 0,
start_positions: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Apply rotary positional embedding to input tensor T in `thd` format with CP support.
"""
conditional_kwargs = {}
if is_te_min_version("2.10.0.dev0"):
conditional_kwargs["start_positions"] = start_positions
else:
if start_positions is not None:
raise ValueError(
"Only TE >= 2.10.0.dev0 supports offset RoPE application with "
"`start_positions` argument."
)

if is_te_min_version("1.12.0", check_equality=True):
return apply_rotary_pos_emb(
t,
freqs,
tensor_format="thd",
fused=True,
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
)
conditional_kwargs["cp_size"] = cp_size
conditional_kwargs["cp_rank"] = cp_rank
else:
assert cp_size == 1, "Only TE >= 1.12 supports RoPE fusion for THD format with CP."
return apply_rotary_pos_emb(
t, freqs, tensor_format="thd", fused=True, cu_seqlens=cu_seqlens
)
if cp_size > 1:
raise ValueError("Only TE >= 1.12.0 supports CP RoPE application for THD format.")

return apply_rotary_pos_emb(
t, freqs, tensor_format="thd", fused=True, cu_seqlens=cu_seqlens, **conditional_kwargs
)

except ImportError:
pass
Expand Down
31 changes: 31 additions & 0 deletions tests/unit_tests/transformer/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

import pytest
import torch
from packaging.version import Version as PkgVersion
from pytest_mock import mocker

from megatron.core.extensions.transformer_engine import (
fused_apply_rotary_pos_emb,
fused_apply_rotary_pos_emb_thd,
)
from megatron.core.models.common.embeddings import apply_rotary_pos_emb
from megatron.core.models.common.embeddings.rotary_pos_embedding import (
MultimodalRotaryEmbedding,
Expand Down Expand Up @@ -94,6 +100,31 @@ def test_cpu_forward(self):
assert output.dtype == torch.float32
assert output.device.type == 'cuda'

@pytest.mark.internal
def test_transformer_engine_version_less_than_2_10(self, mocker):
with pytest.raises(Exception) as exc_info:
mocker.patch("megatron.core.utils.get_te_version", return_value=PkgVersion("2.9"))
t = torch.randn(64, 1, 1, 8)
freqs = torch.randn(64, 1, 1, 8)
fused_apply_rotary_pos_emb(t, freqs, start_positions=torch.tensor([0, 1, 2, 3]))

assert str(exc_info.value) == (
"Only TE >= 2.10.0.dev0 supports offset RoPE application with "
"`start_positions` argument."
)

with pytest.raises(Exception) as exc_info_thd:
mocker.patch("megatron.core.utils.get_te_version", return_value=PkgVersion("2.9"))
t = torch.randn(64, 1, 8)
freqs = torch.randn(64, 1, 1, 8)
cu_seqlens = torch.tensor([0, 64])
fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs, start_positions=torch.tensor([0]))

assert str(exc_info_thd.value) == (
"Only TE >= 2.10.0.dev0 supports offset RoPE application with "
"`start_positions` argument."
)


class TestQKVRotaryEmbedding:
def setup_method(self):
Expand Down
Loading