Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pycbc/inference/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .marginalized_gaussian_noise import MarginalizedPolarization
from .marginalized_gaussian_noise import MarginalizedHMPolPhase
from .marginalized_gaussian_noise import MarginalizedTime
from .marginalized_gaussian_noise import MarginalizedHMPhase
from .brute_marg import BruteParallelGaussianMarginalize
from .brute_marg import BruteLISASkyModesMarginalize
from .gated_gaussian_noise import (GatedGaussianNoise, GatedGaussianMargPol)
Expand Down Expand Up @@ -198,6 +199,7 @@ def read_from_config(cp, **kwargs):
MarginalizedPolarization,
MarginalizedHMPolPhase,
MarginalizedTime,
MarginalizedHMPhase,
BruteParallelGaussianMarginalize,
BruteLISASkyModesMarginalize,
GatedGaussianNoise,
Expand Down
151 changes: 148 additions & 3 deletions pycbc/inference/models/marginalized_gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
import logging
import numpy
from scipy import special

from pycbc.waveform import generator
from pycbc.waveform import generator, get_fd_waveform
from pycbc.waveform.utils import apply_fd_time_shift
from pycbc.filter.matchedfilter import overlap_cplx
from pycbc.detector import Detector
from .gaussian_noise import (BaseGaussianNoise,
create_waveform_generator,
GaussianNoise, catch_waveform_error)
from .tools import marginalize_likelihood, DistMarg
from .tools import marginalize_likelihood, DistMarg, hm_phase_marginalize,hm_phase_peaks


class MarginalizedPhaseGaussianNoise(GaussianNoise):
Expand Down Expand Up @@ -806,3 +807,147 @@ def _loglr(self, return_unmarginalized=False):
setattr(self._current_stats, 'maxl_polarization', self.pol[idx])
setattr(self._current_stats, 'maxl_phase', self.phase[idx])
return float(lr_total)

class MarginalizedHMPhase(BaseGaussianNoise):

name = 'marginalized_hm_phase'

def __init__(self, variable_params, data, low_frequency_cutoff,
sample_rate = 2048, numerical = False, grid = False,
grid_points = 2000,first_order_correction = False,offset = 0,
dominant_mode_peak = False,
**kwargs):
super(MarginalizedHMPhase,self).__init__(variable_params, data,
low_frequency_cutoff,
**kwargs)
sample_rate = float(sample_rate)
self.numerical = numerical
self.grid = grid
self.grid_points = grid_points
self.first_order_correction = first_order_correction
self.offset = offset
self.dominant_mode_peak = dominant_mode_peak
df = data[self.detectors[0]].delta_f
self.df = df
self.sample_rate = sample_rate
flen = int(round(sample_rate / self.df) / 2 + 1)
self.flen = flen
self.shm = {}
self.hmhn = {}
# Extract mode array from static params
## TODO : Fix junk in handling mode_arrays.
p = self.static_params.copy()
if 'mode_array' in p:
if isinstance(p['mode_array'],float):
self._mode_array = [(2,2)]
else:
self._mode_array = [(int(lm[0]),int(lm[1])) for lm in p['mode_array'].split()]
_ = p.pop('mode_array')
else:
raise ValueError('Provide Mode array')

# Remove coa_phase from static params since it will be marginalized
if 'coa_phase' in p:
_ = p.pop('coa_phase')

self.static_waveform_params = p

# Initialize detector objects
self.det = {}
for ifo in self.data:
self.det[ifo] = Detector(ifo)
# Resize data to consistent length
self.data[ifo].resize(flen)

def _loglr(self):
# Get current parameters
params = self.current_params
if 'mode_array' in params:
_ = params.pop('mode_array')
if 'coa_phase' in params:
_ = params.pop('coa_phase')
# Calculate waveforms for each mode
## TODO : Can replace all this calculation by waveform_generator.generate(**current_params)?
hplm, hclm = {}, {}
for lm in self._mode_array:
hp, hc = get_fd_waveform(delta_f = self.df,
mode_array = lm,
coa_phase = 0, **params)
hp.resize(self.flen)
hc.resize(self.flen)
hplm[lm] = hp
hclm[lm] = hc


hpm, hcm = {}, {}
for (l,m), hplm_val in hplm.items():
hpm[m] = hpm.get(m, 0) + hplm_val
for (l,m), hclm_val in hclm.items():
hcm[m] = hcm.get(m, 0) + hclm_val

## Detector frame signals ( can replace by generator)?
h = {}
shm = {}
hmhn = {}
for ifo in self.data:
h[ifo] = {}
# Calculate antenna pattern and time shifts
flow = self.kmin[ifo] * self.df
fhigh = self.kmax[ifo] * self.df
fp, fc = self.det[ifo].antenna_pattern(params['ra'], params['dec'],
params['polarization'], params['tc'])

time_delay = self.det[ifo].time_delay_from_earth_center(params['ra'],params['dec'],
params['tc'])
## This is tc in det frame
time_shift = time_delay + params['tc']

## Apply projections for each mode and time shift
for m in hpm:
h_unshifted = fp*hpm[m] + fc*hcm[m]
## Set epoch to data epoch
h_unshifted._epoch = self.data[ifo].start_time
h[ifo][m] = apply_fd_time_shift(h_unshifted, time_shift)

# Calculate inner products shm = <s|hm>
shm[ifo] = {}
for m in h[ifo]:
shm[ifo][m] = overlap_cplx(
self.data[ifo],
h[ifo][m],
psd=self.psds[ifo],
low_frequency_cutoff=flow,
high_frequency_cutoff=fhigh,
normalized=False)
# Calculate inner products hmhn = <hm|hn>
hmhn[ifo] = {}
for m,n in itertools.combinations_with_replacement(h[ifo].keys(), 2):
hmhn[ifo][(m,n)] = overlap_cplx(
h[ifo][m],
h[ifo][n],
psd=self.psds[ifo],
low_frequency_cutoff=flow,
high_frequency_cutoff=fhigh,
normalized=False
)
## TODO : Move these calculations in the above loop
shm_total = {}
hmhn_total = {}
for ifo in shm.keys():
for m in shm[ifo].keys():
if m in shm_total:
shm_total[m] += shm[ifo][m]
else:
shm_total[m] = shm[ifo][m]
for ifo in hmhn.keys():
for (m,n) in hmhn[ifo].keys():
if (m,n) in hmhn_total:
hmhn_total[(m,n)] += hmhn[ifo][(m,n)]
else:
hmhn_total[(m,n)] = hmhn[ifo][(m,n)]
self.shm = shm_total
self.hmhn = hmhn_total
self.peaks = hm_phase_peaks(shm_total,hmhn_total,self.dominant_mode_peak)
return hm_phase_marginalize(shm_total,hmhn_total, self.numerical,self.grid,
self.grid_points,self.first_order_correction,self.offset,
self.dominant_mode_peak)
186 changes: 183 additions & 3 deletions pycbc/inference/models/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import logging
import warnings
from distutils.util import strtobool

import time
import numpy
import numpy.random
import tqdm

from scipy.special import logsumexp, i0e
from scipy.special import logsumexp, i0e, factorial
from scipy.interpolate import RectBivariateSpline, interp1d
from pycbc.distributions import JointDistribution

from scipy.integrate import quad
from pycbc.detector import Detector


Expand Down Expand Up @@ -966,3 +966,183 @@ def marginalize_likelihood(sh, hh,
if return_peak:
return vloglr, maxv, maxl
return vloglr

def hm_phase_marginalize(shm,hmhn,numerical,grid,grid_points,
first_order_correction,offset,
dominant_mode_peak):
'''
returns the likelihood marginalized over the phase provided the
inner products between each modes

Inputs:
shm : A dictionary of <s|hm>

hmhn : A dictionary of <hm|hn>

'''
_m_max = max(shm.keys())
z = {p:0 for p in range(1,_m_max+1)}
for p in shm:
z[p] = -shm[p]
hmhm = 0
for (m,n) in hmhn:
p_val = n-m
if p_val in z:
z[p_val] += hmhn[(m,n)]
if p_val == 0:
hmhm += numpy.real(hmhn[(m,n)])

if grid:
print(f'using grid with {grid_points} points')
start_time = time.perf_counter()
phi = numpy.linspace(0,2*np.pi,grid_points)
logl_phi = (-sum((z[m].real*numpy.cos(m*phi)-z[m].imag*numpy.sin(m*phi))
for m in z)-numpy.abs(z[2])) - numpy.abs(z[2])
marg_lr = logsumexp(logl_phi) - numpy.log(2*numpy.pi) + numpy.abs(z[2]) - (hmhm/2)
end_time = time.perf_counter()
print(end_time-start_time)
return marg_lr


if numerical:
print('using numerical')
start_time = time.perf_counter()
def integrand(phi):
l_phi = (-sum((z[m].real*numpy.cos(m*phi)-z[m].imag*numpy.sin(m*phi))
for m in z)-numpy.abs(z[2]))
return numpy.exp(l_phi)
quad_int = quad(integrand,0,2*numpy.pi)[0]
marg_lr = numpy.log(quad_int) - (hmhm/2) - numpy.log(2*numpy.pi)+ numpy.abs(z[2])
end_time = time.perf_counter()
print(end_time-start_time)
return marg_lr

else:
print('using analytic approximation')
start_time = time.perf_counter()
_roots = hm_phase_peaks(shm,hmhn,dominant_mode_peak)

##Calculate marg_lhood
peak_vals = []
correction_factors = []
if first_order_correction:
for r in (_roots+offset):
a = {}
for n in range(7):
a[n] = 0
for p_val in z:
a[n] += (p_val**(n) * z[p_val].real * numpy.cos(p_val*r + (n*numpy.pi/2))
- p_val**(n) * z[p_val].imag * numpy.sin(p_val*r + (n*numpy.pi/2)))
a[n] = a[n]/factorial(n)

if a[2] > 0:
cf = numpy.sqrt(numpy.pi)*(
(a[2]**(-1/2))
+(1/2)*(0.5*a[1]*a[1])*(a[2]**(-3/2))
+(3/4)*((a[1]*a[3])-(a[4]))*(a[2]**(-5/2))
+(15/8)*((0.5*a[3]*a[3])+(a[1]*a[5])-(a[6]))*(a[2]**(-7/2))
+(105/16)*((0.5*a[4]*a[4])+(a[5]*a[3]))*(a[2]**(-9/2))
+(945/32)*((0.5*a[5]*a[5])+(a[4]*a[6]))*(a[2]**(-11/2))
+(10395/64)*(0.5*a[6]*a[6])*(a[2]**(-13/2))
)

correction_factors.append(cf)
peak_vals.append(-a[0])

peak_vals = numpy.array(peak_vals)
correction_factors = numpy.array(correction_factors)


marg_loglr = (logsumexp(peak_vals,b=correction_factors)
-numpy.log(2*numpy.pi) - (hmhm/2))
end_time = time.perf_counter()
print(end_time-start_time)
return marg_loglr
else:
for r in (_roots+offset):
a = {}
for n in range(7):
a[n] = 0
for p_val in z:
a[n] += (p_val**(n) * z[p_val].real * numpy.cos(p_val*r + (n*numpy.pi/2))
- p_val**(n) * z[p_val].imag * numpy.sin(p_val*r + (n*numpy.pi/2)))
a[n] = a[n]/factorial(n)

if a[2] > 0:
cf = numpy.sqrt(numpy.pi)*(
(a[2]**(-1/2))
+(3/4)*(-a[4])*(a[2]**(-5/2))
+(15/8)*((0.5*a[3]*a[3])-(a[6]))*(a[2]**(-7/2))
+(105/16)*((0.5*a[4]*a[4])+(a[5]*a[3]))*(a[2]**(-9/2))
+(945/32)*((0.5*a[5]*a[5])+(a[4]*a[6]))*(a[2]**(-11/2))
+(10395/64)*(0.5*a[6]*a[6])*(a[2]**(-13/2))
)

correction_factors.append(cf)
peak_vals.append(-a[0])

peak_vals = numpy.array(peak_vals)
correction_factors = numpy.array(correction_factors)


marg_loglr = (logsumexp(peak_vals,b=correction_factors)
-numpy.log(2*numpy.pi) - (hmhm/2))
end_time = time.perf_counter()
print(end_time-start_time)
return marg_loglr

def hm_phase_peaks(shm,hmhn,dominant_mode_peak):
"""
Returns the maximas and minimas of the likelihood in phase
within (0,2pi)

if dominant_mode_peak = True, then it returns the peaks assuming
only the dominant (2,2) mode

else, returns the maximas and minimas assuming all higher modes
are present

"""

if dominant_mode_peak:
sp = numpy.zeros(4)
for i,n in enumerate(range(0,4)):
sp[i] = 0.5*(numpy.arctan(-shm[2].imag/shm[2].real) + n*numpy.pi)
if sp[i] < 0 :
sp[i] += 2*numpy.pi
return sp
else:
_m_max = max(shm.keys())
z = {p:0 for p in range(1,_m_max+1)}
for p in shm:
z[p] = -shm[p]
hmhm = 0
for (m,n) in hmhn:
p_val = n-m
if p_val in z:
z[p_val] += hmhn[(m,n)]
if p_val == 0:
hmhm += numpy.real(hmhn[(m,n)])

N = len(z)
w = numpy.zeros(2*N + 1, dtype=complex)
z_array = numpy.array([z[i] for i in range(1, N+1)])
w[:N] = 1j * numpy.arange(N, 0, -1) * numpy.conj(z_array[::-1]) # (N-k) * conj(z[N-k])
w[N] = 0
w[N+1:] = -1j*numpy.arange(1, N+1) * z_array # (k-N) * z[k-N]

F_last_row = -w[:-1]/w[-1]
size = len(F_last_row)
## Create off diagonal matrix of size (2N-1,2N)
F = numpy.diag(numpy.ones(size-1, dtype=complex), k=1)
F[-1, :] = F_last_row

eigvals, _ = numpy.linalg.eig(F)
all_sp = numpy.angle(eigvals) - 1j*numpy.log(numpy.abs(eigvals))

## Take only the purely real roots
## NOTE : is there a better way of doing this?
threshold = 1e-3
sp = numpy.real(all_sp[numpy.abs(numpy.imag(all_sp)) < threshold])
sp = sp % (2*numpy.pi)
return sp
Loading