Skip to content

nklkhlr/jax-cfm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX-CFM: Conditional Flow Matching in JAX

Overview

This package is a JAX-based implementation of Conditional Flow Matching (CFM) - an approach for generative modelling based on continuous normalizing flows. The API design of this package is closely tied to that of the TorchCFM library to allow users used to TorchCFM who want to migrate to JAX an easy transition.

This repository is currently under construction and thus may not be bug-free or complete at this point.

Installation

To install JAX-CFM clone this repository and run pip install . in an environment with a python version >= 3.10.

If you intend to contribute or run examples, please consider installing with optional packages as well (e.g. pip install .[dev] or pip install .[examples]).

Eventually, the goal is to make the package available on PyPi.

Dependencies

JAX-CFM relies on ott-jax for Optimal Transport-related tasks and uses equinox and jaxtyping for API design and type annotations and checking.

About

Conditional Flow Matching in JAX

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages