@@ -50,11 +50,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
50
50
51
51
52
52
def reshape_for_broadcast (freqs_cis : torch .Tensor , x : torch .Tensor ):
53
+ # https://github.com/meta-llama/llama3/blob/main/llama/model.py#L61
53
54
ndim = x .ndim
54
- # print(f"freqs_cis shape: {freqs_cis.shape}, x shape: {x.shape}")
55
- assert 0 <= 1 < ndim
56
- shape = [d if i == 2 or i == ndim - 1 or i == 0 else 1 for i , d in enumerate (x .shape )]
57
-
55
+ shape = [d if i == ndim - 1 or i == 2 or i == 0 else 1 for i , d in enumerate (x .shape )]
58
56
return freqs_cis .view (* shape )
59
57
60
58
@@ -66,7 +64,9 @@ def apply_rotary_emb(
66
64
xq_ = torch .view_as_complex (xq .float ().reshape (* xq .shape [:- 1 ], - 1 , 2 ))
67
65
xk_ = torch .view_as_complex (xk .float ().reshape (* xk .shape [:- 1 ], - 1 , 2 ))
68
66
try :
67
+ # print(f"freqs_cis shape: {freqs_cis.shape}, xq_ shape: {xq_.shape}")
69
68
freqs_cis = reshape_for_broadcast (freqs_cis , xq_ )
69
+ # print(f"new freqs_cis shape: {freqs_cis.shape}")
70
70
except Exception as e :
71
71
print (e )
72
72
print ('We are at the reset timestep!' )
@@ -137,25 +137,31 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues
137
137
# 如果使用 RoPE,则对 freqs_cis 进行切片
138
138
if self .config .rotary_emb :
139
139
# 修复:如果 start_pos 是标量,则将其扩展为当前 batch 大小的相同数值
140
- # *2是由于timestep只是统计了obs,但是序列是obs act
140
+ # t========== *2是由于timestep只是统计了obs,但是序列是obs act==========
141
141
if isinstance (start_pos , int ) or isinstance (start_pos , float ):
142
142
start_pos_tensor = torch .full ((sequences .shape [0 ],), int (start_pos ), device = sequences .device ) * 2
143
143
else :
144
144
# start_pos_tensor = torch.as_tensor(start_pos, device=sequences.device)
145
145
try :
146
- start_pos_tensor = torch .as_tensor ([x .item () for x in start_pos ], device = sequences .device )
146
+ start_pos_tensor = torch .as_tensor ([x .item () for x in start_pos ], device = sequences .device ) * 2
147
147
except Exception as e :
148
148
# print(e)
149
149
start_pos_tensor = torch .as_tensor (
150
150
[x .reshape (- 1 )[0 ].item () for x in start_pos ], # 强制展平后取第一个元素
151
151
device = sequences .device
152
152
) * 2
153
+
153
154
# 对每个样本根据 start_pos 取对应区间的 freqs_cis
154
155
start_pos_tensor = torch .remainder (start_pos_tensor , self .config .max_seq_len )
155
156
# 将各个样本的 start_pos 转换为列表
156
157
start_pos_list = start_pos_tensor .tolist ()
157
158
freqs_cis_slices = [self .freqs_cis [int (pos ): int (pos ) + seqlen ] for pos in start_pos_list ]
158
159
freqs_cis = torch .stack (freqs_cis_slices )
160
+
161
+ if freqs_cis .ndim == 3 and freqs_cis .shape [1 ] == 1 :
162
+ # 将形状 [seq_len, 1, num_pairs] 转换为 [seq_len, num_pairs]
163
+ freqs_cis = freqs_cis .squeeze (1 )
164
+ # print(f'165 freqs_cis.shape:{freqs_cis.shape}')
159
165
else :
160
166
freqs_cis = None
161
167
@@ -307,8 +313,8 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None,
307
313
for i in range (B ):
308
314
mask [i ] = self .mask [L :L + T , :L + T ].clone ()
309
315
mask [i , :, :(L - valid_context_lengths [i ])] = 0 # Set invalid parts to 0.
310
- # Adjust mask dimensions to match the last two dimensions of att.
311
- # (B, T, L + T) -> (B, 1, T, L + T) -> (B, num_heads, T, L + T)
316
+ # Adjust mask dimensions to match the last two dimensions of att.
317
+ # (B, T, L + T) -> (B, 1, T, L + T) -> (B, num_heads, T, L + T)
312
318
mask = mask .unsqueeze (1 ).expand (- 1 , att .size (1 ), - 1 , - 1 )
313
319
else :
314
320
# mask.shape: (T, L + T)
0 commit comments