-
Notifications
You must be signed in to change notification settings - Fork 376
Improvements to pycbc_brute_bank #5281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
709f1c7
7fa5c1c
70d763e
727729c
c92c179
1557bd0
f4af8df
42a2c26
1e05aad
10b1d03
550781b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,8 @@ | |
| """ | ||
| import numpy | ||
| import logging | ||
| import signal | ||
| import os | ||
| import argparse | ||
| import numpy.random | ||
| from scipy.stats import gaussian_kde | ||
|
|
@@ -53,6 +55,8 @@ parser.add_argument('--approximant', required=False, | |
| parser.add_argument('--minimal-match', default=0.97, type=float) | ||
| parser.add_argument('--buffer-length', default=2, type=float, | ||
| help='size of waveform buffer in seconds') | ||
| parser.add_argument('--use-td-waveform', action='store_true', | ||
| help='Generate waveform in the time domain (default is frequency domain).') | ||
| parser.add_argument('--full-resolution-buffer-length', default=None, type=float, | ||
| help='Size of the waveform buffer in seconds for generating time-domain signals at full resolution before conversion to the frequency domain.') | ||
| parser.add_argument('--max-signal-length', type= float, | ||
|
|
@@ -81,6 +85,8 @@ parser.add_argument('--tau0-end', type=float) | |
| parser.add_argument('--tau0-cutoff-frequency', type=float, default=15.0) | ||
| parser.add_argument('--nprocesses', type=int, default=1, | ||
| help='Number of processes to use for waveform generation parallelization. If not given then only a single core will be used.') | ||
| parser.add_argument('--parallel-check', action='store_true', help="Do bank checking parallel, note that this means that proposals WILL NOT be checked against each other.") | ||
| parser.add_argument('--max-connections', type=int, help="Maximum number of matches to store with each template", default=numpy.inf) | ||
| pycbc.psd.insert_psd_option_group(parser) | ||
| parser.add_argument('--use-trimmed-buffer', action='store_true', | ||
| help=('When specified, the match calculation will use only the first and last 100 samples ' | ||
|
|
@@ -134,6 +140,7 @@ class TriangleBank(object): | |
| def __init__(self, p=None): | ||
| self.waveforms = p if p is not None else [] | ||
| self.tbins = {} | ||
| self.max_matches = [] | ||
|
|
||
| def __len__(self): | ||
| return len(self.waveforms) | ||
|
|
@@ -229,23 +236,33 @@ class TriangleBank(object): | |
| while 1: | ||
| j = inc.pop() | ||
| if j is None: | ||
| hp.matches = matches[r] | ||
| hp.indices = r | ||
| msort = matches[r].argsort() | ||
|
|
||
| msorted = matches[r][msort] | ||
| rsorted = r[msort] | ||
| keep = numpy.ones(len(msorted), dtype=bool) | ||
| if args.max_connections < len(keep): | ||
| #keep[args.max_connections//2: -args.max_connections//2] = False | ||
| keep[args.max_connections:] = False | ||
|
|
||
| hp.matches = msorted[keep].copy() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this array copy not now increase the memory footprint, particularly in the max_connections infinite / default case?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually it's the opposite, otherwise loose reference could keep something in memory when we don't really need it anymore.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A further clarification may help here. If you just return a view, even if that's all you have, the full array is always kept. I explicitly want the reduced form to be what is stored and not the full array, hence explicitly asking for a new array that contains a copy of what the view pointed to. This means that the original values can be cleaned up by python's garbage collector. |
||
| hp.indices = rsorted[keep].copy() | ||
|
|
||
| logging.info("TADD MaxMatch:%0.3f Size:%i " | ||
| "AfterSigma:%i AfterTau0:%i Matches:%i" | ||
| % (mmax, len(self), msig, mtau, mnum)) | ||
| hp.max_match = mmax | ||
| return False | ||
|
|
||
| hc = self[j] | ||
|
|
||
| # Defensive initialization if matches/indices are missing | ||
| if not hasattr(hc, 'matches'): | ||
| hc.matches = numpy.empty(len(self)) | ||
| hc.matches = numpy.empty(2) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Im not sure I understand this change
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In truth, I'd like to remove these lines. I can't really figure out why they are needed. The only guess I have is that perhaps there is some boundary issue that this is masking. However, to the point, there is no reason to add these with the full size of the existing bank. Since the information added is meaningless, it ends up being a waste of memory. Putting it at 2 was simply a lazy way to force the correct dimensions. |
||
| hc.matches[:] = numpy.nan | ||
|
|
||
| if not hasattr(hc, 'indices'): | ||
| hc.indices = numpy.arange(len(self)) | ||
|
|
||
| hc.indices = numpy.arange(2) | ||
|
|
||
| m = hp.gen.match(hp, hc) | ||
| matches[j] = m | ||
|
|
@@ -265,34 +282,39 @@ class TriangleBank(object): | |
| if m > mmax: | ||
| mmax = m | ||
|
|
||
| def check_params(self, gen, params, threshold, force_add=False): | ||
| def check_params(self, gen, params, threshold, | ||
| force_add=False, | ||
| parallel_check=False): | ||
| num_added = 0 | ||
| total_num = len(tuple(params.values())[0]) | ||
| waveform_cache = [] | ||
|
|
||
| pool = pycbc.pool.choose_pool(args.nprocesses) | ||
| for return_wf in pool.imap_unordered( | ||
| wf_wrapper, | ||
| ({k: params[k][idx] for k in params} for idx in range(total_num)) | ||
| (({k: params[k][idx] for k in params}, parallel_check, threshold) for idx in range(total_num)) | ||
| ): | ||
| waveform_cache += [return_wf] | ||
|
|
||
| pool.close_pool() | ||
| del pool | ||
|
|
||
| for hp in waveform_cache: | ||
| if hp is not None: | ||
| hp.gen = gen | ||
| hp.threshold = threshold | ||
| if hp not in self: | ||
| num_added += 1 | ||
| self.insert(hp) | ||
| elif force_add: | ||
| if hp.checked is None: | ||
| hp.gen = gen | ||
| hp.checked = hp not in self | ||
|
|
||
| if hp.checked: | ||
| self.max_matches.append(hp.max_match) | ||
|
|
||
| if hp.checked or force_add: | ||
| num_added += 1 | ||
| self.insert(hp) | ||
| self.insert(hp) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. whitespace |
||
| else: | ||
| logging.info("Waveform generation failed!") | ||
| continue | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. whitespace |
||
| return bank, num_added / total_num | ||
|
|
||
| def decimate_frequency_domain(template, target_df): | ||
|
|
@@ -326,6 +348,24 @@ def decimate_frequency_domain(template, target_df): | |
| decimated_template = pycbc.types.FrequencySeries(decimated_signal, delta_f=target_df) | ||
| return decimated_template | ||
|
|
||
| def handle_exit_signal(signum, frame): | ||
| logging.warning(f"Signal {signum} received. Triggering emergency save...") | ||
| save_bank() | ||
| # Force exit so the script doesn't try to resume the loops | ||
| os._exit(0) | ||
|
|
||
| def save_bank(): | ||
| logging.info("Saving current bank to %s", args.output_file) | ||
| with HFile(args.output_file, 'w') as o: | ||
| o.attrs['minimal_match'] = args.minimal_match | ||
| for k in bank.keys(): | ||
| val = bank.key(k) | ||
| if val.dtype.char == 'U': | ||
| val = val.astype('bytes') | ||
| o[k] = val | ||
| o['max_matches'] = numpy.array(bank.max_matches) | ||
| logging.info("Save complete.") | ||
|
|
||
| class GenUniformWaveform(object): | ||
| def __init__(self, buffer_length, sample_rate, f_lower): | ||
| self.f_lower = f_lower | ||
|
|
@@ -363,17 +403,25 @@ class GenUniformWaveform(object): | |
| if hasattr(kwds['approximant'], 'decode'): | ||
| kwds['approximant'] = kwds['approximant'].decode() | ||
|
|
||
| if args.full_resolution_buffer_length is not None: | ||
| buff_len = args.full_resolution_buffer_length | ||
| else: | ||
| buff_len = 1.0 / self.delta_f | ||
|
|
||
| if kwds['approximant'] in pycbc.waveform.fd_approximants(): | ||
| if args.full_resolution_buffer_length is not None: | ||
| # Generate the frequency-domain waveform at full frequency resolution | ||
| high_hp, high_hc = pycbc.waveform.get_fd_waveform(delta_f=1 / args.full_resolution_buffer_length, | ||
|
|
||
| # Optionally generate time-domain waveform | ||
| if args.use_td_waveform: | ||
| hp, hc = pycbc.waveform.get_td_waveform(delta_t=1.0 / args.sample_rate, | ||
| **kwds) | ||
| # Decimate the generated signal to a reduced frequency resolution | ||
| hp = decimate_frequency_domain(high_hp, 1 / args.buffer_length) | ||
| hc = decimate_frequency_domain(high_hc, 1 / args.buffer_length) | ||
|
|
||
| hp = hp.to_frequencyseries(delta_f = 1.0 / buff_len) | ||
| hc = hc.to_frequencyseries(delta_f = 1.0 / buff_len) | ||
|
|
||
| else: | ||
| hp, hc = pycbc.waveform.get_fd_waveform(delta_f=self.delta_f, | ||
| hp, hc = pycbc.waveform.get_fd_waveform(delta_f = 1.0 / buff_len, | ||
| **kwds) | ||
|
|
||
| if args.use_cross: | ||
| hp = hc | ||
|
|
||
|
|
@@ -387,6 +435,10 @@ class GenUniformWaveform(object): | |
| delta_f=self.delta_f, delta_t=dt, | ||
| **kwds) | ||
|
|
||
| if args.full_resolution_buffer_length is not None: | ||
| # Decimate the generated signal to a reduced frequency resolution | ||
| hp = decimate_frequency_domain(hp, 1 / args.buffer_length) | ||
|
|
||
| hp.resize(self.flen) | ||
| hp = hp.astype(numpy.complex64) | ||
| hp[self.kmin:-1] *= self.w | ||
|
|
@@ -416,9 +468,24 @@ gen = GenUniformWaveform(args.buffer_length, | |
| args.sample_rate, args.low_frequency_cutoff) | ||
| bank = TriangleBank() | ||
|
|
||
| def wf_wrapper(p): | ||
| def silent_child_exit(signum, frame): | ||
| # The child workers execute this branch | ||
| # os._exit(0) ensures they die instantly and quietly | ||
| os._exit(0) | ||
|
|
||
| def wf_wrapper(args): | ||
| p, parallel_check, threshold = args | ||
|
|
||
| signal.signal(signal.SIGINT, silent_child_exit) | ||
| signal.signal(signal.SIGTERM, silent_child_exit) | ||
| try: | ||
| hp = gen.generate(**p) | ||
| hp.checked = None | ||
| hp.threshold = threshold | ||
| if parallel_check: | ||
| hp.gen = gen | ||
| hp.checked = hp not in bank | ||
| hp.gen = None | ||
| return hp | ||
| except Exception as e: | ||
| print(e) | ||
|
|
@@ -431,7 +498,6 @@ if args.input_file: | |
| force_add=args.keep_entire_input_file) | ||
|
|
||
| def draw(rtype): | ||
|
|
||
| if rtype == 'uniform': | ||
| if args.input_config is None: | ||
| params = {name: numpy.random.uniform(pmin, pmax, size=size) | ||
|
|
@@ -534,6 +600,9 @@ def cdraw(rtype, ts, te): | |
|
|
||
| return p | ||
|
|
||
| signal.signal(signal.SIGINT, handle_exit_signal) | ||
| signal.signal(signal.SIGTERM, handle_exit_signal) | ||
|
|
||
| tau0s = args.tau0_start | ||
| tau0e = tau0s + args.tau0_crawl | ||
|
|
||
|
|
@@ -553,7 +622,8 @@ while tau0s < args.tau0_end: | |
| break | ||
|
|
||
| blen = len(bank) | ||
| bank, uconv = bank.check_params(gen, params, args.minimal_match) | ||
| bank, uconv = bank.check_params(gen, params, args.minimal_match, | ||
| parallel_check=args.parallel_check) | ||
| logging.info("%s: Round (U): %s Size: %s conv: %s added: %s", | ||
| region, r, len(bank), uconv, len(bank) - blen) | ||
| if r > 10: | ||
|
|
@@ -565,9 +635,13 @@ while tau0s < args.tau0_end: | |
| kloop += 1 | ||
| params = cdraw('kde', tau0s, tau0e) | ||
| blen = len(bank) | ||
| bank, kconv = bank.check_params(gen, params, args.minimal_match) | ||
| logging.info("%s: Round (K) (%s): %s Size: %s conv: %s added: %s", | ||
| region, kloop, r, len(bank), kconv, len(bank) - blen) | ||
| bank, kconv = bank.check_params(gen, params, args.minimal_match, | ||
| parallel_check=args.parallel_check) | ||
|
|
||
| trail_matches = numpy.array(bank.max_matches[int(len(bank.max_matches)*0.9):]) | ||
| ave = numpy.mean(trail_matches) | ||
| logging.info("%s: Round (K) (%s): %s Size: %s conv: %0.4f added: %s Trail Ave: %0.4f Trail Min: %0.4f", | ||
| region, kloop, r, len(bank), kconv, len(bank) - blen, ave, trail_matches.min()) | ||
|
|
||
|
|
||
| if uconv: | ||
|
|
@@ -587,10 +661,4 @@ while tau0s < args.tau0_end: | |
| tau0s += args.tau0_crawl / 2 | ||
| tau0e += args.tau0_crawl / 2 | ||
|
|
||
| o = HFile(args.output_file, 'w') | ||
| o.attrs['minimal_match'] = args.minimal_match | ||
| for k in bank.keys(): | ||
| val = bank.key(k) | ||
| if val.dtype.char == 'U': | ||
| val = val.astype('bytes') | ||
| o[k] = val | ||
| save_bank() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this line be wrapped to match maximum line length requirements?