Skip to content

Commit cf966cd

Browse files
committed
FEAT: make geometry.pyx threadsafe and enable freethreaded build
1 parent 1eb60cf commit cf966cd

3 files changed

Lines changed: 90 additions & 32 deletions

File tree

bilby_cython/geometry.pyx

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@ from .time import greenwich_mean_sidereal_time
66
cdef double CC = 299792458.0
77

88

9-
cpdef time_delay_geocentric(np.ndarray detector1, np.ndarray detector2, double ra, double dec, double time):
9+
cpdef time_delay_geocentric(
10+
np.ndarray[np.float64_t, ndim=1] detector1,
11+
np.ndarray[np.float64_t, ndim=1] detector2,
12+
double ra,
13+
double dec,
14+
double time,
15+
):
1016
"""
1117
Calculate time delay between two detectors in geocentric coordinates based on XLALArrivaTimeDiff in TimeDelay.c
1218
@@ -53,7 +59,12 @@ cpdef time_delay_geocentric(np.ndarray detector1, np.ndarray detector2, double r
5359
_GEOCENTER = np.zeros(3, dtype=float)
5460

5561

56-
cpdef time_delay_from_geocenter(np.ndarray detector1, double ra, double dec, double time):
62+
cpdef time_delay_from_geocenter(
63+
np.ndarray[np.float64_t, ndim=1] detector1,
64+
double ra,
65+
double dec,
66+
double time,
67+
):
5768
"""
5869
Calculate time delay between a detectors and the geocenter
5970
based on XLALArrivalTimeDiff in TimeDelay.c
@@ -78,7 +89,14 @@ cpdef time_delay_from_geocenter(np.ndarray detector1, double ra, double dec, dou
7889
return time_delay_geocentric(detector1, _GEOCENTER, ra, dec, time)
7990

8091

81-
cdef _vectors_for_polarization_tensor(double phi, double theta, double psi):
92+
cdef _vectors_for_polarization_tensor(
93+
double phi,
94+
double theta,
95+
double psi,
96+
double[:] omega_view,
97+
double[:] m_view,
98+
double[:] n_view,
99+
):
82100
r"""
83101
Compute the three vectors that can be used to construct the different
84102
population modes.
@@ -129,27 +147,25 @@ cdef _vectors_for_polarization_tensor(double phi, double theta, double psi):
129147
omega_view[2] = m_view[0] * n_view[1] - m_view[1] * n_view[0]
130148

131149

132-
m = np.zeros(3)
133-
n = np.zeros(3)
134-
omega = np.zeros(3)
135-
cdef double[:] m_view = m
136-
cdef double[:] n_view = n
137-
cdef double[:] omega_view = omega
138-
139-
140-
cpdef _polarization_tensor(double[:, :] output_view, str mode):
150+
cpdef _polarization_tensor(
151+
double[:, :] output_view,
152+
str mode,
153+
double[:] omega_view,
154+
double[:] m_view,
155+
double[:] n_view,
156+
):
141157
if mode == 'plus':
142-
_plus(output_view)
158+
_plus(output_view, m_view, n_view)
143159
elif mode == 'cross':
144-
_cross(output_view)
160+
_cross(output_view, m_view, n_view)
145161
elif mode == 'breathing':
146-
_breathing(output_view)
162+
_breathing(output_view, m_view, n_view)
147163
elif mode == 'longitudinal':
148-
_longitudinal(output_view)
164+
_longitudinal(output_view, omega_view)
149165
elif mode == 'x':
150-
_x(output_view)
166+
_x(output_view, omega_view, m_view)
151167
elif mode == 'y':
152-
_y(output_view)
168+
_y(output_view, omega_view, n_view)
153169
else:
154170
raise ValueError("{} not a polarization mode!".format(mode))
155171

@@ -182,15 +198,21 @@ cpdef get_polarization_tensor(double ra, double dec, double time, double psi, st
182198
183199
"""
184200
cdef double gmst, phi, theta
185-
output = np.zeros((3, 3))
201+
output = np.empty((3, 3))
186202
cdef double[:, :] output_view = output
203+
omega = np.empty(3)
204+
m = np.empty(3)
205+
n = np.empty(3)
206+
cdef double[:] omega_view = omega
207+
cdef double[:] m_view = m
208+
cdef double[:] n_view = n
187209

188210
gmst = fmod(greenwich_mean_sidereal_time(time), 2 * pi)
189211
phi = ra - gmst
190212
theta = pi / 2 - dec
191-
_vectors_for_polarization_tensor(phi, theta, psi)
213+
_vectors_for_polarization_tensor(phi, theta, psi, omega_view, m_view, n_view)
192214

193-
_polarization_tensor(output_view, mode)
215+
_polarization_tensor(output_view, mode, omega_view, m_view, n_view)
194216

195217
return output
196218

@@ -225,22 +247,28 @@ cpdef get_polarization_tensor_multiple_modes(double ra, double dec, double time,
225247
"""
226248
cdef double gmst, phi, theta
227249
cdef double[:, :] output_view
250+
omega = np.empty(3)
251+
m = np.empty(3)
252+
n = np.empty(3)
253+
cdef double[:] omega_view
254+
cdef double[:] m_view
255+
cdef double[:] n_view
228256
output = list()
229257

230258
gmst = fmod(greenwich_mean_sidereal_time(time), 2 * pi)
231259
phi = ra - gmst
232260
theta = pi / 2 - dec
233-
_vectors_for_polarization_tensor(phi, theta, psi)
261+
_vectors_for_polarization_tensor(phi, theta, psi, omega_view, m_view, n_view)
234262

235263
for mode in modes:
236264
tensor = np.zeros((3, 3))
237265
output_view = tensor
238-
_polarization_tensor(output_view, mode)
266+
_polarization_tensor(output_view, mode, omega_view, m_view, n_view)
239267
output.append(tensor)
240268
return output
241269

242270

243-
cdef _plus(double[:, :] output):
271+
cdef _plus(double[:, :] output, double[:] m_view, double[:] n_view):
244272
cdef int ii, jj
245273

246274
for ii in range(3):
@@ -250,7 +278,7 @@ cdef _plus(double[:, :] output):
250278
output[jj][ii] = output[ii][jj]
251279

252280

253-
cdef _breathing(double[:, :] output):
281+
cdef _breathing(double[:, :] output, double[:] m_view, double[:] n_view):
254282
cdef int ii, jj
255283

256284
for ii in range(3):
@@ -260,7 +288,7 @@ cdef _breathing(double[:, :] output):
260288
output[jj][ii] = output[ii][jj]
261289

262290

263-
cdef _longitudinal(double[:, :] output):
291+
cdef _longitudinal(double[:, :] output, double[:] omega_view):
264292
cdef int ii, jj
265293

266294
for ii in range(3):
@@ -280,19 +308,19 @@ cdef _symmetric_response(double[:, :] output, double[:] input_1, double[:] input
280308
output[jj][ii] = output[ii][jj]
281309

282310

283-
cdef _cross(double[:, :] output):
311+
cdef _cross(double[:, :] output, double[:] m_view, double[:] n_view):
284312
_symmetric_response(output, m_view, n_view)
285313

286314

287-
cdef _x(double[:, :] output):
315+
cdef _x(double[:, :] output, double[:] omega_view, double[:] m_view):
288316
_symmetric_response(output, m_view, omega_view)
289317

290318

291-
cdef _y(double[:, :] output):
319+
cdef _y(double[:, :] output, double[:] omega_view, double[:] n_view):
292320
_symmetric_response(output, n_view, omega_view)
293321

294322

295-
cpdef three_by_three_matrix_contraction(np.ndarray x, np.ndarray y):
323+
cpdef three_by_three_matrix_contraction(np.ndarray[np.float64_t, ndim=2] x, np.ndarray[np.float64_t, ndim=2] y):
296324
"""
297325
Doubly contract two 3x3 input matrices following Einstein summation.
298326
@@ -323,7 +351,7 @@ cpdef three_by_three_matrix_contraction(np.ndarray x, np.ndarray y):
323351
return output
324352

325353

326-
cpdef detector_tensor(np.ndarray x, np.ndarray y):
354+
cpdef detector_tensor(np.ndarray[np.float64_t, ndim=1] x, np.ndarray[np.float64_t, ndim=1] y):
327355
r"""
328356
Compute the detector tensor given the two unit arm vectors.
329357
@@ -469,7 +497,7 @@ cpdef rotation_matrix_from_delta(delta_x):
469497
return rotation
470498

471499

472-
cpdef zenith_azimuth_to_theta_phi(double zenith, double azimuth, np.ndarray delta_x):
500+
cpdef zenith_azimuth_to_theta_phi(double zenith, double azimuth, np.ndarray[np.float64_t, ndim=1] delta_x):
473501
"""
474502
Convert from the 'detector frame' to the Earth frame.
475503
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import concurrent.futures
2+
import numpy as np
3+
from bilby_cython import geometry
4+
5+
6+
def test_polarization_tensor_threadsafe():
7+
"""
8+
A basic test of thread safety for the polarization tensor calculation.
9+
Previously, this was not thread safe due to the use of global variables
10+
to store intermediate results.
11+
"""
12+
13+
def dummy_func(val):
14+
return geometry.get_polarization_tensor(*val, "plus")
15+
16+
values = np.random.uniform(0, 1, (10000, 4))
17+
18+
truths = np.array([geometry.get_polarization_tensor(*val, "plus") for val in values])
19+
20+
results = truths.copy()
21+
with concurrent.futures.ThreadPoolExecutor(max_workers=15) as executor:
22+
jobs = {executor.submit(dummy_func, val): ii for ii, val in enumerate(values)}
23+
for job in concurrent.futures.as_completed(jobs):
24+
results[jobs[job]] = job.result()
25+
26+
assert np.allclose(truths, results)

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
class LazyImportBuildExtCmd(build_ext):
99
def finalize_options(self):
1010
from Cython.Build import cythonize
11+
from Cython.Compiler.Version import version as cython_version
12+
from packaging.version import Version
1113

1214
compiler_directives = dict(
1315
language_level=3,
@@ -22,6 +24,8 @@ def finalize_options(self):
2224
annotate = True
2325
else:
2426
annotate = False
27+
if Version(cython_version) >= Version("3.1.0a1"):
28+
compiler_directives["freethreading_compatible"] = True
2529
self.distribution.ext_modules = cythonize(
2630
self.distribution.ext_modules,
2731
compiler_directives=compiler_directives,

0 commit comments

Comments
 (0)