1- # Copyright (C) 2020 Alex Nitz
1+ # Copyright (C) 2020 Alex Nitz
22# This program is free software; you can redistribute it and/or modify it
33# under the terms of the GNU General Public License as published by the
44# Free Software Foundation; either version 3 of the License, or (at your
2424from scipy .special import logsumexp
2525
2626from .gaussian_noise import BaseGaussianNoise
27- from .tools import draw_sample
27+ from .tools import draw_sample , marginalize_likelihood , DistMarg
28+
29+
30+ from pycbc .waveform import td_taper
31+ from pycbc .waveform import get_waveform_filter_length_in_time
32+ from pycbc .conversions import mass1_from_mtotal_q , mass2_from_mtotal_q
33+ from pycbc .types import TimeSeries , FrequencySeries
34+ from pycbc .distributions import JointDistribution
2835
2936_model = None
3037class likelihood_wrapper (object ):
@@ -34,10 +41,215 @@ def __init__(self, model):
3441
3542 def __call__ (self , params ):
3643 global _model
44+
3745 _model .update (** params )
46+
3847 loglr = _model .loglr
48+ lognl = _model .lognl
49+ loglikelihood = lognl + loglr
50+ logjacobian = _model .logjacobian
51+ logprior = _model .logprior
52+
3953 return loglr , _model .current_stats
4054
55+
56+ class BruteTotalMassMarginalize (BaseGaussianNoise , DistMarg ):
57+ name = "marginalized_mtotal"
58+
59+ def __init__ (self , variable_params ,
60+ cores = 1 ,
61+ base_model = None ,
62+ marginalize_mtotal = None ,
63+ mtotal_grid = None ,
64+ mtotal_grid_num = None ,
65+ fiducial_mtotal = None ,
66+ ** kwds ):
67+ from pycbc .inference .models import models
68+
69+ self .marginalize_vector_params = []
70+ if 'marginalize_vector_params' in kwds :
71+ if isinstance (kwds ['marginalize_vector_params' ], str ) and kwds ['marginalize_vector_params' ]:
72+ self .marginalize_vector_params = [p .strip () for p in kwds ['marginalize_vector_params' ].split (',' ) if p .strip ()]
73+ elif isinstance (kwds ['marginalize_vector_params' ], (list , tuple )):
74+ self .marginalize_vector_params = [p .strip () for p in kwds ['marginalize_vector_params' ] if p .strip ()]
75+
76+ # --- Save the original mtotal prior distribution and bounds before deletion ---
77+ self .mtotal_prior_dist = None
78+ self .mtotal_bounds = None
79+ prior = kwds .get ('prior' , None )
80+ if prior and 'mtotal' in prior .bounds :
81+ for dist in prior .distributions :
82+ if dist .name == 'mtotal' :
83+ self .mtotal_prior_dist = dist
84+ break
85+ self .mtotal_bounds = (prior .bounds ['mtotal' ].min , prior .bounds ['mtotal' ].max )
86+
87+ # --- Setup bounds for mtotal marginalization using the saved bounds ---
88+ if marginalize_mtotal :
89+ prior = kwds .get ('prior' , {})
90+ if prior :
91+ prior_bounds = prior .bounds
92+ min_mtotal = prior_bounds ['mtotal' ].min
93+ max_mtotal = prior_bounds ['mtotal' ].max
94+ else :
95+ min_mtotal = max_mtotal = None
96+
97+ if fiducial_mtotal is None and min_mtotal is not None and max_mtotal is not None :
98+ fiducial_mtotal = 0.5 * (min_mtotal + max_mtotal )
99+ else :
100+ fiducial_mtotal = float (fiducial_mtotal ) if fiducial_mtotal else None
101+
102+ if mtotal_grid is None and min_mtotal is not None and max_mtotal is not None :
103+ if mtotal_grid_num is None :
104+ raise ValueError ("Must specify mtotal_grid_num if mtotal_grid is not given" )
105+ self .mtotal_grid = numpy .linspace (min_mtotal , max_mtotal , int (mtotal_grid_num ))
106+ elif mtotal_grid is not None :
107+ self .mtotal_grid = numpy .array ([float (m ) for m in mtotal_grid ])
108+ else :
109+ raise ValueError ("Either provide 'mtotal_grid' (a list or array) or specify 'mtotal_grid_num' (the number of points to create a grid over the prior)." )
110+
111+ base_model_cls = models [base_model ]
112+ self .model = base_model_cls (variable_params = variable_params , ** kwds )
113+ self .call = likelihood_wrapper (self .model )
114+
115+ marginalized_params = []
116+
117+ if marginalize_mtotal :
118+ marginalized_params .append ('mtotal' )
119+ if 'marginalize_phase' in kwds and kwds ['marginalize_phase' ]:
120+ marginalized_params .append ('coa_phase' )
121+ if 'marginalize_distance' in kwds and kwds ['marginalize_distance' ]:
122+ marginalized_params .append (kwds .get ('marginalize_distance_param' , 'distance' ))
123+
124+ # ------ Reconstruct the marginalized prior ------
125+ marginalized_params .extend (self .marginalize_vector_params )
126+
127+ variable_params = tuple (p for p in variable_params if p not in marginalized_params )
128+
129+ prior_dists_by_param = {}
130+ for d in self .model .prior_distribution .distributions :
131+ if hasattr (d , 'params' ) and isinstance (d .params , (list , tuple )) and len (d .params ) > 0 :
132+ for p_name in d .params :
133+ if isinstance (p_name , str ):
134+ prior_dists_by_param [p_name ] = d
135+ else :
136+ if isinstance (d .name , str ) and d .name in variable_params :
137+ prior_dists_by_param [d .name ] = d
138+
139+ current_prior_dists = prior_dists_by_param .copy ()
140+ for param in marginalized_params :
141+ if param in current_prior_dists :
142+ del current_prior_dists [param ]
143+
144+ variable_params = tuple (p for p in variable_params if p not in marginalized_params )
145+
146+ self .model .prior_distribution = JointDistribution (variable_params , * current_prior_dists .values ())
147+ kwds ['prior' ] = self .model .prior_distribution
148+
149+ super ().__init__ (variable_params , ** kwds )
150+
151+ # Set up multiprocessing
152+ if cores > 1 :
153+ self .pool = Pool (int (cores ))
154+ self .mapfunc = self .pool .map
155+ else :
156+ self .pool = None
157+ self .mapfunc = map
158+
159+ self .marginalize_mtotal = marginalize_mtotal
160+ self .fiducial_mtotal = fiducial_mtotal
161+
162+ @property
163+ def _extra_stats (self ):
164+ stats = self .model ._extra_stats
165+ stats .append ('maxl_mtotal' )
166+ if 'maxl_loglr' not in stats :
167+ stats .append ('maxl_loglr' )
168+ return stats
169+
170+ def scale_waveform (self , h_plus_ref , h_cross_ref , mtotal_ref , mtotal_new , q , f_lower , approximant ):
171+ # Compute scaling factors
172+ time_scale = mtotal_new / mtotal_ref
173+ amp_scale = mtotal_new / mtotal_ref
174+
175+ # Rescale amplitude
176+ h_plus_scaled = h_plus_ref * amp_scale
177+ h_cross_scaled = h_cross_ref * amp_scale
178+
179+ return h_plus_scaled , h_cross_scaled
180+
181+ def _loglr (self ):
182+ if self .mtotal_grid is not None :
183+ ref_params = self .current_params .copy ()
184+ for key in ['ra' , 'dec' , 'tc' , 'polarization' , 'distance' ,'mtotal' ]:
185+ ref_params .pop (key , None )
186+ ref_params ["mtotal" ] = self .fiducial_mtotal
187+ ref_params ["mass1" ] = mass1_from_mtotal_q (ref_params ["mtotal" ], ref_params ["q" ])
188+ ref_params ["mass2" ] = mass2_from_mtotal_q (ref_params ["mtotal" ], ref_params ["q" ])
189+
190+ wfs = self .model .waveform_generator .generate (** ref_params )
191+
192+ self .reference_waveform = wfs
193+
194+ approximant = ref_params ['approximant' ]
195+ f_lower = ref_params ['f_lower' ]
196+ q = self .current_params ["q" ]
197+
198+ params = []
199+ self .scaled_waveforms = {}
200+
201+ params = []
202+
203+ for mtotal in self .mtotal_grid :
204+ scaled_waveform = {}
205+ for det , (hp , hc ) in wfs .items ():
206+ scaled_hp , scaled_hc = self .scale_waveform (
207+ hp , hc , self .fiducial_mtotal , mtotal , q , f_lower ,
208+ approximant = approximant
209+ )
210+ scaled_waveform [det ] = (scaled_hp , scaled_hc )
211+
212+ self .scaled_waveforms [mtotal ] = scaled_waveform
213+
214+ pmod = ref_params .copy ()
215+ pmod .pop ("mass1" , None )
216+ pmod .pop ("mass2" , None )
217+ pmod .pop ("mtotal" , None )
218+ pmod ["custom_waveform" ] = self .scaled_waveforms [mtotal ]
219+ pmod ["rescale_mtotal" ] = mtotal
220+ params .append (pmod )
221+
222+ # Check first loglr value to decide whether to proceed scaling the waveform (skips calculation for loglr <0)
223+ first_val = self .call (params [0 ])
224+ first_loglr = first_val [0 ]
225+
226+ if first_loglr < 0 :
227+ print ("Early exit: First loglr < 0, skipping full grid evaluation." )
228+ self ._current_stats .maxl_mtotal = self .mtotal_grid [0 ]
229+ self ._current_stats .maxl_loglr = first_loglr
230+ for stat in first_val [1 ]:
231+ setattr (self ._current_stats , stat , first_val [1 ][stat ])
232+ return first_loglr
233+
234+ # Proceed with full marginalization if the first is acceptable
235+ params_full = params
236+ vals = list (self .mapfunc (self .call , params_full ))
237+
238+ loglr = numpy .array ([v [0 ] for v in vals ])
239+ maxidx = loglr .argmax ()
240+ maxstats = vals [maxidx ][1 ]
241+ max_mtotal = self .mtotal_grid [maxidx ]
242+
243+ for stat in maxstats :
244+ setattr (self ._current_stats , stat , maxstats [stat ])
245+ self ._current_stats .maxl_mtotal = max_mtotal
246+ self ._current_stats .maxl_loglr = loglr [maxidx ]
247+ print (logsumexp (loglr ) - numpy .log (len (self .mtotal_grid )))
248+ return logsumexp (loglr ) - numpy .log (len (self .mtotal_grid ))
249+ else :
250+ return self .model .loglr
251+
252+
41253class BruteParallelGaussianMarginalize (BaseGaussianNoise ):
42254 name = "brute_parallel_gaussian_marginalize"
43255
0 commit comments