Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 103 additions & 35 deletions bin/bank/pycbc_brute_bank
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"""
import numpy
import logging
import signal
import os
import argparse
import numpy.random
from scipy.stats import gaussian_kde
Expand Down Expand Up @@ -53,6 +55,8 @@ parser.add_argument('--approximant', required=False,
parser.add_argument('--minimal-match', default=0.97, type=float)
parser.add_argument('--buffer-length', default=2, type=float,
help='size of waveform buffer in seconds')
parser.add_argument('--use-td-waveform', action='store_true',
help='Generate waveform in the time domain (default is frequency domain).')
parser.add_argument('--full-resolution-buffer-length', default=None, type=float,
help='Size of the waveform buffer in seconds for generating time-domain signals at full resolution before conversion to the frequency domain.')
parser.add_argument('--max-signal-length', type= float,
Expand Down Expand Up @@ -81,6 +85,8 @@ parser.add_argument('--tau0-end', type=float)
parser.add_argument('--tau0-cutoff-frequency', type=float, default=15.0)
parser.add_argument('--nprocesses', type=int, default=1,
help='Number of processes to use for waveform generation parallelization. If not given then only a single core will be used.')
parser.add_argument('--parallel-check', action='store_true', help="Do bank checking parallel, note that this means that proposals WILL NOT be checked against each other.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this line be wrapped to match maximum line length requirements?

parser.add_argument('--max-connections', type=int, help="Maximum number of matches to store with each template", default=numpy.inf)
pycbc.psd.insert_psd_option_group(parser)
parser.add_argument('--use-trimmed-buffer', action='store_true',
help=('When specified, the match calculation will use only the first and last 100 samples '
Expand Down Expand Up @@ -134,6 +140,7 @@ class TriangleBank(object):
def __init__(self, p=None):
self.waveforms = p if p is not None else []
self.tbins = {}
self.max_matches = []

def __len__(self):
return len(self.waveforms)
Expand Down Expand Up @@ -229,23 +236,33 @@ class TriangleBank(object):
while 1:
j = inc.pop()
if j is None:
hp.matches = matches[r]
hp.indices = r
msort = matches[r].argsort()

msorted = matches[r][msort]
rsorted = r[msort]
keep = numpy.ones(len(msorted), dtype=bool)
if args.max_connections < len(keep):
#keep[args.max_connections//2: -args.max_connections//2] = False
keep[args.max_connections:] = False

hp.matches = msorted[keep].copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this array copy not now increase the memory footprint, particularly in the max_connections infinite / default case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it's the opposite, otherwise loose reference could keep something in memory when we don't really need it anymore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A further clarification may help here. If you just return a view, even if that's all you have, the full array is always kept. I explicitly want the reduced form to be what is stored and not the full array, hence explicitly asking for a new array that contains a copy of what the view pointed to. This means that the original values can be cleaned up by python's garbage collector.

hp.indices = rsorted[keep].copy()

logging.info("TADD MaxMatch:%0.3f Size:%i "
"AfterSigma:%i AfterTau0:%i Matches:%i"
% (mmax, len(self), msig, mtau, mnum))
hp.max_match = mmax
return False

hc = self[j]

# Defensive initialization if matches/indices are missing
if not hasattr(hc, 'matches'):
hc.matches = numpy.empty(len(self))
hc.matches = numpy.empty(2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure I understand this change

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In truth, I'd like to remove these lines. I can't really figure out why they are needed. The only guess I have is that perhaps there is some boundary issue that this is masking. However, to the point, there is no reason to add these with the full size of the existing bank. Since the information added is meaningless, it ends up being a waste of memory. Putting it at 2 was simply a lazy way to force the correct dimensions.

hc.matches[:] = numpy.nan

if not hasattr(hc, 'indices'):
hc.indices = numpy.arange(len(self))

hc.indices = numpy.arange(2)

m = hp.gen.match(hp, hc)
matches[j] = m
Expand All @@ -265,34 +282,39 @@ class TriangleBank(object):
if m > mmax:
mmax = m

def check_params(self, gen, params, threshold, force_add=False):
def check_params(self, gen, params, threshold,
force_add=False,
parallel_check=False):
num_added = 0
total_num = len(tuple(params.values())[0])
waveform_cache = []

pool = pycbc.pool.choose_pool(args.nprocesses)
for return_wf in pool.imap_unordered(
wf_wrapper,
({k: params[k][idx] for k in params} for idx in range(total_num))
(({k: params[k][idx] for k in params}, parallel_check, threshold) for idx in range(total_num))
):
waveform_cache += [return_wf]

pool.close_pool()
del pool

for hp in waveform_cache:
if hp is not None:
hp.gen = gen
hp.threshold = threshold
if hp not in self:
num_added += 1
self.insert(hp)
elif force_add:
if hp.checked is None:
hp.gen = gen
hp.checked = hp not in self

if hp.checked:
self.max_matches.append(hp.max_match)

if hp.checked or force_add:
num_added += 1
self.insert(hp)
self.insert(hp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whitespace

else:
logging.info("Waveform generation failed!")
continue

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whitespace

return bank, num_added / total_num

def decimate_frequency_domain(template, target_df):
Expand Down Expand Up @@ -326,6 +348,24 @@ def decimate_frequency_domain(template, target_df):
decimated_template = pycbc.types.FrequencySeries(decimated_signal, delta_f=target_df)
return decimated_template

def handle_exit_signal(signum, frame):
logging.warning(f"Signal {signum} received. Triggering emergency save...")
save_bank()
# Force exit so the script doesn't try to resume the loops
os._exit(0)

def save_bank():
logging.info("Saving current bank to %s", args.output_file)
with HFile(args.output_file, 'w') as o:
o.attrs['minimal_match'] = args.minimal_match
for k in bank.keys():
val = bank.key(k)
if val.dtype.char == 'U':
val = val.astype('bytes')
o[k] = val
o['max_matches'] = numpy.array(bank.max_matches)
logging.info("Save complete.")

class GenUniformWaveform(object):
def __init__(self, buffer_length, sample_rate, f_lower):
self.f_lower = f_lower
Expand Down Expand Up @@ -363,17 +403,25 @@ class GenUniformWaveform(object):
if hasattr(kwds['approximant'], 'decode'):
kwds['approximant'] = kwds['approximant'].decode()

if args.full_resolution_buffer_length is not None:
buff_len = args.full_resolution_buffer_length
else:
buff_len = 1.0 / self.delta_f

if kwds['approximant'] in pycbc.waveform.fd_approximants():
if args.full_resolution_buffer_length is not None:
# Generate the frequency-domain waveform at full frequency resolution
high_hp, high_hc = pycbc.waveform.get_fd_waveform(delta_f=1 / args.full_resolution_buffer_length,

# Optionally generate time-domain waveform
if args.use_td_waveform:
hp, hc = pycbc.waveform.get_td_waveform(delta_t=1.0 / args.sample_rate,
**kwds)
# Decimate the generated signal to a reduced frequency resolution
hp = decimate_frequency_domain(high_hp, 1 / args.buffer_length)
hc = decimate_frequency_domain(high_hc, 1 / args.buffer_length)

hp = hp.to_frequencyseries(delta_f = 1.0 / buff_len)
hc = hc.to_frequencyseries(delta_f = 1.0 / buff_len)

else:
hp, hc = pycbc.waveform.get_fd_waveform(delta_f=self.delta_f,
hp, hc = pycbc.waveform.get_fd_waveform(delta_f = 1.0 / buff_len,
**kwds)

if args.use_cross:
hp = hc

Expand All @@ -387,6 +435,10 @@ class GenUniformWaveform(object):
delta_f=self.delta_f, delta_t=dt,
**kwds)

if args.full_resolution_buffer_length is not None:
# Decimate the generated signal to a reduced frequency resolution
hp = decimate_frequency_domain(hp, 1 / args.buffer_length)

hp.resize(self.flen)
hp = hp.astype(numpy.complex64)
hp[self.kmin:-1] *= self.w
Expand Down Expand Up @@ -416,9 +468,24 @@ gen = GenUniformWaveform(args.buffer_length,
args.sample_rate, args.low_frequency_cutoff)
bank = TriangleBank()

def wf_wrapper(p):
def silent_child_exit(signum, frame):
# The child workers execute this branch
# os._exit(0) ensures they die instantly and quietly
os._exit(0)

def wf_wrapper(args):
p, parallel_check, threshold = args

signal.signal(signal.SIGINT, silent_child_exit)
signal.signal(signal.SIGTERM, silent_child_exit)
try:
hp = gen.generate(**p)
hp.checked = None
hp.threshold = threshold
if parallel_check:
hp.gen = gen
hp.checked = hp not in bank
hp.gen = None
return hp
except Exception as e:
print(e)
Expand All @@ -431,7 +498,6 @@ if args.input_file:
force_add=args.keep_entire_input_file)

def draw(rtype):

if rtype == 'uniform':
if args.input_config is None:
params = {name: numpy.random.uniform(pmin, pmax, size=size)
Expand Down Expand Up @@ -534,6 +600,9 @@ def cdraw(rtype, ts, te):

return p

signal.signal(signal.SIGINT, handle_exit_signal)
signal.signal(signal.SIGTERM, handle_exit_signal)

tau0s = args.tau0_start
tau0e = tau0s + args.tau0_crawl

Expand All @@ -553,7 +622,8 @@ while tau0s < args.tau0_end:
break

blen = len(bank)
bank, uconv = bank.check_params(gen, params, args.minimal_match)
bank, uconv = bank.check_params(gen, params, args.minimal_match,
parallel_check=args.parallel_check)
logging.info("%s: Round (U): %s Size: %s conv: %s added: %s",
region, r, len(bank), uconv, len(bank) - blen)
if r > 10:
Expand All @@ -565,9 +635,13 @@ while tau0s < args.tau0_end:
kloop += 1
params = cdraw('kde', tau0s, tau0e)
blen = len(bank)
bank, kconv = bank.check_params(gen, params, args.minimal_match)
logging.info("%s: Round (K) (%s): %s Size: %s conv: %s added: %s",
region, kloop, r, len(bank), kconv, len(bank) - blen)
bank, kconv = bank.check_params(gen, params, args.minimal_match,
parallel_check=args.parallel_check)

trail_matches = numpy.array(bank.max_matches[int(len(bank.max_matches)*0.9):])
ave = numpy.mean(trail_matches)
logging.info("%s: Round (K) (%s): %s Size: %s conv: %0.4f added: %s Trail Ave: %0.4f Trail Min: %0.4f",
region, kloop, r, len(bank), kconv, len(bank) - blen, ave, trail_matches.min())


if uconv:
Expand All @@ -587,10 +661,4 @@ while tau0s < args.tau0_end:
tau0s += args.tau0_crawl / 2
tau0e += args.tau0_crawl / 2

o = HFile(args.output_file, 'w')
o.attrs['minimal_match'] = args.minimal_match
for k in bank.keys():
val = bank.key(k)
if val.dtype.char == 'U':
val = val.astype('bytes')
o[k] = val
save_bank()
Loading