Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
859e8ed
add ringdown wf models to output individual modes
acorreia61201 Apr 3, 2025
7dcbd42
change approx names
acorreia61201 Apr 4, 2025
98f3cec
add ggn model with phase marginalization (first pass)
acorreia61201 Apr 9, 2025
7e364eb
calculate and set maxL phase from marginalization
acorreia61201 Apr 9, 2025
c8a1703
add support for new ringdown models in multimode generators; restruct…
acorreia61201 Apr 10, 2025
f3ab787
second pass (fix errors to work with ringdown waveforms)
acorreia61201 Apr 10, 2025
107cb63
fixed angles; outputting values without errors for logL and phase
acorreia61201 Apr 10, 2025
d3bbbf6
add support for td mode plugins; fix/clean up phase marg likelihood
acorreia61201 Apr 15, 2025
506195a
fix likelihood evaluation in marg phase model; add controls to read i…
acorreia61201 May 8, 2025
0aa9a84
switch to overwhitening in likelihood calcs; rebase to master
acorreia61201 May 15, 2025
b7d27d5
comment debug statements
acorreia61201 May 15, 2025
ba2cef6
slice wfs during inner products to be consistent with other models
acorreia61201 May 20, 2025
50dcb76
fixed bug in ringdown modal wf calls; make get_waveforms more consist…
acorreia61201 May 20, 2025
52c7aa2
remove freq domain ringdown approximants from waveform_modes
acorreia61201 May 20, 2025
24adb95
rewrite to marginalize with respect to relative phases
acorreia61201 Aug 21, 2025
daa856d
move lognl to likelihood function
acorreia61201 Aug 22, 2025
707f1e0
returning correct phase; fixing log likelihood
acorreia61201 Aug 25, 2025
be5557a
change likelihood; agrees with ungated marg phase model
acorreia61201 Aug 25, 2025
04eba16
fixed normalization
acorreia61201 Aug 26, 2025
c07a4dd
remove debugging for pe testing
acorreia61201 Aug 27, 2025
178077e
fix bug in gate time calculation
acorreia61201 Sep 2, 2025
e85b034
marginalize directly in method instead of calling from tools
acorreia61201 Sep 9, 2025
1577a43
split cross term in likelihood calculation
acorreia61201 Sep 25, 2025
3368c65
add missing placeholder
acorreia61201 Sep 29, 2025
4f00be0
bug in likelihood calc
acorreia61201 Sep 30, 2025
d3d7589
tweaks while debugging
acorreia61201 Oct 16, 2025
372dc35
first pass of numerical marginalization
acorreia61201 Nov 10, 2025
dcb551a
move cos/sin wfs to generator class; fix relative phase shift for sin…
acorreia61201 Nov 11, 2025
7ee143d
bugfix to generator; add docs
acorreia61201 Nov 12, 2025
48b90db
rebase and update plugins to master
acorreia61201 Nov 17, 2025
133660d
remove unnecessary commits
acorreia61201 Nov 21, 2025
e5c2b38
isolate waveform changes from marg phase pr
acorreia61201 Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions pycbc/inference/models/gated_gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _set_covfit(self, det):
def logdet_fit(self, cov, p):
"""Construct a linear regression from a sample of truncated covariance
matrices.

Returns the sample points used for linear fit generation as well as the
linear fit parameters.
"""
Expand Down Expand Up @@ -197,7 +197,7 @@ def logdet_fit(self, cov, p):
x = numpy.vstack([sample_sizes, numpy.ones(len(sample_sizes))]).T
m, b = numpy.linalg.lstsq(x, sample_dets, rcond=None)[0]
return (sample_sizes, sample_dets), (m, b)

@BaseGaussianNoise.normalize.setter
def normalize(self, normalize):
"""Clears the current stats if the normalization state is changed.
Expand Down Expand Up @@ -226,11 +226,11 @@ def gate_indices(self, det):
gt = gate_start + window
lindex, rindex = ts.get_gate_indices(gt, window)
return lindex, rindex

def det_lognorm(self, det, start_index=None, end_index=None):
"""Calculate the normalization term from the truncated covariance
matrix.

Determinant is estimated using a linear fit to logdet vs truncated
matrix size.
"""
Expand Down Expand Up @@ -414,14 +414,14 @@ def _get_gate_times(self, gatestart, gateend, ra, dec):
Parameters
----------
gatestart : float
Geocentric start time of the gate.
Start time of the gate.
gateend : float
Geocentric end time of the gate.
End time of the gate.
ra : float
Right ascension of the signal.
dec : float
Declination of the signal.

Returns
-------
dict :
Expand All @@ -432,10 +432,9 @@ def _get_gate_times(self, gatestart, gateend, ra, dec):
thisdet = Detector(det)
# account for the time delay between the waveforms of the
# different detectors
gatestartdelay = gatestart + thisdet.time_delay_from_earth_center(
ra, dec, gatestart)
gateenddelay = gateend + thisdet.time_delay_from_earth_center(
ra, dec, gateend)
refdet = self.current_params.get('tc_ref_frame', 'geocentric')
gatestartdelay = thisdet.arrival_time(gatestart, ra, dec, refdet)
gateenddelay = thisdet.arrival_time(gateend, ra, dec, refdet)
dgatedelay = gateenddelay - gatestartdelay
gatetimes[det] = (gatestartdelay, dgatedelay)
return gatetimes
Expand Down Expand Up @@ -672,14 +671,31 @@ def _loglikelihood(self):
rr = 4 * invpsd.delta_f * rtilde[slc].inner(gated_rtilde[slc]).real
logl += norm - 0.5*rr
return float(logl)


@property
def _extra_stats(self):
"""Adds ``loglr``, plus ``cplx_loglr`` and ``optimal_snrsq`` in each
detector."""
return ['loglr', 'maxl_phase'] + ['{}_optimal_snrsq'.format(det) for det in self._data]

def _nowaveform_loglr(self):
"""Convenience function to set loglr values if no waveform generated.
"""
setattr(self._current_stats, 'loglikelihood', -numpy.inf)
# maxl phase doesn't exist, so set it to nan
setattr(self._current_stats, 'maxl_phase', numpy.nan)
for det in self._data:
# snr can't be < 0 by definition, so return 0
setattr(self._current_stats, '{}_optimal_snrsq'.format(det), 0.)
return -numpy.inf

@property
def multi_signal_support(self):
""" The list of classes that this model supports in a multi-signal
likelihood
"""
return [type(self)]

def multi_loglikelihood(self, models):
""" Calculate a multi-model (signal) likelihood
"""
Expand Down Expand Up @@ -805,7 +821,7 @@ def get_gated_waveforms(self):
pols.append(h)
out[det] = tuple(pols)
return out

def get_gate_times_hmeco(self):
"""Gets the time to apply a gate based on the current sky position.
Returns
Expand Down Expand Up @@ -931,14 +947,14 @@ def _loglikelihood(self):
# compute the marginalized log likelihood
marglogl = special.logsumexp(loglr) + lognl - numpy.log(len(self.pol))
return float(marglogl)

@property
def multi_signal_support(self):
""" The list of classes that this model supports in a multi-signal
likelihood
"""
return [type(self)]

@catch_waveform_error
def multi_loglikelihood(self, models):
""" Calculate a multi-model (signal) likelihood
Expand All @@ -964,4 +980,4 @@ def multi_loglikelihood(self, models):
for x in wfs]))

self._current_wfs = combine
return self._loglikelihood()
return self._loglikelihood()
50 changes: 40 additions & 10 deletions pycbc/waveform/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,13 @@ class FDomainMassSpinRingdownGenerator(BaseGenerator):

"""
def __init__(self, variable_args=(), **frozen_params):
super(FDomainMassSpinRingdownGenerator, self).__init__(ringdown.get_fd_from_final_mass_spin,
if frozen_params['approximant'] == 'FdQNMfromFinalMassSpin':
approximant = ringdown.get_fd_from_final_mass_spin
elif frozen_params['approximant'] == 'FdModesfromFinalMassSpin':
approximant = ringdown.get_fd_modes_from_final_mass_spin
else:
raise ValueError(f"Invalid approximant name: {frozen_params['approximant']}")
super(FDomainMassSpinRingdownGenerator, self).__init__(approximant,
variable_args=variable_args, **frozen_params)


Expand Down Expand Up @@ -371,7 +377,13 @@ class FDomainFreqTauRingdownGenerator(BaseGenerator):

"""
def __init__(self, variable_args=(), **frozen_params):
super(FDomainFreqTauRingdownGenerator, self).__init__(ringdown.get_fd_from_freqtau,
if frozen_params['approximant'] == 'FdQNMfromFreqTau':
approximant = ringdown.get_fd_from_freqtau
elif frozen_params['approximant'] == 'FdModesfromFreqTau':
approximant = ringdown.get_fd_modes_from_freqtau
else:
raise ValueError(f"Invalid approximant name: {frozen_params['approximant']}")
super(FDomainFreqTauRingdownGenerator, self).__init__(approximant,
variable_args=variable_args, **frozen_params)


Expand Down Expand Up @@ -399,7 +411,13 @@ class TDomainMassSpinRingdownGenerator(BaseGenerator):

"""
def __init__(self, variable_args=(), **frozen_params):
super(TDomainMassSpinRingdownGenerator, self).__init__(ringdown.get_td_from_final_mass_spin,
if frozen_params['approximant'] == 'TdQNMfromFinalMassSpin':
approximant = ringdown.get_td_from_final_mass_spin
elif frozen_params['approximant'] == 'TdModesfromFinalMassSpin':
approximant = ringdown.get_td_modes_from_final_mass_spin
else:
raise ValueError(f"Invalid approximant name: {frozen_params['approximant']}")
super(TDomainMassSpinRingdownGenerator, self).__init__(approximant,
variable_args=variable_args, **frozen_params)


Expand Down Expand Up @@ -427,7 +445,13 @@ class TDomainFreqTauRingdownGenerator(BaseGenerator):

"""
def __init__(self, variable_args=(), **frozen_params):
super(TDomainFreqTauRingdownGenerator, self).__init__(ringdown.get_td_from_freqtau,
if frozen_params['approximant'] == 'TdQNMfromFreqTau':
approximant = ringdown.get_td_from_freqtau
elif frozen_params['approximant'] == 'TdModesfromFreqTau':
approximant = ringdown.get_td_modes_from_freqtau
else:
raise ValueError(f"Invalid approximant name: {frozen_params['approximant']}")
super(TDomainFreqTauRingdownGenerator, self).__init__(approximant,
variable_args=variable_args, **frozen_params)


Expand Down Expand Up @@ -1220,32 +1244,38 @@ def get_td_generator(approximant, modes=False):
if modes:
return TDomainCBCModesGenerator
return TDomainCBCGenerator

if approximant in waveform_modes._mode_waveform_td:
return TDomainCBCModesGenerator

if approximant in ringdown.ringdown_td_approximants:
if approximant == 'TdQNMfromFinalMassSpin':
if approximant in ['TdQNMfromFinalMassSpin', 'TdModesfromFinalMassSpin']:
return TDomainMassSpinRingdownGenerator
return TDomainFreqTauRingdownGenerator

if approximant in supernovae.supernovae_td_approximants:
return TDomainSupernovaeGenerator

raise ValueError(f"No time-domain generator found for "
"approximant: {approximant}")
raise ValueError(f"No time-domain generator found for"
"approximant: {approximant}")

def get_fd_generator(approximant, modes=False):
"""Returns the frequency-domain generator for the given approximant."""
if approximant in waveform.fd_approximants():
if modes:
return FDomainCBCModesGenerator
return FDomainCBCGenerator

if approximant in waveform_modes._mode_waveform_fd:
return FDomainCBCModesGenerator

if approximant in ringdown.ringdown_fd_approximants:
if approximant == 'FdQNMfromFinalMassSpin':
if approximant in ['FdQNMfromFinalMassSpin', 'FdModesfromFinalMassSpin']:
return FDomainMassSpinRingdownGenerator
return FDomainFreqTauRingdownGenerator

raise ValueError(f"No frequency-domain generator found for "
"approximant: {approximant}")
raise ValueError(f"No frequency-domain generator found for"
"approximant: {approximant}")

def select_waveform_generator(approximant, domain=None):
"""Returns the single-IFO generator for the approximant.
Expand Down
13 changes: 11 additions & 2 deletions pycbc/waveform/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


def add_custom_waveform(approximant, function, domain,
sequence=False, has_det_response=False,
sequence=False, has_det_response=False, modes=False,
force=False,):
""" Make custom waveform available to pycbc

Expand All @@ -23,14 +23,18 @@ def add_custom_waveform(approximant, function, domain,
"""
from pycbc.waveform.waveform import (cpu_fd, cpu_td, fd_sequence,
fd_det, fd_det_sequence)
from pycbc.waveform.waveform_modes import _mode_waveform_td

used = RuntimeError("Can't load plugin waveform {}, the name is"
" already in use.".format(approximant))

if domain == 'time':
if not force and (approximant in cpu_td):
raise used
cpu_td[approximant] = function
if modes:
_mode_waveform_td[approximant] = function
else:
cpu_td[approximant] = function
elif domain == 'frequency':
if sequence:
if not has_det_response:
Expand Down Expand Up @@ -121,6 +125,11 @@ def retrieve_waveform_plugins():
for plugin in entry_points(group='pycbc.waveform.td'):
add_custom_waveform(plugin.name, plugin.load(), 'time')

# Check for td modal waveforms
for plugin in entry_points(group='pycbc.waveform.td_modes'):
add_custom_waveform(plugin.name, plugin.load(), 'time',
modes=True)

# Check for waveform length estimates
for plugin in entry_points(group='pycbc.waveform.length'):
add_length_estimator(plugin.name, plugin.load())
Expand Down
Loading
Loading