Skip to content

Commit 70e5920

Browse files
committed
[CP][RFC] Enable FlexCP for llama3 with parallelize_module
Similar to #1696, but this PR uses parallel_module similar to TP/SP. This PR also requires pytorch/pytorch#162542
1 parent bd3850b commit 70e5920

File tree

3 files changed

+41
-15
lines changed

3 files changed

+41
-15
lines changed

torchtitan/models/attention.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.nn.attention import sdpa_kernel, SDPBackend
1616
from torch.nn.attention.flex_attention import (
1717
_mask_mod_signature,
18+
AuxOutput,
1819
BlockMask,
1920
create_block_mask,
2021
flex_attention,
@@ -28,6 +29,26 @@
2829
FLEX_ATTN_MASK_T = tuple[str, int | None]
2930

3031

32+
class FlexAttentionWrapper(torch.nn.Module):
33+
_flex_attn: ClassVar[Callable] = torch.compile(
34+
flex_attention, mode="max-autotune-no-cudagraphs"
35+
)
36+
37+
def __init__(self) -> None:
38+
super().__init__()
39+
40+
def forward(self, *args: object, **kwargs: object) -> [
41+
torch.Tensor | tuple[torch.Tensor, torch.Tensor],
42+
tuple[torch.Tensor, AuxOutput],
43+
]:
44+
# 1. _flex_attn has to be a class variable, otherwise there will
45+
# be multiple complied flex_attention, which can be slow.
46+
# 2. `self._flex_attn` is not correct, `self` will be passed in
47+
# as the first argument, which will cause an error.
48+
# `FlexAttentionWrapper._flex_attn` is correct.
49+
return FlexAttentionWrapper._flex_attn(*args, **kwargs)
50+
51+
3152
class FlexAttention(torch.nn.Module):
3253
"""FlexAttention module that uses torch.nn.attention.flex_attention.
3354
@@ -46,11 +67,6 @@ class FlexAttention(torch.nn.Module):
4667
to the keys within the same block.
4768
"""
4869

49-
# We registered flex_attention related attributes as class variables as we
50-
# need to amortize the cost of compilation.
51-
flex_attn: ClassVar[Callable] = torch.compile(
52-
flex_attention, mode="max-autotune-no-cudagraphs"
53-
)
5470
compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask)
5571
used_attn_mask_types: ClassVar[set[FLEX_ATTN_MASK_T]] = set()
5672
# Attention mask type to the created BlockMask.
@@ -71,6 +87,7 @@ def __init__(
7187
raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.")
7288
self.attn_mask_type = attn_mask_type
7389
self.fixed_block_size = fixed_block_size
90+
self.attention_fn_wrapper = FlexAttentionWrapper()
7491

7592
FlexAttention.used_attn_mask_types.add(self.mask_key)
7693

@@ -86,7 +103,7 @@ def forward(
86103
scale: float | None = None,
87104
) -> torch.Tensor:
88105
block_mask = FlexAttention.block_masks[self.mask_key]
89-
return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale)
106+
return self.attention_fn_wrapper(q, k, v, block_mask=block_mask, scale=scale)
90107

91108
@staticmethod
92109
def _get_causal_mask_mod() -> _mask_mod_signature:
@@ -251,6 +268,11 @@ def init_attention_mask(
251268
# while we continue debugging accuracy issues. However, we want to evaluate
252269
# the user experience with CP enabled.
253270
if cp_mesh is not None:
271+
from torch.distributed.tensor.experimental._attention import _DispatchMode
272+
273+
torch.distributed.tensor.experimental._attention._dispatch_mode = (
274+
_DispatchMode.MODULE_WRAPPER
275+
)
254276
FlexAttention.compiled_create_block_mask = functools.partial(
255277
create_cp_block_mask, device_mesh=cp_mesh
256278
)

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from torch.distributed.device_mesh import DeviceMesh
1515
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
1616
from torch.distributed.tensor import Replicate, Shard
17+
18+
from torch.distributed.tensor.experimental._attention import _ContextParallel
1719
from torch.distributed.tensor.parallel import (
1820
ColwiseParallel,
1921
parallelize_module,
@@ -67,8 +69,6 @@ def parallelize_llama(
6769
"""
6870

6971
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
70-
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
71-
raise NotImplementedError("CP support for FlexAttention is still in progress.")
7272

7373
if parallel_dims.tp_enabled:
7474
enable_float8_linear = "float8" in job_config.model.converters
@@ -90,6 +90,17 @@ def parallelize_llama(
9090
)
9191
maybe_enable_async_tp(job_config, world_mesh["tp"])
9292

93+
if parallel_dims.cp_enabled:
94+
for block in model.layers.values():
95+
parallelize_module(
96+
module=block.attention.sdpa.attention_fn_wrapper,
97+
device_mesh=world_mesh["cp"],
98+
parallelize_plan=_ContextParallel(
99+
seq_dim=2,
100+
attention_type=_ContextParallel.AttentionType.FLEX,
101+
),
102+
)
103+
93104
model_compile_enabled = (
94105
job_config.compile.enable and "model" in job_config.compile.components
95106
)

torchtitan/models/llama3/model/args.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
4545
)
4646
self.max_seq_len = seq_len
4747

48-
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
49-
raise NotImplementedError(
50-
"CP support for FlexAttention is still in progress."
51-
)
52-
53-
self.max_seq_len = seq_len
54-
5548
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
5649
nparams = sum(p.numel() for p in model.parameters())
5750
nparams_embedding = sum(

0 commit comments

Comments
 (0)