Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
b2fb557
adding a base class to build density estimators on top of
htjb Nov 26, 2025
fe025a1
tighter prototyping of the functions
htjb Nov 26, 2025
c31ad44
util functions
htjb Nov 26, 2025
aadb5dc
beginning to lay out a kde class as a starting point for v2
htjb Nov 26, 2025
2ffd42b
adding forward and inverse transformations
htjb Nov 27, 2025
c6120a7
realised I can use tensorflow probability with the jax backend
htjb Nov 27, 2025
02e0198
fiddling with the parameter order in the sampling function
htjb Nov 27, 2025
da3b55a
adding in a log prob function with correction
htjb Nov 27, 2025
fc7ea0e
kde reimplementation
htjb Nov 27, 2025
32a8dec
removing old margarine kde class
htjb Nov 27, 2025
8210575
hmmm implementing the conditional inverse transform sampling is trick…
htjb Nov 27, 2025
ce08c15
adding a jax test train split function
htjb Nov 27, 2025
cbc25d0
removing now redundatn processing code
htjb Nov 27, 2025
2596472
rearranging the file structure
htjb Nov 27, 2025
be6ea2a
organising utils into seperate file
htjb Nov 27, 2025
26d7204
an implementation of kmeans in jax
htjb Nov 27, 2025
e885e8f
jax implementation of the silhouette score
htjb Nov 27, 2025
2157315
making sure that the bounds function returns (min, max) and fixing ty…
htjb Nov 27, 2025
28f60e1
adding a generic cluster code to build piecewise estimators from any …
htjb Nov 27, 2025
545645b
better flagging of pycache and ds_store
htjb Nov 27, 2025
7198547
make base class and abstract class with required methods
htjb Nov 28, 2025
ac0b2a2
doesnt need to inherit from the base class
htjb Nov 28, 2025
3911fc7
dont need to init base class here
htjb Nov 28, 2025
752e22f
implementation of the NICE normalising flow
htjb Nov 28, 2025
c790283
bug fix in log prob calculation for kde
htjb Nov 28, 2025
f701de6
depends on flax
htjb Nov 28, 2025
89f1364
better variable name for hidden size
htjb Nov 28, 2025
a62c5f6
updating the log_prob_under_nice function name so it doesnt clash wit…
htjb Nov 28, 2025
916d38a
better reporting of the loss
htjb Nov 28, 2025
32cb9db
realNVP implementation
htjb Nov 28, 2025
e537711
correcting a bug in the bounds estimation
htjb Nov 28, 2025
d82bdbb
calcualte theta ranges seperately for each cluster so that they are w…
htjb Nov 28, 2025
53a9e3c
some doc string improvements
htjb Nov 28, 2025
356cfda
adding a more detailed discription of the origin of the approximate b…
htjb Nov 28, 2025
f44529a
new margarine strap line
htjb Nov 28, 2025
ab5f17e
bumping version number because im excited
htjb Nov 28, 2025
b5b8d6b
fiddling with the pyproject.toml
htjb Nov 28, 2025
c9af1c9
merging
htjb Nov 28, 2025
1475ab6
missing comma
htjb Nov 29, 2025
af61154
need tfp nightly for compatibility with latest jax
htjb Nov 29, 2025
2fe86ba
laying out new statistics code
htjb Nov 29, 2025
ded5d79
code to calcualte kl and bmd
htjb Nov 29, 2025
7a41ff4
jit the log prob functions
htjb Nov 29, 2025
5512bcf
removing old maf class so i can rewrite
htjb Nov 30, 2025
2814a4a
starting to lay out the base of the maf class
htjb Nov 30, 2025
4c16802
adding a to do list briefly so i can keep track
htjb Nov 30, 2025
be26f45
build the mades, masks and inverse and forward passes
htjb Dec 1, 2025
7f20c84
adding in activations after first layers
htjb Dec 1, 2025
547c375
pretty sure this is set up correctly
htjb Dec 1, 2025
7659e96
okay i think this maf is pretty well set up and optimized
htjb Dec 1, 2025
a660b1a
updating the todo
htjb Dec 1, 2025
c5c7b48
modifying the clustered_distribution example
htjb Dec 2, 2025
a6d67be
removing some old code
htjb Dec 2, 2025
39bf3e1
removing old test files
htjb Dec 2, 2025
38d1b51
importing the base from the correct file
htjb Dec 2, 2025
5640fcb
more stable kernel initialisation
htjb Dec 2, 2025
7ef9233
removing ndims in base class
htjb Dec 2, 2025
febb392
adding tests for kde and some of the util functions
htjb Dec 2, 2025
220385e
better name for utils test and formatting on tests
htjb Dec 2, 2025
f713279
tighter constraints on kl and bmd accuracy (I think something wrong w…
htjb Dec 2, 2025
b251d85
testing on a more straight forward problem
htjb Dec 2, 2025
69d6521
i think the latest versions of flax are only tested on 3.11 and above
htjb Dec 2, 2025
bc78dd9
trying to use stochasticity of training to set atol and rtol on kl an…
htjb Dec 2, 2025
a1512c3
playing with the tests. most pass now just trouble with maf training …
htjb Dec 3, 2025
75babff
type in instance check in loglike
htjb Dec 4, 2025
bc9b154
bug fixing the permutations in the maf but still needs some tinkering
htjb Dec 4, 2025
6f526e3
better initialisation of final layer in the mades and some tinkering
htjb Dec 5, 2025
9ce214a
removing maf stuff from v2 branch after splitting into maf branch
htjb Dec 5, 2025
6d24c89
removing maf tests from v2 branch
htjb Dec 5, 2025
c535866
testing the importance sampling functionality
htjb Dec 5, 2025
dafb13f
batched training for nice and realnvp
htjb Dec 5, 2025
8fc3cc7
rough test code for cluster class
htjb Dec 5, 2025
93f1e2c
setting batch size for better performance on toy problems
htjb Dec 9, 2025
9fb9e33
specifying max cluster size
htjb Dec 9, 2025
a009009
seperate theta bounds for each flow causes issues when evaluating log…
htjb Dec 9, 2025
9d3257e
playing with the flow settings but some inf is creeping in somewhere …
htjb Dec 9, 2025
e31a247
hmmm seems the inf is the true kl
htjb Dec 10, 2025
eb7f48c
problem was basically set up wrong so the true kl and bmd was wrong
htjb Dec 10, 2025
f4b5c03
removing the old tutorials notebook
htjb Dec 10, 2025
81edf5a
removing old example multimodal distribution code
htjb Dec 10, 2025
95ab95b
fiddling with training parameters and batch size
htjb Dec 10, 2025
bee8cb6
dealing with conflict with amster branch
htjb Dec 11, 2025
cb0db8d
jax friendly inverse and forward pass and jit of trainingset
htjb Dec 11, 2025
1fbf124
trying to get the cluster tests working
htjb Dec 15, 2025
f01ef84
better type hinting on base estimator
htjb Dec 16, 2025
4939ad8
allow users to set the number of clusters
htjb Dec 16, 2025
a2f9ebd
easier test case for clusters and a larger realnvp network
htjb Dec 16, 2025
408a17f
workign on save and load functions but something isnt quite right
htjb Dec 17, 2025
adff989
return an array from approximate bounds rather than a tuple
htjb Dec 17, 2025
06840d0
a working save and load function for the NICE model
htjb Dec 17, 2025
7b1a21b
surpress some orbax warnings and add save and load tests for nice
htjb Dec 17, 2025
3109b4c
addign a functioning save and load to the realnvp class
htjb Dec 17, 2025
4c56bd7
a test for saving and loading of realnvp
htjb Dec 17, 2025
58275d1
custome extensions on save fiels
htjb Dec 18, 2025
2b21364
more consistent save and load functions for kdes
htjb Dec 18, 2025
92b5088
test save and load of kdes and filepath fixes
htjb Dec 18, 2025
6bc66c4
saving and passign theta ranges
htjb Dec 18, 2025
ebb8c23
working version of save for all estimators and cluster save test
htjb Dec 18, 2025
5a563b1
bug fix in kde save function
htjb Dec 18, 2025
e7d4af0
writing bijective transformation for kde with some help from an llm
htjb Dec 18, 2025
4bf3767
jitting fucntions in __call__ for kde
htjb Dec 18, 2025
f40e8a2
removing the todo file
htjb Jan 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
python-version: ["3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v2
Expand Down
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
*/__pycache__
*/.DS_Store
**/__pycache__
**/.DS_Store
margarine.egg-info/
.pytest_cache/
.ruff_cache/
Expand Down
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
================================================================
margarine: Posterior Sampling and Marginal Bayesian Statistics
margarine: you won't believe it's not your posterior samples!
================================================================

Introduction
------------

:margarine: Marginal Bayesian Statistics
:Authors: Harry T.J. Bevins
:Version: 1.4.2
:Version: 2.0.0
:Homepage: https://github.com/htjb/margarine
:Documentation: https://margarine.readthedocs.io/

Expand Down
2 changes: 1 addition & 1 deletion margarine/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.4.2"
__version__ = "2.0.0"
102 changes: 102 additions & 0 deletions margarine/base/baseflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Base density estimator for margarine package.

Defines a base class for density estimators with common interface methods
including:
- train
- sample
- __call__
- log_prob
- log_like
- save
- load
"""

from abc import ABC, abstractmethod

import jax
import jax.numpy as jnp


class BaseDensityEstimator(ABC):
"""Base class for density estimators in the margarine package."""

@abstractmethod
def train(self) -> None:
"""Train the density estimator on the provided data."""
raise NotImplementedError("Train method must be implemented.")

@abstractmethod
def sample(self, key: jnp.ndarray, num_samples: int) -> jnp.ndarray:
"""Generate samples from the density estimator.

Args:
key: JAX random key for sampling.
num_samples: Number of samples to generate.

Returns:
jnp.ndarray: Generated samples as a JAX array.
"""
u = jax.random.uniform(key, shape=(num_samples, self.theta.shape[1]))
raise self(u)

def __call__(self, u: jnp.ndarray) -> jnp.ndarray:
"""Evaluate the density estimator at given points.

Args:
u: Samples from the unit hypercube.

Returns:
jnp.ndarray: samples from the density estimator.
"""
raise NotImplementedError("Call method must be implemented.")

@abstractmethod
def log_prob(self, x: jnp.ndarray) -> jnp.ndarray:
"""Compute the log-probability of given samples.

Args:
x: Samples for which to compute the log-probability.

Returns:
jnp.ndarray: Log-probabilities of the samples.
"""
raise NotImplementedError("log_prob method must be implemented.")

@abstractmethod
def log_like(
self,
x: jnp.ndarray,
logevidence: float,
prior_density: jnp.ndarray,
) -> jnp.ndarray:
"""Compute the log-likelihood of given samples.

Args:
x: Samples for which to compute the log-likelihood.
logevidence: Log-evidence value.
prior_density: Prior density estimator or densities.

Returns:
jnp.ndarray: Log-likelihoods of the samples.
"""
return NotImplementedError("log_like method must be implemented.")

def save(self, filepath: str) -> None:
"""Save the density estimator to a file.

Args:
filepath: Path to the file where the estimator will be saved.
"""
raise NotImplementedError("save method must be implemented.")

@classmethod
def load(cls, filepath: str) -> "BaseDensityEstimator":
"""Load a density estimator from a file.

Args:
filepath: Path to the file from which to load the estimator.

Returns:
BaseDensityEstimator: Loaded density estimator instance.
"""
raise NotImplementedError("load method must be implemented.")
Loading