-
Notifications
You must be signed in to change notification settings - Fork 291
Open
Description
Hi and thanks a lot for your amazing work on this library.
Describe the bug
After upgrading from version 1.0.9 to 1.0.10, a RuntimeError occurs when calling GeodesicProbPath.sample() with high-dimensional input (i.e., inputs with shape > 2D including batch dimension).
RuntimeError: einsum(): the number of subscripts in the equation
(1) does not match the number of dimensions (2) for operand 0 and no ellipsis was given
To Reproduce
from flow_matching.path import GeodesicProbPath, PathSample
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.utils.manifolds import Euclidean
import torch
batch_size = 128
data_dim = (16, 4)
x_0 = torch.randn((batch_size, *data_dim)) # (128, 16, 4)
x_1 = torch.randn((batch_size, *data_dim)) # (128, 16, 4)
t = torch.linspace(0, 1, batch_size) # (128)
manifold = Euclidean()
scheduler = CondOTScheduler()
path = GeodesicProbPath(scheduler, manifold)
sample: PathSample = path.sample(x_0=x_0, x_1=x_1, t=t)
# RuntimeError: einsum(): the number of subscripts in the equation
# (1) does not match the number of dimensions (2) for operand 0 and no ellipsis was givenExpected behavior
The function should support data with arbitrary trailing dimensions (e.g., (batch_size, D1, D2)), not just 2D inputs.
Thank you in advance
Lukas
vonHartz
Metadata
Metadata
Assignees
Labels
No labels