Skip to content

Commit bb8ee6f

Browse files
committed
test
1 parent b28fe7c commit bb8ee6f

File tree

2 files changed

+66
-29
lines changed

2 files changed

+66
-29
lines changed

torchtitan/experiments/gpt_oss/infra/parallelize.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,11 @@ def apply_non_moe_tp(
249249

250250
# shard attention.sinks across heads
251251
# TODO(jianiw): Fix the sink implementation
252-
attn = transformer_block.attention
253-
attn.register_parameter(
254-
"sinks",
255-
nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Shard(0)])),
256-
)
252+
# attn = transformer_block.attention
253+
# attn.register_parameter(
254+
# "sinks",
255+
# nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Replicate()])),
256+
# )
257257

258258
if enable_async_tp:
259259
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

torchtitan/experiments/gpt_oss/model/model.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def apply_rotary_emb(
8787
xk_out = (xk * cos) + (rotate_half(xk) * sin)
8888
return xq_out.type_as(xq), xk_out.type_as(xk)
8989

90+
9091
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
9192
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
9293
bs, slen, n_kv_heads, head_dim = x.shape
@@ -109,7 +110,7 @@ def eager_attention_forward(
109110
scaling: float,
110111
dropout: float = 0.0,
111112
**kwargs,
112-
):
113+
):
113114
key_values = key.transpose(2, 3) # When TP is enabled, key should be shard()
114115
print(f"key_values : {key_values.placements} {key_values.shape}")
115116
print(f"query : {query.placements} {query.shape}")
@@ -145,32 +146,45 @@ def eager_attention_forward(
145146
attn_output = torch.matmul(attn_weights, value)
146147
return attn_output
147148

149+
148150
class Attention(nn.Module):
149151
"""
150152
Multi-head attention (MLA) module.
151153
"""
152154

153-
def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = False):
155+
def __init__(
156+
self, model_args: GptOssModelArgs, use_sliding_attention: bool = False
157+
):
154158
super().__init__()
155159

156-
self.sliding_window = model_args.sliding_window if use_sliding_attention else None
160+
self.sliding_window = (
161+
model_args.sliding_window if use_sliding_attention else None
162+
)
157163
self.head_dim = model_args.head_dim
158164
self.n_heads = model_args.num_attention_heads
159165
self.n_kv_heads = model_args.num_key_value_heads
160166

161167
self.n_rep = self.n_heads // self.n_kv_heads
162168

163169
self.wq = nn.Linear(
164-
model_args.hidden_size, model_args.num_attention_heads * model_args.head_dim, bias=True
170+
model_args.hidden_size,
171+
model_args.num_attention_heads * model_args.head_dim,
172+
bias=True,
165173
)
166174
self.wk = nn.Linear(
167-
model_args.hidden_size, model_args.num_key_value_heads * model_args.head_dim, bias=True
175+
model_args.hidden_size,
176+
model_args.num_key_value_heads * model_args.head_dim,
177+
bias=True,
168178
)
169179
self.wv = nn.Linear(
170-
model_args.hidden_size, model_args.num_key_value_heads * model_args.head_dim, bias=True
180+
model_args.hidden_size,
181+
model_args.num_key_value_heads * model_args.head_dim,
182+
bias=True,
171183
)
172184
self.wo = nn.Linear(
173-
model_args.num_attention_heads * model_args.head_dim, model_args.hidden_size, bias=True
185+
model_args.num_attention_heads * model_args.head_dim,
186+
model_args.hidden_size,
187+
bias=True,
174188
)
175189
self.sinks = nn.Parameter(torch.empty(model_args.num_attention_heads))
176190

@@ -179,9 +193,15 @@ def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = Fa
179193
if self.use_flex_attn:
180194
# Only apply sliding window to every other layer
181195
if use_sliding_attention:
182-
self.attn = build_attention(use_flex_attn=True, attn_mask_type="sliding_window", sliding_window=self.sliding_window)
196+
self.attn = build_attention(
197+
use_flex_attn=True,
198+
attn_mask_type="sliding_window",
199+
sliding_window=self.sliding_window,
200+
)
183201
else:
184-
self.attn = build_attention(use_flex_attn=True, attn_mask_type=model_args.attn_mask_type)
202+
self.attn = build_attention(
203+
use_flex_attn=True, attn_mask_type=model_args.attn_mask_type
204+
)
185205
else:
186206
# NOTE: sampling with FlexAttn seems broken; use TorchAttn if needed
187207
self.attn = eager_attention_forward
@@ -219,32 +239,39 @@ def forward(
219239
v = values.transpose(1, 2).contiguous()
220240

221241
if self.use_flex_attn:
222-
# FlexAttention
242+
# FlexAttention
223243
output, lse = self.attn(
224-
q, k, v,
244+
q,
245+
k,
246+
v,
225247
scale=None,
226-
return_lse=True,
248+
return_lse=False,
227249
)
228250

229251
# Apply attention sink rescaling: rescale by σ(lse - w[h])
230-
# This is mathematically equivalent to concatenating learnable sink weights
231-
sink_scale = torch.sigmoid(lse - self.sink.view(1, -1, 1)).unsqueeze(
252+
# This is mathematically equivalent to concatenating learnable sink weights
253+
sink_scale = torch.sigmoid(lse - self.sinks.view(1, -1, 1)).unsqueeze(
232254
-1
233255
) # [B,H,S,1]
234-
output = output * sink_scale
256+
output = output * sink_scale.to(output.dtype)
235257

236258
else:
237259
# eager attention forward
238260
output = self.attn(
239-
q, k, v, self.sinks,
261+
q,
262+
k,
263+
v,
264+
self.sinks,
240265
attention_mask=self.sliding_window_causal(seqlen, x.device),
241266
scaling=self.head_dim**-0.5,
242267
dropout=0.0,
243268
)
244-
output = output.transpose(1, 2).contiguous() # (B, H, T, D) -> (B, T, H, D)
269+
output = output.transpose(1, 2).contiguous() # (B, H, T, D) -> (B, T, H, D)
245270

246271
# Reshape and project output
247-
output = output.reshape(bsz, seqlen, -1).contiguous() # (bsz, seqlen, n_heads * v_head_dim)
272+
output = output.reshape(
273+
bsz, seqlen, -1
274+
).contiguous() # (bsz, seqlen, n_heads * v_head_dim)
248275
output = self.wo(output) # (bsz, seqlen, dim)
249276
return output
250277

@@ -263,7 +290,7 @@ def init_weights(self, init_std: float):
263290
# TODO: statically init the mask using train.seq_len
264291
def sliding_window_causal(self, seqlen, device):
265292
i = torch.arange(seqlen, device=device)
266-
q_idx = i[:, None]
293+
q_idx = i[:, None]
267294
kv_idx = i[None, :]
268295

269296
causal_mask = q_idx >= kv_idx
@@ -282,11 +309,17 @@ def __init__(self, layer_id: int, model_args: GptOssModelArgs):
282309

283310
super().__init__()
284311
use_sliding_attention = layer_id % 2 == 0
285-
self.attention = Attention(model_args, use_sliding_attention=use_sliding_attention)
286-
self.attention_norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps)
312+
self.attention = Attention(
313+
model_args, use_sliding_attention=use_sliding_attention
314+
)
315+
self.attention_norm = nn.RMSNorm(
316+
model_args.hidden_size, eps=model_args.norm_eps
317+
)
287318
self.ffn_norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps)
288319

289-
self.moe = GptOssMoE(model_args, dim=model_args.hidden_size, hidden_dim=model_args.moe_inter_dim)
320+
self.moe = GptOssMoE(
321+
model_args, dim=model_args.hidden_size, hidden_dim=model_args.moe_inter_dim
322+
)
290323
self.moe_enabled = True # for composability with load balancing
291324

292325
self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5
@@ -323,14 +356,18 @@ def __init__(self, model_args: GptOssModelArgs):
323356
super().__init__()
324357
self.model_args = model_args
325358
self.max_seq_len = model_args.max_seq_len
326-
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.hidden_size)
359+
self.tok_embeddings = nn.Embedding(
360+
model_args.vocab_size, model_args.hidden_size
361+
)
327362
self.register_buffer(
328363
"rope_cache", self._precompute_rope_cache(), persistent=False
329364
)
330365

331366
self.layers = torch.nn.ModuleDict()
332367
for layer_id in range(model_args.num_hidden_layers):
333-
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args).to(torch.bfloat16)
368+
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args).to(
369+
torch.bfloat16
370+
)
334371

335372
self.norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps)
336373
self.output = nn.Linear(

0 commit comments

Comments
 (0)