diff --git a/bin/pycbc_inspiral b/bin/pycbc_inspiral index 24193310714..716f65fa841 100644 --- a/bin/pycbc_inspiral +++ b/bin/pycbc_inspiral @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/home/xangma/miniconda3/envs/pycbcgpu/bin/python3.11 # Copyright (C) 2014 Alex Nitz # @@ -16,14 +16,15 @@ # with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +import time +t_very_start = time.time() + import sys import os import copy import logging import argparse -import numpy import itertools -import time from pycbc.pool import BroadcastPool as Pool import pycbc @@ -33,8 +34,16 @@ from pycbc.filter import MatchedFilterControl, make_frequency_series, qtransform from pycbc.types import TimeSeries, FrequencySeries, zeros, float32, complex64 import pycbc.opt import pycbc.inject - +from pycbc.scheme import CUPYScheme last_progress_update = -1.0 +import numpy + +try: + import cupy as cp +except ImportError: + cp = None + +print(f"PROFILE: Import time: {time.time() - t_very_start:.2f} s", flush=True) def update_progress(p,u,n): """ updates a file 'progress.txt' with a value 0 .. 1.0 when enough (filtering) progress was made @@ -203,6 +212,8 @@ parser.add_argument("--multiprocessing-nprocesses", type=int, "Used in conjunction with the option" "--finalize-events-template-rate which should be set" "to a multiple of the number of processes.") +parser.add_argument("--gpu-batch-size", type=int, default=1, + help="Number of templates to process in parallel on GPU") # Add options groups psd.insert_psd_option_group(parser) @@ -234,73 +245,39 @@ gwstrain = strain.from_cli(opt, dyn_range_fac=DYN_RANGE_FAC, strain_segments = strain.StrainSegments.from_cli(opt, gwstrain) -def template_triggers(t_num): - """ Get the triggers for a specific template +def template_triggers(t_nums): + """ Get the triggers for a template or batch of templates + + This is a wrapper that routes to GPU-optimized or CPU code paths. + + Parameters + ---------- + t_nums : int or list + Either a single template index or a list of template indices """ - template = None - tparam = None - out_vals_all = [] - for s_num, stilde in enumerate(segments): - # Filter check checks the 'inj_filter_rejector' options to - # determine whether - # to filter this template/segment if injections are present. - if not inj_filter_rejector.template_segment_checker( - bank, t_num, stilde): - continue - if template is None: - template = bank[t_num] - tparam = template.params - - if opt.update_progress: - update_progress((t_num + (s_num / float(len(segments))) ) / len(bank), - opt.update_progress, opt.update_progress_file) - logging.info("Filtering template %d/%d segment %d/%d" % - (t_num + 1, len(bank), s_num + 1, len(segments))) - - sigmasq = template.sigmasq(stilde.psd) - snr, norm, corr, idx, snrv = \ - matched_filter.matched_filter_and_cluster(s_num, - sigmasq, - cluster_window, - epoch=stilde._epoch) - if not len(idx): - continue - - out_vals = out_vals_ref.copy() - out_vals['bank_chisq'], out_vals['bank_chisq_dof'] = \ - bank_chisq.values(template, stilde.psd, stilde, snrv, norm, - idx+stilde.analyze.start) - - out_vals['chisq'], out_vals['chisq_dof'] = \ - power_chisq.values(corr, snrv, norm, stilde.psd, - idx+stilde.analyze.start, template) - - out_vals['sg_chisq'] = sg_chisq.values(stilde, template, stilde.psd, - snrv, norm, - out_vals['chisq'], - out_vals['chisq_dof'], - idx+stilde.analyze.start) - - out_vals['cont_chisq'], _ = \ - autochisq.values(snr, idx+stilde.analyze.start, template, - stilde.psd, norm, stilde=stilde, - low_frequency_cutoff=flow) - - idx += stilde.cumulative_index - - out_vals['time_index'] = idx - out_vals['snr'] = snrv * norm - out_vals['sigmasq'] = numpy.zeros(len(snrv), dtype=float32) + sigmasq - if opt.psdvar_short_segment is not None: - out_vals['psd_var_val'] = \ - pycbc.psd.find_trigger_value(psd_var, - out_vals['time_index'], - opt.gps_start_time, opt.sample_rate) - #print(idx, out_vals['time_index']) - - out_vals_all.append(copy.deepcopy(out_vals)) - #print(out_vals_all) - return out_vals_all, tparam + # Handle single template case for backward compatibility + if isinstance(t_nums, (int, numpy.integer)): + t_nums = [t_nums] + + # Check if we're using CUPY scheme for batch processing + using_cupy = isinstance(ctx, pycbc.scheme.CUPYScheme) + + if using_cupy and len(t_nums) > 1: + # Use GPU-optimized batched implementation + from pycbc.filter.inspiral_utils import template_triggers_gpu_batched + + return template_triggers_gpu_batched( + t_nums, bank, segments, matched_filter, + power_chisq, sg_chisq, inj_filter_rejector, + cluster_window, out_vals_ref, opt, + psd_var=psd_var if opt.psdvar_short_segment is not None else None + ) + else: + # Fall back to single-template CPU processing + raise NotImplementedError( + "CPU/single-template processing not yet re-implemented. " + "This will be restored later." + ) with ctx: if opt.fft_backends == 'fftw': @@ -394,7 +371,12 @@ with ctx: opt, names, [out_types[n] for n in names], psd=segments[0].psd, gating_info=gwstrain.gating_info, q_trans=q_trans) - template_mem = zeros(tlen, dtype = complex64) + # Process in batches for GPU + batch_size = opt.gpu_batch_size + if opt.processing_scheme == 'cupy': + template_mem = [zeros(tlen, dtype=complex64) for _ in range(batch_size)] + else: + template_mem = zeros(tlen, dtype=complex64) cluster_window = int(opt.cluster_window * gwstrain.sample_rate) if opt.cluster_window == 0.0: @@ -409,15 +391,15 @@ with ctx: if opt.multiprocessing_nprocesses: ncores *= opt.multiprocessing_nprocesses - matched_filter = MatchedFilterControl(opt.low_frequency_cutoff, None, - opt.snr_threshold, tlen, delta_f, complex64, - segments, template_mem, use_cluster, - downsample_factor=opt.downsample_factor, - upsample_threshold=opt.upsample_threshold, - upsample_method=opt.upsample_method, - gpu_callback_method=opt.gpu_callback_method, - cluster_function=opt.cluster_function) + opt.snr_threshold, tlen, delta_f, complex64, + segments, template_mem, use_cluster, + downsample_factor=opt.downsample_factor, + upsample_threshold=opt.upsample_threshold, + upsample_method=opt.upsample_method, + gpu_callback_method=opt.gpu_callback_method, + cluster_function=opt.cluster_function, + batch_size=opt.gpu_batch_size) bank_chisq = vetoes.SingleDetBankVeto(opt.bank_veto_bank_file, flen, delta_f, flow, complex64, @@ -455,54 +437,141 @@ with ctx: logging.info("Template bank size after thinning: %s", len(bank)) tsetup = time.time() - tstart + logging.info("PROFILE: Setup time (data loading, PSD, etc): %.2f sec", tsetup) tcheckpoint = time.time() tanalyze = list(range(tnum_start, len(bank))) n = opt.finalize_events_template_rate n = 1 if n is None else n - tchunks = [tanalyze[i:i + n] for i in range(0, len(tanalyze), n)] - mmap = map if opt.multiprocessing_nprocesses: mmap = Pool(opt.multiprocessing_nprocesses).map - for tchunk in tchunks: - data = list(mmap(template_triggers, tchunk)) - - for elem in data: - out_vals_all, tparam = elem - if len(out_vals_all) > 0: - event_mgr.new_template(tmplt=tparam) + # Start timing the main loop + tloop_start = time.time() + templates_processed = 0 - for edata in out_vals_all: - event_mgr.add_template_events(names, [edata[n] for n in names]) - - event_mgr.cluster_template_events("time_index", "snr", cluster_window) - event_mgr.finalize_template_events() + if isinstance(ctx, CUPYScheme): + # Create batches of templates, each of size batch_size + tchunks = [tanalyze[i:i + batch_size] for i in range(0, len(tanalyze), batch_size)] + else: + # Original chunking for non-CUPY schemes + tchunks = [tanalyze[i:i + n] for i in range(0, len(tanalyze), n)] + for tchunk in tchunks: + chunk_start = time.time() + if isinstance(ctx, CUPYScheme): + data = [template_triggers(tchunk)] + else: + data = list(mmap(template_triggers, tchunk)) + + data_time = time.time() - chunk_start + print(f"PROFILE: template_triggers call time: {data_time:.4f} s") + + # Process events with new output format + # FIXME: Doesn't work yet! + #event_proc_start = time.time() + #for elem in data: + # out_vals, tparams = elem + + # # Check if we have any triggers + # if len(out_vals['template_id']) == 0: + # continue + + # # Get unique template IDs and process each template + # # Convert to numpy for easier manipulation + # template_ids_array = numpy.array(out_vals['template_id']) + # unique_template_ids = numpy.unique(template_ids_array) + + # for template_id in unique_template_ids: + # template_id_int = int(template_id) + + # # Find the index in tparams (which corresponds to position in tchunk) + # # template_id is the global template index, find its position in the batch + # batch_idx = tchunk.index(template_id_int) + # tparam = tparams[batch_idx] + + # # Mask for this template's triggers + # mask = template_ids_array == template_id + + # # Extract events for this template + # template_out_vals = {} + # for key in out_vals.keys(): + # if key == 'template_id': + # continue # Skip template_id, it's not in names + # # Convert to numpy array, apply mask, convert back to PyCBC array + # data_array = numpy.array(out_vals[key]) + # template_out_vals[key] = out_vals[key].__class__(data_array[mask], copy=False) + + # # Fill in missing keys with None + # for key in names: + # if key not in template_out_vals: + # template_out_vals[key] = None + + # # Add to event manager + # event_mgr.new_template(tmplt=tparam) + # event_mgr.add_template_events(names, [template_out_vals[n] for n in names]) + # event_mgr.cluster_template_events("time_index", "snr", cluster_window) + # event_mgr.finalize_template_events() + + #event_proc_time = time.time() - event_proc_start + #print(f"PROFILE: Event processing time: {event_proc_time:.4f} s") + if opt.finalize_events_template_rate is not None: - event_mgr.consolidate_events(opt, gwstrain=gwstrain) - - if opt.checkpoint_interval and \ - (time.time() - tcheckpoint > opt.checkpoint_interval): - event_mgr.save_state(max(tchunk), opt.output + '.checkpoint') - tcheckpoint = time.time() - - if opt.checkpoint_exit_maxtime and \ - (time.time() - tstart > opt.checkpoint_exit_maxtime): - event_mgr.save_state(max(tchunk), opt.output + '.checkpoint') - sys.exit(opt.checkpoint_exit_code) - + event_mgr.consolidate_events(opt, gwstrain=gwstrain) + + chunk_time = time.time() - chunk_start + print(f"PROFILE: Total chunk time: {chunk_time:.4f} s") + avg_template_time = chunk_time / len(tchunk) if len(tchunk) > 0 else 0 + logging.info("Processed %d templates in %.2f sec (%.2f sec/template)", + len(tchunk), chunk_time, avg_template_time) + + if opt.finalize_events_template_rate is not None: + event_mgr.consolidate_events(opt, gwstrain=gwstrain) + + if opt.checkpoint_interval and \ + (time.time() - tcheckpoint > opt.checkpoint_interval): + event_mgr.save_state(max(tchunk), opt.output + '.checkpoint') + tcheckpoint = time.time() + + if opt.checkpoint_exit_maxtime and \ + (time.time() - tstart > opt.checkpoint_exit_maxtime): + event_mgr.save_state(max(tchunk), opt.output + '.checkpoint') + sys.exit(opt.checkpoint_exit_code) + + # End timing of main loop + tloop_end = time.time() + tloop_total = tloop_end - tloop_start + avg_template_time = tloop_total / templates_processed if templates_processed > 0 else 0 + logging.info("PROFILE: Main loop completed in %.2f sec", tloop_total) + logging.info("PROFILE: Processed %d templates", templates_processed) + logging.info("PROFILE: Average time per template: %.4f sec", avg_template_time) + +finalize_start = time.time() event_mgr.consolidate_events(opt, gwstrain=gwstrain) event_mgr.finalize_events() +finalize_time = time.time() - finalize_start +logging.info("PROFILE: Event finalization time: %.2f sec", finalize_time) + +# Print detailed GPU profiling if using batched GPU mode +if hasattr(opt, 'processing_scheme') and opt.processing_scheme == 'cupy': + try: + from pycbc.filter.inspiral_utils import print_profile + print_profile() + except (ImportError, AttributeError): + pass + logging.info("Outputting %s triggers" % str(len(event_mgr.events))) tstop = time.time() run_time = tstop - tstart event_mgr.save_performance(ncores, len(segments), len(bank), run_time, tsetup) +write_start = time.time() logging.info("Writing out triggers") event_mgr.write_events(opt.output) +write_time = time.time() - write_start +logging.info("PROFILE: Write events time: %.2f sec", write_time) if opt.fft_backends == 'fftw': if opt.fftw_output_float_wisdom_file: diff --git a/examples/inspiral/BANK_SPLIT0.hdf b/examples/inspiral/BANK_SPLIT0.hdf new file mode 100644 index 00000000000..88003336b82 Binary files /dev/null and b/examples/inspiral/BANK_SPLIT0.hdf differ diff --git a/examples/inspiral/run_cupy.sh b/examples/inspiral/run_cupy.sh new file mode 100755 index 00000000000..74918905b96 --- /dev/null +++ b/examples/inspiral/run_cupy.sh @@ -0,0 +1,64 @@ + +# Select bank file: use argument or default to small bank +BANK_FILE=${1:-BANK_SPLIT0.hdf} + +# Uncomment if you want profile information + +#nvprof --print-callstack --log-file profout.txt \ + +#/home/ian.harry/.conda/envs/env_lisa_premerger/bin/../nsight-compute/2024.1.1/ncu -o profile \ + +#python -m cProfile -o output_cupy.pstats \ + +#/home/ian.harry/nsight-systems-2024.7.1/bin/nsys profile \ +# --trace cuda,osrt,nvtx \ +# --cuda-memory-usage true \ +# --force-overwrite true \ +# --output profile_run_v1 \ +# --python-sampling=true \ +#`which python` -m nvtx -- ~/.conda/envs/env_lisa_premerger/bin/pycbc_inspiral \ +#python -m cProfile -o output_cupy.pstats `which pycbc_inspiral` \ +#time ~/.conda/envs/env_lisa_premerger/bin/pycbc_inspiral \ + +`which pycbc_inspiral` \ +--frame-files DATA_FILE.gwf \ +--sample-rate 2048 \ +--sgchisq-snr-threshold 6.0 \ +--sgchisq-locations "mtotal>40:20-30,20-45,20-60,20-75,20-90,20-105,20-120" \ +--segment-end-pad 16 \ +--low-frequency-cutoff 30 \ +--pad-data 8 \ +--cluster-window 1 \ +--cluster-function symmetric \ +--injection-window 4.5 \ +--segment-start-pad 112 \ +--psd-segment-stride 8 \ +--psd-inverse-length 16 \ +--filter-inj-only \ +--psd-segment-length 16 \ +--snr-threshold 5.5 \ +--segment-length 256 \ +--autogating-width 0.25 \ +--autogating-threshold 100 \ +--autogating-cluster 0.5 \ +--autogating-taper 0.25 \ +--newsnr-threshold 5 \ +--psd-estimation median \ +--strain-high-pass 20 \ +--order -1 \ +--chisq-bins "1.75*(get_freq('fSEOBNRv2Peak',params.mass1,params.mass2,params.spin1z,params.spin2z)-60.)**0.5" \ +--channel-name H1:LOSC-STRAIN \ +--gps-start-time 1126259078 \ +--gps-end-time 1126259846 \ +--output H1-INSPIRAL_GPU-OUT.hdf \ +--approximant "SPAtmplt" \ +--processing-scheme cupy \ +--bank-file ${BANK_FILE} \ +--gpu-batch-size 128 \ +--chisq-snr-threshold 5.25 \ +--verbose +#--verbose 2> inspiral_$1_$2_$3.log +# Uncomment above if you want logging + +# Uncomment for profile pngs +#gprof2dot -f pstats output_$1_$2_$3.pstats | dot -Tpng -o $1_$2_$3.png diff --git a/pycbc/events/eventmgr.py b/pycbc/events/eventmgr.py index 3cde84c3c4a..e883db68751 100644 --- a/pycbc/events/eventmgr.py +++ b/pycbc/events/eventmgr.py @@ -30,6 +30,7 @@ import logging import pickle import numpy +import cupy import h5py from pycbc.types import Array @@ -104,7 +105,7 @@ def __new__(cls, *args, **kwargs): # # will work? Is there a better way? class _BaseThresholdCluster(object): - def threshold_and_cluster(self, threshold, window): + def threshold_and_cluster(self, threshold, window, batch_idx=None): """ Threshold and cluster the memory specified at instantiation with the threshold and window size specified at creation. @@ -116,6 +117,9 @@ def threshold_and_cluster(self, threshold, window): to return when thresholding and clustering. window : uint32 The size (in number of samples) of the window over which to cluster + batch_idx : int, optional + For batched operations, which template in the batch to process. + If None, assumes single template operation. Returns: -------- @@ -345,6 +349,8 @@ def add_template_events(self, columns, vectors): if v is not None: if isinstance(v, Array): new_events[c] = v.numpy() + elif isinstance(v, cupy.ndarray): + new_events[c] = cupy.asnumpy(v) else: new_events[c] = v self.template_events = numpy.append(self.template_events, new_events) diff --git a/pycbc/events/threshold_cupy.py b/pycbc/events/threshold_cupy.py index c20fe18aa13..db47b9858fe 100644 --- a/pycbc/events/threshold_cupy.py +++ b/pycbc/events/threshold_cupy.py @@ -34,11 +34,13 @@ # https://stackoverflow.com/questions/77798014/cupy-rawkernel-cuda-error-not-found-named-symbol-not-found-cupy tkernel1 = mako.template.Template(""" -extern "C" __global__ void threshold_and_cluster(float2* in, float2* outv, int* outl, int window, float threshold){ - int s = window * blockIdx.x; +extern "C" __global__ void threshold_and_cluster(float2* in, float2* outv, int* outl, int window, float* thresholds, int series_length, int analyse_start) { + int batch_idx = blockIdx.y; // Batch index + int s = batch_idx * series_length + window * blockIdx.x + analyse_start; int e = s + window; + float threshold = thresholds[batch_idx]; - // shared memory for chuck size candidates + // Shared memory remains unchanged, but it now processes series per batch index __shared__ float svr[${chunk}]; __shared__ float svi[${chunk}]; __shared__ int sl[${chunk}]; @@ -69,7 +71,7 @@ svr[threadIdx.x] = mvr; svi[threadIdx.x] = mvi; sl[threadIdx.x] = ml; - + __syncthreads(); if (threadIdx.x < 32){ @@ -122,11 +124,11 @@ if (svv[0] > threshold){ tl = idx[0]; - outv[blockIdx.x].x = svr[tl]; - outv[blockIdx.x].y = svi[tl]; - outl[blockIdx.x] = sl[tl]; - } else{ - outl[blockIdx.x] = -1; + outv[batch_idx*${blockmemsize} + blockIdx.x].x = svr[tl]; + outv[batch_idx*${blockmemsize} + blockIdx.x].y = svi[tl]; + outl[batch_idx*${blockmemsize} + blockIdx.x] = sl[tl] % ${slen} - analyse_start; + } else { + outl[batch_idx*${blockmemsize} + blockIdx.x] = -1; } } } @@ -134,12 +136,14 @@ """) tkernel2 = mako.template.Template(""" -extern "C" __global__ void threshold_and_cluster2(float2* outv, int* outl, float threshold, int window){ +extern "C" __global__ void threshold_and_cluster2(float2* outv, int* outl, float* thresholds, int window){ __shared__ int loc[${blocks}]; __shared__ float val[${blocks}]; - + int i = threadIdx.x; - + int posi = i % ${blockmemsize}; + float threshold = thresholds[i / ${blockmemsize}]; + int l = outl[i]; loc[i] = l; @@ -150,13 +154,13 @@ // Check right - if ( (i < (${blocks} - 1)) && (val[i + 1] > val[i]) ){ + if ( (posi < (${blocksize} - 1)) && (val[i + 1] > val[i]) ){ outl[i] = -1; return; } // Check left - if ( (i > 0) && (val[i - 1] > val[i]) ){ + if ( (posi > 0) && (val[i - 1] > val[i]) ){ outl[i] = -1; return; } @@ -164,7 +168,7 @@ """) @functools.lru_cache(maxsize=None) -def get_tkernel(slen, window): +def get_tkernel(slen, alen, window, block_mem_size=None, batch_size=None): if window < 32: raise ValueError("GPU threshold kernel does not support a window smaller than 32 samples") @@ -177,83 +181,223 @@ def get_tkernel(slen, window): else: nt = 1024 - nb = int(cp.ceil(slen / float(window))) + nb = int(cp.ceil(alen / float(window))) if nb > 1024: raise ValueError("More than 1024 blocks not supported yet") + if block_mem_size is not None: + blocks = block_mem_size * batch_size + else: + blocks = nb + block_mem_size = nb + fn = cp.RawKernel( - tkernel1.render(chunk=nt), + tkernel1.render(chunk=nt, slen=slen, blockmemsize=block_mem_size), 'threshold_and_cluster', backend='nvcc' ) fn2 = cp.RawKernel( - tkernel2.render(blocks=nb), + tkernel2.render(blocks=nb, blocksize=nb, blockmemsize=block_mem_size), 'threshold_and_cluster2', backend='nvcc' ) return (fn, fn2), nt, nb -def threshold_and_cluster(series, threshold, window): +def threshold_and_cluster(series_batch, threshold, window): + raise NotImplementedError("Needs writing properly") + # Not sure this function is accessed easily. However, right now, it has not + # been properly written. A starting point is here though! global val global loc - if val is None: - val = cp.zeros(4096*256, dtype=cp.complex64) - if loc is None: - loc = cp.zeros(4096*256, cp.int32) - - outl = loc - outv = val - slen = len(series) - series = series.data - (fn, fn2), nt, nb = get_tkernel(slen, window) + batch_size, series_length = series_batch.shape + + if val is None or val.size < batch_size * 4096 * 256: + val = cp.zeros((batch_size, 4096 * 256), dtype=cp.complex64) + if loc is None or loc.size < batch_size * 4096 * 256: + loc = cp.zeros((batch_size, 4096 * 256), dtype=cp.int32) + + outl = loc[:, :series_length] + outv = val[:, :series_length] + (fn, fn2), nt, nb = get_tkernel(series_length, window) threshold = cp.float32(threshold * threshold) window = cp.int32(window) - cl = loc[0:nb] - cv = val[0:nb] - - fn((nb,), (nt,), (series.data, outv, outl, window, threshold)) - fn2((1,), (nb,), (outv, outl, threshold, window)) - w = (cl != -1) - return cv[w], cl[w] + # Launch kernel with batch dimension + grid = (nb, batch_size, 1) + block = (nt, 1, 1) + + fn(grid, block, (series_batch.data, outv, outl, window, threshold, series_length)) + fn2(grid, block, (outv, outl, threshold, window)) + + results = [] + for batch_idx in range(batch_size): + w = (outl[batch_idx] != -1) + results.append((outv[batch_idx][w], outl[batch_idx][w])) + return results + +class FastFilter: + # Compile kernels only once as class variables + size_kernel = cp.RawKernel(r''' + extern "C" __global__ + void compute_sizes(const float2* cv, const int* cl, + int* sizes, const int nb, + const int snum) { + int row = blockDim.x * blockIdx.x + threadIdx.x; + if (row >= snum) return; + + int count = 0; + int row_offset = row * nb; + for (int col = 0; col < nb; col++) { + if (cl[row_offset + col] != -1) { + count++; + } + } + sizes[row] = count; + } + ''', 'compute_sizes') + + filter_kernel = cp.RawKernel(r''' + extern "C" __global__ + void filter_arrays(const float2* cv, const int* cl, + float2* cv_out, int* cl_out, + const int* sizes, const int* offsets, + const int nb, const int snum) { + int row = blockDim.x * blockIdx.x + threadIdx.x; + if (row >= snum) return; + + int out_idx = offsets[row]; + int row_offset = row * nb; + for (int col = 0; col < nb; col++) { + if (cl[row_offset + col] != -1) { + cv_out[out_idx].x = cv[row_offset + col].x; + cv_out[out_idx].y = cv[row_offset + col].y; + cl_out[out_idx] = cl[row_offset + col]; + out_idx++; + } + } + } + ''', 'filter_arrays') + + def __init__(self, snum, nb): + self.snum = snum + self.nb = nb + self.sizes = cp.empty(snum, dtype=cp.int32) + self.offsets = cp.empty(snum, dtype=cp.int32) + + def filter_arrays(self, cv, cl): + # Input handling + cv = cp.asarray(cv, dtype=cp.complex64) + cl = cp.asarray(cl, dtype=cp.int32) + + if cv.ndim == 2: + if cv.shape != (self.snum, self.nb): + raise ValueError(f"Expected shape ({self.snum}, {self.nb}), got {cv.shape}") + cv = cv.reshape(-1) + + if cl.ndim == 2: + if cl.shape != (self.snum, self.nb): + raise ValueError(f"Expected shape ({self.snum}, {self.nb}), got {cl.shape}") + cl = cl.reshape(-1) + + # Compute sizes + n_threads = (self.snum + 1023) // 1024 + blocks = (self.snum + n_threads - 1) // n_threads + + self.size_kernel((blocks,), (n_threads,), + (cv, cl, self.sizes, self.nb, self.snum)) + + # Compute offsets + cp.cumsum(self.sizes[:-1], out=self.offsets[1:]) + self.offsets[0] = 0 + + # Allocate output arrays + total_size = int(self.sizes.sum()) + cv_out = cp.empty(total_size, dtype=cp.complex64) + cl_out = cp.empty(total_size, dtype=cp.int32) + + + # Run filter kernel + # This is the slowest part of the function, everything else is negligible + self.filter_kernel((blocks,), (n_threads,), + (cv, cl, cv_out, cl_out, + self.sizes, self.offsets, + self.nb, self.snum)) + # Create template mapping array + template_map = cp.empty(total_size, dtype=cp.int32) + + # Split into separate arrays + cv_filtered = [] + cl_filtered = [] + start = 0 + for idx, size in enumerate(self.sizes): + size = int(size) + if size > 0: + cv_filtered.append(cv_out[start:start + size]) + cl_filtered.append(cl_out[start:start + size]) + template_map[start:start + size] = idx + else: + cv_filtered.append(cp.empty(0, dtype=cp.complex64)) + cl_filtered.append(cp.empty(0, dtype=cp.int32)) + start += size + return cv_filtered, cl_filtered, cv_out, cl_out, self.sizes, template_map class CUDAThresholdCluster(_BaseThresholdCluster): - def __init__(self, series): - self.series = series + def __init__(self, series_batch, analyse_slice): + self.series_batch = series_batch + # This value is hardcoded as it is the longest length currently + # supported. Memory usage for this is tiny anyway so no need to be + # shorter. + self.batch_mem_size = 1024 + self.batch_size, self.series_length = series_batch.shape + self.analyse_start = cp.int32(analyse_slice.start) + self.analyse_end = cp.int32(analyse_slice.stop) + self.analyse_len = self.analyse_end = self.analyse_start global val global loc - if val is None: - val = cp.zeros(4096*256, dtype=cp.complex64) - if loc is None: - loc = cp.zeros(4096*256, cp.int32) + if val is None or val.size < self.batch_size * self.batch_mem_size: + val = cp.zeros((self.batch_size, self.batch_mem_size), dtype=cp.complex64) + if loc is None or loc.size < self.batch_size * self.batch_mem_size: + loc = cp.zeros((self.batch_size, self.batch_mem_size), cp.int32) self.outl = loc self.outv = val - self.slen = len(series) + # This is kind of hardcoded here, sorry. We maybe should pass window in here. + nb = int(cp.ceil(self.analyse_len / float(2048))) + self.fast_filter = FastFilter(snum=self.batch_size, nb=nb) def threshold_and_cluster(self, threshold, window): - threshold = cp.float32(threshold * threshold) + threshold = threshold * threshold + threshold = cp.asarray(threshold, dtype=cp.float32) window = cp.int32(window) - (fn, fn2), nt, nb = get_tkernel(self.slen, window) - cl = loc[0:nb] - cv = val[0:nb] - - fn( - (nt, 1, 1), - (nb, 1), - (self.series.data, self.outv, self.outl, window, threshold) - ) - fn2( - (nb, 1, 1), - (1, 1), - (self.outv, self.outl, threshold, window) + (fn, fn2), nt, nb = get_tkernel( + self.series_length, + self.analyse_len, + window, + block_mem_size=self.batch_mem_size, + batch_size=self.batch_size ) - w = (cl != -1) - return cp.asnumpy(cv[w]), cp.asnumpy(cl[w]) - -def _threshold_cluster_factory(series): + grid = (nb, self.batch_size, 1) + block = (nt, 1, 1) + + fn(grid, block, (self.series_batch.data, self.outv, self.outl, window, threshold, self.series_length, self.analyse_start)) + fn2(grid, block, (self.outv, self.outl, threshold, window)) + + results = [] + # for batch_idx in range(self.batch_size): + # cl = self.outl[batch_idx][:nb] # Clustered locations for this batch + # cv = self.outv[batch_idx][:nb] # Clustered values for this batch + # w = (cl != -1) # Valid locations + # results.append((cv[w], cl[w])) + + # slightly faster version + cv_filtered, cl_filtered, cv_out, cl_out, sizes, template_map = \ + self.fast_filter.filter_arrays(self.outv[:, :nb], self.outl[:, :nb]) + results = list(zip(cv_filtered, cl_filtered)) + return results, cv_out, cl_out, sizes, template_map + +def _threshold_cluster_factory(*args, **kwargs): return CUDAThresholdCluster diff --git a/pycbc/fft/cupyfft.py b/pycbc/fft/cupyfft.py index c45d7173608..80c73d35a0c 100644 --- a/pycbc/fft/cupyfft.py +++ b/pycbc/fft/cupyfft.py @@ -26,6 +26,7 @@ for the PyCBC package. """ +from pycbc.types import Array import logging import cupy.fft from .core import _check_fft_args @@ -86,3 +87,79 @@ def __init__(self, invec, outvec, nbatch=1, size=None): def execute(self): ifft(self.invec, self.outvec, self.prec, self.itype, self.otype) + +def batch_fft(invecs, outvecs, _, itype, otype): + """Batched FFT operation for multiple templates""" + if itype == 'complex' and otype == 'complex': + # Reshape input for batch operation + batch_size = len(invecs) + fft_size = len(invecs[0]) + batch_data = cupy.stack([v.data for v in invecs]) + + # Perform batched FFT + result = cupy.fft.fft(batch_data) + + # Copy results back to output vectors + for i, outvec in enumerate(outvecs): + outvec.data[:] = result[i] + + elif itype == 'real' and otype == 'complex': + batch_size = len(invecs) + fft_size = len(invecs[0]) + batch_data = cupy.stack([v.data for v in invecs]) + + result = cupy.fft.rfft(batch_data) + + for i, outvec in enumerate(outvecs): + outvec.data[:] = result[i] + else: + raise ValueError(_INV_FFT_MSG.format("FFT", itype, otype)) + +def batch_ifft(invecs, outvecs, _, itype, otype): + """Batched IFFT operation for multiple templates""" + if itype == 'complex' and otype == 'complex': + # Stack the input arrays directly + # batch_data = cupy.stack([v.data for v in invecs]) + batch_data = invecs + + # Perform batch IFFT + result = cupy.fft.ifft(batch_data) + + # Copy results back efficiently using cupy.copyto + # for i, outvec in enumerate(outvecs): + # cupy.copyto(outvec.data, result[i]) + # outvec *= len(outvec) + cupy.copyto(outvecs, result) # Copy all results if outvecs_data is a pre-allocated contiguous array + outvecs *= outvecs.shape[-1] + + elif itype == 'complex' and otype == 'real': + batch_data = cupy.stack([v.data for v in invecs]) + result = cupy.fft.irfft(batch_data) + + for i, outvec in enumerate(outvecs): + cupy.copyto(outvec.data, result[i]) + outvec *= len(outvec) + else: + raise ValueError(_INV_FFT_MSG.format("IFFT", itype, otype)) + +class BatchFFT(_BaseFFT): + """Class for performing batched FFTs via the cupy interface""" + def __init__(self, invecs, outvecs, batch_size): + self.invecs = invecs + self.outvecs = outvecs + self.batch_size = batch_size + self.prec, self.itype, self.otype = _check_fft_args(invecs[0], outvecs[0]) + + def execute(self): + batch_fft(self.invecs, self.outvecs, self.prec, self.itype, self.otype) + +class BatchIFFT(_BaseIFFT): + """Class for performing batched IFFTs via the cupy interface""" + def __init__(self, invecs, outvecs, batch_size): + self.invecs = invecs + self.outvecs = outvecs + self.batch_size = batch_size + self.prec, self.itype, self.otype = _check_fft_args(Array(invecs[0]), Array(outvecs[0])) + + def execute(self): + batch_ifft(self.invecs, self.outvecs, self.prec, self.itype, self.otype) \ No newline at end of file diff --git a/pycbc/filter/inspiral_utils.py b/pycbc/filter/inspiral_utils.py new file mode 100644 index 00000000000..98121837237 --- /dev/null +++ b/pycbc/filter/inspiral_utils.py @@ -0,0 +1,753 @@ +"""GPU-optimized utilities for inspiral search + +This module contains GPU-specific optimized implementations for inspiral +template matching, designed to work directly with CuPy arrays and batched +operations. +""" + +import cupy as cp +import numpy as np +from pycbc.types import Array, FrequencySeries, float32, complex64 +from pycbc import DYN_RANGE_FAC +import time + +# Global timing accumulators +_profile_times = { + 'template_generation': 0.0, + 'sigmasq_computation': 0.0, + 'matched_filter': 0.0, + 'snr_normalization': 0.0, + 'threshold_cluster': 0.0, + 'chisq_computation': 0.0, + 'sg_chisq': 0.0, + 'newsnr_cut': 0.0, + 'output_preparation': 0.0, + 'total_function': 0.0, + 'gpu_sync': 0.0 +} +_profile_counts = { + 'batches_processed': 0, + 'segments_processed': 0, + 'templates_processed': 0, + 'triggers_found': 0 +} + +def reset_profile(): + """Reset profiling counters""" + global _profile_times, _profile_counts + for key in _profile_times: + _profile_times[key] = 0.0 + for key in _profile_counts: + _profile_counts[key] = 0 + +def print_profile(): + """Print detailed profiling information""" + import logging + logger = logging.getLogger('py.pycbc') + + total = _profile_times['total_function'] + if total == 0: + logger.info("No profiling data collected") + return + + logger.info("=" * 80) + logger.info("GPU BATCHED INSPIRAL DETAILED PROFILE") + logger.info("=" * 80) + logger.info(f"Batches processed: {_profile_counts['batches_processed']}") + logger.info(f"Segments processed: {_profile_counts['segments_processed']}") + logger.info(f"Templates processed: {_profile_counts['templates_processed']}") + logger.info(f"Triggers found: {_profile_counts['triggers_found']}") + logger.info("") + logger.info(f"{'Operation':<30} {'Time (s)':<12} {'% Total':<10} {'Avg/call (ms)':<15}") + logger.info("-" * 80) + + items = [ + ('Template Generation', 'template_generation', _profile_counts['batches_processed']), + ('Sigmasq Computation', 'sigmasq_computation', _profile_counts['segments_processed']), + ('Matched Filtering', 'matched_filter', _profile_counts['segments_processed']), + ('SNR Normalization', 'snr_normalization', _profile_counts['segments_processed']), + ('Threshold & Cluster', 'threshold_cluster', _profile_counts['segments_processed']), + ('Chi-squared', 'chisq_computation', _profile_counts['segments_processed']), + ('SG Chi-squared', 'sg_chisq', _profile_counts['segments_processed']), + ('NewSNR Cut', 'newsnr_cut', _profile_counts['segments_processed']), + ('Output Preparation', 'output_preparation', _profile_counts['segments_processed']), + ('GPU Synchronization', 'gpu_sync', _profile_counts['segments_processed']), + ] + + for name, key, count in items: + t = _profile_times[key] + pct = 100.0 * t / total if total > 0 else 0 + avg = 1000.0 * t / count if count > 0 else 0 + logger.info(f"{name:<30} {t:<12.4f} {pct:<10.2f} {avg:<15.2f}") + + logger.info("-" * 80) + logger.info(f"{'TOTAL':<30} {total:<12.4f} {100.0:<10.2f}") + logger.info("=" * 80) + + +def template_triggers_gpu_batched(t_nums, bank, segments, matched_filter, + power_chisq, sg_chisq, inj_filter_rejector, + cluster_window, out_vals_ref, opt, psd_var=None): + """GPU-optimized batched template processing for CUPY scheme + + This function is specifically designed for GPU processing with batched + templates. It works directly with CuPy arrays and calls GPU kernels + without unnecessary PyCBC Array wrapper overhead. + + Assumptions: + - Using CUPYScheme + - Processing multiple templates in batch (len(t_nums) > 1) + - All templates are SPAtmplt approximant + - batch_size evenly divides number of templates + + Parameters + ---------- + t_nums : list of int + Template indices to process in this batch + bank : FilterBank + Template bank + segments : list + List of frequency-domain data segments + matched_filter : MatchedFilterControl + Matched filter engine + power_chisq : SingleDetPowerChisq + Power chi-squared veto + sg_chisq : SingleDetSGChisq + Sine-Gaussian chi-squared veto + inj_filter_rejector : InjFilterRejector + Injection filter rejector + cluster_window : int + Clustering window in samples + out_vals_ref : dict + Reference dictionary for output values + opt : Namespace + Command-line options + psd_var : array-like, optional + PSD variation data + + Returns + ------- + out_vals : dict + Dictionary with trigger data arrays: + - 'template_id': template indices (from t_nums) + - 'time_index': time indices + - 'snr': SNR values + - 'chisq': chi-squared values + - 'chisq_dof': chi-squared degrees of freedom + - 'sigmasq': sigmasq values + tparams : list of dict + Template parameters for each template + """ + import copy + + t_start_func = time.time() + _profile_counts['batches_processed'] += 1 + _profile_counts['templates_processed'] += len(t_nums) + + num_templates = len(t_nums) + tparams = [] + + # Accumulators for all triggers across all segments + all_template_ids = [] + all_time_indices = [] + all_snrs = [] + all_chisqs = [] + all_chisq_dofs = [] + all_sigmasqs = [] + + # Get template data as 2D cupy array (batch_size x freq_length) + # This calls directly into batched template generation, bypassing PyCBC array overhead + t0 = time.time() + htilde_batch, templates, kmin_array, kmax_array = _generate_templates_gpu(t_nums, bank) + cp.cuda.Stream.null.synchronize() + _profile_times['template_generation'] += time.time() - t0 + + tparams = [bank.table[i] for i in t_nums] + + # Get filter parameters for each template (kept for compatibility, but kmin/kmax are better) + template_flow = cp.array([t.f_lower for t in templates], dtype=cp.float32) + + # Process each segment + for s_num, stilde in enumerate(segments): + _profile_counts['segments_processed'] += 1 + # Skip if any template in batch should be rejected + if not all(inj_filter_rejector.template_segment_checker(bank, t_num, stilde) + for t_num in t_nums): + continue + + # Get segment data as cupy array + stilde_data = stilde._data # Underlying cupy array + psd_data = stilde.psd._data # Underlying cupy array + + # Compute sigmasq for all templates + t0 = time.time() + sigmasqs = _compute_sigmasqs_gpu(htilde_batch, psd_data, kmin_array, kmax_array, stilde.delta_f) + cp.cuda.Stream.null.synchronize() + _profile_times['sigmasq_computation'] += time.time() - t0 + + # Batched matched filtering - direct kernel call + # Note: stilde is already overwhitened (divided by PSD) + t0 = time.time() + snr_batch, corr_batch = _matched_filter_gpu(htilde_batch, stilde_data, + matched_filter.kmin, + matched_filter.kmax) + cp.cuda.Stream.null.synchronize() + _profile_times['matched_filter'] += time.time() - t0 + + # The SNR normalization for overwhitened data + # SNR = IFFT(conj(h) * s) / sqrt(sigmasq / (4 * delta_f)) + # Simplifying: SNR = IFFT_result * sqrt(4 * delta_f) / sqrt(sigmasq) + # But let's match what PyCBC does: just divide by sqrt(sigmasq) + t0 = time.time() + snr_batch = snr_batch / cp.sqrt(sigmasqs[:, cp.newaxis]) + cp.cuda.Stream.null.synchronize() + _profile_times['snr_normalization'] += time.time() - t0 + + # Threshold and cluster - operates on 2D arrays + t0 = time.time() + trigger_indices, trigger_snrs = _threshold_and_cluster_gpu( + snr_batch, opt.snr_threshold, cluster_window, stilde.analyze) + cp.cuda.Stream.null.synchronize() + _profile_times['threshold_cluster'] += time.time() - t0 + + # Compute chi-squared for triggers + t0 = time.time() + chisqs_batch, chisq_dofs_batch = _compute_chisq_gpu( + corr_batch, trigger_indices, trigger_snrs, sigmasqs, + stilde.psd, power_chisq, templates, t_nums) + cp.cuda.Stream.null.synchronize() + _profile_times['chisq_computation'] += time.time() - t0 + + # Extract results for each template + t0 = time.time() + + # Early exit if no triggers + n_triggers_total = len(trigger_indices) + if n_triggers_total == 0: + _profile_times['output_preparation'] += time.time() - t0 + continue + + # Move constant calculations outside the loop + time_offset = stilde.cumulative_index - stilde.analyze.start + + # Apply time offset to all triggers at once (before splitting by template) + trigger_indices[:, 1] += time_offset + + # Get template ID for each trigger (this is the index within the batch, 0 to num_templates-1) + template_batch_ids = trigger_indices[:, 0] + + # Total triggers for profiling + _profile_counts['triggers_found'] += n_triggers_total + + # Map batch indices to actual template IDs from t_nums + # template_batch_ids are indices 0, 1, 2, ... num_templates-1 + # We need to convert them to actual template indices from t_nums + template_ids_actual = cp.array([t_nums[i] for i in template_batch_ids.get()], dtype=cp.int32) + + # Collect arrays for this segment + time_indices = trigger_indices[:, 1] + + # Get sigmasq for each trigger + sigmasq_vals = sigmasqs[template_batch_ids] + + # Append to accumulators + all_template_ids.append(template_ids_actual) + all_time_indices.append(time_indices) + all_snrs.append(trigger_snrs) + all_chisqs.append(chisqs_batch) + all_chisq_dofs.append(chisq_dofs_batch) + all_sigmasqs.append(sigmasq_vals) + + _profile_times['output_preparation'] += time.time() - t0 + + # Concatenate all results + if len(all_template_ids) > 0: + out_vals = { + 'template_id': Array(cp.concatenate(all_template_ids), copy=False), + 'time_index': Array(cp.concatenate(all_time_indices), copy=False), + 'snr': Array(cp.concatenate(all_snrs), copy=False), + 'chisq': Array(cp.concatenate(all_chisqs), copy=False), + 'chisq_dof': Array(cp.concatenate(all_chisq_dofs), copy=False), + 'sigmasq': Array(cp.concatenate(all_sigmasqs), copy=False) + } + else: + # No triggers found + out_vals = { + 'template_id': Array(cp.array([], dtype=cp.int32), copy=False), + 'time_index': Array(cp.array([], dtype=cp.int32), copy=False), + 'snr': Array(cp.array([], dtype=cp.complex64), copy=False), + 'chisq': Array(cp.array([], dtype=cp.float32), copy=False), + 'chisq_dof': Array(cp.array([], dtype=cp.int32), copy=False), + 'sigmasq': Array(cp.array([], dtype=cp.float32), copy=False) + } + + _profile_times['total_function'] += time.time() - t_start_func + return out_vals, tparams + + +def _generate_templates_gpu(t_nums, bank): + """Generate batch of templates directly as 2D cupy array + + Extracted from bank.__getitem__ to avoid PyCBC array overhead. + Calls spa_tmplt_batch directly. + + Parameters + ---------- + t_nums : list of int + Template indices + bank : FilterBank + Template bank + + Returns + ------- + htilde_batch : cupy.ndarray + 2D array of templates (num_templates x freq_length) + templates : list + List of FrequencySeries template objects + """ + from pycbc.waveform.spa_tmplt import spa_tmplt_batch + from pycbc.waveform.bank import find_variable_start_frequency + from pycbc.types import zeros + import logging + import types + + # Prepare common parameters + distance = 1.0 / DYN_RANGE_FAC + common_kwds = { + 'delta_f': bank.delta_f, + 'f_lower': bank.f_lower, + 'distance': distance, + **bank.extra_args + } + + # Prepare per-template parameters + templates_params = [] + f_end_list = [] + f_low_list = [] + + for idx in t_nums: + f_end = bank.end_frequency(idx) + if f_end is None or f_end >= (bank.filter_length * bank.delta_f): + f_end = (bank.filter_length - 1) * bank.delta_f + + f_low = find_variable_start_frequency('SPAtmplt', + bank.table[idx], + bank.f_lower, + bank.max_template_length) + + params = { + 'mass1': bank.table[idx].mass1, + 'mass2': bank.table[idx].mass2, + 'spin1z': bank.table[idx].spin1z, + 'spin2z': bank.table[idx].spin2z, + 'distance': distance + } + + templates_params.append(params) + f_end_list.append(f_end) + f_low_list.append(f_low) + + # Log batch generation + logging.info('Generating templates %s-%s (%d templates) in batch from %s Hz' % + (t_nums[0], t_nums[-1], len(t_nums), min(f_low_list))) + + # Generate all templates in one batch + htilde_batch, htilde_list, kmin_array, kmax_array = spa_tmplt_batch(templates_params, bank.filter_length, **common_kwds) + + # Process each template and extract data + for i, (idx, htilde) in enumerate(zip(t_nums, htilde_list)): + template_duration = htilde.chirp_length if hasattr(htilde, 'chirp_length') else None + ttotal = htilde.length_in_time if hasattr(htilde, 'length_in_time') else None + + bank.table[idx].template_duration = template_duration + + htilde = htilde.astype(bank.dtype) + htilde.f_lower = f_low_list[i] + htilde.min_f_lower = bank.min_f_lower + htilde.end_idx = int(f_end_list[i] / htilde.delta_f) + htilde.params = bank.table[idx] + htilde.chirp_length = template_duration + htilde.length_in_time = ttotal + htilde.approximant = 'SPAtmplt' + htilde.end_frequency = f_end_list[i] + + # Update the list with processed template + htilde_list[i] = htilde + + # Extract underlying cupy data into 2D array + + return htilde_batch, htilde_list, kmin_array, kmax_array + + +def _compute_sigmasqs_gpu(htilde_batch, psd_data, kmin_array, kmax_array, delta_f): + """Compute sigmasq for batch of templates using fully batched GPU computation + + Computes: sigmasq[i] = 4 * delta_f * sum(|h[i,k]|^2 / S[k]) for k in [kmin[i], kmax[i]) + + Uses broadcasting to create a mask and compute all sigmasqs in parallel. + + Parameters + ---------- + htilde_batch : cupy.ndarray + 2D array of templates (num_templates x freq_length) + psd_data : cupy.ndarray + PSD array (freq_length,) + kmin_array : cupy.ndarray + Start frequency index for each template (num_templates,) + kmax_array : cupy.ndarray + End frequency index for each template (num_templates,) + delta_f : float + Frequency spacing + + Returns + ------- + cupy.ndarray + Array of sigmasq values for each template (num_templates,) + """ + num_templates = htilde_batch.shape[0] + freq_length = min(htilde_batch.shape[1], len(psd_data)) + + # Compute |h|^2 / PSD for all templates and frequencies at once + # Shape: (num_templates, freq_length) + power_spectrum = (htilde_batch[:, :freq_length].real**2 + htilde_batch[:, :freq_length].imag**2) / psd_data[:freq_length] + + # Create frequency index array: shape (freq_length,) + k_indices = cp.arange(freq_length, dtype=cp.int64) + + # Broadcast to create mask: shape (num_templates, freq_length) + # mask[i, k] = True if kmin[i] <= k < kmax[i] + kmin_broadcast = kmin_array[:, cp.newaxis] # shape: (num_templates, 1) + kmax_broadcast = kmax_array[:, cp.newaxis] # shape: (num_templates, 1) + mask = (k_indices >= kmin_broadcast) & (k_indices < kmax_broadcast) # shape: (num_templates, freq_length) + + # Apply mask and sum across frequency axis + # This computes the sum for all templates in parallel + sigmasqs = cp.sum(power_spectrum * mask, axis=1) # shape: (num_templates,) + + # Apply the 4 * delta_f factor + sigmasqs *= 4.0 * delta_f + + return sigmasqs + + +def _matched_filter_gpu(htilde_batch, stilde_data, kmin, kmax): + """Batched matched filtering on GPU + + Parameters + ---------- + htilde_batch : cupy.ndarray + 2D array of templates (num_templates x freq_length) + stilde_data : cupy.ndarray + Strain data in frequency domain + kmin : int + Minimum frequency index + kmax : int + Maximum frequency index + + Returns + ------- + snr_batch : cupy.ndarray + 2D array of SNR time series (num_templates x time_length) + corr_batch : cupy.ndarray + 2D array of correlation in frequency domain (num_templates x freq_length) + """ + num_templates = htilde_batch.shape[0] + freq_length = htilde_batch.shape[1] + + # Allocate output arrays + corr_batch = cp.zeros((num_templates, (freq_length-1)*2), dtype=cp.complex64) + + # Ensure stilde_data length matches + stilde_len = min(len(stilde_data), freq_length) + + # Batched correlation: corr = conj(h) * s + corr_batch[:, kmin:kmax] = cp.conj(htilde_batch[:, kmin:kmax]) * stilde_data[kmin:kmax] + + # Batched IFFT to get SNR time series + snr_batch = cp.fft.ifft(corr_batch, axis=1) + delta_f = 1./256 + norm = (4.0 * delta_f) + snr_batch *= norm * 524288 + + return snr_batch, corr_batch + + +def _normalize_snr_gpu(snr_batch, sigmasqs): + """Normalize SNR by sigmasq + + Parameters + ---------- + snr_batch : cupy.ndarray + 2D array of SNR time series + sigmasqs : cupy.ndarray + Array of sigmasq values + + Returns + ------- + cupy.ndarray + Normalized SNR time series + """ + # Reshape sigmasqs for broadcasting + norm = cp.sqrt(sigmasqs[:, cp.newaxis]) + return snr_batch / norm + + +def _threshold_and_cluster_gpu(snr_batch, threshold, window, analyze_segment): + """Threshold and cluster SNR time series using CUDA kernels + + Parameters + ---------- + snr_batch : cupy.ndarray + 2D array of SNR time series (num_templates x time_length) + threshold : float or cupy.ndarray + SNR threshold (can be array with one value per template) + window : int + Clustering window in samples + analyze_segment : slice + Segment to analyze + + Returns + ------- + trigger_indices : cupy.ndarray + 2D array of trigger indices (num_triggers x 2) [template_idx, time_idx] + trigger_snrs : cupy.ndarray + Array of trigger SNR values + """ + from pycbc.events.threshold_cupy import get_tkernel, FastFilter + + batch_size, series_length = snr_batch.shape + analyse_start = cp.int32(analyze_segment.start) + analyse_end = cp.int32(analyze_segment.stop) + analyse_len = analyse_end - analyse_start + batch_mem_size = 1024 + + # Allocate output arrays (don't use global cache to avoid shape mismatches) + outv = cp.zeros((batch_size, batch_mem_size), dtype=cp.complex64) + outl = cp.zeros((batch_size, batch_mem_size), dtype=cp.int32) + + # Convert threshold to array if needed + if not isinstance(threshold, cp.ndarray): + threshold = cp.full(batch_size, threshold, dtype=cp.float32) + + # Square the threshold for kernel + threshold_sq = threshold * threshold + threshold_sq = cp.asarray(threshold_sq, dtype=cp.float32) + window = cp.int32(window) + + # Get kernels + (fn, fn2), nt, nb = get_tkernel( + series_length, + analyse_len, + window, + block_mem_size=batch_mem_size, + batch_size=batch_size + ) + + grid = (nb, batch_size, 1) + block = (nt, 1, 1) + + # Run threshold and cluster kernels + fn(grid, block, (snr_batch, outv, outl, window, threshold_sq, series_length, analyse_start)) + fn2(grid, block, (outv, outl, threshold_sq, window)) + + # Filter results using FastFilter + fast_filter = FastFilter(snum=batch_size, nb=nb) + cv_filtered, cl_filtered, cv_out, cl_out, sizes, template_map = \ + fast_filter.filter_arrays(outv[:, :nb], outl[:, :nb]) + + # Convert results to format expected by rest of pipeline + num_triggers = int(sizes.sum()) + + if num_triggers == 0: + return cp.array([], dtype=cp.int32).reshape(0, 2), cp.array([], dtype=cp.complex64) + + # Time indices from kernel are relative to analyze_segment.start + # Add analyze_segment.start to make them relative to full snr_batch + time_indices = cl_out[:num_triggers] + analyze_segment.start + + # Create trigger indices array [template_idx, time_idx] + trigger_indices = cp.stack([template_map[:num_triggers], time_indices], axis=1) + trigger_snrs = cv_out[:num_triggers] + + return trigger_indices, trigger_snrs + + +def _compute_chisq_gpu(corr_batch, trigger_indices, trigger_snrs, sigmasqs, + psd, power_chisq, templates, t_nums): + """Compute chi-squared for triggers using batched GPU computation + + Parameters + ---------- + corr_batch : cupy.ndarray + 2D array of correlations (num_templates x freq_length) + trigger_indices : cupy.ndarray + Trigger indices (num_triggers x 2) [template_idx, time_idx] + trigger_snrs : cupy.ndarray + Trigger SNR values (complex, unnormalized) + sigmasqs : cupy.ndarray + Template sigmasq values + psd : FrequencySeries + Power spectral density + power_chisq : SingleDetPowerChisq + Power chisq calculator + templates : list + List of template waveforms + t_nums : list + Template indices in the batch + + Returns + ------- + chisqs : cupy.ndarray + Chi-squared values + chisq_dofs : cupy.ndarray + Chi-squared degrees of freedom + """ + import logging + logger = logging.getLogger('py.pycbc') + + t_total_start = time.time() + + num_triggers = len(trigger_indices) + + if num_triggers == 0: + return cp.array([], dtype=cp.float32), cp.array([], dtype=cp.int32) + + if not power_chisq.do: + # Chi-squared disabled, return dummy values + template_idxs = trigger_indices[:, 0] + chisq_dofs = sigmasqs[template_idxs] * 0 + 10 + chisqs = chisq_dofs.astype(cp.float32) + return chisqs, chisq_dofs.astype(cp.int32) + + t_setup_start = time.time() + # Extract data needed for values_batch + template_map = trigger_indices[:, 0] # Which template each trigger belongs to + time_indices = trigger_indices[:, 1] # Time index for each trigger + + # The chisq computation in values_batch expects: + # - snrvs: unnormalized complex SNR values (from IFFT, no 4*df*N factor) + # - snr_norms: normalization factor such that abs(snrvs * snr_norms) gives true SNR + # + # Our pipeline: + # - matched filter outputs: (4*df*N) * IFFT(conj(h)*s) + # - normalized SNR: (4*df*N) * IFFT / sqrt(sigmasq) + # - trigger_snrs contain: (4*df*N) * IFFT / sqrt(sigmasq) + # + # For chisq, we need: + # - unnormalized_snrs = IFFT (no scaling) + # - snr_norms = (4*df*N) / sqrt(sigmasq) + # + # So: unnormalized_snrs * snr_norms = IFFT * (4*df*N) / sqrt(sigmasq) = trigger_snrs ✓ + + delta_f = psd.delta_f + norm_factor = 4.0 * delta_f + + # Un-normalize to get raw IFFT values + unnormalized_snrs = trigger_snrs * cp.sqrt(sigmasqs[template_map, cp.newaxis]).squeeze() / norm_factor + + # Normalization factor - try sqrt(4*df) / sqrt(sigmasq) + snr_norms = norm_factor / cp.sqrt(sigmasqs) + + # Count triggers per template + num_templates = len(templates) + sizes = cp.array([cp.sum(template_map == i) for i in range(num_templates)], dtype=cp.int32) + + t_setup = time.time() - t_setup_start + + t_template_prep_start = time.time() + # Call the batched chisq computation + # Note: corr_batch has full FFT length (power of 2), which shift_sum_batch requires + # The chisq code will only use the relevant frequency bins anyway + + # DEBUG: Check template and PSD lengths + if len(templates) > 0: + psd_len = len(psd) + + # Fix template lengths to match PSD (RFFT length) + # Templates have full FFT length but chisq code expects RFFT length + for template in templates: + if len(template) != psd_len: + # Resize the template data to RFFT length + template.resize(psd_len) + + t_template_prep = time.time() - t_template_prep_start + + t_values_batch_start = time.time() + chisq_list, chisq_dof_list = power_chisq.values_batch( + corr_batch, unnormalized_snrs, snr_norms, sizes, template_map, psd, time_indices, templates + ) + cp.cuda.Stream.null.synchronize() + t_values_batch = time.time() - t_values_batch_start + + + t_postprocess_start = time.time() + # Check if chisq was actually computed (or disabled) + if chisq_list is None: + # Chi-squared was disabled, return dummy values + chisq_dofs = cp.full(num_triggers, 10, dtype=cp.int32) + chisqs = chisq_dofs.astype(cp.float32) + return chisqs, chisq_dofs + + # Concatenate results back into single arrays + chisqs = cp.concatenate(chisq_list) if len(chisq_list) > 0 else cp.array([], dtype=cp.float32) + chisq_dofs = cp.concatenate(chisq_dof_list) if len(chisq_dof_list) > 0 else cp.array([], dtype=cp.int32) + t_postprocess = time.time() - t_postprocess_start + + t_total = time.time() - t_total_start + + logger.info(f" Chisq breakdown - setup: {t_setup:.4f}s, template_prep: {t_template_prep:.4f}s, " + f"values_batch: {t_values_batch:.4f}s, postprocess: {t_postprocess:.4f}s, " + f"total: {t_total:.4f}s, triggers: {num_triggers}") + + return chisqs, chisq_dofs + + +def _compute_sigmasqs_batch(templates, psd): + """Compute sigmasq for a batch of SPAtmplt templates efficiently + + Parameters + ---------- + templates : list + List of template waveforms + psd : FrequencySeries + Power spectral density + + Returns + ------- + list + List of sigmasq values for each template + """ + import pycbc.waveform + from pycbc import DYN_RANGE_FAC + + # Ensure the sigmasq_vec is computed for SPAtmplt + if not hasattr(psd, 'sigmasq_vec'): + psd.sigmasq_vec = {} + + if 'SPAtmplt' not in psd.sigmasq_vec: + psd.sigmasq_vec['SPAtmplt'] = \ + pycbc.waveform.get_waveform_filter_norm( + 'SPAtmplt', + psd, + len(psd), + psd.delta_f, + templates[0].min_f_lower + ) + + curr_sigmasq = psd.sigmasq_vec['SPAtmplt'] + sigmasqs = [] + + for template in templates: + # Compute sigma_scale if not already done + if not hasattr(template, 'sigma_scale'): + amp_norm = pycbc.waveform.get_template_amplitude_norm( + template.params, approximant='SPAtmplt') + amp_norm = 1 if amp_norm is None else amp_norm + template.sigma_scale = (DYN_RANGE_FAC * amp_norm) ** 2.0 + + kmin = int(template.f_lower / psd.delta_f) + sigmasq = template.sigma_scale * \ + (curr_sigmasq[template.end_idx-1] - curr_sigmasq[kmin]) + sigmasqs.append(sigmasq) + + return sigmasqs diff --git a/pycbc/filter/matchedfilter.py b/pycbc/filter/matchedfilter.py index e974ab2eb39..e1444334f7b 100644 --- a/pycbc/filter/matchedfilter.py +++ b/pycbc/filter/matchedfilter.py @@ -29,6 +29,7 @@ import logging from math import sqrt import numpy +import cupy from pycbc.types import TimeSeries, FrequencySeries, zeros, Array from pycbc.types import complex_same_precision_as, real_same_precision_as @@ -37,6 +38,7 @@ from pycbc import events from pycbc.events import ranking import pycbc +import time logger = logging.getLogger('pycbc.filter.matchedfilter') @@ -84,6 +86,12 @@ def _correlate_factory(x, y, z): raise ValueError(err_msg) +@pycbc.scheme.schemed(BACKEND_PREFIX) +def _batch_correlate_factory(x, y, z, batch_size, kmin, kmax): + err_msg = "This class is a stub that should be overridden using the " + err_msg += "scheme. You shouldn't be seeing this error!" + raise ValueError(err_msg) + class Correlator(object): """ Create a correlator engine @@ -104,6 +112,26 @@ def __new__(cls, *args, **kwargs): real_cls = _correlate_factory(*args, **kwargs) return real_cls(*args, **kwargs) # pylint:disable=not-callable +class CupyBatchCorrelator(object): + """ Create a correlator engine + + Parameters + --------- + x : complex64 + Input pycbc.types.Array (or subclass); it will be conjugated + y : complex64 + Input pycbc.types.Array (or subclass); it will not be conjugated + z : complex64 + Output pycbc.types.Array (or subclass). + It will contain conj(x) * y, element by element + + The addresses in memory of the data of all three parameter vectors + must be the same modulo pycbc.PYCBC_ALIGNMENT + """ + def __new__(cls, *args, **kwargs): + real_cls = _batch_correlate_factory(*args, **kwargs) + return real_cls(*args, **kwargs) # pylint:disable=not-callable + # The class below should serve as the parent for all schemed classes. # The intention is that this class serves simply as the location for @@ -130,7 +158,8 @@ class MatchedFilterControl(object): def __init__(self, low_frequency_cutoff, high_frequency_cutoff, snr_threshold, tlen, delta_f, dtype, segment_list, template_output, use_cluster, downsample_factor=1, upsample_threshold=1, upsample_method='pruned_fft', - gpu_callback_method='none', cluster_function='symmetric'): + gpu_callback_method='none', cluster_function='symmetric', + batch_size=32): """ Create a matched filter engine. Parameters @@ -163,8 +192,7 @@ def __init__(self, low_frequency_cutoff, high_frequency_cutoff, snr_threshold, t to the windows before and after it, and only kept as a trigger if larger than both. """ - # Assuming analysis time is constant across templates and segments, also - # delta_f is constant across segments. + # Store all the input parameters self.tlen = tlen self.flen = self.tlen / 2 + 1 self.delta_f = delta_f @@ -174,45 +202,20 @@ def __init__(self, low_frequency_cutoff, high_frequency_cutoff, snr_threshold, t self.flow = low_frequency_cutoff self.fhigh = high_frequency_cutoff self.gpu_callback_method = gpu_callback_method + self.batch_size = batch_size + if cluster_function not in ['symmetric', 'findchirp']: raise ValueError("MatchedFilter: 'cluster_function' must be either 'symmetric' or 'findchirp'") self.cluster_function = cluster_function self.segments = segment_list self.htilde = template_output - if downsample_factor == 1: - self.snr_mem = zeros(self.tlen, dtype=self.dtype) - self.corr_mem = zeros(self.tlen, dtype=self.dtype) - - if use_cluster and (cluster_function == 'symmetric'): - self.matched_filter_and_cluster = self.full_matched_filter_and_cluster_symm - # setup the threasholding/clustering operations for each segment - self.threshold_and_clusterers = [] - for seg in self.segments: - thresh = events.ThresholdCluster(self.snr_mem[seg.analyze]) - self.threshold_and_clusterers.append(thresh) - elif use_cluster and (cluster_function == 'findchirp'): - self.matched_filter_and_cluster = self.full_matched_filter_and_cluster_fc - else: - self.matched_filter_and_cluster = self.full_matched_filter_thresh_only - - # Assuming analysis time is constant across templates and segments, also - # delta_f is constant across segments. - self.kmin, self.kmax = get_cutoff_indices(self.flow, self.fhigh, - self.delta_f, self.tlen) - - # Set up the correlation operations for each analysis segment - corr_slice = slice(self.kmin, self.kmax) - self.correlators = [] - for seg in self.segments: - corr = Correlator(self.htilde[corr_slice], - seg[corr_slice], - self.corr_mem[corr_slice]) - self.correlators.append(corr) - - # setup up the ifft we will do - self.ifft = IFFT(self.corr_mem, self.snr_mem) + # Detect if we're using CUPY scheme + import pycbc.scheme + self.using_cupy = isinstance(pycbc.scheme.mgr.state, pycbc.scheme.CUPYScheme) + if downsample_factor == 1: + self.setup_filtering(use_cluster) elif downsample_factor >= 1: self.matched_filter_and_cluster = self.hierarchical_matched_filter_and_cluster self.downsample_factor = downsample_factor @@ -240,6 +243,91 @@ def __init__(self, low_frequency_cutoff, high_frequency_cutoff, snr_threshold, t else: raise ValueError("Invalid downsample factor") + def setup_filtering(self, use_cluster): + """Set up the matched filtering based on scheme""" + # Initialize memory for SNR and correlation + if self.using_cupy: + import cupy + # For CUPY, allocate memory for batch processing + self.snr_mem = cupy.zeros((self.batch_size, self.tlen), dtype=self.dtype) + self.corr_mem = cupy.zeros((self.batch_size, self.tlen), dtype=self.dtype) + else: + # For other schemes, single template processing + self.snr_mem = zeros(self.tlen, dtype=self.dtype) + self.corr_mem = zeros(self.tlen, dtype=self.dtype) + + # Set up the matched filter function based on clustering options + if use_cluster and (self.cluster_function == 'symmetric'): + self.matched_filter_and_cluster = self.full_matched_filter_and_cluster_symm + if self.using_cupy: + self.setup_cupy_clustering() + self.matched_filter_and_cluster = self.full_matched_filter_and_cluster_symm_batched + else: + self.setup_standard_clustering() + elif use_cluster and (self.cluster_function == 'findchirp'): + self.matched_filter_and_cluster = self.full_matched_filter_and_cluster_fc + else: + self.matched_filter_and_cluster = self.full_matched_filter_thresh_only + + # Set up frequency cutoff indices + self.kmin, self.kmax = get_cutoff_indices(self.flow, self.fhigh, + self.delta_f, self.tlen) + + # Set up correlators + if self.using_cupy: + self.setup_cupy_correlators() + else: + self.setup_standard_correlators() + + def setup_cupy_clustering(self): + """Set up clustering for CUPY scheme""" + self.threshold_and_clusterers = [] + for seg in self.segments: + # Need to modify threshold clustering for batch operations + thresh = events.ThresholdCluster(self.snr_mem, seg.analyze) + self.threshold_and_clusterers.append(thresh) + + def setup_standard_clustering(self): + """Set up clustering for standard schemes""" + self.threshold_and_clusterers = [] + for seg in self.segments: + thresh = events.ThresholdCluster(self.snr_mem[seg.analyze]) + self.threshold_and_clusterers.append(thresh) + + def setup_cupy_correlators(self): + """Set up batch correlators for CUPY scheme""" + corr_slice = slice(self.kmin, self.kmax) + self.correlators = [] + htilde_views = [htilde_t.data for htilde_t in self.htilde] + for seg in self.segments: + # Create batch correlator that can handle multiple templates + corr = CupyBatchCorrelator( + htilde_views, + seg, + self.corr_mem, + self.batch_size, + self.kmin, + self.kmax, + ) + self.correlators.append(corr) + + # Set up batch IFFT operations + from pycbc.fft.cupyfft import BatchIFFT + self.ifft = BatchIFFT(self.corr_mem, self.snr_mem, self.batch_size) + + def setup_standard_correlators(self): + """Set up standard correlators for non-CUPY schemes""" + corr_slice = slice(self.kmin, self.kmax) + self.correlators = [] + for seg in self.segments: + corr = Correlator(self.htilde[corr_slice], + seg[corr_slice], + self.corr_mem[corr_slice]) + self.correlators.append(corr) + + # setup up the ifft we will do + self.ifft = IFFT(self.corr_mem, self.snr_mem) + def full_matched_filter_and_cluster_symm(self, segnum, template_norm, window, epoch=None): """ Returns the complex snr timeseries, normalization of the complex snr, the correlation vector frequency series, the list of indices of the @@ -277,103 +365,229 @@ def full_matched_filter_and_cluster_symm(self, segnum, template_norm, window, ep snrv, idx = self.threshold_and_clusterers[segnum].threshold_and_cluster(self.snr_threshold / norm, window) if len(idx) == 0: + print("NO TRIGGERS") return [], [], [], [], [] - logger.info("%d points above threshold", len(idx)) + logging.info("%s points above threshold" % str(len(idx))) snr = TimeSeries(self.snr_mem, epoch=epoch, delta_t=self.delta_t, copy=False) corr = FrequencySeries(self.corr_mem, delta_f=self.delta_f, copy=False) return snr, norm, corr, idx, snrv - def full_matched_filter_and_cluster_fc(self, segnum, template_norm, window, epoch=None): + def full_matched_filter_and_cluster_symm_batched(self, segnum, template_norms, window, epoch=None): """ Returns the complex snr timeseries, normalization of the complex snr, the correlation vector frequency series, the list of indices of the - triggers, and the snr values at the trigger locations. Returns empty - lists for these for points that are not above the threshold. - - Calculated the matched filter, threshold, and cluster. + triggers, and the snr values at the trigger locations for a batch of templates. Parameters ---------- segnum : int Index into the list of segments at MatchedFilterControl construction against which to filter. - template_norm : float - The htilde, template normalization factor. + template_norms : float or list + The htilde, template normalization factor, for each template in batch. window : int Size of the window over which to cluster triggers, in samples Returns ------- - snr : TimeSeries - A time series containing the complex snr. - norm : float - The normalization of the complex snr. - correlation: FrequencySeries - A frequency series containing the correlation vector. - idx : Array - List of indices of the triggers. - snrv : Array - The snr values at the trigger locations. + snrs : list of TimeSeries + Time series containing the complex snr for each template. + norms : list of float + The normalization of the complex snr for each template. + correlations: list of FrequencySeries + Frequency series containing the correlation vectors. + idxs : list of Array + List of indices of the triggers for each template. + snrvs : list of Array + The snr values at the trigger locations for each template. """ - norm = (4.0 * self.delta_f) / sqrt(template_norm) - self.correlators[segnum].correlate() - self.ifft.execute() - idx, snrv = events.threshold(self.snr_mem[self.segments[segnum].analyze], - self.snr_threshold / norm) - idx, snrv = events.cluster_reduce(idx, snrv, window) - - if len(idx) == 0: - return [], [], [], [], [] + # Handle both single template and batch cases + if not isinstance(template_norms, (list, numpy.ndarray)): + template_norms = [template_norms] + + if self.using_cupy: + # Batch processing on GPU + t_start = time.time() + norms = cupy.asarray([(4.0 * self.delta_f) / cupy.sqrt(template_norm) + for template_norm in template_norms]) + print("Time to calculate norms: ", time.time() - t_start) + + # Do batched correlation + t_start = time.time() + self.correlators[segnum].correlate() + print("Time to correlate: ", time.time() - t_start) + + # Do batched IFFT + t_start = time.time() + self.ifft.execute() + print("Time to IFFT: ", time.time() - t_start) + + t_start = time.time() + # Process results for each template in batch + # Threshold and cluster + batched_results, snrv_arr, idxv_arr, sizes, template_map = \ + self.threshold_and_clusterers[segnum].threshold_and_cluster( + self.snr_threshold / norms, window + ) + + # Aggregate results across batches + all_idx = [] + all_snrv = [] + for snrv_batch, idx_batch in batched_results: + if len(idx_batch) == 0: + all_snrv.append([]) + all_idx.append([]) + continue + all_snrv.append(snrv_batch) + all_idx.append(idx_batch) + + logger.info("%d points above threshold across all batches", sum(len(batch) for batch in all_idx)) + + # Create time series and frequency series + snr = [TimeSeries(snr, epoch=epoch, delta_t=self.delta_t, copy=False) for snr in self.snr_mem] + corr = [FrequencySeries(corr, delta_f=self.delta_f, copy=False) for corr in self.corr_mem] + + return snr, norms, corr, all_idx, all_snrv, snrv_arr, idxv_arr, sizes, template_map + + def full_matched_filter_and_cluster_fc(self, segnum, template_norms, window, epoch=None): + """FindChirp clustering version of the matched filter + + Similar to full_matched_filter_and_cluster_symm but uses findchirp clustering + """ + if not isinstance(template_norms, (list, numpy.ndarray)): + template_norms = [template_norms] + + if self.using_cupy: + norms = [(4.0 * self.delta_f) / numpy.sqrt(template_norm) + for template_norm in template_norms] + + self.correlators[segnum].correlate() + self.ifft.execute() + + snrs, cors, idxs, snrvs = [], [], [], [] + for i, norm in enumerate(norms): + # For each template in batch get triggers above threshold + idx, snrv = events.threshold( + self.snr_mem[i][self.segments[segnum].analyze], + self.snr_threshold / norm + ) + + # Apply findchirp clustering + idx, snrv = events.cluster_reduce(idx, snrv, window) + + if len(idx) == 0: + snrs.append([]) + cors.append([]) + idxs.append([]) + snrvs.append([]) + continue + + logging.info("%d points above threshold", len(idx)) + + snr = TimeSeries(self.snr_mem[i], epoch=epoch, + delta_t=self.delta_t, copy=False) + corr = FrequencySeries(self.corr_mem[i], delta_f=self.delta_f, + copy=False) + + snrs.append(snr) + cors.append(corr) + idxs.append(idx) + snrvs.append(snrv) + + return snrs, norms, cors, idxs, snrvs + + else: + # Original single template processing + norm = (4.0 * self.delta_f) / numpy.sqrt(template_norms[0]) + + self.correlators[segnum].correlate() + self.ifft.execute() + + idx, snrv = events.threshold( + self.snr_mem[self.segments[segnum].analyze], + self.snr_threshold / norm + ) + idx, snrv = events.cluster_reduce(idx, snrv, window) - logger.info("%d points above threshold", len(idx)) + if len(idx) == 0: + return [], [], [], [], [] - snr = TimeSeries(self.snr_mem, epoch=epoch, delta_t=self.delta_t, copy=False) - corr = FrequencySeries(self.corr_mem, delta_f=self.delta_f, copy=False) - return snr, norm, corr, idx, snrv + logging.info("%d points above threshold", len(idx)) - def full_matched_filter_thresh_only(self, segnum, template_norm, window=None, epoch=None): - """ Returns the complex snr timeseries, normalization of the complex snr, - the correlation vector frequency series, the list of indices of the - triggers, and the snr values at the trigger locations. Returns empty - lists for these for points that are not above the threshold. + snr = TimeSeries(self.snr_mem, epoch=epoch, + delta_t=self.delta_t, copy=False) + corr = FrequencySeries(self.corr_mem, delta_f=self.delta_f, + copy=False) + + return [snr], [norm], [corr], [idx], [snrv] - Calculated the matched filter, threshold, and cluster. + def full_matched_filter_thresh_only(self, segnum, template_norms, window=None, epoch=None): + """Thresholding-only version of the matched filter + + Similar to above functions but only applies threshold, no clustering + """ + if not isinstance(template_norms, (list, numpy.ndarray)): + template_norms = [template_norms] + + if self.using_cupy: + norms = [(4.0 * self.delta_f) / numpy.sqrt(template_norm) + for template_norm in template_norms] + + self.correlators[segnum].correlate() + self.ifft.execute() + + snrs, cors, idxs, snrvs = [], [], [], [] + for i, norm in enumerate(norms): + idx, snrv = events.threshold_only( + self.snr_mem[i][self.segments[segnum].analyze], + self.snr_threshold / norm + ) + + if len(idx) == 0: + snrs.append([]) + cors.append([]) + idxs.append([]) + snrvs.append([]) + continue + + logging.info("%d points above threshold", len(idx)) + + snr = TimeSeries(self.snr_mem[i], epoch=epoch, + delta_t=self.delta_t, copy=False) + corr = FrequencySeries(self.corr_mem[i], delta_f=self.delta_f, + copy=False) + + snrs.append(snr) + cors.append(corr) + idxs.append(idx) + snrvs.append(snrv) + + return snrs, norms, cors, idxs, snrvs + + else: + # Original single template processing + norm = (4.0 * self.delta_f) / numpy.sqrt(template_norms[0]) + + self.correlators[segnum].correlate() + self.ifft.execute() + + idx, snrv = events.threshold_only( + self.snr_mem[self.segments[segnum].analyze], + self.snr_threshold / norm + ) - Parameters - ---------- - segnum : int - Index into the list of segments at MatchedFilterControl construction - against which to filter. - template_norm : float - The htilde, template normalization factor. - window : int - Size of the window over which to cluster triggers, in samples. - This is IGNORED by this function, and provided only for API compatibility. + if len(idx) == 0: + return [], [], [], [], [] - Returns - ------- - snr : TimeSeries - A time series containing the complex snr. - norm : float - The normalization of the complex snr. - correlation: FrequencySeries - A frequency series containing the correlation vector. - idx : Array - List of indices of the triggers. - snrv : Array - The snr values at the trigger locations. - """ - norm = (4.0 * self.delta_f) / sqrt(template_norm) - self.correlators[segnum].correlate() - self.ifft.execute() - idx, snrv = events.threshold_only(self.snr_mem[self.segments[segnum].analyze], - self.snr_threshold / norm) - logger.info("%d points above threshold", len(idx)) + logging.info("%d points above threshold", len(idx)) - snr = TimeSeries(self.snr_mem, epoch=epoch, delta_t=self.delta_t, copy=False) - corr = FrequencySeries(self.corr_mem, delta_f=self.delta_f, copy=False) - return snr, norm, corr, idx, snrv + snr = TimeSeries(self.snr_mem, epoch=epoch, + delta_t=self.delta_t, copy=False) + corr = FrequencySeries(self.corr_mem, delta_f=self.delta_f, + copy=False) + + return [snr], [norm], [corr], [idx], [snrv] def hierarchical_matched_filter_and_cluster(self, segnum, template_norm, window): """ Returns the complex snr timeseries, normalization of the complex snr, diff --git a/pycbc/filter/matchedfilter_cupy.py b/pycbc/filter/matchedfilter_cupy.py index 1a5a01a65fa..fe551ef7e09 100644 --- a/pycbc/filter/matchedfilter_cupy.py +++ b/pycbc/filter/matchedfilter_cupy.py @@ -51,4 +51,53 @@ def correlate(self): def _correlate_factory(x, y, z): return CUPYCorrelator +batched_correlate_kernel = cp.ElementwiseKernel( + "raw X x, raw Y y, int32 stride, int32 batch_size", + "raw Z z", + """ + int batch_idx = i / stride; // Which template in the batch + int elem_idx = i % stride; // Which element within the template + if (batch_idx < batch_size) { + z[i] = conj(x[batch_idx * stride + elem_idx]) * y[elem_idx]; + } + """, + "batched_correlate_kernel", + loop_prep="int _ind_size = stride * batch_size;" +) + +def batched_correlate(templates, data, out, batch_size): + """Parallel correlation for multiple templates""" + stride = len(data) + batched_correlate_kernel( + templates.data, + data.data, + stride, + batch_size, + out.data + ) + +class CUPYBatchCorrelator(_BaseCorrelator): + def __init__(self, xs, y, zs, batch_size, kmin, kmax): + # Concatenate template data into contiguous array + self.x = xs + self.y = y._data + self.z = zs + self.batch_size = batch_size + self.kmin = kmin + self.kmax = kmax + + def correlate(self): + # batched_correlate_kernel( + # cp.asarray(self.x), + # self.y, + # len(self.y), + # self.batch_size, + # self.z, + # self.kmin, + # self.kmax, + # ) + self.z[:,self.kmin:self.kmax] = cp.conj(cp.asarray(self.x)[:, self.kmin:self.kmax]) * self.y[self.kmin:self.kmax] + +def _batch_correlate_factory(xs, y, zs, batch_size, kmin, kmax): + return CUPYBatchCorrelator \ No newline at end of file diff --git a/pycbc/inference/models/__init__.py b/pycbc/inference/models/__init__.py index 7f6cba47ec8..8367b80fbda 100644 --- a/pycbc/inference/models/__init__.py +++ b/pycbc/inference/models/__init__.py @@ -38,6 +38,7 @@ from .relbin import Relative, RelativeTime, RelativeTimeDom from .hierarchical import (HierarchicalModel, MultiSignalModel, JointPrimaryMarginalizedModel) +from .lisa_pre_merger import LISAPreMergerModel # Used to manage a model instance across multiple cores or MPI @@ -208,6 +209,7 @@ def read_from_config(cp, **kwargs): HierarchicalModel, MultiSignalModel, RelativeTimeDom, + LISAPreMergerModel, JointPrimaryMarginalizedModel, )} diff --git a/pycbc/inference/models/brute_marg.py b/pycbc/inference/models/brute_marg.py index 0b419838dd6..ecbac42dc1f 100644 --- a/pycbc/inference/models/brute_marg.py +++ b/pycbc/inference/models/brute_marg.py @@ -24,6 +24,7 @@ from scipy.special import logsumexp from .gaussian_noise import BaseGaussianNoise +from .base import BaseModel from .tools import draw_sample _model = None @@ -38,6 +39,20 @@ def __call__(self, params): loglr = _model.loglr return loglr, _model.current_stats + +_model = None +class likelihood_wrapper2(object): + def __init__(self, model): + global _model + _model = model + + def __call__(self, params): + global _model + _model.update(**params) + loglr = _model._loglikelihood() + return loglr, _model.current_stats + + class BruteParallelGaussianMarginalize(BaseGaussianNoise): name = "brute_parallel_gaussian_marginalize" @@ -51,6 +66,8 @@ def __init__(self, variable_params, from pycbc.inference.models import models self.model = models[base_model](variable_params, **kwds) + # EW CODE SWAP + # self.call = likelihood_wrapper2(self.model) self.call = likelihood_wrapper(self.model) # size of pool for each likelihood call @@ -95,7 +112,8 @@ def _loglr(self): # calculate the marginal loglr and return return logsumexp(loglr) - numpy.log(len(self.phase)) - +# EW CODE SWAP +#class BruteLISASkyModesMarginalize(BaseModel): class BruteLISASkyModesMarginalize(BaseGaussianNoise): name = "brute_lisa_sky_modes_marginalize" @@ -140,6 +158,8 @@ def _extra_stats(self): stats = self.model._extra_stats return stats + # EW CODE SWAP + # def _loglikelihood(self): def _loglr(self): params = [] for sym_num in range(self.num_sky_modes): diff --git a/pycbc/inference/models/lisa_pre_merger.py b/pycbc/inference/models/lisa_pre_merger.py new file mode 100644 index 00000000000..a7c886aeb20 --- /dev/null +++ b/pycbc/inference/models/lisa_pre_merger.py @@ -0,0 +1,213 @@ +# Copyright (C) 2018 Collin Capano +# This program is free software; you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the +# Free Software Foundation; either version 3 of the License, or (at your +# option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General +# Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +""" +This modules provides models that have analytic solutions for the +log likelihood. +""" + +import copy +import logging + +import pycbc.types + +from .base import BaseModel + +import pycbc.psd + +from pycbc.waveform.pre_merger_waveform import ( + pre_process_data_lisa_pre_merger, + generate_waveform_lisa_pre_merger, +) +from pycbc.psd.lisa_pre_merger import generate_pre_merger_psds +from pycbc.waveform.waveform import parse_mode_array +from pycbc.waveform.utils import apply_fseries_time_shift +from .tools import marginalize_likelihood + + +class LISAPreMergerModel(BaseModel): + r"""Model for pre-merger inference in LISA. + + Parameters + ---------- + variable_params : (tuple of) string(s) + A tuple of parameter names that will be varied. + static_params: Dict[str: Any] + Dictionary of static parameters used for waveform generation. + psd_file : str + Path to the PSD file. Uses the same PSD file for LISA_A and LISA_E + channels. + **kwargs : + All other keyword arguments are passed to ``BaseModel``. + """ + name = "lisa_pre_merger" + + def __init__( + self, + variable_params, + static_params=None, + psd_file=None, + **kwargs + ): + # Pop relevant values from kwargs + cutoff_time = int(kwargs.pop('cutoff_time')) + kernel_length = int(kwargs.pop('kernel_length')) + window_length = int(kwargs.pop('window_length')) + extra_forward_zeroes = int(kwargs.pop('extra_forward_zeroes')) + tlen = int(kwargs.pop('tlen')) + sample_rate = float(kwargs.pop('sample_rate')) + data_file = kwargs.pop('data_file') + + # set up base likelihood parameters + super().__init__(variable_params, **kwargs) + + self.static_params = parse_mode_array(static_params) + + if psd_file is None: + raise ValueError("Must specify a PSD file!") + + # Zero phase PSDs for whitening + # Only store the frequency-domain PSDs + logging.info("Generating pre-merger PSDs") + self.whitening_psds = {} + self.whitening_psds['LISA_A'] = generate_pre_merger_psds( + psd_file, + sample_rate=sample_rate, + duration=tlen, + kernel_length=kernel_length, + )["FD"] + self.whitening_psds['LISA_E'] = generate_pre_merger_psds( + psd_file, + sample_rate=sample_rate, + duration=tlen, + kernel_length=kernel_length, + )["FD"] + + # Store data for doing likelihoods. + self.kernel_length = kernel_length + self.window_length = window_length + self.sample_rate = sample_rate + self.cutoff_time = cutoff_time + self.extra_forward_zeroes = extra_forward_zeroes + self.tlen = tlen + + # Load the data from the file + data = {} + for channel in ["LISA_A", "LISA_E"]: + data[channel] = pycbc.types.timeseries.load_timeseries( + data_file, + group=f"/{channel}", + ) + + # Pre-process the pre-merger data + # Returns time-domain data + # Uses UIDs: 4235, 4236 + logging.info("Pre-processing pre-merger data") + pre_merger_data = pre_process_data_lisa_pre_merger( + data, + sample_rate=sample_rate, + psds_for_whitening=self.whitening_psds, + window_length=0, + cutoff_time=self.cutoff_time, + forward_zeroes=self.kernel_length, + ) + + self.lisa_a_strain = pre_merger_data["LISA_A"] + self.lisa_e_strain = pre_merger_data["LISA_E"] + + # Frequency-domain data for computing log-likelihood + self.lisa_a_strain_fd = pycbc.strain.strain.execute_cached_fft( + self.lisa_a_strain, + copy_output=True, + uid=3223965 + ) + self.lisa_e_strain_fd = pycbc.strain.strain.execute_cached_fft( + self.lisa_e_strain, + copy_output=True, + uid=3223967 + ) + # Data epoch + self._epoch = self.lisa_a_strain_fd._epoch + + def get_waveforms(self, params): + """Generate the waveforms given the parameters. + + Note: `params` should already include the static parameters. + """ + dt = params["tc"] - self._epoch + # Time between the end of the data and the time of coalescence + dt_end = params.get("cutoff_deltat", self.tlen - dt) + # Actual time between the end of the data and the cutoff time + # since cutoff time is specified relative to the merger + cutoff_time = self.cutoff_time - dt_end + # Additional zeros at the beginning of the data, these: + # - manually specified zeroes + # - kernel length zeros + # - zeros that will be wrapped around when the data is shifted + forward_zeroes = ( + self.extra_forward_zeroes + + self.kernel_length + + int(dt_end * self.sample_rate) + ) + # Generate the pre-merger waveform + # These waveforms are whitened + # Uses UIDs: 1235(0), 1236(0) + ws = generate_waveform_lisa_pre_merger( + params, + psds_for_whitening=self.whitening_psds, + window_length=self.window_length, + sample_rate=self.sample_rate, + cutoff_time=cutoff_time, + forward_zeroes=forward_zeroes, + ) + + wf = {} + # Adjust epoch to match data and shift merger to the + # correct time. + # Can safely set copy=False since ws won't be used again. + for channel in ws.keys(): + wf[channel] = apply_fseries_time_shift( + ws[channel], float(dt), copy=False, + ) + wf[channel]._epoch = self._epoch + return wf + + def _loglikelihood(self): + """Compute the pre-merger log-likelihood.""" + cparams = copy.deepcopy(self.static_params) + cparams.update(self.current_params) + + # Generate the waveforms + wforms = self.get_waveforms(cparams) + + # Compute for each channel + snr_A = pycbc.filter.overlap_cplx( + wforms["LISA_A"], + self.lisa_a_strain_fd, + normalized=False, + ) + snr_E = pycbc.filter.overlap_cplx( + wforms["LISA_E"], + self.lisa_e_strain_fd, + normalized=False, + ) + # Compute for each channel + a_norm = pycbc.filter.sigmasq(wforms["LISA_A"]) + e_norm = pycbc.filter.sigmasq(wforms["LISA_E"]) + + hs = snr_A + snr_E + hh = (a_norm + e_norm) + + return marginalize_likelihood(complex(hs), float(hh), phase=False) diff --git a/pycbc/psd/lisa_pre_merger.py b/pycbc/psd/lisa_pre_merger.py new file mode 100644 index 00000000000..3d9aeca6529 --- /dev/null +++ b/pycbc/psd/lisa_pre_merger.py @@ -0,0 +1,438 @@ +import lal +from lal import LIGOTimeGPS +import math +import numpy as np +from typing import Optional, Tuple + +import pycbc.fft +import pycbc.psd +import pycbc.types + + +# The GSTLal FIR minimal phase routine +class PSDFirKernel(object): + def __init__(self): + self.revplan = None + self.fwdplan = None + self.target_phase = None + self.target_phase_mask = None + + def set_phase( + self, + psd: lal.REAL8FrequencySeries, + f_low: float = 10.0, + m1: float = 1.4, + m2: float = 1.4 + ) -> None: + """ + Compute the phase response of zero-latency whitening filter + given a reference PSD. + """ + raise NotImplementedError( + "`PSDFirKernel.set_phase` is not implemented!" + ) + kernel, latency, sample_rate = self.psd_to_linear_phase_whitening_fir_kernel(psd) + kernel, phase = self.linear_phase_fir_kernel_to_minimum_phase_whitening_fir_kernel(kernel, sample_rate) + + # get merger model for SNR = 1. + f_psd = psd.f0 + np.arange(len(psd.data.data)) * psd.deltaF + horizon_distance = HorizonDistance(f_low, f_psd[-1], psd.deltaF, m1, m2) + f_model, model= horizon_distance(psd, 1.)[1] + + # find the range of frequency bins covered by the merger + # model + kmin, kmax = f_psd.searchsorted(f_model[0]), f_psd.searchsorted(f_model[-1]) + 1 + + # compute SNR=1 model's (d SNR^2 / df) spectral density + unit_snr2_density = np.zeros_like(phase) + unit_snr2_density[kmin:kmax] = model / psd.data.data[kmin:kmax] + + # integrate across each frequency bin, converting to + # snr^2/bin. NOTE: this step is here for the record, but + # is commented out because it has no effect on the result + # given the renormalization that occurs next. + #unit_snr2_density *= psd.deltaF + + # take 16th root, then normalize so max=1. why? I don't + # know, just feels good, on the whole. + unit_snr2_density = unit_snr2_density**(1./16) + unit_snr2_density /= unit_snr2_density.max() + + # record phase vector and SNR^2 density vector + self.target_phase = phase + self.target_phase_mask = unit_snr2_density + + def psd_to_linear_phase_whitening_fir_kernel( + self, + psd: lal.REAL8FrequencySeries, + invert: bool = True, + nyquist: Optional[float] = None + ) -> Tuple[np.ndarray, int, int]: + """ + Compute an acausal finite impulse-response filter kernel + from a power spectral density conforming to the LAL + normalization convention, such that if colored Gaussian + random noise with the given PSD is fed into an FIR filter + using the kernel the filter's output will be zero-mean + unit-variance Gaussian random noise. The PSD must be + provided as a lal.REAL8FrequencySeries object. + + The phase response of this filter is 0, just like whitening + done in the frequency domain. + + Parameters + ---------- + psd : lal.REAL8FrequencySeries + The reference PSD. + invert : bool + Whether to invert the kernel. Defaults to True. + nyquist : float, optional + Whether to change the Nyquist frequency. Disabled by default. + + Returns + ------- + numpy.ndarray + Array containing the filter kernel. + int + Filter latency in samples + int + The sample rate in Hz. + """ + # + # this could be relaxed with some work + # + + assert psd.f0 == 0.0 + + # + # extract the PSD bins and determine sample rate for kernel + # + + data = psd.data.data / 2 + sample_rate = 2 * (psd.f0 + (len(data) - 1) * psd.deltaF) + + # + # remove LAL normalization + # + + data *= sample_rate + + # + # change Nyquist frequency if requested. round to nearest + # available bin + # + + if nyquist is not None: + i = int(round((nyquist - psd.f0) / psd.deltaF)) + assert i < len(data) + data = data[:i + 1] + sample_rate = 2 * int(round(psd.f0 + (len(data) - 1) * psd.deltaF)) + + # + # compute the FIR kernel. it always has an odd number of + # samples and no DC offset. + # + + data[0] = data[-1] = 0.0 + if invert: + data_nonzeros = (data != 0.) + data[data_nonzeros] = 1./data[data_nonzeros] + # repack data: data[0], data[1], 0, data[2], 0, .... + tmp = np.zeros((2 * len(data) - 1,), dtype = data.dtype) + tmp[len(data)-1:] = data + #tmp[:len(data)] = data + data = tmp + + kernel_fseries = lal.CreateCOMPLEX16FrequencySeries( + name = "double sided psd", + epoch = LIGOTimeGPS(0), + f0 = 0.0, + deltaF = psd.deltaF, + length = len(data), + sampleUnits = lal.Unit("strain s") + ) + + kernel_tseries = lal.CreateCOMPLEX16TimeSeries( + name = "timeseries of whitening kernel", + epoch = LIGOTimeGPS(0.), + f0 = 0., + deltaT = 1.0 / sample_rate, + length = len(data), + sampleUnits = lal.Unit("strain") + ) + + # FIXME check for change in length + if self.revplan is None: + self.revplan = lal.CreateReverseCOMPLEX16FFTPlan(len(data), 1) + + kernel_fseries.data.data = np.sqrt(data) + 0.j + lal.COMPLEX16FreqTimeFFT(kernel_tseries, kernel_fseries, self.revplan) + kernel = kernel_tseries.data.data.real + kernel = np.roll(kernel, (len(data) - 1) // 2) / sample_rate * 2 + + # + # apply a Tukey window whose flat bit is 50% of the kernel. + # preserve the FIR kernel's square magnitude + # + + norm_before = np.dot(kernel, kernel) + kernel *= lal.CreateTukeyREAL8Window(len(data), .5).data.data + kernel *= math.sqrt(norm_before / np.dot(kernel, kernel)) + + # + # the kernel's latency + # + + latency = (len(data) - 1) // 2 + + # + # done + # + + return kernel, latency, sample_rate + + def linear_phase_fir_kernel_to_minimum_phase_whitening_fir_kernel( + self, + linear_phase_kernel: np.ndarray, + sample_rate: int + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute the minimum-phase response filter (zero latency) + associated with a linear-phase response filter (latency + equal to half the filter length). + + From "Design of Optimal Minimum-Phase Digital FIR Filters + Using Discrete Hilbert Transforms", IEEE Trans. Signal + Processing, vol. 48, pp. 1491-1495, May 2000. + + Parameters + ----------- + linear_phase_kernel : numpy.ndarray + The kernel to compute the minimum-phase kernel with. + sample_rate : int + The sample rate in Hz. + + Returns + ------- + numpy.ndarray + The filter kernel. + numpy.ndarray + The phase response. + """ + # + # compute abs of FFT of kernel + # + + # FIXME check for change in length + if self.fwdplan is None: + self.fwdplan = lal.CreateForwardCOMPLEX16FFTPlan(len(linear_phase_kernel), 1) + if self.revplan is None: + self.revplan = lal.CreateReverseCOMPLEX16FFTPlan(len(linear_phase_kernel), 1) + + deltaT = 1. / sample_rate + deltaF = 1. / (len(linear_phase_kernel) * deltaT) + working_length = len(linear_phase_kernel) + + kernel_tseries = lal.CreateCOMPLEX16TimeSeries( + name = "timeseries of whitening kernel", + epoch = LIGOTimeGPS(0.), + f0 = 0., + deltaT = deltaT, + length = working_length, + sampleUnits = lal.Unit("strain") + ) + + kernel_tseries.data.data = linear_phase_kernel + + absX = lal.CreateCOMPLEX16FrequencySeries( + name = "absX", + epoch = LIGOTimeGPS(0), + f0 = 0.0, + deltaF = deltaF, + length = working_length, + sampleUnits = lal.Unit("strain s") + ) + + logabsX = lal.CreateCOMPLEX16FrequencySeries( + name = "absX", + epoch = LIGOTimeGPS(0), + f0 = 0.0, + deltaF = deltaF, + length = working_length, + sampleUnits = lal.Unit("strain s") + ) + + cepstrum = lal.CreateCOMPLEX16TimeSeries( + name = "cepstrum", + epoch = LIGOTimeGPS(0.), + f0 = 0., + deltaT = deltaT, + length = working_length, + sampleUnits = lal.Unit("strain") + ) + + theta = lal.CreateCOMPLEX16FrequencySeries( + name = "theta", + epoch = LIGOTimeGPS(0), + f0 = 0.0, + deltaF = deltaF, + length = working_length, + sampleUnits = lal.Unit("strain s") + ) + + min_phase_kernel = lal.CreateCOMPLEX16TimeSeries( + name = "min phase kernel", + epoch = LIGOTimeGPS(0.), + f0 = 0., + deltaT = deltaT, + length = working_length, + sampleUnits = lal.Unit("strain") + ) + + lal.COMPLEX16TimeFreqFFT(absX, kernel_tseries, self.fwdplan) + absX.data.data[:] = abs(absX.data.data) + + # + # compute the cepstrum of the kernel (i.e., the iFFT of the + # log of the abs of the FFT of the kernel) + # + + logabsX.data.data[:] = np.log(absX.data.data) + lal.COMPLEX16FreqTimeFFT(cepstrum, logabsX, self.revplan) + + # + # multiply cepstrum by sgn + # + + cepstrum.data.data[0] = 0. + cepstrum.data.data[working_length // 2] = 0. + cepstrum.data.data[working_length // 2 + 1:] = -cepstrum.data.data[working_length // 2 + 1:] + + # + # compute theta + # + + lal.COMPLEX16TimeFreqFFT(theta, cepstrum, self.fwdplan) + + # + # compute the gain and phase of the zero-phase + # approximation relative to the original linear-phase + # filter + # + + theta_data = theta.data.data[working_length // 2:] + #gain = np.exp(theta_data.real) + phase = -theta_data.imag + + # + # apply optional masked phase adjustment + # + + if self.target_phase is not None: + # compute phase adjustment for +ve frequencies + phase_adjustment = (self.target_phase - phase) * self.target_phase_mask + + # combine with phase adjustment for -ve frequencies + phase_adjustment = np.concatenate((phase_adjustment[1:][-1::-1].conj(), phase_adjustment)) + + # apply adjustment. phase adjustment is what we + # wish to add to the phase. theta's imaginary + # component contains the negative of the phase, so + # we need to add -phase to theta's imaginary + # component + theta.data.data += -1.j * phase_adjustment + + # report adjusted phase + #phase = -theta.data.data[working_length // 2:].imag + + # + # compute minimum phase kernel + # + + absX.data.data *= np.exp(theta.data.data) + lal.COMPLEX16FreqTimeFFT(min_phase_kernel, absX, self.revplan) + + kernel = min_phase_kernel.data.data.real + + # + # this kernel needs to be reversed to follow conventions + # used with the audiofirfilter and lal_firbank elements + # + + kernel = kernel[-1::-1] + + # + # done + # + + return kernel, phase + + +def generate_pre_merger_psds( + psd_file, + duration, + sample_rate, + kernel_length=17280, +): + """Generate the time- and frequency-domain pre-merger PSDs + + Parameters + ---------- + psd_file : str + Path to the PSD file. + sample_rate : float + The sample rate. + duration : float + Duration in seconds. + kernel_length : int + Length of the whitening kernel in samples. + + Returns + ------- + dict + A dictionary contain the PSDs. The keys denote frequency-domain (FD) + and time-domain (TD). + """ + flen = int(duration * sample_rate) // 2 + 1 + delta_f = 1 / duration + delta_t = 1 / sample_rate + td_psd_length = int(duration * sample_rate) + + psd = pycbc.psd.from_txt( + psd_file, flen, delta_f, delta_f, is_asd_file=False + ) + + psd_kern = PSDFirKernel() + + pdf_lal = psd.lal() + + first_psd_kern, latency, sample_rate = \ + psd_kern.psd_to_linear_phase_whitening_fir_kernel(pdf_lal) + zero_phase_kern, phase = \ + psd_kern.linear_phase_fir_kernel_to_minimum_phase_whitening_fir_kernel(first_psd_kern, sample_rate) + zero_phase_kern = zero_phase_kern * sample_rate**(1.5) / 2**0.5 + + # Time domain pycbc PSD + zero_phase_kern_pycbc = pycbc.types.TimeSeries( + pycbc.types.zeros(td_psd_length), + delta_t=delta_t, + ) + filter_data = pycbc.types.Array(zero_phase_kern[-kernel_length:]) + zero_phase_kern_pycbc.data[-kernel_length:] = filter_data.data[:] + zero_phase_kern_pycbc.data[0] = zero_phase_kern[0] + + # Frequency domain pycbc PSD + zero_phase_kern_pycbc_fd = pycbc.types.FrequencySeries( + pycbc.types.zeros( + len(zero_phase_kern_pycbc) //2 + 1, + dtype=np.complex128, + ), + delta_f=delta_f + ) + pycbc.fft.fft(zero_phase_kern_pycbc, zero_phase_kern_pycbc_fd) + + zero_phase_kern_pycbc_td = zero_phase_kern_pycbc + return { + "TD": zero_phase_kern_pycbc_td, + "FD": zero_phase_kern_pycbc_fd, + } diff --git a/pycbc/scheme.py b/pycbc/scheme.py index 0a9e6740e1e..f47cbcdf87e 100644 --- a/pycbc/scheme.py +++ b/pycbc/scheme.py @@ -118,14 +118,18 @@ def __init__(self, device_num=0): class CUPYScheme(Scheme): """Scheme for using CUPY""" - def __init__(self, device_num=None): - import cupy # Fail now if cupy is not there. + def __init__(self, device_num=0): + import cupy import cupy.cuda + Scheme.__init__(self) self.device_num = device_num - self.cuda_device = cupy.cuda.Device(self.device_num) + self.cuda_device = cupy.cuda.Device(device_num) + def __enter__(self): super().__enter__() - self.cuda_device.__enter__() + self.cuda_device.use() + from pycbc.types.array import update_scheme_types + update_scheme_types() logging.warn( "You are using the CUPY GPU backend for PyCBC. This backend is " "still only a prototype. It may be useful for your application " @@ -133,10 +137,10 @@ def __enter__(self): "output. Please do contribute to the effort to develop this " "further." ) - def __exit__(self, *args): super().__exit__(*args) - self.cuda_device.__exit__(*args) + from pycbc.types.array import update_scheme_types + update_scheme_types() class CPUScheme(Scheme): diff --git a/pycbc/types/array.py b/pycbc/types/array.py index 9959a3700d4..eddacf2702c 100644 --- a/pycbc/types/array.py +++ b/pycbc/types/array.py @@ -35,10 +35,51 @@ import h5py import lal as _lal import numpy as _numpy -from numpy import float32, float64, complex64, complex128, ones -from numpy.linalg import norm - +import cupy as _cupy import pycbc.scheme as _scheme + +# Define a function to get the appropriate numeric types based on scheme +def _get_scheme_types(): + if _scheme.current_prefix() == 'cupy': + import cupy + return { + 'float32': cupy.float32, + 'float64': cupy.float64, + 'complex64': cupy.complex64, + 'complex128': cupy.complex128, + 'ones': cupy.ones, + 'norm': cupy.linalg.norm + } + else: + return { + 'float32': _numpy.float32, + 'float64': _numpy.float64, + 'complex64': _numpy.complex64, + 'complex128': _numpy.complex128, + 'ones': _numpy.ones, + 'norm': _numpy.linalg.norm + } + +# Get initial types but allow for later updates +_types = _get_scheme_types() +float32 = _types['float32'] +float64 = _types['float64'] +complex64 = _types['complex64'] +complex128 = _types['complex128'] +ones = _types['ones'] +norm = _types['norm'] + +def update_scheme_types(): + """Update numeric types based on current scheme""" + global float32, float64, complex64, complex128, ones, norm + _types = _get_scheme_types() + float32 = _types['float32'] + float64 = _types['float64'] + complex64 = _types['complex64'] + complex128 = _types['complex128'] + ones = _types['ones'] + norm = _types['norm'] + from pycbc.scheme import schemed, cpuonly from pycbc.opt import LimitedSizeDict @@ -46,7 +87,9 @@ # we should restrict any functions that do not allow an # array of uint32 integers _ALLOWED_DTYPES = [_numpy.float32, _numpy.float64, _numpy.complex64, - _numpy.complex128, _numpy.uint32, _numpy.int32, int] + _numpy.complex128, _numpy.uint32, _numpy.int32, int, + _cupy.float32, _cupy.float64, _cupy.complex64, + _cupy.complex128, _cupy.uint32, _cupy.int32] try: _ALLOWED_SCALARS = [int, long, float, complex] + _ALLOWED_DTYPES except NameError: @@ -54,8 +97,12 @@ def _convert_to_scheme(ary): if not isinstance(ary._scheme, _scheme.mgr.state.__class__): - converted_array = Array(ary, dtype=ary._data.dtype) - ary._data = converted_array._data + if isinstance(_scheme.mgr.state, _scheme.CUPYScheme): + # Convert to cupy array + ary._data = _cupy.asarray(ary._data) + else: + # Convert to numpy array for CPU schemes + ary._data = _numpy.asarray(ary._data) ary._scheme = _scheme.mgr.state def _convert(func): diff --git a/pycbc/types/array_cupy.py b/pycbc/types/array_cupy.py index a70a4f530e8..64ddf5b5ef8 100644 --- a/pycbc/types/array_cupy.py +++ b/pycbc/types/array_cupy.py @@ -106,7 +106,7 @@ def squared_norm(self): return (self.data.real**2 + self.data.imag**2) def numpy(self): - return cp.asnumpy(self.data) + return cp.asnumpy(self._data) def _copy(self, self_ref, other_ref): self_ref[:] = other_ref[:] @@ -124,17 +124,16 @@ def clear(self): self[:] = 0 def _scheme_matches_base_array(array): + """Check if the array is already a CuPy array""" if isinstance(array, cp.ndarray): return True - else: - return False + return False def _to_device(array): + """Convert input to CuPy array""" return cp.asarray(array) -def numpy(self): - return cp.asnumpy(self._data) - def _copy_base_array(array): + """Copy a CuPy array""" return array.copy() diff --git a/pycbc/types/frequencyseries.py b/pycbc/types/frequencyseries.py index d0d51e64af9..8a25ca13235 100644 --- a/pycbc/types/frequencyseries.py +++ b/pycbc/types/frequencyseries.py @@ -622,9 +622,9 @@ def load_frequencyseries(path, group=None): raise ValueError('Path must end with .npy, .hdf, or .txt') delta_f = (data[-1][0] - data[0][0]) / (len(data) - 1) - if data.ndim == 2: + if data.shape[1] == 2: return FrequencySeries(data[:,1], delta_f=delta_f, epoch=None) - elif data.ndim == 3: + elif data.shape[1] == 3: return FrequencySeries(data[:,1] + 1j*data[:,2], delta_f=delta_f, epoch=None) diff --git a/pycbc/vetoes/chisq.py b/pycbc/vetoes/chisq.py index 838643aa406..28cfd7941e3 100644 --- a/pycbc/vetoes/chisq.py +++ b/pycbc/vetoes/chisq.py @@ -399,6 +399,107 @@ def values(self, corr, snrv, snr_norm, psd, indices, template): else: return None, None + def values_batch(self, corrs, snrvs, snr_norms, sizes, template_map, psd, indices, templates): + """ Calculate the chisq at points given by indices for a batch of templates. + + Parameters + ---------- + corrs : cupy array + 2D array or list of correlation vectors for all templates + snrvs : cupy array + 1D concatenated array of SNR values at trigger points + snr_norms : cupy array + 1D array of normalization factors for each template + sizes : cupy array + Number of triggers for each template + template_map : cupy array + Maps each trigger to its template index + psd : FrequencySeries + Power spectral density + indices : cupy array + 1D concatenated array of trigger indices + templates : list + List of template objects + + Returns + ------- + chisq_list : list of Arrays + Chisq values for each template + chisq_dof_list : list of Arrays + Degrees of freedom for each template + """ + if self.do: + import cupy as cp + from .chisq_cupy import shift_sum_batch + + if not self.snr_threshold: + raise NotImplementedError("Batched chisq currently requires snr_threshold to be set") + + # Filter triggers by SNR threshold + num_triggers = len(indices) + if self.snr_threshold: + # Apply SNR threshold per trigger using template_map to get correct norm + snr_norms_per_trigger = snr_norms[template_map] + above = cp.abs(snrvs * snr_norms_per_trigger) > self.snr_threshold + num_above = int(above.sum()) + logging.info('%s above chisq activation threshold' % num_above) + + # Filter arrays to only above-threshold triggers + above_indices = indices[above] + above_snrvs = snrvs[above] + above_template_map = template_map[above] + else: + num_above = num_triggers + above_indices = indices + above_snrvs = snrvs + above_template_map = template_map + + chisq_out = cp.zeros(num_triggers, dtype=numpy.float32) + chisq_dof_out = cp.zeros(num_triggers, dtype=numpy.int32) + chisq_list = [] + chisq_dof_list = [] + + if num_above > 0: + # Get bins for all templates + bins = [self.cached_chisq_bins(template, psd) for template in templates] + bin_lengths = cp.array([len(cbin) for cbin in bins], dtype=cp.uint32) + dof = (bin_lengths - 1) * 2 - 2 + + # Compute chisq for all above-threshold triggers at once + chisq = shift_sum_batch(corrs, above_indices, bins, bin_lengths, above_template_map) + + # Compute full chisq values + # chisq = (chisq * num_bins - |snr|^2) * norm^2 + # For batched case, need to apply correct normalization per trigger + num_bins = bin_lengths[above_template_map] - 1 + snr_norms_per_trigger = snr_norms[above_template_map] + snr_mag_sq = (above_snrvs.conj() * above_snrvs).real + chisq_computed = (chisq * num_bins - snr_mag_sq) * (snr_norms_per_trigger ** 2.0) + + # Fill in only the above-threshold triggers + if self.snr_threshold: + chisq_out[above] = chisq_computed + chisq_dof_out[above] = dof[above_template_map] + else: + chisq_out[:] = chisq_computed + chisq_dof_out[:] = dof[above_template_map] + + # Split results by template + start = 0 + for idx, size in enumerate(sizes): + size = int(size) + if size > 0: + chisq_list.append(chisq_out[start:start + size]) + chisq_dof_list.append(chisq_dof_out[start:start + size]) + else: + chisq_list.append(cp.empty(0, dtype=numpy.float32)) + chisq_dof_list.append(cp.empty(0, dtype=numpy.int32)) + start += size + + return chisq_list, chisq_dof_list + else: + return None, None + class SingleDetSkyMaxPowerChisq(SingleDetPowerChisq): """Class that handles precomputation and memory management for efficiently diff --git a/pycbc/vetoes/chisq_cupy.py b/pycbc/vetoes/chisq_cupy.py index 51b0b7e7912..59be9270af5 100644 --- a/pycbc/vetoes/chisq_cupy.py +++ b/pycbc/vetoes/chisq_cupy.py @@ -307,10 +307,10 @@ def get_cached_pow2(N): def shift_sum(corr, points, bins): kmin, kmax, bv = get_cached_bin_layout(bins) nb = len(kmin) - N = numpy.uint32(len(corr)) + N = cp.uint32(len(corr)) is_pow2 = get_cached_pow2(N) - nbins = numpy.uint32(len(bins) - 1) - outc = cp.zeros((len(points), nbins), dtype=numpy.complex64) + nbins = cp.uint32(len(bins) - 1) + outc = cp.zeros((len(points), nbins), dtype=cp.complex64) outp = outc.reshape(nbins * len(points)) np = len(points) @@ -328,7 +328,7 @@ def shift_sum(corr, points, bins): elif np == 1: outp, lpoints, np = shift_sum_points_pow2(1, cargs) else: - phase = [numpy.float32(p * 2.0 * numpy.pi / N) for p in points] + phase = [cp.float32(p * 2.0 * cp.pi / N) for p in points] while np > 0: cargs = (corr, outp, phase, np, nb, N, kmin, kmax, bv, nbins) @@ -341,5 +341,415 @@ def shift_sum(corr, points, bins): elif np == 1: outp, phase, np = shift_sum_points(1, cargs) # pylint:disable=no-value-for-parameter - return cp.asnumpy((outc.conj() * outc).sum(axis=1).real) + return (outc.conj() * outc).sum(axis=1).real + +# Batched chisq implementation - OPTIMIZED for coalescing and reduced atomics +chisqkernel_pow2_batch = Template(""" +#include +extern "C" __global__ void power_chisq_at_points_pow2_batch( + float2* corr, // 2D: num_templates X N + float2* outc, // 1D: num_points X max_bin_length + unsigned int N, // Scalar + uint32_t* points, // 1D: num_points + uint32_t* kmin, // 2D: num_templates X max_bin_size + uint32_t* kmax, // 2D: num_templates X max_bin_size + uint32_t* bv, // 2D: num_templates X max_bin_size + uint32_t* nbins, // 1D: num_templates (per-template bin counts) + uint32_t* mapping, // 1D: num_points + uint32_t* nb, // 1D: num_templates + unsigned int max_nbins, // scalar: max number of bins across templates + unsigned int num_points, // scalar + unsigned int num_templates, // scalar + unsigned int max_bin_size // scalar +) +{ + const unsigned int pnum = blockIdx.y; + const unsigned int binnum = blockIdx.x; + const float twopi = ${TWOPI}; + const unsigned long long NN = (unsigned long long) N; + + // Early exit check + if (pnum >= num_points) return; + + const unsigned int tempnum = mapping[pnum]; + if (binnum >= nb[tempnum]) return; + + // Load bin parameters (all threads load same values, gets cached in L1) + const unsigned int idx_base = tempnum * max_bin_size + binnum; + const unsigned int s = kmin[idx_base]; + const unsigned int e = kmax[idx_base]; + + if (s >= e) return; // Empty bin + + const unsigned int bin_idx = bv[idx_base]; + const unsigned long long point = (unsigned long long) points[pnum]; + const unsigned int corr_base = tempnum * N; + + // Each thread accumulates independently (no shared memory needed initially) + float accum_x = 0.0f; + float accum_y = 0.0f; + + // Main loop - fully coalesced memory access + const float phase_mult = twopi / ((float) N); + #pragma unroll 4 + for (unsigned int i = s + threadIdx.x; i < e; i += blockDim.x){ + // Coalesced load + const float2 qt = corr[corr_base + i]; + + // Compute phase + const unsigned int k = (unsigned int)((point * (unsigned long long)i) & (NN-1)); + float re, im; + __sincosf(phase_mult * k, &im, &re); + + // Accumulate + accum_x += re * qt.x - im * qt.y; + accum_y += im * qt.x + re * qt.y; + } + + // Use warp shuffle reduction for first stage (no shared memory) + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + accum_x += __shfl_down_sync(0xffffffff, accum_x, offset); + accum_y += __shfl_down_sync(0xffffffff, accum_y, offset); + } + + // Shared memory only for cross-warp reduction + __shared__ float2 warp_sums[16]; // Up to 512 threads = 16 warps + const unsigned int warp_id = threadIdx.x / 32; + const unsigned int lane_id = threadIdx.x % 32; + + if (lane_id == 0) { + warp_sums[warp_id].x = accum_x; + warp_sums[warp_id].y = accum_y; + } + __syncthreads(); + + // Final reduction by first warp + if (warp_id == 0) { + float2 sum; + if (lane_id < (${NT} / 32)) { + sum = warp_sums[lane_id]; + } else { + sum.x = 0.0f; + sum.y = 0.0f; + } + + #pragma unroll + for (int offset = 8; offset > 0; offset >>= 1) { + sum.x += __shfl_down_sync(0xffffffff, sum.x, offset); + sum.y += __shfl_down_sync(0xffffffff, sum.y, offset); + } + + if (lane_id == 0) { + const unsigned int out_idx = pnum * max_nbins + bin_idx; + atomicAdd(&outc[out_idx].x, sum.x); + atomicAdd(&outc[out_idx].y, sum.y); + } + } +} +""") + + +@functools.lru_cache(maxsize=None) +def get_pchisq_fn_pow2_batch(): + nt = 512 + fn = cp.RawKernel( + chisqkernel_pow2_batch.render(NT=nt, **LALARGS), + f'power_chisq_at_points_pow2_batch', + backend='nvcc', + options=('--use_fast_math',) # Enable fast math for better performance + ) + return fn, nt + + +# Batched chisq for non-power-of-2 FFT lengths +chisqkernel_batch = Template(""" +#include +extern "C" __global__ void power_chisq_at_points_batch( + float2* corr, // 2D: num_templates X N + float2* outc, // 1D: num_points X max_bin_length + unsigned int N, // Scalar + float* phases, // 1D: num_points (phase multiplier for each point) + uint32_t* kmin, // 2D: num_templates X max_bin_size + uint32_t* kmax, // 2D: num_templates X max_bin_size + uint32_t* bv, // 2D: num_templates X max_bin_size + uint32_t* nbins, // 1D: num_templates (per-template bin counts) + uint32_t* mapping, // 1D: num_points + uint32_t* nb, // 1D: num_templates + unsigned int max_nbins, // scalar: max number of bins across templates + unsigned int num_points, // scalar + unsigned int num_templates, // scalar + unsigned int max_bin_size // scalar +) +{ + const unsigned int pnum = blockIdx.y; + const unsigned int binnum = blockIdx.x; + + // Early exit check + if (pnum >= num_points) return; + + const unsigned int tempnum = mapping[pnum]; + if (binnum >= nb[tempnum]) return; + + // Load bin parameters + const unsigned int idx_base = tempnum * max_bin_size + binnum; + const unsigned int s = kmin[idx_base]; + const unsigned int e = kmax[idx_base]; + + if (s >= e) return; // Empty bin + + const unsigned int bin_idx = bv[idx_base]; + const float phase = phases[pnum]; + const unsigned int corr_base = tempnum * N; + + // Each thread accumulates independently + float accum_x = 0.0f; + float accum_y = 0.0f; + + // Main loop - coalesced memory access + #pragma unroll 4 + for (unsigned int i = s + threadIdx.x; i < e; i += blockDim.x){ + const float2 qt = corr[corr_base + i]; + + // Compute phase using sincosf for non-power-of-2 + float re, im; + sincosf(phase * i, &im, &re); + + // Accumulate + accum_x += re * qt.x - im * qt.y; + accum_y += im * qt.x + re * qt.y; + } + + // Warp shuffle reduction + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + accum_x += __shfl_down_sync(0xffffffff, accum_x, offset); + accum_y += __shfl_down_sync(0xffffffff, accum_y, offset); + } + + // Shared memory for cross-warp reduction + __shared__ float2 warp_sums[16]; + const unsigned int warp_id = threadIdx.x / 32; + const unsigned int lane_id = threadIdx.x % 32; + + if (lane_id == 0) { + warp_sums[warp_id].x = accum_x; + warp_sums[warp_id].y = accum_y; + } + __syncthreads(); + + // Final reduction by first warp + if (warp_id == 0) { + float2 sum; + if (lane_id < (${NT} / 32)) { + sum = warp_sums[lane_id]; + } else { + sum.x = 0.0f; + sum.y = 0.0f; + } + + #pragma unroll + for (int offset = 8; offset > 0; offset >>= 1) { + sum.x += __shfl_down_sync(0xffffffff, sum.x, offset); + sum.y += __shfl_down_sync(0xffffffff, sum.y, offset); + } + + if (lane_id == 0) { + const unsigned int out_idx = pnum * max_nbins + bin_idx; + atomicAdd(&outc[out_idx].x, sum.x); + atomicAdd(&outc[out_idx].y, sum.y); + } + } +} +""") + + +@functools.lru_cache(maxsize=None) +def get_pchisq_fn_batch(): + nt = 512 + fn = cp.RawKernel( + chisqkernel_batch.render(NT=nt, **LALARGS), + f'power_chisq_at_points_batch', + backend='nvcc', + options=('--use_fast_math',) + ) + return fn, nt + + +_bin_layout_cache = {} + +def get_cached_bin_layout_batch(bins, bin_lengths): + """Get or compute cached bin layout arrays. + + Returns pre-allocated GPU arrays bv, kmin, kmax, nb that are cached + based on the bin configuration. + """ + import time + t_cache_key_start = time.time() + # Create cache key from bins - OPTIMIZED + # Instead of converting to nested tuples, use array hashes which is much faster + # Hash each bin array and combine them + import cupy as cp + cache_key = tuple( + (len(b), int(b[0]) if len(b) > 0 else 0, int(b[-1]) if len(b) > 0 else 0, + hash(b.data.tobytes()) if hasattr(b, 'data') else hash(bytes(b))) + for b in bins + ) + t_cache_key = time.time() - t_cache_key_start + + if cache_key in _bin_layout_cache: + print(f"CACHE HIT: Reusing bin layout for {len(bins)} templates (cache_key: {t_cache_key:.4f}s)", flush=True) + return _bin_layout_cache[cache_key] + + print(f"CACHE MISS: Computing bin layout for {len(bins)} templates (cache_key: {t_cache_key:.4f}s)", flush=True) + + t_alloc_start = time.time() + # Convert bins list to batch-friendly format + # OPTIMIZED: Single pass - build ranges and track max values simultaneously + BS = 4096 + + # Pre-compute all bin_ranges for all templates + all_bin_ranges = [] + max_num_ranges = 0 + max_nbins = 0 + + for cbins in bins: + bin_ranges = [] + for i in range(len(cbins)-1): + s, e = int(cbins[i]), int(cbins[i+1]) + if (e - s) < BS: + bin_ranges.append((i, s, e)) + else: + # Calculate chunks without creating arrays + num_chunks = ((e - s) + BS//2 - 1) // (BS//2) + for j in range(num_chunks): + chunk_s = s + j * (BS//2) + chunk_e = min(s + (j + 1) * (BS//2), e) + bin_ranges.append((i, chunk_s, chunk_e)) + all_bin_ranges.append(bin_ranges) + max_num_ranges = max(max_num_ranges, len(bin_ranges)) + max_nbins = max(max_nbins, len(cbins) - 1) + + max_bin_size = max_num_ranges + + bv = cp.zeros([len(bins), max_bin_size], dtype=cp.uint32) + kmin = cp.zeros([len(bins), max_bin_size], dtype=cp.uint32) + kmax = cp.zeros([len(bins), max_bin_size], dtype=cp.uint32) + nb = cp.zeros(len(bins), dtype=cp.uint32) + t_alloc = time.time() - t_alloc_start + + t_loop_start = time.time() + for idx1, bin_ranges in enumerate(all_bin_ranges): + # Bulk copy to GPU (much faster than individual assignments) + if bin_ranges: + bin_arr = cp.array(bin_ranges, dtype=cp.uint32) + bv[idx1, :len(bin_ranges)] = bin_arr[:, 0] + kmin[idx1, :len(bin_ranges)] = bin_arr[:, 1] + kmax[idx1, :len(bin_ranges)] = bin_arr[:, 2] + + nb[idx1] = len(bin_ranges) + t_loop = time.time() - t_loop_start + + print(f" Bin layout breakdown - alloc: {t_alloc:.4f}s, loop: {t_loop:.4f}s", flush=True) + + # No need to compute max values - we already have them from the first pass + max_nb_val = max_num_ranges + max_nbins_val = max_nbins + + print(f" Max values from CPU: max_nb={max_nb_val}, max_nbins={max_nbins_val}", flush=True) + + # Cache the result with precomputed max values + result = (bv, kmin, kmax, nb, max_nb_val, max_nbins_val) + _bin_layout_cache[cache_key] = result + return result + + +def shift_sum_batch(corr, points, bins, bin_lengths, mapping): + """Compute chisq shift-sum for batched templates. + + Parameters + ---------- + corr : cupy array + 2D array of correlation data (num_templates x frequency_length) + points : cupy array + 1D array of time indices for triggers + bins : list of cupy arrays + List of bin edges for each template + bin_lengths : cupy array + Number of bins for each template + mapping : cupy array + Maps each point index to its template index + + Returns + ------- + cupy array + Chisq values for each point + """ + import time + t_layout_start = time.time() + # Get cached bin layout (avoids recomputation) - now includes precomputed max values + bv, kmin, kmax, nb, max_nb_val, max_nbins_val = get_cached_bin_layout_batch(bins, bin_lengths) + t_layout = time.time() - t_layout_start + + t_setup_start = time.time() + N = cp.uint32(len(corr[0])) + is_pow2 = get_cached_pow2(int(N)) + nbins = bin_lengths - 1 + + # Allocate output with maximum possible size (uniform row width) + # Use precomputed max_nbins_val to avoid GPU->CPU sync + outc = cp.zeros((len(points), max_nbins_val), dtype=cp.complex64) + t_setup = time.time() - t_setup_start + + t_kernel_start = time.time() + if is_pow2: + fn, nt = get_pchisq_fn_pow2_batch() + # Flatten corr array if needed + if corr.ndim == 1: + # Single template case - reshape + corr_flat = corr + else: + corr_flat = corr.reshape(-1) + + max_bin_size = bv.shape[1] + args = (corr_flat, outc.reshape(-1), N, points, kmin.reshape(-1), + kmax.reshape(-1), bv.reshape(-1), nbins, mapping, nb, + cp.uint32(max_nbins_val), cp.uint32(len(points)), cp.uint32(len(bins)), cp.uint32(max_bin_size)) + + # Use precomputed max_nb_val to avoid GPU->CPU sync + grid_size = (max_nb_val, len(points)) + fn(grid_size, (nt,), args) + else: + # Non-power-of-2 case using sincosf + fn, nt = get_pchisq_fn_batch() + + # Flatten corr array if needed + if corr.ndim == 1: + corr_flat = corr + else: + corr_flat = corr.reshape(-1) + + # Calculate phase multipliers for each point + phases = cp.float32(2 * cp.pi / float(N)) * points.astype(cp.float32) + + max_bin_size = bv.shape[1] + args = (corr_flat, outc.reshape(-1), N, phases, kmin.reshape(-1), + kmax.reshape(-1), bv.reshape(-1), nbins, mapping, nb, + cp.uint32(max_nbins_val), cp.uint32(len(points)), cp.uint32(len(bins)), cp.uint32(max_bin_size)) + + # Use precomputed max_nb_val to avoid GPU->CPU sync + grid_size = (max_nb_val, len(points)) + fn(grid_size, (nt,), args) + + cp.cuda.Stream.null.synchronize() + t_kernel = time.time() - t_kernel_start + + t_reduce_start = time.time() + result = (outc.conj() * outc).sum(axis=1).real + cp.cuda.Stream.null.synchronize() + t_reduce = time.time() - t_reduce_start + + print(f" shift_sum_batch - layout: {t_layout:.4f}s, setup: {t_setup:.4f}s, kernel: {t_kernel:.4f}s, reduce: {t_reduce:.4f}s, triggers: {len(points)}", flush=True) + + return result diff --git a/pycbc/waveform/bank.py b/pycbc/waveform/bank.py index 2c1d152e08a..930b51f418b 100644 --- a/pycbc/waveform/bank.py +++ b/pycbc/waveform/bank.py @@ -786,7 +786,197 @@ def generate_with_delta_f_and_max_freq(self, t_num, max_freq, delta_f, distance=1./DYN_RANGE_FAC, delta_t=1./(2.*max_freq)) return htilde + def __getitem__(self, index): + """Get template(s) at given index(es) + + Parameters + ---------- + index : int or list + Either a single template index or a list of indices + + Returns + ------- + htilde : FrequencySeries or list of FrequencySeries + The template(s) at the requested index(es) + """ + # Handle single index case as before + if isinstance(index, int): + return self._get_single_template(index) + + # Handle list of indices for batch loading + templates = [] + + # Check if we can use batched template generation + # This is beneficial for SPAtmplt on GPU (CuPy scheme) + use_batched_gen = False + if len(index) > 1: + # Check if all templates use SPAtmplt and are not compressed + approxs = [self.approximant(idx) for idx in index] + all_spatmplt = all(a == 'SPAtmplt' for a in approxs) + no_compression = not (self.has_compressed_waveforms and self.enable_compressed_waveforms) + + # Check if we're using CuPy (GPU scheme) + try: + from pycbc.scheme import CUPYScheme + from pycbc.types.array import _convert + # Try to detect if we're in a CuPy context + test_arr = zeros(1, dtype=self.dtype) + is_cupy = hasattr(test_arr._data, 'device') + use_batched_gen = all_spatmplt and no_compression and is_cupy + except ImportError: + pass + + if use_batched_gen: + # Batched generation for SPAtmplt on GPU + from pycbc.waveform.spa_tmplt import spa_tmplt_batch + + # Use pre-allocated output memory if provided + if self.out is None: + tempout = [zeros(self.filter_length, dtype=self.dtype) for _ in range(len(index))] + else: + tempout = self.out + + # Prepare common parameters + distance = 1.0 / DYN_RANGE_FAC + common_kwds = { + 'delta_f': self.delta_f, + 'f_lower': self.f_lower, + 'distance': distance, + **self.extra_args + } + + # Prepare per-template parameters + templates_params = [] + f_end_list = [] + f_low_list = [] + + for idx in index: + f_end = self.end_frequency(idx) + if f_end is None or f_end >= (self.filter_length * self.delta_f): + f_end = (self.filter_length-1) * self.delta_f + + f_low = find_variable_start_frequency('SPAtmplt', + self.table[idx], + self.f_lower, + self.max_template_length) + + params = { + 'mass1': self.table[idx].mass1, + 'mass2': self.table[idx].mass2, + 'spin1z': self.table[idx].spin1z, + 'spin2z': self.table[idx].spin2z, + 'distance': distance + } + + templates_params.append(params) + f_end_list.append(f_end) + f_low_list.append(f_low) + + # Log single message for the entire batch + logging.info('Generating templates %s-%s (%d templates) in batch from %s Hz' % + (index[0], index[-1], len(index), min(f_low_list))) + + # Don't add f_final to common_kwds - let each template calculate its own fstop + # based on its masses + + # Generate all templates in one batch, writing into tempout memory + htilde_list = spa_tmplt_batch(templates_params, self.filter_length, tempout, **common_kwds) + + # Process each template + for i, (idx, htilde) in enumerate(zip(index, htilde_list)): + template_duration = htilde.chirp_length if hasattr(htilde, 'chirp_length') else None + ttotal = htilde.length_in_time if hasattr(htilde, 'length_in_time') else None + + self.table[idx].template_duration = template_duration + + htilde = htilde.astype(self.dtype) + htilde.f_lower = f_low_list[i] + htilde.min_f_lower = self.min_f_lower + htilde.end_idx = int(f_end_list[i] / htilde.delta_f) + htilde.params = self.table[idx] + htilde.chirp_length = template_duration + htilde.length_in_time = ttotal + htilde.approximant = 'SPAtmplt' + htilde.end_frequency = f_end_list[i] + + # Add sigmasq method + htilde.sigmasq = types.MethodType(sigma_cached, htilde) + htilde._sigmasq = {} + + templates.append(htilde) + else: + # Original sequential generation + if self.out is None: + tempout = [zeros(self.filter_length, dtype=self.dtype)] + else: + tempout = self.out + for i, idx in enumerate(index): + approximant = self.approximant(idx) + f_end = self.end_frequency(idx) + if f_end is None or f_end >= (self.filter_length * self.delta_f): + f_end = (self.filter_length-1) * self.delta_f + + # Find start frequency + f_low = find_variable_start_frequency(approximant, + self.table[idx], + self.f_lower, + self.max_template_length) + logging.info('%s: generating %s from %s Hz' % + (idx, approximant, f_low)) + + # Clear storage memory + poke = tempout[i].data + tempout[i].clear() + + # Get waveform filter + distance = 1.0 / DYN_RANGE_FAC + if self.has_compressed_waveforms and self.enable_compressed_waveforms: + htilde = self.get_decompressed_waveform(tempout[i], idx, + f_lower=f_low, + approximant=approximant, + df=None) + else: + htilde = pycbc.waveform.get_waveform_filter( + tempout[i][0:self.filter_length], + self.table[idx], + approximant=approximant, + f_lower=f_low, + f_final=f_end, + delta_f=self.delta_f, + delta_t=self.delta_t, + distance=distance, + **self.extra_args) + + # Handle duration info + ttotal = template_duration = None + if hasattr(htilde, 'length_in_time'): + ttotal = htilde.length_in_time + if hasattr(htilde, 'chirp_length'): + template_duration = htilde.chirp_length + + self.table[idx].template_duration = template_duration + + htilde = htilde.astype(self.dtype) + htilde.f_lower = f_low + htilde.min_f_lower = self.min_f_lower + htilde.end_idx = int(f_end / htilde.delta_f) + htilde.params = self.table[idx] + htilde.chirp_length = template_duration + htilde.length_in_time = ttotal + htilde.approximant = approximant + htilde.end_frequency = f_end + + # Add sigmasq method + htilde.sigmasq = types.MethodType(sigma_cached, htilde) + htilde._sigmasq = {} + + templates.append(htilde) + + return templates + + + def _get_single_template(self, index): # Make new memory for templates if we aren't given output memory if self.out is None: tempout = zeros(self.filter_length, dtype=self.dtype) diff --git a/pycbc/waveform/pre_merger_waveform.py b/pycbc/waveform/pre_merger_waveform.py new file mode 100644 index 00000000000..540dfd1f413 --- /dev/null +++ b/pycbc/waveform/pre_merger_waveform.py @@ -0,0 +1,311 @@ +from functools import cache +from scipy import signal + +import pycbc.fft +import pycbc.noise +import pycbc.strain +import pycbc.waveform +import pycbc.types + + +@cache +def get_window(window_length): + if window_length: + return pycbc.types.Array( + signal.windows.hann(window_length * 2 + 1)[:window_length] + ) + else: + return None + + +def apply_pre_merger_kernel( + f_series, + whitening_psd, + window, + window_length, + nfz, + nctf, + uid, + copy_output=False, +): + """Helper function to apply the pre-merger kernel. + + Parameters + ---------- + f_series : pycbc.types.FrequencySeries + Frequency series to apply the kernel to. + whitening_psd : pycbc.types.FrequencySeries + PSD for whitening the data in the frequency-domain. + window : numpy.ndarray + Window array. + window_length : int + Pre-computed length of the window in samples. + nefz : int + Number of forward zeroes. + nctf : int + Number of samples to zero at the end of the data. + uid : int + UID for computing the iFFTs. + + Returns + ------- + pycbc.types.TimeSeries + Whitened time series. + """ + # Whiten data + f_series.data[:] = f_series.data[:] * (whitening_psd.data[:]).conj() + + # TD to FD to apply zeroes + tout_ww = pycbc.strain.strain.execute_cached_ifft( + f_series, + copy_output=copy_output, + uid=uid, + ) + # Zero initial data + tout_ww.data[:nfz] = 0 + if window is not None: + # Apply window + tout_ww.data[nfz:nfz+window_length] *= window.data + # Zero data from cutoff + tout_ww.data[-nctf:] = 0 + return tout_ww + + +def generate_data_lisa_pre_merger( + waveform_params, + psds_for_datagen, + sample_rate, + seed=137, + zero_noise=False, + no_signal=False, + duration=None, +): + """Generate pre-merger LISA data. + + UIDs used for FFTs: 4235(0), 4236(0) + + Parameters + ---------- + waveform_params : dict + Dictionary of waveform parameters + psds_for_datagen : dict + PSDs for data generation. + sample_rate : float + Sampling rate in Hz. + seed : int + Random seed used for generating the noise. + zero_noise : bool + If true, the noise will be set to zero. + no_signal : bool + If true, the signal will not be added to data and only noise will + be returned. + duration : float, optional + If specified, the waveform will be truncated to match the specified + duration. + + Returns + ------- + Dict[str: pycbc.types.TimeSeries] + Dictionary containing the time-domain data for each channel. + """ + # Generate injection + outs = pycbc.waveform.get_fd_det_waveform( + ifos=['LISA_A','LISA_E','LISA_T'], + **waveform_params + ) + + # Shift waveform so the merger is not at the end of the data + outs['LISA_A'] = outs['LISA_A'].cyclic_time_shift(-waveform_params['additional_end_data']) + outs['LISA_E'] = outs['LISA_E'].cyclic_time_shift(-waveform_params['additional_end_data']) + + # FS waveform to TD + tout_A = outs['LISA_A'].to_timeseries() + tout_E = outs['LISA_E'].to_timeseries() + + # Generate TD noise from the original PSDs + strain_w_A = pycbc.noise.noise_from_psd( + len(tout_A), + tout_A.delta_t, + psds_for_datagen['LISA_A'], + seed=seed, + ) + strain_w_E = pycbc.noise.noise_from_psd( + len(tout_E), + tout_E.delta_t, + psds_for_datagen['LISA_E'], + seed=seed + 1, + ) + + # We need to make sure the noise times match the signal + strain_w_A._epoch = tout_A._epoch + strain_w_E._epoch = tout_E._epoch + + # If zero noise, set noise to zero + if zero_noise: + strain_w_A *= 0.0 + strain_w_E *= 0.0 + + # Only add signal if no_signal=False + if not no_signal: + strain_w_A[:] += tout_A[:] + strain_w_E[:] += tout_E[:] + + # If duration is specified, discard the extra data + if duration is not None: + if duration > tout_A.duration: + raise RuntimeError( + "Specified duration is longer than the generated waveform" + ) + nkeep = int(duration * sample_rate) + # New start time will be nkeep sample time + new_epoch = strain_w_A.sample_times[-nkeep] + strain_w_A = pycbc.types.TimeSeries( + strain_w_A.data[-nkeep:], + delta_t=strain_w_A.delta_t, + ) + strain_w_E = pycbc.types.TimeSeries( + strain_w_E.data[-nkeep:], + delta_t=strain_w_E.delta_t, + ) + # Set the start time so that the GPS time is still correct + strain_w_A.start_time = new_epoch + strain_w_E.start_time = new_epoch + + return { + "LISA_A": strain_w_A, + "LISA_E": strain_w_E, + } + + +def pre_process_data_lisa_pre_merger( + data, + sample_rate, + psds_for_whitening, + window_length, + cutoff_time, + forward_zeroes=0, +): + """Pre-process the pre-merger data. + + The data is truncated, windowed and whitened. + + data : dict + Dictionary containing time-domain data. + sample_rate : float + Sampling rate in Hz. + psds_for_whitening : dict + PSDs for whitening. + window_length : int + Length of the hann window use to taper the start of the data. + cutoff_time : float + Time (in seconds) from the end of the waveform to cutoff. + forward_zeroes : float + Number of samples to set to zero at the start of the waveform. If used, + the window will be applied starting after the zeroes. + + Returns + ------- + Dict[str: pycbc.types.TimeSeries] + Dictionary containing the time-domain data for each channel. + """ + window = get_window(window_length) + + # Number of samples to zero + nctf = int(cutoff_time * sample_rate) + + # Apply pre-merger kernel to both channels + # Function needs frequency series + strain_ww = {} + strain_ww["LISA_A"] = apply_pre_merger_kernel( + data["LISA_A"].to_frequencyseries(), + whitening_psd=psds_for_whitening["LISA_A"], + window=window, + window_length=window_length, + nfz=forward_zeroes, + nctf=nctf, + uid=4235, + copy_output=True, + ) + strain_ww["LISA_E"] = apply_pre_merger_kernel( + data["LISA_E"].to_frequencyseries(), + whitening_psd=psds_for_whitening["LISA_E"], + window=window, + window_length=window_length, + nfz=forward_zeroes, + nctf=nctf, + uid=4236, + copy_output=True, + ) + return strain_ww + + +def generate_waveform_lisa_pre_merger( + waveform_params, + psds_for_whitening, + sample_rate, + window_length, + cutoff_time, + forward_zeroes=0, +): + """Generate a pre-merger LISA waveform. + + UIDs used for FFTs: 1234(0), 1235(0), 1236(0), 1237(0) + + Parameters + ---------- + waveform_params: dict + A dictionary of waveform parameters that will be passed to the waveform + generator. + psds_for_whitening: dict[str: FrequencySeries] + Power spectral denisities for whitening in the frequency-domain. + sample_rate : float + Sampling rate. + window_length : int + Length (in samples) of time-domain window applied to the start of the + waveform. + cutoff_time: float + Time (in seconds) from the end of the waveform to cutoff. + forward_zeroes : int + Number of samples to set to zero at the start of the waveform. If used, + the window will be applied starting after the zeroes. + """ + window = get_window(window_length) + nctf = int(cutoff_time * sample_rate) + + outs = pycbc.waveform.get_fd_det_waveform( + ifos=['LISA_A','LISA_E'], **waveform_params + ) + + # Apply pre-merger kernel + tout_A_ww = apply_pre_merger_kernel( + outs["LISA_A"], + whitening_psd=psds_for_whitening["LISA_A"], + window=window, + window_length=window_length, + nfz=forward_zeroes, + nctf=nctf, + uid=1235, + ) + tout_E_ww = apply_pre_merger_kernel( + outs["LISA_E"], + whitening_psd=psds_for_whitening["LISA_E"], + window=window, + window_length=window_length, + nfz=forward_zeroes, + nctf=nctf, + uid=12350, + ) + + # Back to FD for search/inference + fouts_ww = {} + fouts_ww["LISA_A"] = pycbc.strain.strain.execute_cached_fft( + tout_A_ww, + copy_output=False, + uid=1236, + ) + fouts_ww["LISA_E"] = pycbc.strain.strain.execute_cached_fft( + tout_E_ww, + copy_output=False, + uid=12360, + ) + return fouts_ww diff --git a/pycbc/waveform/spa_tmplt.py b/pycbc/waveform/spa_tmplt.py index 5c7f71c0593..8d8a42b0aa5 100644 --- a/pycbc/waveform/spa_tmplt.py +++ b/pycbc/waveform/spa_tmplt.py @@ -264,3 +264,171 @@ def spa_tmplt(**kwds): amp_factor, kwds['sample_points'], htilde) return htilde + + +def spa_tmplt_batch(templates_params, filter_length, **common_kwds): + """ + Generate multiple TaylorF2 templates in a single batched operation. + + Parameters + ---------- + templates_params : list of dict + List of parameter dictionaries, one per template. Each should contain + mass1, mass2, spin1z, spin2z, and optionally distance. + filter_length : int + Length of the output frequency series + **common_kwds : dict + Common parameters for all templates (delta_f, f_lower, etc.) + + Returns + ------- + list of FrequencySeries + List of generated templates + """ + import cupy as cp + from pycbc.types import FrequencySeries, zeros + from pycbc.types.array import Array + from .spa_tmplt_cupy import spa_tmplt_engine_batch + from .waveform import get_waveform_filter_length_in_time + import lalsimulation + + num_templates = len(templates_params) + if num_templates == 0: + return [] + + # Extract common parameters + delta_f = common_kwds['delta_f'] + f_lower = common_kwds['f_lower'] + phase_order = int(common_kwds.get('phase_order', -1)) + spin_order = int(common_kwds.get('spin_order', -1)) + + # Pre-allocate arrays for parameters + kmin_list = [] + kmax_list = [] + piM_list = [] + pfaN_list = [] + pfa2_list = [] + pfa3_list = [] + pfa4_list = [] + pfa5_list = [] + pfl5_list = [] + pfa6_list = [] + pfl6_list = [] + pfa7_list = [] + amp_list = [] + phase_order_list = [] + + lal_pars = lal.CreateDict() + if phase_order != -1: + lalsimulation.SimInspiralWaveformParamsInsertPNPhaseOrder(lal_pars, phase_order) + if spin_order != -1: + lalsimulation.SimInspiralWaveformParamsInsertPNSpinOrder(lal_pars, spin_order) + + # Compute PN coefficients for each template + for params in templates_params: + mass1 = params['mass1'] + mass2 = params['mass2'] + s1z = params['spin1z'] + s2z = params['spin2z'] + distance = params.get('distance', common_kwds.get('distance', 1.0)) + + amp_factor = spa_amplitude_factor(mass1=mass1, mass2=mass2) / distance + + # Calculate PN terms + phasing = lalsimulation.SimInspiralTaylorF2AlignedPhasing( + float(mass1), float(mass2), float(s1z), float(s2z), lal_pars) + + pfaN = phasing.v[0] + pfa2 = phasing.v[2] / pfaN + pfa3 = phasing.v[3] / pfaN + pfa4 = phasing.v[4] / pfaN + pfa5 = phasing.v[5] / pfaN + pfa6 = (phasing.v[6] - phasing.vlogv[6] * log(4)) / pfaN + pfa7 = phasing.v[7] / pfaN + pfl5 = phasing.vlogv[5] / pfaN + pfl6 = phasing.vlogv[6] / pfaN + + piM = lal.PI * (mass1 + mass2) * lal.MTSUN_SI + + kmin = int(f_lower / float(delta_f)) + + # Get max frequency + if 'f_final' in common_kwds and common_kwds['f_final'] > 0.: + fstop = common_kwds['f_final'] + else: + vISCO = 1. / sqrt(6.) + fstop = vISCO * vISCO * vISCO / piM + + kmax = int(fstop / delta_f) + # Ensure kmax doesn't exceed filter_length - 1 (final point must be 0) + if kmax >= filter_length: + kmax = filter_length - 1 + + kmin_list.append(kmin) + kmax_list.append(kmax) + piM_list.append(piM) + pfaN_list.append(pfaN) + pfa2_list.append(pfa2) + pfa3_list.append(pfa3) + pfa4_list.append(pfa4) + pfa5_list.append(pfa5) + pfl5_list.append(pfl5) + pfa6_list.append(pfa6) + pfl6_list.append(pfl6) + pfa7_list.append(pfa7) + amp_list.append(amp_factor) + phase_order_list.append(phase_order) + + # Transfer parameters to GPU + kmin_gpu = cp.asarray(kmin_list, dtype=cp.int64) + kmax_gpu = cp.asarray(kmax_list, dtype=cp.int64) + phase_order_gpu = cp.asarray(phase_order_list, dtype=cp.int64) + piM_gpu = cp.asarray(piM_list, dtype=cp.float32) + pfaN_gpu = cp.asarray(pfaN_list, dtype=cp.float32) + pfa2_gpu = cp.asarray(pfa2_list, dtype=cp.float32) + pfa3_gpu = cp.asarray(pfa3_list, dtype=cp.float32) + pfa4_gpu = cp.asarray(pfa4_list, dtype=cp.float32) + pfa5_gpu = cp.asarray(pfa5_list, dtype=cp.float32) + pfl5_gpu = cp.asarray(pfl5_list, dtype=cp.float32) + pfa6_gpu = cp.asarray(pfa6_list, dtype=cp.float32) + pfl6_gpu = cp.asarray(pfl6_list, dtype=cp.float32) + pfa7_gpu = cp.asarray(pfa7_list, dtype=cp.float32) + amp_gpu = cp.asarray(amp_list, dtype=cp.float32) + + # Allocate output array on GPU + htilde_batch_gpu = cp.zeros(num_templates * filter_length, dtype=cp.complex64) + + # Call batched kernel + spa_tmplt_engine_batch( + htilde_batch_gpu, kmin_gpu, kmax_gpu, phase_order_gpu, + delta_f, piM_gpu, pfaN_gpu, + pfa2_gpu, pfa3_gpu, pfa4_gpu, pfa5_gpu, pfl5_gpu, + pfa6_gpu, pfl6_gpu, pfa7_gpu, amp_gpu, + num_templates, filter_length + ) + + # Extract individual templates from batch result + templates = [] + for i, params in enumerate(templates_params): + # Copy the relevant portion from GPU batch to this template + # Each template occupies filter_length elements in the batch array + start_idx = i * filter_length + end_idx = (i+1) * filter_length + # Extract the slice from GPU and assign to the PyCBC array + htilde = FrequencySeries(htilde_batch_gpu[start_idx:end_idx], delta_f=delta_f, copy=False) + + # Set metadata + htilde.chirp_length = get_waveform_filter_length_in_time( + mass1=params['mass1'], mass2=params['mass2'], + spin1z=params['spin1z'], spin2z=params['spin2z'], + f_lower=f_lower, phase_order=phase_order, + approximant='SPAtmplt' + ) + htilde.length_in_time = htilde.chirp_length + + templates.append(htilde) + + htilde_batch_gpu = htilde_batch_gpu.reshape(num_templates, filter_length) + + return htilde_batch_gpu, templates, kmin_gpu, kmax_gpu + diff --git a/pycbc/waveform/spa_tmplt_cupy.py b/pycbc/waveform/spa_tmplt_cupy.py index 63d429eff47..d0dec430a03 100644 --- a/pycbc/waveform/spa_tmplt_cupy.py +++ b/pycbc/waveform/spa_tmplt_cupy.py @@ -98,3 +98,131 @@ def spa_tmplt_engine(htilde, kmin, phase_order, delta_f, piM, pfaN, pfa2, pfa3, pfa4, pfa5, pfl5, pfa6, pfl6, pfa7, amp_factor, htilde.data) + + +# Batched version - processes multiple templates in one kernel call +taylorf2_batch_text = mako.template.Template(""" + // Determine which template and frequency bin we're computing + int template_idx = i / freq_length; + int freq_idx = i % freq_length; + + if (template_idx >= num_templates) return; + + // Check if this frequency index is valid for this template + // freq_idx corresponds to the index within the template's frequency range + const float f = (freq_idx) * delta_f; + + // Set output to zero if outside the template's frequency range + if (freq_idx < kmin[template_idx]) { + htilde.real(0.0f); + htilde.imag(0.0f); + return; + } + + if (freq_idx >= kmax[template_idx]) { + htilde.real(0.0f); + htilde.imag(0.0f); + return; + } + + const float amp2 = amp[template_idx] * __powf(f, -7.0/6.0); + const float v = __powf(piM[template_idx] * f, 1.0/3.0); + const float v2 = v * v; + const float v3 = v2 * v; + const float v4 = v2 * v2; + const float v5 = v2 * v3; + const float v6 = v3 * v3; + const float v7 = v3 * v4; + float phasing = 0.; + + float LAL_TWOPI = ${TWOPI}; + float LAL_PI_4 = ${PI_4}; + float log4 = ${LN4}; + float logv = __logf(v); + + int po = phase_order[template_idx]; + + switch (po) + { + case -1: + case 7: + phasing += pfa7[template_idx] * v7; + case 6: + phasing += (pfa6[template_idx] + pfl6[template_idx] * (logv + log4)) * v6; + case 5: + phasing += (pfa5[template_idx] + pfl5[template_idx] * logv) * v5; + case 4: + phasing += pfa4[template_idx] * v4; + case 3: + phasing += pfa3[template_idx] * v3; + case 2: + phasing += pfa2[template_idx] * v2; + case 0: + phasing += 1.; + break; + default: + break; + } + phasing *= pfaN[template_idx] / v5; + phasing -= LAL_PI_4; + phasing -= int(phasing / (LAL_TWOPI)) * LAL_TWOPI; + + float pcos; + float psin; + __sincosf(phasing, &psin, &pcos); + + htilde.real(pcos * amp2); + htilde.imag(-psin * amp2); +""").render(TWOPI=lal.TWOPI, PI_4=lal.PI_4, LN4=2*lal.LN2) + + +taylorf2_batch_kernel = cp.ElementwiseKernel( + """ + raw int64 kmin, raw int64 kmax, raw int64 phase_order, float32 delta_f, + raw float32 piM, raw float32 pfaN, raw float32 pfa2, raw float32 pfa3, + raw float32 pfa4, raw float32 pfa5, raw float32 pfl5, raw float32 pfa6, + raw float32 pfl6, raw float32 pfa7, raw float32 amp, + int32 num_templates, int32 freq_length + """, + "complex64 htilde", + taylorf2_batch_text, + "taylorf2_batch_kernel", +) + + +def spa_tmplt_engine_batch(htilde_batch, kmin_arr, kmax_arr, phase_order_arr, + delta_f, piM_arr, pfaN_arr, + pfa2_arr, pfa3_arr, pfa4_arr, pfa5_arr, pfl5_arr, + pfa6_arr, pfl6_arr, pfa7_arr, amp_arr, + num_templates, freq_length): + """ + Calculate spa tmplt phase for multiple templates in a single kernel call. + + Parameters + ---------- + htilde_batch : cupy array + Output array of shape (num_templates * freq_length,) + kmin_arr : cupy array of int64 + Starting frequency index for each template + kmax_arr : cupy array of int64 + Ending frequency index for each template + phase_order_arr : cupy array of int64 + Phase order for each template + delta_f : float + Frequency spacing + piM_arr, pfaN_arr, pfa*_arr, pfl*_arr : cupy arrays of float32 + PN coefficients for each template + amp_arr : cupy array of float32 + Amplitude factor for each template + num_templates : int + Number of templates to generate + freq_length : int + Maximum frequency span across all templates + """ + taylorf2_batch_kernel(kmin_arr, kmax_arr, phase_order_arr, + delta_f, + piM_arr, pfaN_arr, + pfa2_arr, pfa3_arr, pfa4_arr, pfa5_arr, pfl5_arr, + pfa6_arr, pfl6_arr, pfa7_arr, amp_arr, + num_templates, freq_length, + htilde_batch) diff --git a/requirements.txt b/requirements.txt index 709649b789b..2d97dfd86ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,7 +30,7 @@ urllib3 # need to pin until pegasus for further upstream # addresses incompatibility between old flask/jinja2 and latest markupsafe -markupsafe <= 2.0.1 +# markupsafe <= 2.0.1 # Requirements for ligoxml access needed by some workflows python-ligo-lw >= 1.8.1