diff --git a/main_cuda.py b/main_cuda.py index 2911b85..4928763 100644 --- a/main_cuda.py +++ b/main_cuda.py @@ -210,7 +210,7 @@ def is_save_iter(i): args.coef_speed * loss_speed + \ args.coef_v_pred * loss_v_pred + \ args.coef_collide * loss_collide + \ - args.coef_ground_affinity + loss_ground_affinity + args.coef_ground_affinity * loss_ground_affinity if torch.isnan(loss): print("loss is nan, exiting...")