Skip to content

Commit d7b5d6e

Browse files
committed
add the webshop fnference of the knowself_eval file
1 parent 439032e commit d7b5d6e

File tree

1 file changed

+95
-35
lines changed

1 file changed

+95
-35
lines changed

eval_agent/knowself_eval.py

Lines changed: 95 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@ def cov_to_text(cov: List[int]) -> str:
4646
return "\n".join([c["content"] for c in cov])
4747

4848

49+
def conv_to_text_webshop(cov: List[int]) -> str:
50+
msg = []
51+
for c in cov[:-1]:
52+
if c["role"] == "user":
53+
msg.append(f"Observation: ...")
54+
else:
55+
msg.append(c["content"])
56+
msg.append(cov[-1]["content"])
57+
return "\n".join(msg)
58+
59+
4960
def format_input(task_desc, current_traj, rules):
5061
return (
5162
"The objectve:\n"
@@ -138,42 +149,90 @@ def call_model(
138149
logger.warning("[Knowledge] in preds")
139150
if task not in add_knowledge_task:
140151
add_knowledge_task.append(task.task_id)
141-
if len(messages) == 3:
142-
# add task knowledge to the prompt
143-
prompt = ""
144-
game_file = task.game_file
145-
for k, v in alfworld_prompt.items():
146-
if k in game_file:
147-
prompt = v
148-
break
149-
if prompt == "":
150-
logger.warning(f"Game file {game_file} not in alfworld_prompt")
151-
return pred_text
152-
knowledge_augmented_input = (
153-
input + f"[Knowledge]<knowledge>{prompt}</knowledge>\n"
154-
)
155-
logger.warning(f"knowledge_augmented_input: {knowledge_augmented_input}")
156-
inputs = tokenizer(knowledge_augmented_input, return_tensors="pt").to("cuda")
157-
generated_ids = model.generate(
158-
inputs.input_ids,
159-
max_new_tokens=1024,
160-
do_sample=False,
161-
pad_token_id=tokenizer.pad_token_id,
162-
stopping_criteria=stopping_criteria,
163-
)
164-
generated_ids = [
165-
output_ids[len(input_ids) :]
166-
for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
167-
]
168-
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
169-
pred_text = response.split(end_token)[0]
170-
pred_text = f"[Knowledge]<knowledge>{prompt}</knowledge>\n" + pred_text
171-
else:
172-
# call select knowledge agent to select knowledge
152+
153+
if args.exp_config == "alfworld":
154+
if len(messages) == 3:
155+
# add task knowledge to the prompt
156+
prompt = ""
157+
game_file = task.game_file
158+
for k, v in alfworld_prompt.items():
159+
if k in game_file:
160+
prompt = v
161+
break
162+
if prompt == "":
163+
logger.warning(f"Game file {game_file} not in alfworld_prompt")
164+
return pred_text
165+
knowledge_augmented_input = (
166+
input + f"[Knowledge]<knowledge>{prompt}</knowledge>\n"
167+
)
168+
logger.warning(f"knowledge_augmented_input: {knowledge_augmented_input}")
169+
inputs = tokenizer(knowledge_augmented_input, return_tensors="pt").to("cuda")
170+
generated_ids = model.generate(
171+
inputs.input_ids,
172+
max_new_tokens=1024,
173+
do_sample=False,
174+
pad_token_id=tokenizer.pad_token_id,
175+
stopping_criteria=stopping_criteria,
176+
)
177+
generated_ids = [
178+
output_ids[len(input_ids) :]
179+
for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
180+
]
181+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
182+
pred_text = response.split(end_token)[0]
183+
pred_text = f"[Knowledge]<knowledge>{prompt}</knowledge>\n" + pred_text
184+
else:
185+
# call select knowledge agent to select knowledge
186+
with open(args.select_knowledge_inst) as f:
187+
prompt = f.read()
188+
task_desc = messages[2]["content"]
189+
current_traj = cov_to_text(messages[3:])
190+
rule_data = json.load(open(args.knowledge_base_path))["all_rules"]
191+
rules = []
192+
for k, v in rule_data.items():
193+
rules.append(v["rule"])
194+
rules = rules.__str__()
195+
cur_task = format_input(task_desc, current_traj, rules)
196+
_, knowledge_prompt = prompt_without_icl(prompt, cur_task)
197+
logger.warning(knowledge_prompt)
198+
select_knowledge_output = select_knowledge_agent(knowledge_prompt)
199+
logger.warning(f"select_knowledge_output: {select_knowledge_output}")
200+
if "[Chosen Rule]:" in select_knowledge_output:
201+
rule = select_knowledge_output.split("[Chosen Rule]:")[1].strip()
202+
else:
203+
rule = ""
204+
knowledge_augmented_input = (
205+
input + f"[Knowledge]<knowledge>{rule}</knowledge>\n"
206+
)
207+
logger.warning(f"[Knowledge]<knowledge>{rule}</knowledge>\n")
208+
inputs = tokenizer(knowledge_augmented_input, return_tensors="pt").to("cuda")
209+
generated_ids = model.generate(
210+
inputs.input_ids,
211+
max_new_tokens=1024,
212+
do_sample=False,
213+
pad_token_id=tokenizer.pad_token_id,
214+
stopping_criteria=stopping_criteria,
215+
)
216+
generated_ids = [
217+
output_ids[len(input_ids) :]
218+
for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
219+
]
220+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[
221+
0
222+
]
223+
pred_text = response.split(end_token)[0]
224+
pred_text = f"[Knowledge]<knowledge>{rule}</knowledge>\n" + pred_text
225+
226+
elif args.exp_config == "webshop":
173227
with open(args.select_knowledge_inst) as f:
174228
prompt = f.read()
175-
task_desc = messages[2]["content"]
176-
current_traj = cov_to_text(messages[3:])
229+
task_desc = (
230+
messages[2]["content"]
231+
.split("Instruction: [SEP]")[1]
232+
.split("[SEP] Search")[0]
233+
.strip()
234+
)
235+
current_traj = conv_to_text_webshop(messages[3:])
177236
rule_data = json.load(open(args.knowledge_base_path))["all_rules"]
178237
rules = []
179238
for k, v in rule_data.items():
@@ -188,10 +247,11 @@ def call_model(
188247
rule = select_knowledge_output.split("[Chosen Rule]:")[1].strip()
189248
else:
190249
rule = ""
250+
191251
knowledge_augmented_input = (
192252
input + f"[Knowledge]<knowledge>{rule}</knowledge>\n"
193253
)
194-
logger.warning(f"[Knowledge]<knowledge>{rule}</knowledge>\n")
254+
logger.warning(f"knowledge:\n[Knowledge]<knowledge>{rule}</knowledge>\n")
195255
inputs = tokenizer(knowledge_augmented_input, return_tensors="pt").to("cuda")
196256
generated_ids = model.generate(
197257
inputs.input_ids,

0 commit comments

Comments
 (0)