@@ -180,6 +180,11 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
180
180
self .act_embedding_table = nn .Embedding (config .action_space_size , self .obs_act_embed_dim , device = self .device )
181
181
print (f"self.act_embedding_table.weight.device: { self .act_embedding_table .weight .device } " )
182
182
183
+ print (f'=' * 20 )
184
+ print (f"self.obs_act_embed_dim:{ self .obs_act_embed_dim } " )
185
+ print (f'=' * 20 )
186
+
187
+
183
188
# if self.num_experts_in_moe_head == -1:
184
189
assert self .num_experts_in_moe_head > 0
185
190
if self .use_normal_head :
@@ -647,10 +652,12 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu
647
652
if self .task_embed_option == "add_task_embed" :
648
653
obs_embeddings = obs_embeddings + self .task_embeddings
649
654
elif self .task_embed_option == "concat_task_embed" :
655
+
650
656
# print(f'=='*20)
651
657
# print(f'obs_embeddings.shape:{obs_embeddings.shape}')
652
658
# print(f'self.task_embeddings.shape:{self.task_embeddings.shape}')
653
659
# print(f'=='*20)
660
+
654
661
if is_init_infer :
655
662
# 注意只有在inference时,只有在is_init_infer时拼接task embeddings,recurr_infer中已经在init_infer中增加了task embeddings的信息了
656
663
# Expand task embeddings to match the sequence shape
@@ -862,21 +869,73 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, ta
862
869
- 1 )
863
870
864
871
num_steps = int (obs_embeddings .size (1 ) * (obs_embeddings .size (2 ) + 1 ))
865
- # act_embeddings = self.act_embedding_table[task_id](act_tokens)
866
872
act_embeddings = self .act_embedding_table (act_tokens )
867
873
868
874
B , L , K , E = obs_embeddings .size ()
869
- obs_act_embeddings = torch .empty (B , L * (K + 1 ), E , device = self .device )
875
+ if self .task_embed_option == "concat_task_embed" :
876
+ # B, L*2, E
877
+ obs_act_embeddings = torch .empty (B , L * (K + 1 ), self .config .embed_dim , device = self .device )
878
+ else :
879
+ # B, L*2, E
880
+ obs_act_embeddings = torch .empty (B , L * (K + 1 ), self .config .embed_dim , device = self .device )
881
+
882
+ if self .task_embed_option == "concat_task_embed" :
883
+ # Expand task embeddings to match the sequence shape
884
+ task_emb_expanded = self .task_embeddings .view (1 , 1 , - 1 ).expand (B , 1 , - 1 )
885
+
870
886
871
887
for i in range (L ):
872
- # obs = obs_embeddings[:, i, :, :]
873
- obs = obs_embeddings [:, i , :, :] + self .task_embeddings # Shape: (B, K, E) TODO: task_embeddings
888
+ if self .task_embed_option == "add_task_embed" :
889
+ obs = obs_embeddings [:, i , :, :] + self .task_embeddings # Shape: (B, K, E) TODO: task_embeddings
890
+ elif self .task_embed_option == "concat_task_embed" :
891
+ obs = torch .cat ([obs_embeddings [:, i , :, :], task_emb_expanded ], dim = - 1 )
892
+ else :
893
+ obs = obs_embeddings [:, i , :, :] # Shape: (B, K, E)
894
+
874
895
act = act_embeddings [:, i , 0 , :].unsqueeze (1 )
896
+ if self .task_embed_option == "concat_task_embed" :
897
+ act = torch .cat ([act , task_emb_expanded ], dim = - 1 )
898
+
875
899
obs_act = torch .cat ([obs , act ], dim = 1 )
900
+ # print(f'obs_act.shape:{obs_act.shape}')
901
+
876
902
obs_act_embeddings [:, i * (K + 1 ):(i + 1 ) * (K + 1 ), :] = obs_act
877
903
878
904
return obs_act_embeddings + self .pos_emb (prev_steps + torch .arange (num_steps , device = self .device )), num_steps
879
905
906
+
907
+ #@profile
908
+ # def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, task_id=0):
909
+ # """
910
+ # Process combined observation embeddings and action tokens.
911
+
912
+ # Arguments:
913
+ # - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens.
914
+ # - prev_steps (:obj:`torch.Tensor`): Previous steps.
915
+ # Returns:
916
+ # - torch.Tensor: Combined observation and action embeddings with position information added.
917
+ # """
918
+ # obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens']
919
+ # if len(obs_embeddings.shape) == 3:
920
+ # obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens,
921
+ # -1)
922
+
923
+ # num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1))
924
+ # # act_embeddings = self.act_embedding_table[task_id](act_tokens)
925
+ # act_embeddings = self.act_embedding_table(act_tokens)
926
+
927
+ # B, L, K, E = obs_embeddings.size()
928
+ # obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device)
929
+
930
+ # for i in range(L):
931
+ # # obs = obs_embeddings[:, i, :, :]
932
+ # obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings
933
+ # act = act_embeddings[:, i, 0, :].unsqueeze(1)
934
+ # obs_act = torch.cat([obs, act], dim=1)
935
+ # obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act
936
+
937
+ # return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps
938
+
880
939
#@profile
881
940
def _transformer_pass (self , sequences , past_keys_values , kvcache_independent , valid_context_lengths , task_id = 0 ):
882
941
"""
0 commit comments