@@ -19,16 +19,16 @@ def main():
19
19
augmentations = args .augmentations , only_visible_robots = True ,
20
20
sample_count = args .sample_count , sample_count_seed = args .sample_count_seed ,
21
21
compute_led_visibility = True )
22
- dataloader = DataLoader (ds , batch_size = 64 , shuffle = False )
22
+ dataloader = DataLoader (ds , batch_size = 32 , shuffle = False )
23
23
24
24
25
25
if args .checkpoint_id :
26
26
model , run_id = load_model_mlflow (experiment_id = args .experiment_id , mlflow_run_name = args .run_name , checkpoint_idx = args .checkpoint_id ,
27
- model_task = args .task , return_run_id = True )
27
+ model_kwargs = args .task , return_run_id = True )
28
28
using_mlflow = True
29
29
else :
30
30
using_mlflow = False
31
- model = load_model_raw (args .checkpoint_path , model_task = args .task )
31
+ model = load_model_raw (args .checkpoint_path , model_kwargs = { 'task' : args .task , 'led_inference' : args . led_inference } )
32
32
33
33
model = model .to (args .device )
34
34
model .eval ()
@@ -44,7 +44,8 @@ def main():
44
44
# print(batch["led_visibility_mask"][0])
45
45
# plt.show()
46
46
outs = model (image )
47
- led_preds .extend (model .predict_leds_with_gt_pos (batch , image ))
47
+ # breakpoint()
48
+ led_preds .extend (model .predict_leds (outs , batch ))
48
49
led_trues .extend (batch ['led_mask' ])
49
50
led_visibility .extend (batch ['led_visibility_mask' ])
50
51
0 commit comments