Skip to content

Commit 6d73824

Browse files
committed
Improve AFS branch mode
1 parent 026ce76 commit 6d73824

File tree

1 file changed

+30
-63
lines changed

1 file changed

+30
-63
lines changed

python/tests/test_tree_stats.py

Lines changed: 30 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ def windowed_tree_stat(ts, stat, windows, span_normalise=True):
144144
return A
145145

146146

147-
# Timewindows test
148147
def naive_branch_general_stat(
149148
ts, w, f, windows=None, time_windows=None, polarised=False, span_normalise=True
150149
):
@@ -153,6 +152,9 @@ def naive_branch_general_stat(
153152
drop_time_windows = time_windows is None
154153
if time_windows is None:
155154
time_windows = [0.0, np.inf]
155+
else:
156+
if time_windows[0] != 0:
157+
time_windows = [0] + time_windows
156158
n, k = w.shape
157159
tw = len(time_windows) - 1
158160
# hack to determine m
@@ -180,7 +182,7 @@ def naive_branch_general_stat(
180182
for u in tree.nodes()
181183
)
182184
sigma[tree.index, j, :] = s * tree.span
183-
for j in range(1, len(time_windows) - 1):
185+
for j in range(1, tw):
184186
sigma[:, j, :] = sigma[:, j, :] - sigma[:, j - 1, :]
185187
if isinstance(windows, str) and windows == "trees":
186188
# need to average across the windows
@@ -191,48 +193,11 @@ def naive_branch_general_stat(
191193
else:
192194
out = windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise)
193195
if drop_time_windows:
194-
# beware: this assumes the first dimension is windows
195-
assert out.shape[1] == 1
196+
assert out.shape[1] == 3
196197
out = out[:, 0]
197198
return out
198199

199200

200-
# Previous version without tw
201-
# def naive_branch_general_stat(
202-
# ts, w, f, windows=None, polarised=False, span_normalise=True
203-
# ):
204-
# if windows is None:
205-
# windows = [0.0, ts.sequence_length]
206-
# n, k = w.shape
207-
# # hack to determine m
208-
# m = len(f(w[0]))
209-
# total = np.sum(w, axis=0)
210-
211-
# sigma = np.zeros((ts.num_trees, m))
212-
# for tree in ts.trees():
213-
# x = np.zeros((ts.num_nodes, k))
214-
# x[ts.samples()] = w
215-
# for u in tree.nodes(order="postorder"):
216-
# for v in tree.children(u):
217-
# x[u] += x[v]
218-
# if polarised:
219-
# s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes())
220-
# else:
221-
# s = sum(
222-
# tree.branch_length(u) * (f(x[u]) + f(total - x[u]))
223-
# for u in tree.nodes()
224-
# )
225-
# sigma[tree.index] = s * tree.span
226-
# if isinstance(windows, str) and windows == "trees":
227-
# # need to average across the windows
228-
# if span_normalise:
229-
# for j, tree in enumerate(ts.trees()):
230-
# sigma[j] /= tree.span
231-
# return sigma
232-
# else:
233-
# return windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise)
234-
235-
236201
def branch_general_stat(
237202
ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True
238203
):
@@ -313,7 +278,6 @@ def polarised_summary(u):
313278
# for the next tree
314279
break
315280

316-
# print("window_index:", window_index, windows.shape)
317281
assert window_index == windows.shape[0] - 1
318282
if span_normalise:
319283
for j in range(num_windows):
@@ -3644,14 +3608,19 @@ def naive_branch_allele_frequency_spectrum(
36443608
drop_windows = windows is None
36453609
if windows is None:
36463610
windows = [0.0, ts.sequence_length]
3611+
else:
3612+
if windows[0] != 0:
3613+
windows = [0] + windows
36473614
drop_time_windows = time_windows is None
36483615
if time_windows is None:
36493616
time_windows = [0.0, np.inf]
3617+
else:
3618+
if time_windows[0] != 0:
3619+
time_windows = [0] + time_windows
36503620
windows = ts.parse_windows(windows)
36513621
num_windows = len(windows) - 1
36523622
num_time_windows = len(time_windows) - 1
36533623
out_dim = [1 + len(sample_set) for sample_set in sample_sets]
3654-
out = np.zeros([num_windows] + out_dim)
36553624
out = np.zeros([num_windows] + [num_time_windows] + out_dim)
36563625
for j in range(num_windows):
36573626
begin = windows[j]
@@ -3689,15 +3658,11 @@ def naive_branch_allele_frequency_spectrum(
36893658
out[j, k, :] = S
36903659

36913660
if drop_time_windows:
3692-
# beware: this assumes the first dimension is windows
3693-
assert out.shape[1] == 1
3661+
assert out.ndim == 2 + len(out_dim)
36943662
out = out[:, 0]
36953663
elif drop_windows:
3696-
# drop windows dim if only using time windows
3664+
assert out.shape[0] == 1
36973665
out = out[0]
3698-
# assert out.shape[0] == 1
3699-
# Warning: when using Windows and TimeWindows,
3700-
# the output has three dimensions
37013666
return out
37023667

37033668

@@ -3756,14 +3721,14 @@ def branch_allele_frequency_spectrum(
37563721
last_update = np.zeros(ts.num_nodes)
37573722
window_index = 0
37583723
parent = np.zeros(ts.num_nodes, dtype=np.int32) - 1
3759-
branch_length = np.zeros(ts.num_nodes)
3724+
# branch_length = np.zeros(ts.num_nodes)
37603725
tree_index = 0
37613726

3762-
def update_result(window_index, u, right, time_windows):
3727+
def update_result(window_index, u, right):
37633728
for k_tw, _ in enumerate(time_windows[:-1]):
37643729
if 0 < count[u, -1] < ts.num_samples:
3765-
# interval between child and parent inside the window
3766-
t_v = branch_length[u] + time[u]
3730+
# t_v = branch_length[u] + time[u]
3731+
t_v = time[parent[u]]
37673732
tw_branch_length = min(time_windows[k_tw + 1], t_v) - max(
37683733
time_windows[0], time[u]
37693734
)
@@ -3779,21 +3744,21 @@ def update_result(window_index, u, right, time_windows):
37793744
for edge in edges_out:
37803745
u = edge.child
37813746
v = edge.parent
3782-
update_result(window_index, u, t_left, time_windows)
3747+
update_result(window_index, u, t_left)
37833748
while v != -1:
3784-
update_result(window_index, v, t_left, time_windows)
3749+
update_result(window_index, v, t_left)
37853750
count[v] -= count[u]
37863751
v = parent[v]
37873752
parent[u] = -1
3788-
branch_length[u] = 0
3753+
# branch_length[u] = 0
37893754

37903755
for edge in edges_in:
37913756
u = edge.child
37923757
v = edge.parent
37933758
parent[u] = v
3794-
branch_length[u] = time[v] - time[u]
3759+
# branch_length[u] = time[v] - time[u]
37953760
while v != -1:
3796-
update_result(window_index, v, t_left, time_windows)
3761+
update_result(window_index, v, t_left)
37973762
count[v] += count[u]
37983763
v = parent[v]
37993764

@@ -3812,7 +3777,7 @@ def update_result(window_index, u, right, time_windows):
38123777
# non-zero branches, but this would add a O(log n) cost to each edge
38133778
# insertion and removal and a lot of complexity to the C implementation.
38143779
for u in range(ts.num_nodes):
3815-
update_result(window_index, u, w_right, time_windows)
3780+
update_result(window_index, u, w_right)
38163781
window_index += 1
38173782
tree_index += 1
38183783

@@ -3822,13 +3787,12 @@ def update_result(window_index, u, right, time_windows):
38223787
result[j] /= windows[j + 1] - windows[j]
38233788

38243789
if drop_time_windows:
3825-
# beware: this assumes the first dimension is windows
3790+
assert result.ndim == 2 + len(out_dim)
38263791
assert result.shape[1] == 1
38273792
result = result[:, 0]
38283793
elif drop_windows:
3829-
# drop windows dim if only using time windows
3794+
assert result.shape[0] == 1
38303795
result = result[0]
3831-
# assert out.shape[0] == 1
38323796
return result
38333797

38343798

@@ -6952,6 +6916,9 @@ def test_afs_branch(self):
69526916
self.assertArrayAlmostEqual(sfs1_tws, sfs1_tws_opti)
69536917
# test if time windows and windows obtained with naive version
69546918
# and opti are equal
6955-
# Warning: when using Windows and TimeWindows,
6956-
# the output has three dimensions
69576919
self.assertArrayAlmostEqual(sfs1_w_tw, sfs1_w_tw_opti)
6920+
# dimmensions Tests
6921+
assert sfs1_tws.ndim == sfs1_w.ndim
6922+
# dimensions are dim1: windows ; dim2: time_windows ;
6923+
# dim3-or-more: num_sample_sets
6924+
assert sfs1_w_tw.ndim == 3

0 commit comments

Comments
 (0)