Based on my understanding, assuming a linear time schedule, the optimization target dx_t = "u_t(x|x1)" should simply be manifold.logmap(x_0, x_1).
However, in this repository you implement
def cond_u(x_0, x_1, t):
path = geodesic(self.manifold, x_0, x_1)
x_t, dx_t = jvp(
lambda t: path(self.scheduler(t).alpha_t),
(t,),
(torch.ones_like(t).to(t),),
)
return x_t, dx_t
x_t, dx_t = vmap(cond_u)(x_0, x_1, t)
I checked, and the results are different and your (the correct) target works better. Interestingly manifold.logmap(x_0, x_1) also works for my basic test setup but not quite as good.
Could someone help me understand why it is more complicated to compute dx_t and also what exactly the jvp does compute?