Skip to content

Commit 07c6ab4

Browse files
Allow skipping covariance computation in find_MAP (#578)
* Allow skipping covariance computation in `find_MAP` * Set `compute_covariance` to False by default * Rename `compute_covariance` to `compute_hessian`
1 parent c7f9d5a commit 07c6ab4

File tree

4 files changed

+28
-12
lines changed

4 files changed

+28
-12
lines changed

pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def find_MAP(
198198
include_transformed: bool = True,
199199
gradient_backend: GradientBackend = "pytensor",
200200
compile_kwargs: dict | None = None,
201+
compute_hessian: bool = False,
201202
**optimizer_kwargs,
202203
) -> (
203204
dict[str, np.ndarray]
@@ -239,6 +240,10 @@ def find_MAP(
239240
Whether to include transformed variable values in the returned dictionary. Defaults to True.
240241
gradient_backend: str, default "pytensor"
241242
Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
243+
compute_hessian: bool
244+
If True, the inverse Hessian matrix at the optimum will be computed and included in the returned
245+
InferenceData object. This is needed for the Laplace approximation, but can be computationally expensive for
246+
high-dimensional problems. Defaults to False.
242247
compile_kwargs: dict, optional
243248
Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
244249
**optimizer_kwargs
@@ -316,14 +321,17 @@ def find_MAP(
316321
**optimizer_kwargs,
317322
)
318323

319-
H_inv = _compute_inverse_hessian(
320-
optimizer_result=optimizer_result,
321-
optimal_point=None,
322-
f_fused=f_fused,
323-
f_hessp=f_hessp,
324-
use_hess=use_hess,
325-
method=method,
326-
)
324+
if compute_hessian:
325+
H_inv = _compute_inverse_hessian(
326+
optimizer_result=optimizer_result,
327+
optimal_point=None,
328+
f_fused=f_fused,
329+
f_hessp=f_hessp,
330+
use_hess=use_hess,
331+
method=method,
332+
)
333+
else:
334+
H_inv = None
327335

328336
raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
329337
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed=True)

pymc_extras/inference/laplace_approx/idata.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ def map_results_to_inference_data(
136136

137137

138138
def add_fit_to_inference_data(
139-
idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None
139+
idata: az.InferenceData,
140+
mu: RaveledVars,
141+
H_inv: np.ndarray | None,
142+
model: pm.Model | None = None,
140143
) -> az.InferenceData:
141144
"""
142145
Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object.
@@ -147,7 +150,7 @@ def add_fit_to_inference_data(
147150
An InferenceData object containing the approximated posterior samples.
148151
mu: RaveledVars
149152
The MAP estimate of the model parameters.
150-
H_inv: np.ndarray
153+
H_inv: np.ndarray, optional
151154
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
152155
model: Model, optional
153156
A PyMC model. If None, the model is taken from the current model context.

pymc_extras/inference/laplace_approx/laplace.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def fit_laplace(
389389
include_transformed=include_transformed,
390390
gradient_backend=gradient_backend,
391391
compile_kwargs=compile_kwargs,
392+
compute_hessian=True,
392393
**optimizer_kwargs,
393394
)
394395

tests/inference/laplace_approx/test_find_map.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def compute_z(x):
133133
],
134134
)
135135
@pytest.mark.parametrize(
136-
"backend, gradient_backend, include_transformed",
137-
[("jax", "jax", True), ("jax", "pytensor", False)],
136+
"backend, gradient_backend, include_transformed, compute_hessian",
137+
[("jax", "jax", True, True), ("jax", "pytensor", False, False)],
138138
ids=str,
139139
)
140140
def test_find_MAP(
@@ -145,6 +145,7 @@ def test_find_MAP(
145145
backend,
146146
gradient_backend: GradientBackend,
147147
include_transformed,
148+
compute_hessian,
148149
rng,
149150
):
150151
pytest.importorskip("jax")
@@ -164,6 +165,7 @@ def test_find_MAP(
164165
include_transformed=include_transformed,
165166
compile_kwargs={"mode": backend.upper()},
166167
maxiter=5,
168+
compute_hessian=compute_hessian,
167169
)
168170

169171
assert hasattr(idata, "posterior")
@@ -184,6 +186,8 @@ def test_find_MAP(
184186
else:
185187
assert not hasattr(idata, "unconstrained_posterior")
186188

189+
assert ("covariance_matrix" in idata.fit) == compute_hessian
190+
187191

188192
def test_find_map_outside_model_context():
189193
"""

0 commit comments

Comments
 (0)