Skip to content

Question on Riemannian probability path with geodesic interpolation #75

@johannespitz

Description

@johannespitz

Based on my understanding, assuming a linear time schedule, the optimization target dx_t = "u_t(x|x1)" should simply be manifold.logmap(x_0, x_1).
However, in this repository you implement

        def cond_u(x_0, x_1, t):
            path = geodesic(self.manifold, x_0, x_1)
            x_t, dx_t = jvp(
                lambda t: path(self.scheduler(t).alpha_t),
                (t,),
                (torch.ones_like(t).to(t),),
            )
            return x_t, dx_t

        x_t, dx_t = vmap(cond_u)(x_0, x_1, t)

I checked, and the results are different and your (the correct) target works better. Interestingly manifold.logmap(x_0, x_1) also works for my basic test setup but not quite as good.

Could someone help me understand why it is more complicated to compute dx_t and also what exactly the jvp does compute?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions