-
Notifications
You must be signed in to change notification settings - Fork 8
Version 2 #68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
…y will come back to this
…pe on key for train test split
…other density estimators not just MAFs as in version 1
|
MAF have proved tricky to set up. I've moved the relevant code to a separate branch so that it doesn't block version 2. I will continue developing it and release later. I have been finding that MAFs typically collapse to the identity transformation or to predicting very tight distributions that don't look like the target. They need careful regularisation and initialisation I think to prevent these two things happening, and this was previously all taken care off in the TensorFlow's RealNVP and NICE are much easier to train. |
…in the stats calcualtions
Description
This PR is designed to update
margarineto use JAX and include several different normalising flows.Impact on issues
Fixes #67 and #29.
#8 is no longer relevant.
I'm leaving error calculation on marginal statistics up to the user, closing #28. Although I will add a discussion to the documentation.
Key changes
BaseDensityEstimatorthat all density estimators inherit from. This means that each density estimator has an expected set of methods.clusterMAFclass has been rewritten into theclusterclass inmargarine/estimators/clustered.py. It takes advantage of the common API for each density estimator to allow users to build Piecewise NFs with any other implemented NF architecture (e.g. users can now build RealNVP PNFs, NICE PNFs, and even Piecewise KDEs).margarine/base/, density estimators are kept inmargarine/estimators/and utilities are kept inmargarine/utils/.jax.scipy.statsdoesn't have one and it is needed for Piecewise Normalising Flows.__call__function for KDE. To transform samples from the unit hypercube on to the KDE you need conditional inverse transform sampling and this needs to be reimplemented in JAX.Changes yet to be implemented
Additional features to be added
These features are things I would like to add either in this PR or in the future (in which case they will be added to a roadmap when this PR is merged).
Checklist:
python -m pytest)