@@ -46,6 +46,17 @@ def cov_to_text(cov: List[int]) -> str:
46
46
return "\n " .join ([c ["content" ] for c in cov ])
47
47
48
48
49
+ def conv_to_text_webshop (cov : List [int ]) -> str :
50
+ msg = []
51
+ for c in cov [:- 1 ]:
52
+ if c ["role" ] == "user" :
53
+ msg .append (f"Observation: ..." )
54
+ else :
55
+ msg .append (c ["content" ])
56
+ msg .append (cov [- 1 ]["content" ])
57
+ return "\n " .join (msg )
58
+
59
+
49
60
def format_input (task_desc , current_traj , rules ):
50
61
return (
51
62
"The objectve:\n "
@@ -138,42 +149,90 @@ def call_model(
138
149
logger .warning ("[Knowledge] in preds" )
139
150
if task not in add_knowledge_task :
140
151
add_knowledge_task .append (task .task_id )
141
- if len (messages ) == 3 :
142
- # add task knowledge to the prompt
143
- prompt = ""
144
- game_file = task .game_file
145
- for k , v in alfworld_prompt .items ():
146
- if k in game_file :
147
- prompt = v
148
- break
149
- if prompt == "" :
150
- logger .warning (f"Game file { game_file } not in alfworld_prompt" )
151
- return pred_text
152
- knowledge_augmented_input = (
153
- input + f"[Knowledge]<knowledge>{ prompt } </knowledge>\n "
154
- )
155
- logger .warning (f"knowledge_augmented_input: { knowledge_augmented_input } " )
156
- inputs = tokenizer (knowledge_augmented_input , return_tensors = "pt" ).to ("cuda" )
157
- generated_ids = model .generate (
158
- inputs .input_ids ,
159
- max_new_tokens = 1024 ,
160
- do_sample = False ,
161
- pad_token_id = tokenizer .pad_token_id ,
162
- stopping_criteria = stopping_criteria ,
163
- )
164
- generated_ids = [
165
- output_ids [len (input_ids ) :]
166
- for input_ids , output_ids in zip (inputs .input_ids , generated_ids )
167
- ]
168
- response = tokenizer .batch_decode (generated_ids , skip_special_tokens = False )[0 ]
169
- pred_text = response .split (end_token )[0 ]
170
- pred_text = f"[Knowledge]<knowledge>{ prompt } </knowledge>\n " + pred_text
171
- else :
172
- # call select knowledge agent to select knowledge
152
+
153
+ if args .exp_config == "alfworld" :
154
+ if len (messages ) == 3 :
155
+ # add task knowledge to the prompt
156
+ prompt = ""
157
+ game_file = task .game_file
158
+ for k , v in alfworld_prompt .items ():
159
+ if k in game_file :
160
+ prompt = v
161
+ break
162
+ if prompt == "" :
163
+ logger .warning (f"Game file { game_file } not in alfworld_prompt" )
164
+ return pred_text
165
+ knowledge_augmented_input = (
166
+ input + f"[Knowledge]<knowledge>{ prompt } </knowledge>\n "
167
+ )
168
+ logger .warning (f"knowledge_augmented_input: { knowledge_augmented_input } " )
169
+ inputs = tokenizer (knowledge_augmented_input , return_tensors = "pt" ).to ("cuda" )
170
+ generated_ids = model .generate (
171
+ inputs .input_ids ,
172
+ max_new_tokens = 1024 ,
173
+ do_sample = False ,
174
+ pad_token_id = tokenizer .pad_token_id ,
175
+ stopping_criteria = stopping_criteria ,
176
+ )
177
+ generated_ids = [
178
+ output_ids [len (input_ids ) :]
179
+ for input_ids , output_ids in zip (inputs .input_ids , generated_ids )
180
+ ]
181
+ response = tokenizer .batch_decode (generated_ids , skip_special_tokens = False )[0 ]
182
+ pred_text = response .split (end_token )[0 ]
183
+ pred_text = f"[Knowledge]<knowledge>{ prompt } </knowledge>\n " + pred_text
184
+ else :
185
+ # call select knowledge agent to select knowledge
186
+ with open (args .select_knowledge_inst ) as f :
187
+ prompt = f .read ()
188
+ task_desc = messages [2 ]["content" ]
189
+ current_traj = cov_to_text (messages [3 :])
190
+ rule_data = json .load (open (args .knowledge_base_path ))["all_rules" ]
191
+ rules = []
192
+ for k , v in rule_data .items ():
193
+ rules .append (v ["rule" ])
194
+ rules = rules .__str__ ()
195
+ cur_task = format_input (task_desc , current_traj , rules )
196
+ _ , knowledge_prompt = prompt_without_icl (prompt , cur_task )
197
+ logger .warning (knowledge_prompt )
198
+ select_knowledge_output = select_knowledge_agent (knowledge_prompt )
199
+ logger .warning (f"select_knowledge_output: { select_knowledge_output } " )
200
+ if "[Chosen Rule]:" in select_knowledge_output :
201
+ rule = select_knowledge_output .split ("[Chosen Rule]:" )[1 ].strip ()
202
+ else :
203
+ rule = ""
204
+ knowledge_augmented_input = (
205
+ input + f"[Knowledge]<knowledge>{ rule } </knowledge>\n "
206
+ )
207
+ logger .warning (f"[Knowledge]<knowledge>{ rule } </knowledge>\n " )
208
+ inputs = tokenizer (knowledge_augmented_input , return_tensors = "pt" ).to ("cuda" )
209
+ generated_ids = model .generate (
210
+ inputs .input_ids ,
211
+ max_new_tokens = 1024 ,
212
+ do_sample = False ,
213
+ pad_token_id = tokenizer .pad_token_id ,
214
+ stopping_criteria = stopping_criteria ,
215
+ )
216
+ generated_ids = [
217
+ output_ids [len (input_ids ) :]
218
+ for input_ids , output_ids in zip (inputs .input_ids , generated_ids )
219
+ ]
220
+ response = tokenizer .batch_decode (generated_ids , skip_special_tokens = False )[
221
+ 0
222
+ ]
223
+ pred_text = response .split (end_token )[0 ]
224
+ pred_text = f"[Knowledge]<knowledge>{ rule } </knowledge>\n " + pred_text
225
+
226
+ elif args .exp_config == "webshop" :
173
227
with open (args .select_knowledge_inst ) as f :
174
228
prompt = f .read ()
175
- task_desc = messages [2 ]["content" ]
176
- current_traj = cov_to_text (messages [3 :])
229
+ task_desc = (
230
+ messages [2 ]["content" ]
231
+ .split ("Instruction: [SEP]" )[1 ]
232
+ .split ("[SEP] Search" )[0 ]
233
+ .strip ()
234
+ )
235
+ current_traj = conv_to_text_webshop (messages [3 :])
177
236
rule_data = json .load (open (args .knowledge_base_path ))["all_rules" ]
178
237
rules = []
179
238
for k , v in rule_data .items ():
@@ -188,10 +247,11 @@ def call_model(
188
247
rule = select_knowledge_output .split ("[Chosen Rule]:" )[1 ].strip ()
189
248
else :
190
249
rule = ""
250
+
191
251
knowledge_augmented_input = (
192
252
input + f"[Knowledge]<knowledge>{ rule } </knowledge>\n "
193
253
)
194
- logger .warning (f"[Knowledge]<knowledge>{ rule } </knowledge>\n " )
254
+ logger .warning (f"knowledge: \n [Knowledge]<knowledge>{ rule } </knowledge>\n " )
195
255
inputs = tokenizer (knowledge_augmented_input , return_tensors = "pt" ).to ("cuda" )
196
256
generated_ids = model .generate (
197
257
inputs .input_ids ,
0 commit comments