-
Notifications
You must be signed in to change notification settings - Fork 291
Open
Description
In flow_matching/tests/solver/test_ode_solver.py,
in 23th line,
class ConstantVelocityModel(ModelWrapper):
def __init__(self):
super().__init__(None)
self.a = torch.nn.Parameter(torch.tensor(1.0))
def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
return x * 0.0 + self.afor enhancing readability, i suggest
return x * 0.0 + self.a -> return torch.ones_like(x) * self.aMetadata
Metadata
Assignees
Labels
No labels