Skip to content

Conversation

@htjb
Copy link
Owner

@htjb htjb commented Dec 1, 2025

Description

This PR is designed to update margarine to use JAX and include several different normalising flows.

Impact on issues

Fixes #67 and #29.
#8 is no longer relevant.
I'm leaving error calculation on marginal statistics up to the user, closing #28. Although I will add a discussion to the documentation.

Key changes

  • Added NICE and RealNVP density estimators.
  • The code base is now written in JAX.
  • Density estimators are now written in JAX and flax.nnx.
  • Added an abstract BaseDensityEstimator that all density estimators inherit from. This means that each density estimator has an expected set of methods.
  • The implementation of Piecewise Normalising Flows in the old clusterMAF class has been rewritten into the cluster class in margarine/estimators/clustered.py. It takes advantage of the common API for each density estimator to allow users to build Piecewise NFs with any other implemented NF architecture (e.g. users can now build RealNVP PNFs, NICE PNFs, and even Piecewise KDEs).
  • Restructured the files so that the base class is kept in margarine/base/, density estimators are kept in margarine/estimators/ and utilities are kept in margarine/utils/.
  • Added a JAX Kmeans implementation since jax.scipy.stats doesn't have one and it is needed for Piecewise Normalising Flows.
  • New tests have been written.
  • Added load and save functions for NICE, RealNVP, cluster and KDE.
  • Added __call__ function for KDE. To transform samples from the unit hypercube on to the KDE you need conditional inverse transform sampling and this needs to be reimplemented in JAX.

Changes yet to be implemented

  • Add a discussion of errors on marginal statistics to the documentation.
  • New documentation and tutorials need to be written because of breaking changes in the API.
  • I think that the code needs some optimization to make the most of JAX still.

Additional features to be added

These features are things I would like to add either in this PR or in the future (in which case they will be added to a roadmap when this PR is merged).

  • I would like to make the flows conditional so that they can be used for SBI and to build $\beta$-flows (see here).
  • floZ style evidence calculations.
  • Add Neural Spline Flows.
  • Add Masked Autoregressive Flows
  • Explore addition of diffusion models as alternative density estimators.

Checklist:

  • I have performed a self-review of my own code
  • New and existing unit tests pass locally with my changes (python -m pytest)
  • I have added tests that prove my fix is effective or that my feature works
  • I have appropriately incremented the semantic version number in both README.rst and margarine/_version.py

htjb added 30 commits November 26, 2025 16:04
…other density estimators not just MAFs as in version 1
@htjb
Copy link
Owner Author

htjb commented Dec 5, 2025

MAF have proved tricky to set up. I've moved the relevant code to a separate branch so that it doesn't block version 2. I will continue developing it and release later. I have been finding that MAFs typically collapse to the identity transformation or to predicting very tight distributions that don't look like the target. They need careful regularisation and initialisation I think to prevent these two things happening, and this was previously all taken care off in the TensorFlow's tfb.MaskedAutoregressiveFlow code.

RealNVP and NICE are much easier to train.

htjb added 29 commits December 9, 2025 15:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

log_like calculation in clusterMAF

2 participants