Skip to content

Invalid t shape for data with ndim > 1 in GeodesicProbPath.sample() #73

@LukasSchweizer

Description

@LukasSchweizer

Hi and thanks a lot for your amazing work on this library.

Describe the bug
After upgrading from version 1.0.9 to 1.0.10, a RuntimeError occurs when calling GeodesicProbPath.sample() with high-dimensional input (i.e., inputs with shape > 2D including batch dimension).

RuntimeError: einsum(): the number of subscripts in the equation
(1) does not match the number of dimensions (2) for operand 0 and no ellipsis was given

To Reproduce

from flow_matching.path import GeodesicProbPath, PathSample
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.utils.manifolds import Euclidean
import torch

batch_size = 128
data_dim = (16, 4)

x_0 = torch.randn((batch_size, *data_dim))      # (128, 16, 4)
x_1 = torch.randn((batch_size, *data_dim))      # (128, 16, 4)
t = torch.linspace(0, 1, batch_size)  # (128)

manifold = Euclidean()
scheduler = CondOTScheduler()
path = GeodesicProbPath(scheduler, manifold)

sample: PathSample = path.sample(x_0=x_0, x_1=x_1, t=t)
# RuntimeError: einsum(): the number of subscripts in the equation
#   (1) does not match the number of dimensions (2) for operand 0 and no ellipsis was given

Expected behavior
The function should support data with arbitrary trailing dimensions (e.g., (batch_size, D1, D2)), not just 2D inputs.

Thank you in advance
Lukas

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions