Skip to content

Commit 865e3da

Browse files
author
kkacanja
committed
Marginalized Total Mass Model. The model takes a base model and brute marginalizes overtop by rescaling one waveform to different total masses.
1 parent 6abd01d commit 865e3da

3 files changed

Lines changed: 225 additions & 7 deletions

File tree

pycbc/inference/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .marginalized_gaussian_noise import MarginalizedPolarization
3232
from .marginalized_gaussian_noise import MarginalizedHMPolPhase
3333
from .marginalized_gaussian_noise import MarginalizedTime
34+
from .brute_marg import BruteTotalMassMarginalize
3435
from .brute_marg import BruteParallelGaussianMarginalize
3536
from .brute_marg import BruteLISASkyModesMarginalize
3637
from .gated_gaussian_noise import (GatedGaussianNoise, GatedGaussianMargPol)
@@ -198,6 +199,7 @@ def read_from_config(cp, **kwargs):
198199
MarginalizedPolarization,
199200
MarginalizedHMPolPhase,
200201
MarginalizedTime,
202+
BruteTotalMassMarginalize,
201203
BruteParallelGaussianMarginalize,
202204
BruteLISASkyModesMarginalize,
203205
GatedGaussianNoise,

pycbc/inference/models/brute_marg.py

Lines changed: 214 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2020 Alex Nitz
1+
# Copyright (C) 2020 Alex Nitz
22
# This program is free software; you can redistribute it and/or modify it
33
# under the terms of the GNU General Public License as published by the
44
# Free Software Foundation; either version 3 of the License, or (at your
@@ -24,7 +24,14 @@
2424
from scipy.special import logsumexp
2525

2626
from .gaussian_noise import BaseGaussianNoise
27-
from .tools import draw_sample
27+
from .tools import draw_sample, marginalize_likelihood, DistMarg
28+
29+
30+
from pycbc.waveform import td_taper
31+
from pycbc.waveform import get_waveform_filter_length_in_time
32+
from pycbc.conversions import mass1_from_mtotal_q, mass2_from_mtotal_q
33+
from pycbc.types import TimeSeries, FrequencySeries
34+
from pycbc.distributions import JointDistribution
2835

2936
_model = None
3037
class likelihood_wrapper(object):
@@ -34,10 +41,215 @@ def __init__(self, model):
3441

3542
def __call__(self, params):
3643
global _model
44+
3745
_model.update(**params)
46+
3847
loglr = _model.loglr
48+
lognl = _model.lognl
49+
loglikelihood = lognl + loglr
50+
logjacobian = _model.logjacobian
51+
logprior = _model.logprior
52+
3953
return loglr, _model.current_stats
4054

55+
56+
class BruteTotalMassMarginalize(BaseGaussianNoise, DistMarg):
57+
name = "marginalized_mtotal"
58+
59+
def __init__(self, variable_params,
60+
cores=1,
61+
base_model=None,
62+
marginalize_mtotal=None,
63+
mtotal_grid=None,
64+
mtotal_grid_num=None,
65+
fiducial_mtotal=None,
66+
**kwds):
67+
from pycbc.inference.models import models
68+
69+
self.marginalize_vector_params = []
70+
if 'marginalize_vector_params' in kwds:
71+
if isinstance(kwds['marginalize_vector_params'], str) and kwds['marginalize_vector_params']:
72+
self.marginalize_vector_params = [p.strip() for p in kwds['marginalize_vector_params'].split(',') if p.strip()]
73+
elif isinstance(kwds['marginalize_vector_params'], (list, tuple)):
74+
self.marginalize_vector_params = [p.strip() for p in kwds['marginalize_vector_params'] if p.strip()]
75+
76+
# --- Save the original mtotal prior distribution and bounds before deletion ---
77+
self.mtotal_prior_dist = None
78+
self.mtotal_bounds = None
79+
prior = kwds.get('prior', None)
80+
if prior and 'mtotal' in prior.bounds:
81+
for dist in prior.distributions:
82+
if dist.name == 'mtotal':
83+
self.mtotal_prior_dist = dist
84+
break
85+
self.mtotal_bounds = (prior.bounds['mtotal'].min, prior.bounds['mtotal'].max)
86+
87+
# --- Setup bounds for mtotal marginalization using the saved bounds ---
88+
if marginalize_mtotal:
89+
prior = kwds.get('prior', {})
90+
if prior:
91+
prior_bounds = prior.bounds
92+
min_mtotal = prior_bounds['mtotal'].min
93+
max_mtotal = prior_bounds['mtotal'].max
94+
else:
95+
min_mtotal = max_mtotal = None
96+
97+
if fiducial_mtotal is None and min_mtotal is not None and max_mtotal is not None:
98+
fiducial_mtotal = 0.5 * (min_mtotal + max_mtotal)
99+
else:
100+
fiducial_mtotal = float(fiducial_mtotal) if fiducial_mtotal else None
101+
102+
if mtotal_grid is None and min_mtotal is not None and max_mtotal is not None:
103+
if mtotal_grid_num is None:
104+
raise ValueError("Must specify mtotal_grid_num if mtotal_grid is not given")
105+
self.mtotal_grid = numpy.linspace(min_mtotal, max_mtotal, int(mtotal_grid_num))
106+
elif mtotal_grid is not None:
107+
self.mtotal_grid = numpy.array([float(m) for m in mtotal_grid])
108+
else:
109+
raise ValueError("Either provide 'mtotal_grid' (a list or array) or specify 'mtotal_grid_num' (the number of points to create a grid over the prior).")
110+
111+
base_model_cls = models[base_model]
112+
self.model = base_model_cls(variable_params=variable_params, **kwds)
113+
self.call = likelihood_wrapper(self.model)
114+
115+
marginalized_params = []
116+
117+
if marginalize_mtotal:
118+
marginalized_params.append('mtotal')
119+
if 'marginalize_phase' in kwds and kwds['marginalize_phase']:
120+
marginalized_params.append('coa_phase')
121+
if 'marginalize_distance' in kwds and kwds['marginalize_distance']:
122+
marginalized_params.append(kwds.get('marginalize_distance_param', 'distance'))
123+
124+
# ------ Reconstruct the marginalized prior ------
125+
marginalized_params.extend(self.marginalize_vector_params)
126+
127+
variable_params = tuple(p for p in variable_params if p not in marginalized_params)
128+
129+
prior_dists_by_param = {}
130+
for d in self.model.prior_distribution.distributions:
131+
if hasattr(d, 'params') and isinstance(d.params, (list, tuple)) and len(d.params) > 0:
132+
for p_name in d.params:
133+
if isinstance(p_name, str):
134+
prior_dists_by_param[p_name] = d
135+
else:
136+
if isinstance(d.name, str) and d.name in variable_params:
137+
prior_dists_by_param[d.name] = d
138+
139+
current_prior_dists = prior_dists_by_param.copy()
140+
for param in marginalized_params:
141+
if param in current_prior_dists:
142+
del current_prior_dists[param]
143+
144+
variable_params = tuple(p for p in variable_params if p not in marginalized_params)
145+
146+
self.model.prior_distribution = JointDistribution(variable_params, *current_prior_dists.values())
147+
kwds['prior'] = self.model.prior_distribution
148+
149+
super().__init__(variable_params, **kwds)
150+
151+
# Set up multiprocessing
152+
if cores > 1:
153+
self.pool = Pool(int(cores))
154+
self.mapfunc = self.pool.map
155+
else:
156+
self.pool = None
157+
self.mapfunc = map
158+
159+
self.marginalize_mtotal = marginalize_mtotal
160+
self.fiducial_mtotal = fiducial_mtotal
161+
162+
@property
163+
def _extra_stats(self):
164+
stats = self.model._extra_stats
165+
stats.append('maxl_mtotal')
166+
if 'maxl_loglr' not in stats:
167+
stats.append('maxl_loglr')
168+
return stats
169+
170+
def scale_waveform(self, h_plus_ref, h_cross_ref, mtotal_ref, mtotal_new, q, f_lower, approximant):
171+
# Compute scaling factors
172+
time_scale = mtotal_new / mtotal_ref
173+
amp_scale = mtotal_new / mtotal_ref
174+
175+
# Rescale amplitude
176+
h_plus_scaled = h_plus_ref * amp_scale
177+
h_cross_scaled = h_cross_ref * amp_scale
178+
179+
return h_plus_scaled, h_cross_scaled
180+
181+
def _loglr(self):
182+
if self.mtotal_grid is not None:
183+
ref_params = self.current_params.copy()
184+
for key in ['ra', 'dec', 'tc', 'polarization', 'distance','mtotal']:
185+
ref_params.pop(key, None)
186+
ref_params["mtotal"] = self.fiducial_mtotal
187+
ref_params["mass1"] = mass1_from_mtotal_q(ref_params["mtotal"], ref_params["q"])
188+
ref_params["mass2"] = mass2_from_mtotal_q(ref_params["mtotal"], ref_params["q"])
189+
190+
wfs = self.model.waveform_generator.generate(**ref_params)
191+
192+
self.reference_waveform = wfs
193+
194+
approximant = ref_params['approximant']
195+
f_lower = ref_params['f_lower']
196+
q = self.current_params["q"]
197+
198+
params = []
199+
self.scaled_waveforms = {}
200+
201+
params = []
202+
203+
for mtotal in self.mtotal_grid:
204+
scaled_waveform = {}
205+
for det, (hp, hc) in wfs.items():
206+
scaled_hp, scaled_hc = self.scale_waveform(
207+
hp, hc, self.fiducial_mtotal, mtotal, q, f_lower,
208+
approximant=approximant
209+
)
210+
scaled_waveform[det] = (scaled_hp, scaled_hc)
211+
212+
self.scaled_waveforms[mtotal] = scaled_waveform
213+
214+
pmod = ref_params.copy()
215+
pmod.pop("mass1", None)
216+
pmod.pop("mass2", None)
217+
pmod.pop("mtotal", None)
218+
pmod["custom_waveform"] = self.scaled_waveforms[mtotal]
219+
pmod["rescale_mtotal"] = mtotal
220+
params.append(pmod)
221+
222+
# Check first loglr value to decide whether to proceed scaling the waveform (skips calculation for loglr <0)
223+
first_val = self.call(params[0])
224+
first_loglr = first_val[0]
225+
226+
if first_loglr < 0:
227+
print("Early exit: First loglr < 0, skipping full grid evaluation.")
228+
self._current_stats.maxl_mtotal = self.mtotal_grid[0]
229+
self._current_stats.maxl_loglr = first_loglr
230+
for stat in first_val[1]:
231+
setattr(self._current_stats, stat, first_val[1][stat])
232+
return first_loglr
233+
234+
# Proceed with full marginalization if the first is acceptable
235+
params_full = params
236+
vals = list(self.mapfunc(self.call, params_full))
237+
238+
loglr = numpy.array([v[0] for v in vals])
239+
maxidx = loglr.argmax()
240+
maxstats = vals[maxidx][1]
241+
max_mtotal = self.mtotal_grid[maxidx]
242+
243+
for stat in maxstats:
244+
setattr(self._current_stats, stat, maxstats[stat])
245+
self._current_stats.maxl_mtotal = max_mtotal
246+
self._current_stats.maxl_loglr = loglr[maxidx]
247+
print(logsumexp(loglr) - numpy.log(len(self.mtotal_grid)))
248+
return logsumexp(loglr) - numpy.log(len(self.mtotal_grid))
249+
else:
250+
return self.model.loglr
251+
252+
41253
class BruteParallelGaussianMarginalize(BaseGaussianNoise):
42254
name = "brute_parallel_gaussian_marginalize"
43255

pycbc/inference/models/marginalized_gaussian_noise.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,16 @@ def _loglr(self):
271271
from pycbc.filter import matched_filter_core
272272

273273
params = self.current_params
274-
if self.all_ifodata_same_rate_length:
275-
wfs = self.waveform_generator.generate(**params)
274+
# Custom waveform is only being used with the brute marginalized total mass model
275+
if "custom_waveform" in params:
276+
wfs = params["custom_waveform"]
276277
else:
277-
wfs = {}
278-
for det in self.data:
279-
wfs.update(self.waveform_generator[det].generate(**params))
278+
if self.all_ifodata_same_rate_length:
279+
wfs = self.waveform_generator.generate(**params)
280+
else:
281+
wfs = {}
282+
for det in self.data:
283+
wfs.update(self.waveform_generator[det].generate(**params))
280284
sh_total = hh_total = 0.
281285
snr_estimate = {}
282286
cplx_hpd = {}

0 commit comments

Comments
 (0)