diff --git a/bin/bank/pycbc_brute_bank b/bin/bank/pycbc_brute_bank index 779bef4819d..0e274c1c994 100755 --- a/bin/bank/pycbc_brute_bank +++ b/bin/bank/pycbc_brute_bank @@ -21,6 +21,8 @@ """ import numpy import logging +import signal +import os import argparse import numpy.random from scipy.stats import gaussian_kde @@ -53,6 +55,8 @@ parser.add_argument('--approximant', required=False, parser.add_argument('--minimal-match', default=0.97, type=float) parser.add_argument('--buffer-length', default=2, type=float, help='size of waveform buffer in seconds') +parser.add_argument('--use-td-waveform', action='store_true', + help='Generate waveform in the time domain (default is frequency domain).') parser.add_argument('--full-resolution-buffer-length', default=None, type=float, help='Size of the waveform buffer in seconds for generating time-domain signals at full resolution before conversion to the frequency domain.') parser.add_argument('--max-signal-length', type= float, @@ -81,6 +85,8 @@ parser.add_argument('--tau0-end', type=float) parser.add_argument('--tau0-cutoff-frequency', type=float, default=15.0) parser.add_argument('--nprocesses', type=int, default=1, help='Number of processes to use for waveform generation parallelization. If not given then only a single core will be used.') +parser.add_argument('--parallel-check', action='store_true', help="Do bank checking parallel, note that this means that proposals WILL NOT be checked against each other.") +parser.add_argument('--max-connections', type=int, help="Maximum number of matches to store with each template", default=numpy.inf) pycbc.psd.insert_psd_option_group(parser) parser.add_argument('--use-trimmed-buffer', action='store_true', help=('When specified, the match calculation will use only the first and last 100 samples ' @@ -134,6 +140,7 @@ class TriangleBank(object): def __init__(self, p=None): self.waveforms = p if p is not None else [] self.tbins = {} + self.max_matches = [] def __len__(self): return len(self.waveforms) @@ -229,23 +236,33 @@ class TriangleBank(object): while 1: j = inc.pop() if j is None: - hp.matches = matches[r] - hp.indices = r + msort = matches[r].argsort() + + msorted = matches[r][msort] + rsorted = r[msort] + keep = numpy.ones(len(msorted), dtype=bool) + if args.max_connections < len(keep): + #keep[args.max_connections//2: -args.max_connections//2] = False + keep[args.max_connections:] = False + + hp.matches = msorted[keep].copy() + hp.indices = rsorted[keep].copy() + logging.info("TADD MaxMatch:%0.3f Size:%i " "AfterSigma:%i AfterTau0:%i Matches:%i" % (mmax, len(self), msig, mtau, mnum)) + hp.max_match = mmax return False hc = self[j] # Defensive initialization if matches/indices are missing if not hasattr(hc, 'matches'): - hc.matches = numpy.empty(len(self)) + hc.matches = numpy.empty(2) hc.matches[:] = numpy.nan if not hasattr(hc, 'indices'): - hc.indices = numpy.arange(len(self)) - + hc.indices = numpy.arange(2) m = hp.gen.match(hp, hc) matches[j] = m @@ -265,7 +282,9 @@ class TriangleBank(object): if m > mmax: mmax = m - def check_params(self, gen, params, threshold, force_add=False): + def check_params(self, gen, params, threshold, + force_add=False, + parallel_check=False): num_added = 0 total_num = len(tuple(params.values())[0]) waveform_cache = [] @@ -273,26 +292,29 @@ class TriangleBank(object): pool = pycbc.pool.choose_pool(args.nprocesses) for return_wf in pool.imap_unordered( wf_wrapper, - ({k: params[k][idx] for k in params} for idx in range(total_num)) + (({k: params[k][idx] for k in params}, parallel_check, threshold) for idx in range(total_num)) ): waveform_cache += [return_wf] + pool.close_pool() del pool for hp in waveform_cache: if hp is not None: - hp.gen = gen - hp.threshold = threshold - if hp not in self: - num_added += 1 - self.insert(hp) - elif force_add: + if hp.checked is None: + hp.gen = gen + hp.checked = hp not in self + + if hp.checked: + self.max_matches.append(hp.max_match) + + if hp.checked or force_add: num_added += 1 - self.insert(hp) + self.insert(hp) else: logging.info("Waveform generation failed!") continue - + return bank, num_added / total_num def decimate_frequency_domain(template, target_df): @@ -326,6 +348,24 @@ def decimate_frequency_domain(template, target_df): decimated_template = pycbc.types.FrequencySeries(decimated_signal, delta_f=target_df) return decimated_template +def handle_exit_signal(signum, frame): + logging.warning(f"Signal {signum} received. Triggering emergency save...") + save_bank() + # Force exit so the script doesn't try to resume the loops + os._exit(0) + +def save_bank(): + logging.info("Saving current bank to %s", args.output_file) + with HFile(args.output_file, 'w') as o: + o.attrs['minimal_match'] = args.minimal_match + for k in bank.keys(): + val = bank.key(k) + if val.dtype.char == 'U': + val = val.astype('bytes') + o[k] = val + o['max_matches'] = numpy.array(bank.max_matches) + logging.info("Save complete.") + class GenUniformWaveform(object): def __init__(self, buffer_length, sample_rate, f_lower): self.f_lower = f_lower @@ -363,17 +403,25 @@ class GenUniformWaveform(object): if hasattr(kwds['approximant'], 'decode'): kwds['approximant'] = kwds['approximant'].decode() + if args.full_resolution_buffer_length is not None: + buff_len = args.full_resolution_buffer_length + else: + buff_len = 1.0 / self.delta_f + if kwds['approximant'] in pycbc.waveform.fd_approximants(): - if args.full_resolution_buffer_length is not None: - # Generate the frequency-domain waveform at full frequency resolution - high_hp, high_hc = pycbc.waveform.get_fd_waveform(delta_f=1 / args.full_resolution_buffer_length, + + # Optionally generate time-domain waveform + if args.use_td_waveform: + hp, hc = pycbc.waveform.get_td_waveform(delta_t=1.0 / args.sample_rate, **kwds) - # Decimate the generated signal to a reduced frequency resolution - hp = decimate_frequency_domain(high_hp, 1 / args.buffer_length) - hc = decimate_frequency_domain(high_hc, 1 / args.buffer_length) + + hp = hp.to_frequencyseries(delta_f = 1.0 / buff_len) + hc = hc.to_frequencyseries(delta_f = 1.0 / buff_len) + else: - hp, hc = pycbc.waveform.get_fd_waveform(delta_f=self.delta_f, + hp, hc = pycbc.waveform.get_fd_waveform(delta_f = 1.0 / buff_len, **kwds) + if args.use_cross: hp = hc @@ -387,6 +435,10 @@ class GenUniformWaveform(object): delta_f=self.delta_f, delta_t=dt, **kwds) + if args.full_resolution_buffer_length is not None: + # Decimate the generated signal to a reduced frequency resolution + hp = decimate_frequency_domain(hp, 1 / args.buffer_length) + hp.resize(self.flen) hp = hp.astype(numpy.complex64) hp[self.kmin:-1] *= self.w @@ -416,9 +468,24 @@ gen = GenUniformWaveform(args.buffer_length, args.sample_rate, args.low_frequency_cutoff) bank = TriangleBank() -def wf_wrapper(p): +def silent_child_exit(signum, frame): + # The child workers execute this branch + # os._exit(0) ensures they die instantly and quietly + os._exit(0) + +def wf_wrapper(args): + p, parallel_check, threshold = args + + signal.signal(signal.SIGINT, silent_child_exit) + signal.signal(signal.SIGTERM, silent_child_exit) try: hp = gen.generate(**p) + hp.checked = None + hp.threshold = threshold + if parallel_check: + hp.gen = gen + hp.checked = hp not in bank + hp.gen = None return hp except Exception as e: print(e) @@ -431,7 +498,6 @@ if args.input_file: force_add=args.keep_entire_input_file) def draw(rtype): - if rtype == 'uniform': if args.input_config is None: params = {name: numpy.random.uniform(pmin, pmax, size=size) @@ -534,6 +600,9 @@ def cdraw(rtype, ts, te): return p +signal.signal(signal.SIGINT, handle_exit_signal) +signal.signal(signal.SIGTERM, handle_exit_signal) + tau0s = args.tau0_start tau0e = tau0s + args.tau0_crawl @@ -553,7 +622,8 @@ while tau0s < args.tau0_end: break blen = len(bank) - bank, uconv = bank.check_params(gen, params, args.minimal_match) + bank, uconv = bank.check_params(gen, params, args.minimal_match, + parallel_check=args.parallel_check) logging.info("%s: Round (U): %s Size: %s conv: %s added: %s", region, r, len(bank), uconv, len(bank) - blen) if r > 10: @@ -565,9 +635,13 @@ while tau0s < args.tau0_end: kloop += 1 params = cdraw('kde', tau0s, tau0e) blen = len(bank) - bank, kconv = bank.check_params(gen, params, args.minimal_match) - logging.info("%s: Round (K) (%s): %s Size: %s conv: %s added: %s", - region, kloop, r, len(bank), kconv, len(bank) - blen) + bank, kconv = bank.check_params(gen, params, args.minimal_match, + parallel_check=args.parallel_check) + + trail_matches = numpy.array(bank.max_matches[int(len(bank.max_matches)*0.9):]) + ave = numpy.mean(trail_matches) + logging.info("%s: Round (K) (%s): %s Size: %s conv: %0.4f added: %s Trail Ave: %0.4f Trail Min: %0.4f", + region, kloop, r, len(bank), kconv, len(bank) - blen, ave, trail_matches.min()) if uconv: @@ -587,10 +661,4 @@ while tau0s < args.tau0_end: tau0s += args.tau0_crawl / 2 tau0e += args.tau0_crawl / 2 -o = HFile(args.output_file, 'w') -o.attrs['minimal_match'] = args.minimal_match -for k in bank.keys(): - val = bank.key(k) - if val.dtype.char == 'U': - val = val.astype('bytes') - o[k] = val +save_bank()