@@ -144,7 +144,6 @@ def windowed_tree_stat(ts, stat, windows, span_normalise=True):
144
144
return A
145
145
146
146
147
- # Timewindows test
148
147
def naive_branch_general_stat (
149
148
ts , w , f , windows = None , time_windows = None , polarised = False , span_normalise = True
150
149
):
@@ -153,6 +152,9 @@ def naive_branch_general_stat(
153
152
drop_time_windows = time_windows is None
154
153
if time_windows is None :
155
154
time_windows = [0.0 , np .inf ]
155
+ else :
156
+ if time_windows [0 ] != 0 :
157
+ time_windows = [0 ] + time_windows
156
158
n , k = w .shape
157
159
tw = len (time_windows ) - 1
158
160
# hack to determine m
@@ -180,7 +182,7 @@ def naive_branch_general_stat(
180
182
for u in tree .nodes ()
181
183
)
182
184
sigma [tree .index , j , :] = s * tree .span
183
- for j in range (1 , len ( time_windows ) - 1 ):
185
+ for j in range (1 , tw ):
184
186
sigma [:, j , :] = sigma [:, j , :] - sigma [:, j - 1 , :]
185
187
if isinstance (windows , str ) and windows == "trees" :
186
188
# need to average across the windows
@@ -191,48 +193,11 @@ def naive_branch_general_stat(
191
193
else :
192
194
out = windowed_tree_stat (ts , sigma , windows , span_normalise = span_normalise )
193
195
if drop_time_windows :
194
- # beware: this assumes the first dimension is windows
195
- assert out .shape [1 ] == 1
196
+ assert out .shape [1 ] == 3
196
197
out = out [:, 0 ]
197
198
return out
198
199
199
200
200
- # Previous version without tw
201
- # def naive_branch_general_stat(
202
- # ts, w, f, windows=None, polarised=False, span_normalise=True
203
- # ):
204
- # if windows is None:
205
- # windows = [0.0, ts.sequence_length]
206
- # n, k = w.shape
207
- # # hack to determine m
208
- # m = len(f(w[0]))
209
- # total = np.sum(w, axis=0)
210
-
211
- # sigma = np.zeros((ts.num_trees, m))
212
- # for tree in ts.trees():
213
- # x = np.zeros((ts.num_nodes, k))
214
- # x[ts.samples()] = w
215
- # for u in tree.nodes(order="postorder"):
216
- # for v in tree.children(u):
217
- # x[u] += x[v]
218
- # if polarised:
219
- # s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes())
220
- # else:
221
- # s = sum(
222
- # tree.branch_length(u) * (f(x[u]) + f(total - x[u]))
223
- # for u in tree.nodes()
224
- # )
225
- # sigma[tree.index] = s * tree.span
226
- # if isinstance(windows, str) and windows == "trees":
227
- # # need to average across the windows
228
- # if span_normalise:
229
- # for j, tree in enumerate(ts.trees()):
230
- # sigma[j] /= tree.span
231
- # return sigma
232
- # else:
233
- # return windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise)
234
-
235
-
236
201
def branch_general_stat (
237
202
ts , sample_weights , summary_func , windows = None , polarised = False , span_normalise = True
238
203
):
@@ -313,7 +278,6 @@ def polarised_summary(u):
313
278
# for the next tree
314
279
break
315
280
316
- # print("window_index:", window_index, windows.shape)
317
281
assert window_index == windows .shape [0 ] - 1
318
282
if span_normalise :
319
283
for j in range (num_windows ):
@@ -3644,14 +3608,19 @@ def naive_branch_allele_frequency_spectrum(
3644
3608
drop_windows = windows is None
3645
3609
if windows is None :
3646
3610
windows = [0.0 , ts .sequence_length ]
3611
+ else :
3612
+ if windows [0 ] != 0 :
3613
+ windows = [0 ] + windows
3647
3614
drop_time_windows = time_windows is None
3648
3615
if time_windows is None :
3649
3616
time_windows = [0.0 , np .inf ]
3617
+ else :
3618
+ if time_windows [0 ] != 0 :
3619
+ time_windows = [0 ] + time_windows
3650
3620
windows = ts .parse_windows (windows )
3651
3621
num_windows = len (windows ) - 1
3652
3622
num_time_windows = len (time_windows ) - 1
3653
3623
out_dim = [1 + len (sample_set ) for sample_set in sample_sets ]
3654
- out = np .zeros ([num_windows ] + out_dim )
3655
3624
out = np .zeros ([num_windows ] + [num_time_windows ] + out_dim )
3656
3625
for j in range (num_windows ):
3657
3626
begin = windows [j ]
@@ -3689,15 +3658,11 @@ def naive_branch_allele_frequency_spectrum(
3689
3658
out [j , k , :] = S
3690
3659
3691
3660
if drop_time_windows :
3692
- # beware: this assumes the first dimension is windows
3693
- assert out .shape [1 ] == 1
3661
+ assert out .ndim == 2 + len (out_dim )
3694
3662
out = out [:, 0 ]
3695
3663
elif drop_windows :
3696
- # drop windows dim if only using time windows
3664
+ assert out . shape [ 0 ] == 1
3697
3665
out = out [0 ]
3698
- # assert out.shape[0] == 1
3699
- # Warning: when using Windows and TimeWindows,
3700
- # the output has three dimensions
3701
3666
return out
3702
3667
3703
3668
@@ -3756,14 +3721,14 @@ def branch_allele_frequency_spectrum(
3756
3721
last_update = np .zeros (ts .num_nodes )
3757
3722
window_index = 0
3758
3723
parent = np .zeros (ts .num_nodes , dtype = np .int32 ) - 1
3759
- branch_length = np .zeros (ts .num_nodes )
3724
+ # branch_length = np.zeros(ts.num_nodes)
3760
3725
tree_index = 0
3761
3726
3762
- def update_result (window_index , u , right , time_windows ):
3727
+ def update_result (window_index , u , right ):
3763
3728
for k_tw , _ in enumerate (time_windows [:- 1 ]):
3764
3729
if 0 < count [u , - 1 ] < ts .num_samples :
3765
- # interval between child and parent inside the window
3766
- t_v = branch_length [ u ] + time [u ]
3730
+ # t_v = branch_length[u] + time[u]
3731
+ t_v = time [parent [ u ] ]
3767
3732
tw_branch_length = min (time_windows [k_tw + 1 ], t_v ) - max (
3768
3733
time_windows [0 ], time [u ]
3769
3734
)
@@ -3779,21 +3744,21 @@ def update_result(window_index, u, right, time_windows):
3779
3744
for edge in edges_out :
3780
3745
u = edge .child
3781
3746
v = edge .parent
3782
- update_result (window_index , u , t_left , time_windows )
3747
+ update_result (window_index , u , t_left )
3783
3748
while v != - 1 :
3784
- update_result (window_index , v , t_left , time_windows )
3749
+ update_result (window_index , v , t_left )
3785
3750
count [v ] -= count [u ]
3786
3751
v = parent [v ]
3787
3752
parent [u ] = - 1
3788
- branch_length [u ] = 0
3753
+ # branch_length[u] = 0
3789
3754
3790
3755
for edge in edges_in :
3791
3756
u = edge .child
3792
3757
v = edge .parent
3793
3758
parent [u ] = v
3794
- branch_length [u ] = time [v ] - time [u ]
3759
+ # branch_length[u] = time[v] - time[u]
3795
3760
while v != - 1 :
3796
- update_result (window_index , v , t_left , time_windows )
3761
+ update_result (window_index , v , t_left )
3797
3762
count [v ] += count [u ]
3798
3763
v = parent [v ]
3799
3764
@@ -3812,7 +3777,7 @@ def update_result(window_index, u, right, time_windows):
3812
3777
# non-zero branches, but this would add a O(log n) cost to each edge
3813
3778
# insertion and removal and a lot of complexity to the C implementation.
3814
3779
for u in range (ts .num_nodes ):
3815
- update_result (window_index , u , w_right , time_windows )
3780
+ update_result (window_index , u , w_right )
3816
3781
window_index += 1
3817
3782
tree_index += 1
3818
3783
@@ -3822,13 +3787,12 @@ def update_result(window_index, u, right, time_windows):
3822
3787
result [j ] /= windows [j + 1 ] - windows [j ]
3823
3788
3824
3789
if drop_time_windows :
3825
- # beware: this assumes the first dimension is windows
3790
+ assert result . ndim == 2 + len ( out_dim )
3826
3791
assert result .shape [1 ] == 1
3827
3792
result = result [:, 0 ]
3828
3793
elif drop_windows :
3829
- # drop windows dim if only using time windows
3794
+ assert result . shape [ 0 ] == 1
3830
3795
result = result [0 ]
3831
- # assert out.shape[0] == 1
3832
3796
return result
3833
3797
3834
3798
@@ -6952,6 +6916,9 @@ def test_afs_branch(self):
6952
6916
self .assertArrayAlmostEqual (sfs1_tws , sfs1_tws_opti )
6953
6917
# test if time windows and windows obtained with naive version
6954
6918
# and opti are equal
6955
- # Warning: when using Windows and TimeWindows,
6956
- # the output has three dimensions
6957
6919
self .assertArrayAlmostEqual (sfs1_w_tw , sfs1_w_tw_opti )
6920
+ # dimmensions Tests
6921
+ assert sfs1_tws .ndim == sfs1_w .ndim
6922
+ # dimensions are dim1: windows ; dim2: time_windows ;
6923
+ # dim3-or-more: num_sample_sets
6924
+ assert sfs1_w_tw .ndim == 3
0 commit comments