diff --git a/pycbc/inference/models/__init__.py b/pycbc/inference/models/__init__.py index 05f6c2b8766..78f5ce6c69b 100644 --- a/pycbc/inference/models/__init__.py +++ b/pycbc/inference/models/__init__.py @@ -32,6 +32,7 @@ from .marginalized_gaussian_noise import MarginalizedPolarization from .marginalized_gaussian_noise import MarginalizedHMPolPhase from .marginalized_gaussian_noise import MarginalizedTime +from .marginalized_gaussian_noise import MarginalizedHMPhase from .brute_marg import BruteParallelGaussianMarginalize from .brute_marg import BruteLISASkyModesMarginalize from .gated_gaussian_noise import (GatedGaussianNoise, GatedGaussianMargPol) @@ -198,6 +199,7 @@ def read_from_config(cp, **kwargs): MarginalizedPolarization, MarginalizedHMPolPhase, MarginalizedTime, + MarginalizedHMPhase, BruteParallelGaussianMarginalize, BruteLISASkyModesMarginalize, GatedGaussianNoise, diff --git a/pycbc/inference/models/marginalized_gaussian_noise.py b/pycbc/inference/models/marginalized_gaussian_noise.py index f720a039a9b..c67dd389837 100644 --- a/pycbc/inference/models/marginalized_gaussian_noise.py +++ b/pycbc/inference/models/marginalized_gaussian_noise.py @@ -22,13 +22,14 @@ import logging import numpy from scipy import special - -from pycbc.waveform import generator +from pycbc.waveform import generator, get_fd_waveform +from pycbc.waveform.utils import apply_fd_time_shift +from pycbc.filter.matchedfilter import overlap_cplx from pycbc.detector import Detector from .gaussian_noise import (BaseGaussianNoise, create_waveform_generator, GaussianNoise, catch_waveform_error) -from .tools import marginalize_likelihood, DistMarg +from .tools import marginalize_likelihood, DistMarg, hm_phase_marginalize,hm_phase_peaks class MarginalizedPhaseGaussianNoise(GaussianNoise): @@ -806,3 +807,147 @@ def _loglr(self, return_unmarginalized=False): setattr(self._current_stats, 'maxl_polarization', self.pol[idx]) setattr(self._current_stats, 'maxl_phase', self.phase[idx]) return float(lr_total) + +class MarginalizedHMPhase(BaseGaussianNoise): + + name = 'marginalized_hm_phase' + + def __init__(self, variable_params, data, low_frequency_cutoff, + sample_rate = 2048, numerical = False, grid = False, + grid_points = 2000,first_order_correction = False,offset = 0, + dominant_mode_peak = False, + **kwargs): + super(MarginalizedHMPhase,self).__init__(variable_params, data, + low_frequency_cutoff, + **kwargs) + sample_rate = float(sample_rate) + self.numerical = numerical + self.grid = grid + self.grid_points = grid_points + self.first_order_correction = first_order_correction + self.offset = offset + self.dominant_mode_peak = dominant_mode_peak + df = data[self.detectors[0]].delta_f + self.df = df + self.sample_rate = sample_rate + flen = int(round(sample_rate / self.df) / 2 + 1) + self.flen = flen + self.shm = {} + self.hmhn = {} + # Extract mode array from static params + ## TODO : Fix junk in handling mode_arrays. + p = self.static_params.copy() + if 'mode_array' in p: + if isinstance(p['mode_array'],float): + self._mode_array = [(2,2)] + else: + self._mode_array = [(int(lm[0]),int(lm[1])) for lm in p['mode_array'].split()] + _ = p.pop('mode_array') + else: + raise ValueError('Provide Mode array') + + # Remove coa_phase from static params since it will be marginalized + if 'coa_phase' in p: + _ = p.pop('coa_phase') + + self.static_waveform_params = p + + # Initialize detector objects + self.det = {} + for ifo in self.data: + self.det[ifo] = Detector(ifo) + # Resize data to consistent length + self.data[ifo].resize(flen) + + def _loglr(self): + # Get current parameters + params = self.current_params + if 'mode_array' in params: + _ = params.pop('mode_array') + if 'coa_phase' in params: + _ = params.pop('coa_phase') + # Calculate waveforms for each mode + ## TODO : Can replace all this calculation by waveform_generator.generate(**current_params)? + hplm, hclm = {}, {} + for lm in self._mode_array: + hp, hc = get_fd_waveform(delta_f = self.df, + mode_array = lm, + coa_phase = 0, **params) + hp.resize(self.flen) + hc.resize(self.flen) + hplm[lm] = hp + hclm[lm] = hc + + + hpm, hcm = {}, {} + for (l,m), hplm_val in hplm.items(): + hpm[m] = hpm.get(m, 0) + hplm_val + for (l,m), hclm_val in hclm.items(): + hcm[m] = hcm.get(m, 0) + hclm_val + + ## Detector frame signals ( can replace by generator)? + h = {} + shm = {} + hmhn = {} + for ifo in self.data: + h[ifo] = {} + # Calculate antenna pattern and time shifts + flow = self.kmin[ifo] * self.df + fhigh = self.kmax[ifo] * self.df + fp, fc = self.det[ifo].antenna_pattern(params['ra'], params['dec'], + params['polarization'], params['tc']) + + time_delay = self.det[ifo].time_delay_from_earth_center(params['ra'],params['dec'], + params['tc']) + ## This is tc in det frame + time_shift = time_delay + params['tc'] + + ## Apply projections for each mode and time shift + for m in hpm: + h_unshifted = fp*hpm[m] + fc*hcm[m] + ## Set epoch to data epoch + h_unshifted._epoch = self.data[ifo].start_time + h[ifo][m] = apply_fd_time_shift(h_unshifted, time_shift) + + # Calculate inner products shm = + shm[ifo] = {} + for m in h[ifo]: + shm[ifo][m] = overlap_cplx( + self.data[ifo], + h[ifo][m], + psd=self.psds[ifo], + low_frequency_cutoff=flow, + high_frequency_cutoff=fhigh, + normalized=False) + # Calculate inner products hmhn = + hmhn[ifo] = {} + for m,n in itertools.combinations_with_replacement(h[ifo].keys(), 2): + hmhn[ifo][(m,n)] = overlap_cplx( + h[ifo][m], + h[ifo][n], + psd=self.psds[ifo], + low_frequency_cutoff=flow, + high_frequency_cutoff=fhigh, + normalized=False + ) + ## TODO : Move these calculations in the above loop + shm_total = {} + hmhn_total = {} + for ifo in shm.keys(): + for m in shm[ifo].keys(): + if m in shm_total: + shm_total[m] += shm[ifo][m] + else: + shm_total[m] = shm[ifo][m] + for ifo in hmhn.keys(): + for (m,n) in hmhn[ifo].keys(): + if (m,n) in hmhn_total: + hmhn_total[(m,n)] += hmhn[ifo][(m,n)] + else: + hmhn_total[(m,n)] = hmhn[ifo][(m,n)] + self.shm = shm_total + self.hmhn = hmhn_total + self.peaks = hm_phase_peaks(shm_total,hmhn_total,self.dominant_mode_peak) + return hm_phase_marginalize(shm_total,hmhn_total, self.numerical,self.grid, + self.grid_points,self.first_order_correction,self.offset, + self.dominant_mode_peak) diff --git a/pycbc/inference/models/tools.py b/pycbc/inference/models/tools.py index 03d931da0f1..2eaee248be7 100644 --- a/pycbc/inference/models/tools.py +++ b/pycbc/inference/models/tools.py @@ -4,15 +4,15 @@ import logging import warnings from distutils.util import strtobool - +import time import numpy import numpy.random import tqdm -from scipy.special import logsumexp, i0e +from scipy.special import logsumexp, i0e, factorial from scipy.interpolate import RectBivariateSpline, interp1d from pycbc.distributions import JointDistribution - +from scipy.integrate import quad from pycbc.detector import Detector @@ -966,3 +966,183 @@ def marginalize_likelihood(sh, hh, if return_peak: return vloglr, maxv, maxl return vloglr + +def hm_phase_marginalize(shm,hmhn,numerical,grid,grid_points, + first_order_correction,offset, + dominant_mode_peak): + ''' + returns the likelihood marginalized over the phase provided the + inner products between each modes + + Inputs: + shm : A dictionary of + + hmhn : A dictionary of + + ''' + _m_max = max(shm.keys()) + z = {p:0 for p in range(1,_m_max+1)} + for p in shm: + z[p] = -shm[p] + hmhm = 0 + for (m,n) in hmhn: + p_val = n-m + if p_val in z: + z[p_val] += hmhn[(m,n)] + if p_val == 0: + hmhm += numpy.real(hmhn[(m,n)]) + + if grid: + print(f'using grid with {grid_points} points') + start_time = time.perf_counter() + phi = numpy.linspace(0,2*np.pi,grid_points) + logl_phi = (-sum((z[m].real*numpy.cos(m*phi)-z[m].imag*numpy.sin(m*phi)) + for m in z)-numpy.abs(z[2])) - numpy.abs(z[2]) + marg_lr = logsumexp(logl_phi) - numpy.log(2*numpy.pi) + numpy.abs(z[2]) - (hmhm/2) + end_time = time.perf_counter() + print(end_time-start_time) + return marg_lr + + + if numerical: + print('using numerical') + start_time = time.perf_counter() + def integrand(phi): + l_phi = (-sum((z[m].real*numpy.cos(m*phi)-z[m].imag*numpy.sin(m*phi)) + for m in z)-numpy.abs(z[2])) + return numpy.exp(l_phi) + quad_int = quad(integrand,0,2*numpy.pi)[0] + marg_lr = numpy.log(quad_int) - (hmhm/2) - numpy.log(2*numpy.pi)+ numpy.abs(z[2]) + end_time = time.perf_counter() + print(end_time-start_time) + return marg_lr + + else: + print('using analytic approximation') + start_time = time.perf_counter() + _roots = hm_phase_peaks(shm,hmhn,dominant_mode_peak) + + ##Calculate marg_lhood + peak_vals = [] + correction_factors = [] + if first_order_correction: + for r in (_roots+offset): + a = {} + for n in range(7): + a[n] = 0 + for p_val in z: + a[n] += (p_val**(n) * z[p_val].real * numpy.cos(p_val*r + (n*numpy.pi/2)) + - p_val**(n) * z[p_val].imag * numpy.sin(p_val*r + (n*numpy.pi/2))) + a[n] = a[n]/factorial(n) + + if a[2] > 0: + cf = numpy.sqrt(numpy.pi)*( + (a[2]**(-1/2)) + +(1/2)*(0.5*a[1]*a[1])*(a[2]**(-3/2)) + +(3/4)*((a[1]*a[3])-(a[4]))*(a[2]**(-5/2)) + +(15/8)*((0.5*a[3]*a[3])+(a[1]*a[5])-(a[6]))*(a[2]**(-7/2)) + +(105/16)*((0.5*a[4]*a[4])+(a[5]*a[3]))*(a[2]**(-9/2)) + +(945/32)*((0.5*a[5]*a[5])+(a[4]*a[6]))*(a[2]**(-11/2)) + +(10395/64)*(0.5*a[6]*a[6])*(a[2]**(-13/2)) + ) + + correction_factors.append(cf) + peak_vals.append(-a[0]) + + peak_vals = numpy.array(peak_vals) + correction_factors = numpy.array(correction_factors) + + + marg_loglr = (logsumexp(peak_vals,b=correction_factors) + -numpy.log(2*numpy.pi) - (hmhm/2)) + end_time = time.perf_counter() + print(end_time-start_time) + return marg_loglr + else: + for r in (_roots+offset): + a = {} + for n in range(7): + a[n] = 0 + for p_val in z: + a[n] += (p_val**(n) * z[p_val].real * numpy.cos(p_val*r + (n*numpy.pi/2)) + - p_val**(n) * z[p_val].imag * numpy.sin(p_val*r + (n*numpy.pi/2))) + a[n] = a[n]/factorial(n) + + if a[2] > 0: + cf = numpy.sqrt(numpy.pi)*( + (a[2]**(-1/2)) + +(3/4)*(-a[4])*(a[2]**(-5/2)) + +(15/8)*((0.5*a[3]*a[3])-(a[6]))*(a[2]**(-7/2)) + +(105/16)*((0.5*a[4]*a[4])+(a[5]*a[3]))*(a[2]**(-9/2)) + +(945/32)*((0.5*a[5]*a[5])+(a[4]*a[6]))*(a[2]**(-11/2)) + +(10395/64)*(0.5*a[6]*a[6])*(a[2]**(-13/2)) + ) + + correction_factors.append(cf) + peak_vals.append(-a[0]) + + peak_vals = numpy.array(peak_vals) + correction_factors = numpy.array(correction_factors) + + + marg_loglr = (logsumexp(peak_vals,b=correction_factors) + -numpy.log(2*numpy.pi) - (hmhm/2)) + end_time = time.perf_counter() + print(end_time-start_time) + return marg_loglr + +def hm_phase_peaks(shm,hmhn,dominant_mode_peak): + """ + Returns the maximas and minimas of the likelihood in phase + within (0,2pi) + + if dominant_mode_peak = True, then it returns the peaks assuming + only the dominant (2,2) mode + + else, returns the maximas and minimas assuming all higher modes + are present + + """ + + if dominant_mode_peak: + sp = numpy.zeros(4) + for i,n in enumerate(range(0,4)): + sp[i] = 0.5*(numpy.arctan(-shm[2].imag/shm[2].real) + n*numpy.pi) + if sp[i] < 0 : + sp[i] += 2*numpy.pi + return sp + else: + _m_max = max(shm.keys()) + z = {p:0 for p in range(1,_m_max+1)} + for p in shm: + z[p] = -shm[p] + hmhm = 0 + for (m,n) in hmhn: + p_val = n-m + if p_val in z: + z[p_val] += hmhn[(m,n)] + if p_val == 0: + hmhm += numpy.real(hmhn[(m,n)]) + + N = len(z) + w = numpy.zeros(2*N + 1, dtype=complex) + z_array = numpy.array([z[i] for i in range(1, N+1)]) + w[:N] = 1j * numpy.arange(N, 0, -1) * numpy.conj(z_array[::-1]) # (N-k) * conj(z[N-k]) + w[N] = 0 + w[N+1:] = -1j*numpy.arange(1, N+1) * z_array # (k-N) * z[k-N] + + F_last_row = -w[:-1]/w[-1] + size = len(F_last_row) + ## Create off diagonal matrix of size (2N-1,2N) + F = numpy.diag(numpy.ones(size-1, dtype=complex), k=1) + F[-1, :] = F_last_row + + eigvals, _ = numpy.linalg.eig(F) + all_sp = numpy.angle(eigvals) - 1j*numpy.log(numpy.abs(eigvals)) + + ## Take only the purely real roots + ## NOTE : is there a better way of doing this? + threshold = 1e-3 + sp = numpy.real(all_sp[numpy.abs(numpy.imag(all_sp)) < threshold]) + sp = sp % (2*numpy.pi) + return sp \ No newline at end of file