@@ -33,11 +33,47 @@ class TransformerConfig:
33
33
# for RoPE
34
34
rope_theta : float
35
35
max_seq_len : int
36
+ rotary_emb : bool = False # 增加配置选项控制是否使用 rotary_emb
37
+
36
38
@property
37
39
def max_tokens (self ):
38
40
return self .tokens_per_block * self .max_blocks
39
41
40
42
43
+
44
+ def precompute_freqs_cis (dim : int , end : int , theta : float = 10000.0 ):
45
+ freqs = 1.0 / (theta ** (torch .arange (0 , dim , 2 )[: (dim // 2 )].float () / dim ))
46
+ t = torch .arange (end , device = freqs .device , dtype = torch .float32 )
47
+ freqs = torch .outer (t , freqs )
48
+ freqs_cis = torch .polar (torch .ones_like (freqs ), freqs ) # complex64
49
+ return freqs_cis
50
+
51
+
52
+ def reshape_for_broadcast (freqs_cis : torch .Tensor , x : torch .Tensor ):
53
+ ndim = x .ndim
54
+ # print(f"freqs_cis shape: {freqs_cis.shape}, x shape: {x.shape}")
55
+ assert 0 <= 1 < ndim
56
+ shape = [d if i == 2 or i == ndim - 1 or i == 0 else 1 for i , d in enumerate (x .shape )]
57
+
58
+ return freqs_cis .view (* shape )
59
+
60
+
61
+ def apply_rotary_emb (
62
+ xq : torch .Tensor ,
63
+ xk : torch .Tensor ,
64
+ freqs_cis : torch .Tensor ,
65
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
66
+ xq_ = torch .view_as_complex (xq .float ().reshape (* xq .shape [:- 1 ], - 1 , 2 ))
67
+ xk_ = torch .view_as_complex (xk .float ().reshape (* xk .shape [:- 1 ], - 1 , 2 ))
68
+ try :
69
+ freqs_cis = reshape_for_broadcast (freqs_cis , xq_ )
70
+ except :
71
+ print ('We are at the reset timestep!' )
72
+ xq_out = torch .view_as_real (xq_ * freqs_cis ).flatten (- 2 )
73
+ xk_out = torch .view_as_real (xk_ * freqs_cis ).flatten (- 2 )
74
+ return xq_out .type_as (xq ), xk_out .type_as (xk )
75
+
76
+
41
77
class Transformer (nn .Module ):
42
78
"""
43
79
Transformer model class.
@@ -59,11 +95,14 @@ def __init__(self, config: TransformerConfig) -> None:
59
95
self .blocks = nn .ModuleList ([Block (config ) for _ in range (config .num_layers )])
60
96
self .ln_f = nn .LayerNorm (config .embed_dim )
61
97
62
- self .freqs_cis = precompute_freqs_cis (
63
- self .config .embed_dim // self .config .num_heads ,
64
- self .config .max_seq_len * 2 ,
65
- self .config .rope_theta ,
66
- )
98
+ # 注册缓存, 自动管理设备转换
99
+ if self .config .rotary_emb :
100
+ freqs_cis = precompute_freqs_cis (
101
+ self .config .embed_dim // self .config .num_heads ,
102
+ self .config .max_seq_len * 2 ,
103
+ self .config .rope_theta ,
104
+ )
105
+ self .register_buffer ("freqs_cis" , freqs_cis )
67
106
68
107
def generate_empty_keys_values (self , n : int , max_tokens : int ) -> KeysValues :
69
108
"""
@@ -93,24 +132,31 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues
93
132
- torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim).
94
133
"""
95
134
seqlen = sequences .shape [1 ]
96
- self .freqs_cis = self .freqs_cis .to (sequences .device )
97
135
98
- # freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
99
-
100
- # If the start position is greater than the predefined maximum sequence length, wrap around
101
- start_pos = torch .tensor (np .array (start_pos ))
102
- if len (start_pos .shape ) > 1 :
103
- # TODO: train start pos [0]
104
- start_pos = torch .remainder (start_pos , self .config .max_seq_len )[:,0 ]
136
+ # 如果使用 RoPE,则对 freqs_cis 进行切片
137
+ if self .config .rotary_emb :
138
+ # 修复:如果 start_pos 是标量,则将其扩展为当前 batch 大小的相同数值
139
+ # *2是由于step_index只是统计了obs,但是序列是obs act
140
+ if isinstance (start_pos , int ) or isinstance (start_pos , float ):
141
+ start_pos_tensor = torch .full ((sequences .shape [0 ],), int (start_pos ), device = sequences .device ) * 2
142
+ else :
143
+ # start_pos_tensor = torch.as_tensor(start_pos, device=sequences.device)
144
+ try :
145
+ start_pos_tensor = torch .as_tensor ([x .item () for x in start_pos ], device = sequences .device )
146
+ except Exception as e :
147
+ # print(e)
148
+ start_pos_tensor = torch .as_tensor (
149
+ [x .reshape (- 1 )[0 ].item () for x in start_pos ], # 强制展平后取第一个元素
150
+ device = sequences .device
151
+ ) * 2
152
+ # 对每个样本根据 start_pos 取对应区间的 freqs_cis
153
+ start_pos_tensor = torch .remainder (start_pos_tensor , self .config .max_seq_len )
154
+ # 将各个样本的 start_pos 转换为列表
155
+ start_pos_list = start_pos_tensor .tolist ()
156
+ freqs_cis_slices = [self .freqs_cis [int (pos ): int (pos ) + seqlen ] for pos in start_pos_list ]
157
+ freqs_cis = torch .stack (freqs_cis_slices )
105
158
else :
106
- start_pos = torch .remainder (start_pos , self .config .max_seq_len )
107
-
108
- start_pos_list = torch .unbind (start_pos )
109
- try :
110
- freqs_cis_slices = [self .freqs_cis [int (pos .item ()): int (pos .item ()) + seqlen ] for pos in start_pos_list ]
111
- except :
112
- print ('debug' )
113
- freqs_cis = torch .stack (freqs_cis_slices ).squeeze (1 )
159
+ freqs_cis = None
114
160
115
161
assert past_keys_values is None or len (past_keys_values ) == len (self .blocks )
116
162
x = self .drop (sequences )
@@ -181,42 +227,6 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None
181
227
return x
182
228
183
229
184
- def precompute_freqs_cis (dim : int , end : int , theta : float = 10000.0 ):
185
- freqs = 1.0 / (theta ** (torch .arange (0 , dim , 2 )[: (dim // 2 )].float () / dim ))
186
- t = torch .arange (end , device = freqs .device , dtype = torch .float32 )
187
- freqs = torch .outer (t , freqs )
188
- freqs_cis = torch .polar (torch .ones_like (freqs ), freqs ) # complex64
189
- return freqs_cis
190
-
191
-
192
- def reshape_for_broadcast (freqs_cis : torch .Tensor , x : torch .Tensor ):
193
- ndim = x .ndim
194
- # print(f"freqs_cis shape: {freqs_cis.shape}, x shape: {x.shape}")
195
- assert 0 <= 1 < ndim
196
- # assert freqs_cis.shape == (x.shape[2], x.shape[-1])
197
- # shape = [d if i == 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
198
- # TODO: check
199
- shape = [d if i == 2 or i == ndim - 1 or i == 0 else 1 for i , d in enumerate (x .shape )]
200
-
201
- return freqs_cis .view (* shape )
202
-
203
-
204
- def apply_rotary_emb (
205
- xq : torch .Tensor ,
206
- xk : torch .Tensor ,
207
- freqs_cis : torch .Tensor ,
208
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
209
- xq_ = torch .view_as_complex (xq .float ().reshape (* xq .shape [:- 1 ], - 1 , 2 ))
210
- xk_ = torch .view_as_complex (xk .float ().reshape (* xk .shape [:- 1 ], - 1 , 2 ))
211
- try :
212
- freqs_cis = reshape_for_broadcast (freqs_cis , xq_ )
213
- except :
214
- print ('We are at the reset timestep!' )
215
- xq_out = torch .view_as_real (xq_ * freqs_cis ).flatten (- 2 )
216
- xk_out = torch .view_as_real (xk_ * freqs_cis ).flatten (- 2 )
217
- return xq_out .type_as (xq ), xk_out .type_as (xk )
218
-
219
-
220
230
class SelfAttention (nn .Module ):
221
231
"""
222
232
Implements self-attention mechanism for transformers.
0 commit comments