Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
270956e
Temp playing
spxiwh Jul 7, 2023
bfc9ac8
Some more hacks to do LISA EW PE
spxiwh Jul 26, 2023
ce796ab
Missing new files!!
spxiwh Jul 26, 2023
9b45879
Some changes
spxiwh Aug 16, 2023
0058ee1
Fixing early warning PE
spxiwh Oct 6, 2023
0610738
One more fix
spxiwh Oct 6, 2023
42e365e
Some pokes in brute_marg
spxiwh Oct 6, 2023
ed4c514
phase marg and modes for LISA EW likelihood
mj-will Dec 13, 2023
32fb2f7
add option to specify PSD path
mj-will Jan 16, 2024
a09661c
update for BBHx plugin and to enable the use of `extra_forward_zeroes`
mj-will Feb 28, 2024
7340fe4
clean up early warning code
mj-will Apr 11, 2024
a19fa3d
fix psd duration
mj-will Apr 11, 2024
a68d702
add PSD lower frequency cutoff
mj-will Apr 15, 2024
4dbd198
manually cut psd if low freq cutoff is specified
mj-will May 9, 2024
0ec6cc3
remove mode array and phenomd config
mj-will May 9, 2024
d8e492c
handle case where psd low freq cutoff is not specified
mj-will May 14, 2024
4f94ccb
remove incorrect PSD lower frequency cutoff
mj-will May 15, 2024
4d6a751
consistency updates for LISA pre-merger work
mj-will Jun 6, 2024
e6f23ee
Fix time shift in inference (#3)
mj-will Jun 7, 2024
c763c57
changes to kernel length
mj-will Jun 21, 2024
053dcb0
avoid extra FFTs
mj-will Jun 21, 2024
daece7f
update comments
mj-will Jun 21, 2024
cd1e2b4
add cutoff_deltat parameter
mj-will Jun 25, 2024
f268a36
Update doc-string
mj-will Jun 25, 2024
eaaa0c2
make cutoff_deltat optional
mj-will Jul 3, 2024
1414a23
fix calcluation of cutoff_deltat (#6)
mj-will Nov 8, 2024
b6d568e
Some tweaks to premerger likelihood to allow cupy
spxiwh Nov 19, 2024
0f44fbc
Force these to floats/complex
spxiwh Nov 20, 2024
30dc864
batching partially working
xangma Dec 9, 2024
81ed763
corr and ifft results populated and matching
xangma Dec 9, 2024
c10adc1
batched threshold cupy is wrong, but it does run ...
xangma Dec 9, 2024
045a65d
CHanges to make threshold work and CPU work
spxiwh Dec 16, 2024
0b6b8c8
Merge pull request #1 from spxiwh/gpu_dev
xangma Dec 16, 2024
3464a92
fixes so results now match cpu results (further verification needed)
xangma Dec 16, 2024
f91cb5a
Filter the -1s out with cupy kernel
xangma Dec 17, 2024
2494d6f
type changes in chisq
xangma Dec 17, 2024
db15eb6
Avoid memory copies in threshold
spxiwh Dec 17, 2024
bcb2703
Merge pull request #2 from spxiwh/gpu_dev
xangma Dec 18, 2024
dfc144d
Latest modifications with copilot help
spxiwh Nov 20, 2025
6b4a2e0
Here we break pycbc_inspiral to do a focused GPU implementation
spxiwh Dec 9, 2025
83e6508
Continuing to work on inspiral_utils.py
spxiwh Dec 10, 2025
293dbbd
Fixing the waveform generation to avoid moving stuff around
spxiwh Dec 10, 2025
f5fe3b5
Checkpoint at faster output + sigmasq ... Chisq next!
spxiwh Dec 11, 2025
8af5398
Some profiling code
spxiwh Jan 9, 2026
8991665
Add test script
spxiwh Jan 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 174 additions & 105 deletions bin/pycbc_inspiral
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python
#!/home/xangma/miniconda3/envs/pycbcgpu/bin/python3.11

# Copyright (C) 2014 Alex Nitz
#
Expand All @@ -16,14 +16,15 @@
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.

import time
t_very_start = time.time()

import sys
import os
import copy
import logging
import argparse
import numpy
import itertools
import time
from pycbc.pool import BroadcastPool as Pool

import pycbc
Expand All @@ -33,8 +34,16 @@ from pycbc.filter import MatchedFilterControl, make_frequency_series, qtransform
from pycbc.types import TimeSeries, FrequencySeries, zeros, float32, complex64
import pycbc.opt
import pycbc.inject

from pycbc.scheme import CUPYScheme
last_progress_update = -1.0
import numpy

try:
import cupy as cp
except ImportError:
cp = None

print(f"PROFILE: Import time: {time.time() - t_very_start:.2f} s", flush=True)

def update_progress(p,u,n):
""" updates a file 'progress.txt' with a value 0 .. 1.0 when enough (filtering) progress was made
Expand Down Expand Up @@ -203,6 +212,8 @@ parser.add_argument("--multiprocessing-nprocesses", type=int,
"Used in conjunction with the option"
"--finalize-events-template-rate which should be set"
"to a multiple of the number of processes.")
parser.add_argument("--gpu-batch-size", type=int, default=1,
help="Number of templates to process in parallel on GPU")

# Add options groups
psd.insert_psd_option_group(parser)
Expand Down Expand Up @@ -234,73 +245,39 @@ gwstrain = strain.from_cli(opt, dyn_range_fac=DYN_RANGE_FAC,

strain_segments = strain.StrainSegments.from_cli(opt, gwstrain)

def template_triggers(t_num):
""" Get the triggers for a specific template
def template_triggers(t_nums):
""" Get the triggers for a template or batch of templates

This is a wrapper that routes to GPU-optimized or CPU code paths.

Parameters
----------
t_nums : int or list
Either a single template index or a list of template indices
"""
template = None
tparam = None
out_vals_all = []
for s_num, stilde in enumerate(segments):
# Filter check checks the 'inj_filter_rejector' options to
# determine whether
# to filter this template/segment if injections are present.
if not inj_filter_rejector.template_segment_checker(
bank, t_num, stilde):
continue
if template is None:
template = bank[t_num]
tparam = template.params

if opt.update_progress:
update_progress((t_num + (s_num / float(len(segments))) ) / len(bank),
opt.update_progress, opt.update_progress_file)
logging.info("Filtering template %d/%d segment %d/%d" %
(t_num + 1, len(bank), s_num + 1, len(segments)))

sigmasq = template.sigmasq(stilde.psd)
snr, norm, corr, idx, snrv = \
matched_filter.matched_filter_and_cluster(s_num,
sigmasq,
cluster_window,
epoch=stilde._epoch)
if not len(idx):
continue

out_vals = out_vals_ref.copy()
out_vals['bank_chisq'], out_vals['bank_chisq_dof'] = \
bank_chisq.values(template, stilde.psd, stilde, snrv, norm,
idx+stilde.analyze.start)

out_vals['chisq'], out_vals['chisq_dof'] = \
power_chisq.values(corr, snrv, norm, stilde.psd,
idx+stilde.analyze.start, template)

out_vals['sg_chisq'] = sg_chisq.values(stilde, template, stilde.psd,
snrv, norm,
out_vals['chisq'],
out_vals['chisq_dof'],
idx+stilde.analyze.start)

out_vals['cont_chisq'], _ = \
autochisq.values(snr, idx+stilde.analyze.start, template,
stilde.psd, norm, stilde=stilde,
low_frequency_cutoff=flow)

idx += stilde.cumulative_index

out_vals['time_index'] = idx
out_vals['snr'] = snrv * norm
out_vals['sigmasq'] = numpy.zeros(len(snrv), dtype=float32) + sigmasq
if opt.psdvar_short_segment is not None:
out_vals['psd_var_val'] = \
pycbc.psd.find_trigger_value(psd_var,
out_vals['time_index'],
opt.gps_start_time, opt.sample_rate)
#print(idx, out_vals['time_index'])

out_vals_all.append(copy.deepcopy(out_vals))
#print(out_vals_all)
return out_vals_all, tparam
# Handle single template case for backward compatibility
if isinstance(t_nums, (int, numpy.integer)):
t_nums = [t_nums]

# Check if we're using CUPY scheme for batch processing
using_cupy = isinstance(ctx, pycbc.scheme.CUPYScheme)

if using_cupy and len(t_nums) > 1:
# Use GPU-optimized batched implementation
from pycbc.filter.inspiral_utils import template_triggers_gpu_batched

return template_triggers_gpu_batched(
t_nums, bank, segments, matched_filter,
power_chisq, sg_chisq, inj_filter_rejector,
cluster_window, out_vals_ref, opt,
psd_var=psd_var if opt.psdvar_short_segment is not None else None
)
else:
# Fall back to single-template CPU processing
raise NotImplementedError(
"CPU/single-template processing not yet re-implemented. "
"This will be restored later."
)

with ctx:
if opt.fft_backends == 'fftw':
Expand Down Expand Up @@ -394,7 +371,12 @@ with ctx:
opt, names, [out_types[n] for n in names], psd=segments[0].psd,
gating_info=gwstrain.gating_info, q_trans=q_trans)

template_mem = zeros(tlen, dtype = complex64)
# Process in batches for GPU
batch_size = opt.gpu_batch_size
if opt.processing_scheme == 'cupy':
template_mem = [zeros(tlen, dtype=complex64) for _ in range(batch_size)]
else:
template_mem = zeros(tlen, dtype=complex64)
cluster_window = int(opt.cluster_window * gwstrain.sample_rate)

if opt.cluster_window == 0.0:
Expand All @@ -409,15 +391,15 @@ with ctx:
if opt.multiprocessing_nprocesses:
ncores *= opt.multiprocessing_nprocesses


matched_filter = MatchedFilterControl(opt.low_frequency_cutoff, None,
opt.snr_threshold, tlen, delta_f, complex64,
segments, template_mem, use_cluster,
downsample_factor=opt.downsample_factor,
upsample_threshold=opt.upsample_threshold,
upsample_method=opt.upsample_method,
gpu_callback_method=opt.gpu_callback_method,
cluster_function=opt.cluster_function)
opt.snr_threshold, tlen, delta_f, complex64,
segments, template_mem, use_cluster,
downsample_factor=opt.downsample_factor,
upsample_threshold=opt.upsample_threshold,
upsample_method=opt.upsample_method,
gpu_callback_method=opt.gpu_callback_method,
cluster_function=opt.cluster_function,
batch_size=opt.gpu_batch_size)

bank_chisq = vetoes.SingleDetBankVeto(opt.bank_veto_bank_file,
flen, delta_f, flow, complex64,
Expand Down Expand Up @@ -455,54 +437,141 @@ with ctx:
logging.info("Template bank size after thinning: %s", len(bank))

tsetup = time.time() - tstart
logging.info("PROFILE: Setup time (data loading, PSD, etc): %.2f sec", tsetup)
tcheckpoint = time.time()

tanalyze = list(range(tnum_start, len(bank)))
n = opt.finalize_events_template_rate
n = 1 if n is None else n
tchunks = [tanalyze[i:i + n] for i in range(0, len(tanalyze), n)]

mmap = map
if opt.multiprocessing_nprocesses:
mmap = Pool(opt.multiprocessing_nprocesses).map

for tchunk in tchunks:
data = list(mmap(template_triggers, tchunk))

for elem in data:
out_vals_all, tparam = elem
if len(out_vals_all) > 0:
event_mgr.new_template(tmplt=tparam)
# Start timing the main loop
tloop_start = time.time()
templates_processed = 0

for edata in out_vals_all:
event_mgr.add_template_events(names, [edata[n] for n in names])

event_mgr.cluster_template_events("time_index", "snr", cluster_window)
event_mgr.finalize_template_events()
if isinstance(ctx, CUPYScheme):
# Create batches of templates, each of size batch_size
tchunks = [tanalyze[i:i + batch_size] for i in range(0, len(tanalyze), batch_size)]
else:
# Original chunking for non-CUPY schemes
tchunks = [tanalyze[i:i + n] for i in range(0, len(tanalyze), n)]

for tchunk in tchunks:
chunk_start = time.time()
if isinstance(ctx, CUPYScheme):
data = [template_triggers(tchunk)]
else:
data = list(mmap(template_triggers, tchunk))

data_time = time.time() - chunk_start
print(f"PROFILE: template_triggers call time: {data_time:.4f} s")

# Process events with new output format
# FIXME: Doesn't work yet!
#event_proc_start = time.time()
#for elem in data:
# out_vals, tparams = elem

# # Check if we have any triggers
# if len(out_vals['template_id']) == 0:
# continue

# # Get unique template IDs and process each template
# # Convert to numpy for easier manipulation
# template_ids_array = numpy.array(out_vals['template_id'])
# unique_template_ids = numpy.unique(template_ids_array)

# for template_id in unique_template_ids:
# template_id_int = int(template_id)

# # Find the index in tparams (which corresponds to position in tchunk)
# # template_id is the global template index, find its position in the batch
# batch_idx = tchunk.index(template_id_int)
# tparam = tparams[batch_idx]

# # Mask for this template's triggers
# mask = template_ids_array == template_id

# # Extract events for this template
# template_out_vals = {}
# for key in out_vals.keys():
# if key == 'template_id':
# continue # Skip template_id, it's not in names
# # Convert to numpy array, apply mask, convert back to PyCBC array
# data_array = numpy.array(out_vals[key])
# template_out_vals[key] = out_vals[key].__class__(data_array[mask], copy=False)

# # Fill in missing keys with None
# for key in names:
# if key not in template_out_vals:
# template_out_vals[key] = None

# # Add to event manager
# event_mgr.new_template(tmplt=tparam)
# event_mgr.add_template_events(names, [template_out_vals[n] for n in names])
# event_mgr.cluster_template_events("time_index", "snr", cluster_window)
# event_mgr.finalize_template_events()

#event_proc_time = time.time() - event_proc_start
#print(f"PROFILE: Event processing time: {event_proc_time:.4f} s")

if opt.finalize_events_template_rate is not None:
event_mgr.consolidate_events(opt, gwstrain=gwstrain)

if opt.checkpoint_interval and \
(time.time() - tcheckpoint > opt.checkpoint_interval):
event_mgr.save_state(max(tchunk), opt.output + '.checkpoint')
tcheckpoint = time.time()

if opt.checkpoint_exit_maxtime and \
(time.time() - tstart > opt.checkpoint_exit_maxtime):
event_mgr.save_state(max(tchunk), opt.output + '.checkpoint')
sys.exit(opt.checkpoint_exit_code)

event_mgr.consolidate_events(opt, gwstrain=gwstrain)

chunk_time = time.time() - chunk_start
print(f"PROFILE: Total chunk time: {chunk_time:.4f} s")
avg_template_time = chunk_time / len(tchunk) if len(tchunk) > 0 else 0
logging.info("Processed %d templates in %.2f sec (%.2f sec/template)",
len(tchunk), chunk_time, avg_template_time)

if opt.finalize_events_template_rate is not None:
event_mgr.consolidate_events(opt, gwstrain=gwstrain)

if opt.checkpoint_interval and \
(time.time() - tcheckpoint > opt.checkpoint_interval):
event_mgr.save_state(max(tchunk), opt.output + '.checkpoint')
tcheckpoint = time.time()

if opt.checkpoint_exit_maxtime and \
(time.time() - tstart > opt.checkpoint_exit_maxtime):
event_mgr.save_state(max(tchunk), opt.output + '.checkpoint')
sys.exit(opt.checkpoint_exit_code)

# End timing of main loop
tloop_end = time.time()
tloop_total = tloop_end - tloop_start
avg_template_time = tloop_total / templates_processed if templates_processed > 0 else 0
logging.info("PROFILE: Main loop completed in %.2f sec", tloop_total)
logging.info("PROFILE: Processed %d templates", templates_processed)
logging.info("PROFILE: Average time per template: %.4f sec", avg_template_time)

finalize_start = time.time()
event_mgr.consolidate_events(opt, gwstrain=gwstrain)
event_mgr.finalize_events()
finalize_time = time.time() - finalize_start
logging.info("PROFILE: Event finalization time: %.2f sec", finalize_time)

# Print detailed GPU profiling if using batched GPU mode
if hasattr(opt, 'processing_scheme') and opt.processing_scheme == 'cupy':
try:
from pycbc.filter.inspiral_utils import print_profile
print_profile()
except (ImportError, AttributeError):
pass

logging.info("Outputting %s triggers" % str(len(event_mgr.events)))

tstop = time.time()
run_time = tstop - tstart
event_mgr.save_performance(ncores, len(segments), len(bank), run_time, tsetup)

write_start = time.time()
logging.info("Writing out triggers")
event_mgr.write_events(opt.output)
write_time = time.time() - write_start
logging.info("PROFILE: Write events time: %.2f sec", write_time)

if opt.fft_backends == 'fftw':
if opt.fftw_output_float_wisdom_file:
Expand Down
Binary file added examples/inspiral/BANK_SPLIT0.hdf
Binary file not shown.
Loading
Loading