@@ -87,6 +87,7 @@ def apply_rotary_emb(
8787 xk_out = (xk * cos ) + (rotate_half (xk ) * sin )
8888 return xq_out .type_as (xq ), xk_out .type_as (xk )
8989
90+
9091def repeat_kv (x : torch .Tensor , n_rep : int ) -> torch .Tensor :
9192 """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
9293 bs , slen , n_kv_heads , head_dim = x .shape
@@ -109,7 +110,7 @@ def eager_attention_forward(
109110 scaling : float ,
110111 dropout : float = 0.0 ,
111112 ** kwargs ,
112- ):
113+ ):
113114 key_values = key .transpose (2 , 3 ) # When TP is enabled, key should be shard()
114115 print (f"key_values : { key_values .placements } { key_values .shape } " )
115116 print (f"query : { query .placements } { query .shape } " )
@@ -145,32 +146,45 @@ def eager_attention_forward(
145146 attn_output = torch .matmul (attn_weights , value )
146147 return attn_output
147148
149+
148150class Attention (nn .Module ):
149151 """
150152 Multi-head attention (MLA) module.
151153 """
152154
153- def __init__ (self , model_args : GptOssModelArgs , use_sliding_attention : bool = False ):
155+ def __init__ (
156+ self , model_args : GptOssModelArgs , use_sliding_attention : bool = False
157+ ):
154158 super ().__init__ ()
155159
156- self .sliding_window = model_args .sliding_window if use_sliding_attention else None
160+ self .sliding_window = (
161+ model_args .sliding_window if use_sliding_attention else None
162+ )
157163 self .head_dim = model_args .head_dim
158164 self .n_heads = model_args .num_attention_heads
159165 self .n_kv_heads = model_args .num_key_value_heads
160166
161167 self .n_rep = self .n_heads // self .n_kv_heads
162168
163169 self .wq = nn .Linear (
164- model_args .hidden_size , model_args .num_attention_heads * model_args .head_dim , bias = True
170+ model_args .hidden_size ,
171+ model_args .num_attention_heads * model_args .head_dim ,
172+ bias = True ,
165173 )
166174 self .wk = nn .Linear (
167- model_args .hidden_size , model_args .num_key_value_heads * model_args .head_dim , bias = True
175+ model_args .hidden_size ,
176+ model_args .num_key_value_heads * model_args .head_dim ,
177+ bias = True ,
168178 )
169179 self .wv = nn .Linear (
170- model_args .hidden_size , model_args .num_key_value_heads * model_args .head_dim , bias = True
180+ model_args .hidden_size ,
181+ model_args .num_key_value_heads * model_args .head_dim ,
182+ bias = True ,
171183 )
172184 self .wo = nn .Linear (
173- model_args .num_attention_heads * model_args .head_dim , model_args .hidden_size , bias = True
185+ model_args .num_attention_heads * model_args .head_dim ,
186+ model_args .hidden_size ,
187+ bias = True ,
174188 )
175189 self .sinks = nn .Parameter (torch .empty (model_args .num_attention_heads ))
176190
@@ -179,9 +193,15 @@ def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = Fa
179193 if self .use_flex_attn :
180194 # Only apply sliding window to every other layer
181195 if use_sliding_attention :
182- self .attn = build_attention (use_flex_attn = True , attn_mask_type = "sliding_window" , sliding_window = self .sliding_window )
196+ self .attn = build_attention (
197+ use_flex_attn = True ,
198+ attn_mask_type = "sliding_window" ,
199+ sliding_window = self .sliding_window ,
200+ )
183201 else :
184- self .attn = build_attention (use_flex_attn = True , attn_mask_type = model_args .attn_mask_type )
202+ self .attn = build_attention (
203+ use_flex_attn = True , attn_mask_type = model_args .attn_mask_type
204+ )
185205 else :
186206 # NOTE: sampling with FlexAttn seems broken; use TorchAttn if needed
187207 self .attn = eager_attention_forward
@@ -219,32 +239,39 @@ def forward(
219239 v = values .transpose (1 , 2 ).contiguous ()
220240
221241 if self .use_flex_attn :
222- # FlexAttention
242+ # FlexAttention
223243 output , lse = self .attn (
224- q , k , v ,
244+ q ,
245+ k ,
246+ v ,
225247 scale = None ,
226- return_lse = True ,
248+ return_lse = False ,
227249 )
228250
229251 # Apply attention sink rescaling: rescale by σ(lse - w[h])
230- # This is mathematically equivalent to concatenating learnable sink weights
231- sink_scale = torch .sigmoid (lse - self .sink .view (1 , - 1 , 1 )).unsqueeze (
252+ # This is mathematically equivalent to concatenating learnable sink weights
253+ sink_scale = torch .sigmoid (lse - self .sinks .view (1 , - 1 , 1 )).unsqueeze (
232254 - 1
233255 ) # [B,H,S,1]
234- output = output * sink_scale
256+ output = output * sink_scale . to ( output . dtype )
235257
236258 else :
237259 # eager attention forward
238260 output = self .attn (
239- q , k , v , self .sinks ,
261+ q ,
262+ k ,
263+ v ,
264+ self .sinks ,
240265 attention_mask = self .sliding_window_causal (seqlen , x .device ),
241266 scaling = self .head_dim ** - 0.5 ,
242267 dropout = 0.0 ,
243268 )
244- output = output .transpose (1 , 2 ).contiguous () # (B, H, T, D) -> (B, T, H, D)
269+ output = output .transpose (1 , 2 ).contiguous () # (B, H, T, D) -> (B, T, H, D)
245270
246271 # Reshape and project output
247- output = output .reshape (bsz , seqlen , - 1 ).contiguous () # (bsz, seqlen, n_heads * v_head_dim)
272+ output = output .reshape (
273+ bsz , seqlen , - 1
274+ ).contiguous () # (bsz, seqlen, n_heads * v_head_dim)
248275 output = self .wo (output ) # (bsz, seqlen, dim)
249276 return output
250277
@@ -263,7 +290,7 @@ def init_weights(self, init_std: float):
263290 # TODO: statically init the mask using train.seq_len
264291 def sliding_window_causal (self , seqlen , device ):
265292 i = torch .arange (seqlen , device = device )
266- q_idx = i [:, None ]
293+ q_idx = i [:, None ]
267294 kv_idx = i [None , :]
268295
269296 causal_mask = q_idx >= kv_idx
@@ -282,11 +309,17 @@ def __init__(self, layer_id: int, model_args: GptOssModelArgs):
282309
283310 super ().__init__ ()
284311 use_sliding_attention = layer_id % 2 == 0
285- self .attention = Attention (model_args , use_sliding_attention = use_sliding_attention )
286- self .attention_norm = nn .RMSNorm (model_args .hidden_size , eps = model_args .norm_eps )
312+ self .attention = Attention (
313+ model_args , use_sliding_attention = use_sliding_attention
314+ )
315+ self .attention_norm = nn .RMSNorm (
316+ model_args .hidden_size , eps = model_args .norm_eps
317+ )
287318 self .ffn_norm = nn .RMSNorm (model_args .hidden_size , eps = model_args .norm_eps )
288319
289- self .moe = GptOssMoE (model_args , dim = model_args .hidden_size , hidden_dim = model_args .moe_inter_dim )
320+ self .moe = GptOssMoE (
321+ model_args , dim = model_args .hidden_size , hidden_dim = model_args .moe_inter_dim
322+ )
290323 self .moe_enabled = True # for composability with load balancing
291324
292325 self .weight_init_std = 0.02 / (2 * (layer_id + 1 )) ** 0.5
@@ -323,14 +356,18 @@ def __init__(self, model_args: GptOssModelArgs):
323356 super ().__init__ ()
324357 self .model_args = model_args
325358 self .max_seq_len = model_args .max_seq_len
326- self .tok_embeddings = nn .Embedding (model_args .vocab_size , model_args .hidden_size )
359+ self .tok_embeddings = nn .Embedding (
360+ model_args .vocab_size , model_args .hidden_size
361+ )
327362 self .register_buffer (
328363 "rope_cache" , self ._precompute_rope_cache (), persistent = False
329364 )
330365
331366 self .layers = torch .nn .ModuleDict ()
332367 for layer_id in range (model_args .num_hidden_layers ):
333- self .layers [str (layer_id )] = TransformerBlock (layer_id , model_args ).to (torch .bfloat16 )
368+ self .layers [str (layer_id )] = TransformerBlock (layer_id , model_args ).to (
369+ torch .bfloat16
370+ )
334371
335372 self .norm = nn .RMSNorm (model_args .hidden_size , eps = model_args .norm_eps )
336373 self .output = nn .Linear (
0 commit comments