diff --git a/configs/moe_config.py b/configs/moe_config.py index 9d783abb..5d7509e4 100644 --- a/configs/moe_config.py +++ b/configs/moe_config.py @@ -16,6 +16,7 @@ class MoEModelConfig: v_dim: int | None = 128 batch_size: int = 24 max_steps: int = 1000 + use_mem_efficient_attention: bool = False # Training parameters gradient_accumulation_steps: int = 4 diff --git a/models/components.py b/models/components.py index 266d621f..89262cf6 100644 --- a/models/components.py +++ b/models/components.py @@ -2,18 +2,17 @@ import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Optional - +from xformers.ops import SwiGLU class Expert(nn.Module): """Single expert network (essentially a FeedForward layer)""" def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() - self.linear1 = nn.Linear(d_model, d_ff, bias=False) - self.linear2 = nn.Linear(d_ff, d_model, bias=False) + self.ffn = SwiGLU(d_model, d_ff, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x): - return self.linear2(self.dropout(F.silu(self.linear1(x)))) + return self.dropout(self.ffn(x)) class TopKRouter(nn.Module): diff --git a/models/layers.py b/models/layers.py index 3b4fba8c..d1f1f8ef 100644 --- a/models/layers.py +++ b/models/layers.py @@ -3,12 +3,15 @@ import torch.nn.functional as F from torchtune.modules import RotaryPositionalEmbeddings from .components import MixtureOfExperts +from xformers.ops import memory_efficient_attention, LowerTriangularMask class Rotary(nn.Module): def __init__(self, dim: int, max_seq_len: int): super().__init__() - self.rope = RotaryPositionalEmbeddings(dim=dim, max_seq_len=max_seq_len, base=10000) + self.rope = RotaryPositionalEmbeddings( + dim=dim, max_seq_len=max_seq_len, base=10000 + ) def forward(self, x_BTHD: torch.Tensor): # x_BTHD shape: [B, T, H, D] - need to convert to [B, T, H, D] for torchtune @@ -18,8 +21,16 @@ def forward(self, x_BTHD: torch.Tensor): class MultiHeadAttention(nn.Module): - def __init__(self, d_model: int, n_heads: int, max_seq_len: int, dropout: float = 0.1): + def __init__( + self, + d_model: int, + n_heads: int, + max_seq_len: int, + dropout: float = 0.1, + use_mem_atten: bool = False, + ): super().__init__() + self.mem_atten = use_mem_atten self.d_model = d_model self.n_heads = n_heads self.d_k = d_model // n_heads @@ -37,18 +48,28 @@ def forward(self, x): qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.n_heads, self.d_k) qkv = qkv.permute(2, 0, 3, 1, 4) - Q, K, V = qkv[0], qkv[1], qkv[2] # [B, H, T, D] + Q, K, V = qkv[0], qkv[1], qkv[2] # [B, H, T, D] # Q = self.rotary(Q) # K = self.rotary(K) # Apply RoPE on [B, T, H, D] Q = self.rotary(Q.transpose(1, 2)).transpose(1, 2) K = self.rotary(K.transpose(1, 2)).transpose(1, 2) - - attn_output = F.scaled_dot_product_attention( - Q, K, V, is_causal=True, dropout_p=self.dropout if self.training else 0.0 + if self.mem_atten: + attn_output = memory_efficient_attention( + Q, K, V, attn_bias=LowerTriangularMask(), p=self.dropout + ) + else: + attn_output = F.scaled_dot_product_attention( + Q, + K, + V, + is_causal=True, + dropout_p=self.dropout if self.training else 0.0, + ) + attn_output = attn_output.transpose(1, 2).reshape( + batch_size, seq_len, self.d_model ) - attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model) # attn_output = attn_output.transpose(1, 2).reshape(B, T, self.d_model) return self.w_o(attn_output) @@ -64,8 +85,10 @@ def __init__( v_dim: int, max_seq_len: int, dropout: float = 0.1, + use_mem_atten: bool = False, ): super().__init__() + self.mem_atten = use_mem_atten self.d_model = d_model self.n_heads = n_heads self.qk_dim = qk_rope_dim + qk_nope_dim @@ -107,9 +130,18 @@ def forward(self, x: torch.Tensor): k_nope, v = torch.split(kv, (self.qk_nope_dim, self.v_dim), dim=-1) k = torch.cat([k_nope, k_rope.expand(-1, -1, self.n_heads, -1)], dim=-1) - attn_output = F.scaled_dot_product_attention( - q, k, v, is_causal=True, dropout_p=self.dropout if self.training else 0.0 - ) + if self.mem_atten: + attn_output = memory_efficient_attention( + q, k, v, attn_bias=LowerTriangularMask(), p=self.dropout + ) + else: + attn_output = F.scaled_dot_product_attention( + q, + k, + v, + is_causal=True, + dropout_p=self.dropout if self.training else 0.0, + ) attn_output = ( attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) ) @@ -133,6 +165,7 @@ def __init__( num_experts: int = 8, top_k: int = 2, dropout: float = 0.1, + use_mem_atten: bool = False, ): super().__init__() @@ -147,9 +180,12 @@ def __init__( v_dim, max_seq_len, dropout, + use_mem_atten, ) else: - self.attention = MultiHeadAttention(d_model, n_heads, max_seq_len, dropout) + self.attention = MultiHeadAttention( + d_model, n_heads, max_seq_len, dropout, use_mem_atten + ) # MoE layer self.feed_forward = MixtureOfExperts(d_model, d_ff, num_experts, top_k, dropout) diff --git a/models/moe_llm.py b/models/moe_llm.py index ecb6fd97..e71bd91a 100644 --- a/models/moe_llm.py +++ b/models/moe_llm.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn import math -from typing import Optional from configs.moe_config import MoEModelConfig from models.layers import MoETransformerBlock @@ -33,6 +32,7 @@ def __init__(self, config: MoEModelConfig): config.num_experts, config.expert_top_k, config.dropout, + config.use_mem_efficient_attention, ) for i in range(config.n_layers) ] diff --git a/requirements.txt b/requirements.txt index f28baa7c..744206f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ transformers torchtune torchao matplotlib +xformers # lm-eval # Single T4 GPU training