Mutual Information estimation via Supervised Training
MIST is a framework for fully data-driven mutual information (MI) estimation. It leverages neural networks trained on large meta-datasets of distributions to learn flexible, differentiable MI estimators that generalize across sample sizes, dimensions, and modalities. The framework supports uncertainty quantification via quantile regression and provides fast, well-calibrated inference suitable for integration into modern ML pipelines.
This repository contains the reference implementation for the preprint "Mutual Information via Supervised Training". It includes scripts to reproduce our experiments as well as tools for training and evaluating MIST-style MI estimators.
Install with pip
pip install mist-statinf
Install with conda
conda env create -f environment.yml
conda activate mist-statinf
Install from sources
Alternatively, you can also clone the latest version from the repository and install it directly from the source code:
pip install -e .
[2025/12] 🚀 MIST and MIST-QR are now available on 🤗 Hugging Face! For a quick and easy way to integrate our models into your projects, check out the usage guides on the model pages.
Alternatively, you can use our built-in inference wrapper. If you want to evaluate MI or obtain confidence intervals on your own data using the MIST or MIST-QR models described in the paper, use the MISTQuickEstimator.
from mist_statinf import MISTQuickEstimator
X, Y = <your data>
mist = MISTQuickEstimator(
loss="mse",
checkpoint="checkpoints/mist/weights.ckpt",
)
mi = mist.estimate_point(X, Y)
print("MIST estimate:", mi)from mist_statinf import MISTQuickEstimator
X, Y = <your data>
mist_qr = MISTQuickEstimator(
loss="qr",
checkpoint="checkpoints/mist_qr/weights.ckpt",
)
mi_median = mist_qr.estimate_point(X, Y)
print("Median MI:", mi_median)
mi_q90 = mist_qr.estimate_point(X, Y, tau=0.90)
print("q90 MI estimate:", mi_q90)
# --- fast quantile-based uncertainty interval ---
interval = mist_qr.estimate_interval_qr(X, Y, lower=0.05, upper=0.95)
print(interval)By default, MISTQuickEstimator loads the pretrained models used in the paper from the package’s checkpoints/ directory, using the architecture defined in configs/inference/quickstart.yaml.
You can override both the checkpoint and the architecture if you have your own trained models.
If you want to reproduce the experiments from the paper, we recommend evaluating our trained estimators on the provided test sets (M_test and M_test_extended).
Since the test sets take a considerable amount of storage space, we publish them separately on Zenodo.
Before running inference, download the desired subset (either M_test or M_test_extended).
Below we show an example using M_test, as it is significantly lighter.
mist-statinf get-data --preset m_test_imd --dir data/test_imd_data
mist-statinf get-data --preset m_test_oomd --dir data/test_oomd_dataThe simplest way to run inference on these datasets is:
mist-statinf infer configs/inference/mist_inference.yaml "checkpoints/mist/" NOTE: The file
mist_inference.yamlallows you to configure the evaluation mode (bootstrap or QCQR calibration), select the specific test subset, and specify which quantiles to compute.
Below we show the results we obtained on M_test:
If you want to reproduce the full training pipeline from the paper — possibly with your own modifications — we recommend following the workflow below.
mist-statinf generate configs/data_generation/train.yaml # the same for test and valThe generated datasets and their corresponding configuration files will appear under
data/train_data and etc.
mist-statinf train configs/train/mist_train.yamlInside the training config you can switch between MSE training and QCQR training.
After training, logs, configs, and the saved model checkpoint will be stored under: logs/mist_train/run_YYYYmmdd-HHMMSS
mist-statinf baselines configs/inference/baselines.yamlBaseline results, logs, and configs will be saved to: logs/bmi_baselines.
mist-statinf infer configs/inference/mist_inference.yaml "logs/mist_train/run_YYYYmmdd-HHMMSS"This will produce CSV predictions and a JSON summary in the same run directory: logs/mist_train/run_YYYYmmdd-HHMMSS.
mist-statinf tune logs/mist_train/run_YYYYmmdd-HHMMSS --model-type MSE --n-trials 30This performs a parameter search (via Optuna) starting from a given training run.
If you use MIST or MIST-QR in your work, please cite:
@misc{gritsai2025mistmutualinformationsupervised,
title={MIST: Mutual Information Via Supervised Training},
author={German Gritsai and Megan Richards and Maxime Méloux and Kyunghyun Cho and Maxime Peyrard},
year={2025},
eprint={2511.18945},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2511.18945},
}German Gritsai, Megan Richards, Maxime Meloux, Kyunghyun Cho, Maxime Peyrard.

