From e5326bad781badb6193125f017c2c3e540f31641 Mon Sep 17 00:00:00 2001 From: Nick Thompson Date: Thu, 8 Feb 2024 10:51:10 -0800 Subject: [PATCH] [CI SKIP][ci skip] Multivariate normal distribution --- .../multivariate_normal_distribution.hpp | 117 ++++++++++++++ test/multivariate_normal_test.cpp | 145 ++++++++++++++++++ 2 files changed, 262 insertions(+) create mode 100644 include/boost/random/multivariate_normal_distribution.hpp create mode 100644 test/multivariate_normal_test.cpp diff --git a/include/boost/random/multivariate_normal_distribution.hpp b/include/boost/random/multivariate_normal_distribution.hpp new file mode 100644 index 000000000..477b1c76f --- /dev/null +++ b/include/boost/random/multivariate_normal_distribution.hpp @@ -0,0 +1,117 @@ +/* + * Copyright Nick Thompson, 2024 + * Use, modification and distribution are subject to the + * Boost Software License, Version 1.0. (See accompanying file + * LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + */ +#ifndef BOOST_RANDOM_MULTIVARIATE_NORMAL_HPP +#define BOOST_RANDOM_MULTIVARIATE_NORMAL_HPP +#include +#include +#include +#include +#include +#include +#if __has_include() +#include +#else +#error "The Eigen library is required for the operation of this class" +#endif + +namespace boost::random { + +// This is super useful functionality, but nonetheless it must be shunted off into a dark corner of the library +// because even today there is no standard matrix class and no standard was to do a Cholesky decomposition. +// Hence a more public place in the library just puts users in dependency hell. +template +class multivariate_normal_distribution { +public: + using Real = typename RandomAccessContainer::value_type; + multivariate_normal_distribution(RandomAccessContainer const & mean, Eigen::Matrix const & covariance_matrix) : m_{mean} { + using std::sqrt; + if (covariance_matrix.rows() != covariance_matrix.cols()) { + std::ostringstream oss; + oss << __FILE__ << ":" << __LINE__ << ":" << __func__; + oss << ": The covariance matrix must be square, but received a (" << covariance_matrix.rows() << ", " << covariance_matrix.cols() << ") matrix."; + throw std::domain_error(oss.str()); + } + if (mean.size() != covariance_matrix.cols()) { + std::ostringstream oss; + oss << __FILE__ << ":" << __LINE__ << ":" << __func__; + oss << ": The mean has dimension " << mean.size() << " but the covariance matrix has " << covariance_matrix.cols() << " columns."; + throw std::domain_error(oss.str()); + } + ldlt_.compute(covariance_matrix); + if(ldlt_.info() != Eigen::Success) { + std::ostringstream oss; + oss << __FILE__ << ":" << __LINE__ << ":" << __func__; + if (ldlt_.info() == Eigen::NumericalIssue) { + oss << ": The covariance matrix is not positive definite. We probably need to use Eigen::LDLT instead."; + throw std::domain_error(oss.str()); + } else if (ldlt_.info() == Eigen::InvalidInput) { + oss << ": Invalid input detected from Eigen."; + throw std::domain_error(oss.str()); + } else if (ldlt_.info() == Eigen::NoConvergence) { + oss << ": Iterative procedure did not converge."; + throw std::domain_error(oss.str()); + } + } + if (!ldlt_.isPositive()) { + std::ostringstream oss; + oss << __FILE__ << ":" << __LINE__ << ":" << __func__; + oss << ": The covariance matrix provided is not positive semi-definite."; + throw std::domain_error(oss.str()); + } + } + + template + RandomAccessContainer operator()(URNG& g) const { + RandomAccessContainer x; + if constexpr (detail::has_resize_v) { + x.resize(m_.size()); + } + (*this)(x, g); + return x; + } + + template + void operator()(RandomAccessContainer& x, URNG& g) const { + using std::normal_distribution; + if (x.size() != m_.size()) { + std::ostringstream oss; + oss << __FILE__ << ":" << __LINE__ << ":" << __func__; + oss << ": Must provide a vector of the same length as the mean."; + throw std::domain_error(oss.str()); + } + + auto dis = normal_distribution(0, 1); + /// First generate standard normal random vector: + Eigen::Vector u; + u.resize(x.size()); + for (size_t i = 0; i < x.size(); ++i) { + u[i] = dis(g); + } + // Transform it with the LDLT decomposition: + // This means: u->PLD^{1/2}u: + auto const & D = ldlt_.vectorD(); + for (size_t i = 0; i < u.size(); ++i) { + u[i] = sqrt(D[i])*u[i]; + } + // Now apply L: + u = ldlt_.matrixL()*u; + // And the permutation: + u = ldlt_.transpositionsP()*u; + for (size_t i = 0; i < x.size(); ++i) { + x[i] = u[i] + m_[i]; + } + } + + + +private: + RandomAccessContainer m_; + Eigen::LDLT > ldlt_; +}; + +} // namespace boost::random +#endif diff --git a/test/multivariate_normal_test.cpp b/test/multivariate_normal_test.cpp new file mode 100644 index 000000000..999b54996 --- /dev/null +++ b/test/multivariate_normal_test.cpp @@ -0,0 +1,145 @@ +/* + * Copyright Nick Thompson, 2024 + * Use, modification and distribution are subject to the + * Boost Software License, Version 1.0. (See accompanying file + * LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + */ + +#include "math_unit_test.hpp" +#include +#include +#include +#include +#include +#include + +using std::abs; +using boost::random::multivariate_normal_distribution; + +template void test_multivariate_normal() { + using Eigen::Matrix; + using Eigen::Dynamic; + constexpr const size_t n = 7; + Matrix C = Matrix::Identity(n, n); + std::mt19937_64 mt(12345); + std::array mean; + std::uniform_real_distribution dis(-1, 1); + std::generate(mean.begin(), mean.end(), [&]() { return dis(mt); }); + auto mvn = multivariate_normal_distribution(mean, C); + std::array x; + std::array empirical_means; + empirical_means.fill(0); + + size_t i = 0; + size_t samples = 2048; + do { + x = mvn(mt); + for (size_t j = 0; j < n; ++j) { + empirical_means[j] += x[j]; + } + } while(i++ < samples); + + for (size_t j = 0; j < n; ++j) { + empirical_means[j] /= samples; + CHECK_ABSOLUTE_ERROR(mean[j], empirical_means[j], 0.05); + } + + // Exhibits why we need to use the LDL^T decomposition: + C = Matrix::Zero(n, n); + mvn = multivariate_normal_distribution(mean, C); + i = 0; + do { + x = mvn(mt); + for (size_t j = 0; j < n; ++j) { + CHECK_EQUAL(mean[j], x[j]); + } + } while(i++ < 10); + // Test that we're applying the permutation matrix correctly: + C = Matrix::Zero(n, n); + C(0,0) = 1; + mvn = multivariate_normal_distribution(mean, C); + i = 0; + do { + x = mvn(mt); + for (size_t j = 1; j < mean.size(); ++j) { + CHECK_EQUAL(mean[j], x[j]); + } + } while(i++ < 3); + + C(0,0) = 0; + C(n-1,n-1) = 1; + mvn = multivariate_normal_distribution(mean, C); + i = 0; + do { + x = mvn(mt); + // All but the last entry must be identical to the mean: + for (size_t j = 0; j < mean.size() - 1; ++j) { + CHECK_EQUAL(mean[j], x[j]); + } + } while(i++ < 3); + + C(0,0) = 0; + C(1,1) = 1; + C(n-1,n-1) = 0; + mvn = multivariate_normal_distribution(mean, C); + i = 0; + do { + x = mvn(mt); + for (size_t j = 0; j < mean.size() - 1; ++j) { + if (j != 1) { + CHECK_EQUAL(mean[j], x[j]); + } + } + } while(i++ < 10); + + C(1,1) = 0; + C(n-2,n-2) = 1; + mvn = multivariate_normal_distribution(mean, C); + i = 0; + do { + x = mvn(mt); + for (size_t j = 0; j < mean.size() - 1; ++j) { + if (j != n-2) { + CHECK_EQUAL(mean[j], x[j]); + } + } + } while(i++ < 3); + + // Scaling test: If C->kC for some constant k, then A->sqrt(k)A. + // First we build a random positive semidefinite matrix: + Matrix C1 = Matrix::Random(n, n); + C = C1.transpose()*C1; + // Set the mean to 0: + for (auto & m : mean) { + m = 0; + } + samples = 1; + std::vector> x1(samples); + mt.seed(12859); + mvn = multivariate_normal_distribution(mean, C); + for (size_t i = 0; i < x1.size(); ++i) { + x1[i] = mvn(mt); + } + // Now scale C: + C *= 16; + // Set the seed back to the original: + mt.seed(12859); + std::vector> x2(samples); + mvn = multivariate_normal_distribution(mean, C); + for (size_t i = 0; i < x2.size(); ++i) { + x2[i] = mvn(mt); + } + // Now x2 = 4*x1 is expected: + for (size_t i = 0; i < x1.size(); ++i) { + for (size_t j = 0; j < n; ++j) { + CHECK_ULP_CLOSE(4*x1[i][j], x2[i][j], 2); + } + } + +} + +int main() { + test_multivariate_normal(); + test_multivariate_normal(); + return boost::math::test::report_errors(); +}