Skip to content

Commit 46b800e

Browse files
author
Nicholas Carlotti
committed
<EXP> Odd pairs experiment
1 parent d51b507 commit 46b800e

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

testing_led.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ def main():
1919
augmentations=args.augmentations, only_visible_robots=True,
2020
sample_count=args.sample_count, sample_count_seed=args.sample_count_seed,
2121
compute_led_visibility=True)
22-
dataloader = DataLoader(ds, batch_size = 64, shuffle = False)
22+
dataloader = DataLoader(ds, batch_size = 32, shuffle = False)
2323

2424

2525
if args.checkpoint_id:
2626
model, run_id = load_model_mlflow(experiment_id=args.experiment_id, mlflow_run_name=args.run_name, checkpoint_idx=args.checkpoint_id,
27-
model_task=args.task, return_run_id=True)
27+
model_kwargs=args.task, return_run_id=True)
2828
using_mlflow = True
2929
else:
3030
using_mlflow = False
31-
model = load_model_raw(args.checkpoint_path, model_task=args.task)
31+
model = load_model_raw(args.checkpoint_path, model_kwargs={'task' : args.task, 'led_inference' : args.led_inference})
3232

3333
model = model.to(args.device)
3434
model.eval()
@@ -44,7 +44,8 @@ def main():
4444
# print(batch["led_visibility_mask"][0])
4545
# plt.show()
4646
outs = model(image)
47-
led_preds.extend(model.predict_leds_with_gt_pos(batch, image))
47+
# breakpoint()
48+
led_preds.extend(model.predict_leds(outs, batch))
4849
led_trues.extend(batch['led_mask'])
4950
led_visibility.extend(batch['led_visibility_mask'])
5051

0 commit comments

Comments
 (0)