diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index a40817e7a9..95b5ddafe2 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -23,6 +23,7 @@ */ #include "testlib.h" +#include #include #include @@ -865,6 +866,84 @@ verify_one_way_stat_func_errors(tsk_treeseq_t *ts, one_way_sample_stat_method *m CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); } +// Temporary definition for time_windows in tsk_treeseq_allele_frequency_spectrum +typedef int one_way_sample_stat_method_tw(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_windows, const double *windows, + tsk_size_t num_time_windows, const double *time_windows, tsk_flags_t options, + double *result); + +// Temporary duplicate for time-windows-having methods +static void +verify_one_way_stat_func_errors_tw( + tsk_treeseq_t *ts, one_way_sample_stat_method_tw *method) +{ + int ret; + tsk_id_t num_nodes = (tsk_id_t) tsk_treeseq_get_num_nodes(ts); + tsk_id_t samples[] = { 0, 1, 2, 3 }; + tsk_size_t sample_set_sizes = 4; + double windows[] = { 0, 0, 0 }; + double time_windows[] = { -1, 0.5, INFINITY }; + double result; + + ret = method(ts, 0, &sample_set_sizes, samples, 0, NULL, 0, NULL, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_SAMPLE_SETS); + + samples[0] = TSK_NULL; + ret = method(ts, 1, &sample_set_sizes, samples, 0, NULL, 0, NULL, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + samples[0] = -10; + ret = method(ts, 1, &sample_set_sizes, samples, 0, NULL, 0, NULL, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + samples[0] = num_nodes; + ret = method(ts, 1, &sample_set_sizes, samples, 0, NULL, 0, NULL, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + samples[0] = num_nodes + 1; + ret = method(ts, 1, &sample_set_sizes, samples, 0, NULL, 0, NULL, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + samples[0] = num_nodes - 1; + ret = method(ts, 1, &sample_set_sizes, samples, 0, NULL, 0, NULL, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SAMPLES); + + samples[0] = 1; + ret = method(ts, 1, &sample_set_sizes, samples, 0, NULL, 0, NULL, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + + samples[0] = 0; + sample_set_sizes = 0; + ret = method(ts, 1, &sample_set_sizes, samples, 0, NULL, 0, NULL, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_EMPTY_SAMPLE_SET); + + sample_set_sizes = 4; + /* Window errors */ + ret = method(ts, 1, &sample_set_sizes, samples, 0, windows, 0, NULL, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS); + + ret = method(ts, 1, &sample_set_sizes, samples, 2, windows, 0, NULL, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + /* Time window errors */ + ret = method( + ts, 1, &sample_set_sizes, samples, 0, NULL, 0, time_windows, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_TIME_WINDOWS_DIM); + + ret = method( + ts, 1, &sample_set_sizes, samples, 0, NULL, 2, time_windows, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_TIME_WINDOWS); + + time_windows[0] = 0.1; + ret = method( + ts, 1, &sample_set_sizes, samples, 0, NULL, 2, time_windows, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_TIME_WINDOWS); + + time_windows[0] = 0; + time_windows[1] = 0; + ret = method( + ts, 1, &sample_set_sizes, samples, 0, NULL, 2, time_windows, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_TIME_WINDOWS); +} + static void verify_two_way_stat_func_errors( tsk_treeseq_t *ts, general_sample_stat_method *method, tsk_flags_t options) @@ -1203,23 +1282,24 @@ verify_afs(tsk_treeseq_t *ts) sample_set_sizes[0] = n - 2; sample_set_sizes[1] = 2; ret = tsk_treeseq_allele_frequency_spectrum( - ts, 2, sample_set_sizes, samples, 0, NULL, 0, result); + ts, 2, sample_set_sizes, samples, 0, NULL, 0, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_treeseq_allele_frequency_spectrum( - ts, 2, sample_set_sizes, samples, 0, NULL, TSK_STAT_POLARISED, result); + ts, 2, sample_set_sizes, samples, 0, NULL, 0, NULL, TSK_STAT_POLARISED, result); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_treeseq_allele_frequency_spectrum(ts, 2, sample_set_sizes, samples, 0, - NULL, TSK_STAT_POLARISED | TSK_STAT_SPAN_NORMALISE, result); + NULL, 0, NULL, TSK_STAT_POLARISED | TSK_STAT_SPAN_NORMALISE, result); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_treeseq_allele_frequency_spectrum(ts, 2, sample_set_sizes, samples, 0, - NULL, TSK_STAT_BRANCH | TSK_STAT_POLARISED | TSK_STAT_SPAN_NORMALISE, result); + NULL, 0, NULL, TSK_STAT_BRANCH | TSK_STAT_POLARISED | TSK_STAT_SPAN_NORMALISE, + result); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_treeseq_allele_frequency_spectrum(ts, 2, sample_set_sizes, samples, 0, - NULL, TSK_STAT_BRANCH | TSK_STAT_SPAN_NORMALISE, result); + NULL, 0, NULL, TSK_STAT_BRANCH | TSK_STAT_SPAN_NORMALISE, result); CU_ASSERT_EQUAL_FATAL(ret, 0); free(result); @@ -2413,21 +2493,26 @@ test_paper_ex_afs_errors(void) tsk_size_t sample_set_sizes[] = { 2, 2 }; tsk_id_t samples[] = { 0, 1, 2, 3 }; double result[10]; /* not thinking too hard about the actual value needed */ + double time_windows[] = { 0, 1 }; int ret; tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - verify_one_way_stat_func_errors(&ts, tsk_treeseq_allele_frequency_spectrum); + verify_one_way_stat_func_errors_tw(&ts, tsk_treeseq_allele_frequency_spectrum); ret = tsk_treeseq_allele_frequency_spectrum( - &ts, 2, sample_set_sizes, samples, 0, NULL, TSK_STAT_NODE, result); + &ts, 2, sample_set_sizes, samples, 0, NULL, 0, NULL, TSK_STAT_NODE, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); ret = tsk_treeseq_allele_frequency_spectrum(&ts, 2, sample_set_sizes, samples, 0, - NULL, TSK_STAT_BRANCH | TSK_STAT_SITE, result); + NULL, 0, NULL, TSK_STAT_BRANCH | TSK_STAT_SITE, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_STAT_MODES); + ret = tsk_treeseq_allele_frequency_spectrum(&ts, 2, sample_set_sizes, samples, 0, + NULL, 1, time_windows, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); + tsk_treeseq_free(&ts); } @@ -2445,14 +2530,14 @@ test_paper_ex_afs(void) /* we have two singletons and one tripleton */ ret = tsk_treeseq_allele_frequency_spectrum( - &ts, 1, sample_set_sizes, samples, 0, NULL, 0, result); + &ts, 1, sample_set_sizes, samples, 0, NULL, 0, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(result[0], 0); CU_ASSERT_EQUAL_FATAL(result[1], 3.0); CU_ASSERT_EQUAL_FATAL(result[2], 0); ret = tsk_treeseq_allele_frequency_spectrum( - &ts, 1, sample_set_sizes, samples, 0, NULL, TSK_STAT_POLARISED, result); + &ts, 1, sample_set_sizes, samples, 0, NULL, 0, NULL, TSK_STAT_POLARISED, result); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(result[0], 0); CU_ASSERT_EQUAL_FATAL(result[1], 2.0); diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index dcf4d2a69d..e210365ae7 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -8182,13 +8182,13 @@ test_time_uncalibrated(void) CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_treeseq_allele_frequency_spectrum( - &ts2, 2, sample_set_sizes, samples, 0, NULL, TSK_STAT_SITE, result); + &ts2, 2, sample_set_sizes, samples, 0, NULL, 0, NULL, TSK_STAT_SITE, result); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_treeseq_allele_frequency_spectrum( - &ts2, 2, sample_set_sizes, samples, 0, NULL, TSK_STAT_BRANCH, result); + &ts2, 2, sample_set_sizes, samples, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TIME_UNCALIBRATED); ret = tsk_treeseq_allele_frequency_spectrum(&ts2, 2, sample_set_sizes, samples, 0, - NULL, TSK_STAT_BRANCH | TSK_STAT_ALLOW_TIME_UNCALIBRATED, result); + NULL, 0, NULL, TSK_STAT_BRANCH | TSK_STAT_ALLOW_TIME_UNCALIBRATED, result); CU_ASSERT_EQUAL_FATAL(ret, 0); sigma = tsk_calloc(tsk_treeseq_get_num_nodes(&ts2), sizeof(double)); diff --git a/c/tskit/core.c b/c/tskit/core.c index 4be44f6242..684c0a6c00 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2024 Tskit Developers + * Copyright (c) 2019-2025 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -512,7 +512,7 @@ tsk_strerror_internal(int err) "(TSK_ERR_BAD_SAMPLE_PAIR_TIMES)"; break; case TSK_ERR_BAD_TIME_WINDOWS: - ret = "Time windows must be strictly increasing and end at infinity. " + ret = "Time windows must be strictly increasing. " "(TSK_ERR_BAD_TIME_WINDOWS)"; break; case TSK_ERR_BAD_NODE_TIME_WINDOW: diff --git a/c/tskit/trees.c b/c/tskit/trees.c index fd15aa3ab7..5c7ced76d1 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2024 Tskit Developers + * Copyright (c) 2019-2025 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -23,6 +23,7 @@ * SOFTWARE. */ +#include "tskit/core.h" #include #include #include @@ -1232,6 +1233,35 @@ tsk_treeseq_check_windows(const tsk_treeseq_t *self, tsk_size_t num_windows, return ret; } +static int +tsk_treeseq_check_time_windows(tsk_size_t num_windows, const double *windows) +{ + int ret = TSK_ERR_BAD_TIME_WINDOWS; + tsk_size_t j; + + if (num_windows < 1) { + ret = TSK_ERR_BAD_TIME_WINDOWS_DIM; + goto out; + } + + if (windows[0] < 0.0) { + goto out; + } + + if (windows[0] != 0.0) { + goto out; + } + + for (j = 0; j < num_windows; j++) { + if (windows[j] >= windows[j + 1]) { + goto out; + } + } + ret = 0; +out: + return ret; +} + /* TODO make these functions more consistent in how the arguments are ordered */ static inline void @@ -3486,35 +3516,51 @@ tsk_treeseq_site_allele_frequency_spectrum(const tsk_treeseq_t *self, static int TSK_WARN_UNUSED tsk_treeseq_update_branch_afs(const tsk_treeseq_t *self, tsk_id_t u, double right, - const double *restrict branch_length, double *restrict last_update, - const double *counts, tsk_size_t num_sample_sets, tsk_size_t window_index, - const tsk_size_t *result_dims, tsk_flags_t options, double *result) + double *restrict last_update, const double *restrict time, tsk_id_t *restrict parent, + const double *time_windows, const double *counts, tsk_size_t num_sample_sets, + tsk_size_t num_time_windows, tsk_size_t window_index, const tsk_size_t *result_dims, + tsk_flags_t options, double *result) { int ret = 0; tsk_size_t afs_size; tsk_size_t k; + tsk_size_t time_window_index; double *afs; tsk_size_t *coordinate = tsk_malloc(num_sample_sets * sizeof(*coordinate)); bool polarised = !!(options & TSK_STAT_POLARISED); const double *count_row = GET_2D_ROW(counts, num_sample_sets + 1, u); - double x = (right - last_update[u]) * branch_length[u]; + double x = 0; + double t_v = 0; + double tw_branch_length = 0; const tsk_size_t all_samples = (tsk_size_t) count_row[num_sample_sets]; - if (coordinate == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } - - if (0 < all_samples && all_samples < self->num_samples) { - afs_size = result_dims[num_sample_sets]; - afs = result + afs_size * window_index; - for (k = 0; k < num_sample_sets; k++) { - coordinate[k] = (tsk_size_t) count_row[k]; - } - if (!polarised) { - fold(coordinate, result_dims, num_sample_sets); + if (parent[u] != TSK_NULL) { + t_v = time[parent[u]]; + if (0 < all_samples && all_samples < self->num_samples) { + time_window_index = 0; + while (time_window_index < num_time_windows + && time_windows[time_window_index] < t_v) { + afs_size = result_dims[num_sample_sets]; + afs = result + + afs_size * (window_index * num_time_windows + time_window_index); + for (k = 0; k < num_sample_sets; k++) { + coordinate[k] = (tsk_size_t) count_row[k]; + } + if (!polarised) { + fold(coordinate, result_dims, num_sample_sets); + } + tw_branch_length + = fabs(TSK_MIN(time_windows[time_window_index + 1], t_v) + - TSK_MAX(time_windows[time_window_index], time[u])); + x = (right - last_update[u]) * tw_branch_length; + increment_nd_array_value( + afs, num_sample_sets, result_dims, coordinate, x); + time_window_index++; + } } - increment_nd_array_value(afs, num_sample_sets, result_dims, coordinate, x); } last_update[u] = right; out: @@ -3525,8 +3571,8 @@ tsk_treeseq_update_branch_afs(const tsk_treeseq_t *self, tsk_id_t u, double righ static int tsk_treeseq_branch_allele_frequency_spectrum(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, double *counts, tsk_size_t num_windows, - const double *windows, const tsk_size_t *result_dims, tsk_flags_t options, - double *result) + tsk_size_t num_time_windows, const double *windows, const double *time_windows, + const tsk_size_t *result_dims, tsk_flags_t options, double *result) { int ret = 0; tsk_id_t u, v; @@ -3571,16 +3617,16 @@ tsk_treeseq_branch_allele_frequency_spectrum(const tsk_treeseq_t *self, tk++; u = edge_child[h]; v = edge_parent[h]; - ret = tsk_treeseq_update_branch_afs(self, u, t_left, branch_length, - last_update, counts, num_sample_sets, window_index, result_dims, options, - result); + ret = tsk_treeseq_update_branch_afs(self, u, t_left, last_update, node_time, + parent, time_windows, counts, num_sample_sets, num_time_windows, + window_index, result_dims, options, result); if (ret != 0) { goto out; } while (v != TSK_NULL) { - ret = tsk_treeseq_update_branch_afs(self, v, t_left, branch_length, - last_update, counts, num_sample_sets, window_index, result_dims, - options, result); + ret = tsk_treeseq_update_branch_afs(self, v, t_left, last_update, + node_time, parent, time_windows, counts, num_sample_sets, + num_time_windows, window_index, result_dims, options, result); if (ret != 0) { goto out; } @@ -3599,9 +3645,9 @@ tsk_treeseq_branch_allele_frequency_spectrum(const tsk_treeseq_t *self, parent[u] = v; branch_length[u] = node_time[v] - node_time[u]; while (v != TSK_NULL) { - ret = tsk_treeseq_update_branch_afs(self, v, t_left, branch_length, - last_update, counts, num_sample_sets, window_index, result_dims, - options, result); + ret = tsk_treeseq_update_branch_afs(self, v, t_left, last_update, + node_time, parent, time_windows, counts, num_sample_sets, + num_time_windows, window_index, result_dims, options, result); if (ret != 0) { goto out; } @@ -3623,9 +3669,9 @@ tsk_treeseq_branch_allele_frequency_spectrum(const tsk_treeseq_t *self, /* Flush the contributions of all nodes to the current window */ for (u = 0; u < (tsk_id_t) num_nodes; u++) { tsk_bug_assert(last_update[u] < w_right); - ret = tsk_treeseq_update_branch_afs(self, u, w_right, branch_length, - last_update, counts, num_sample_sets, window_index, result_dims, - options, result); + ret = tsk_treeseq_update_branch_afs(self, u, w_right, last_update, + node_time, parent, time_windows, counts, num_sample_sets, + num_time_windows, window_index, result_dims, options, result); if (ret != 0) { goto out; } @@ -3653,13 +3699,15 @@ int tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_windows, const double *windows, - tsk_flags_t options, double *result) + tsk_size_t num_time_windows, const double *time_windows, tsk_flags_t options, + double *result) { int ret = 0; bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); bool stat_node = !!(options & TSK_STAT_NODE); const double default_windows[] = { 0, self->tables->sequence_length }; + const double default_time_windows[] = { 0, INFINITY }; const tsk_size_t num_nodes = self->tables->nodes.num_rows; const tsk_size_t K = num_sample_sets + 1; tsk_size_t j, k, l, afs_size; @@ -3669,7 +3717,6 @@ tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, * reuse code from the general_stats code paths. */ double *counts = NULL; double *count_row; - if (stat_node) { ret = TSK_ERR_UNSUPPORTED_STAT_MODE; goto out; @@ -3693,6 +3740,21 @@ tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, goto out; } } + if (time_windows == NULL) { + num_time_windows = 1; + time_windows = default_time_windows; + } else { + ret = tsk_treeseq_check_time_windows(num_time_windows, time_windows); + if (ret != 0) { + goto out; + } + if (stat_site + && tsk_memcmp(time_windows, default_time_windows, 2 * sizeof(const double)) + != 0) { + ret = TSK_ERR_UNSUPPORTED_STAT_MODE; + goto out; + } + } ret = tsk_treeseq_check_sample_sets( self, num_sample_sets, sample_set_sizes, sample_sets); if (ret != 0) { @@ -3728,15 +3790,16 @@ tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, count_row[num_sample_sets] = 1; } result_dims[num_sample_sets] = (tsk_size_t) afs_size; + tsk_memset(result, 0, num_windows * num_time_windows * afs_size * sizeof(*result)); - tsk_memset(result, 0, num_windows * afs_size * sizeof(*result)); if (stat_site) { ret = tsk_treeseq_site_allele_frequency_spectrum(self, num_sample_sets, sample_set_sizes, counts, num_windows, windows, result_dims, options, result); } else { ret = tsk_treeseq_branch_allele_frequency_spectrum(self, num_sample_sets, counts, - num_windows, windows, result_dims, options, result); + num_windows, num_time_windows, windows, time_windows, result_dims, options, + result); } if (options & TSK_STAT_SPAN_NORMALISE) { diff --git a/c/tskit/trees.h b/c/tskit/trees.h index bef944fff3..68c97e386b 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1056,7 +1056,8 @@ int tsk_treeseq_Y1(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, int tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_windows, const double *windows, - tsk_flags_t options, double *result); + tsk_size_t num_time_windows, const double *time_windows, tsk_flags_t options, + double *result); typedef int general_sample_stat_method(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 8663bd8695..aa7a3bebd1 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2024 Tskit Developers + * Copyright (c) 2019-2025 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -9439,29 +9439,30 @@ TreeSequence_allele_frequency_spectrum( TreeSequence *self, PyObject *args, PyObject *kwds) { PyObject *ret = NULL; - static char *kwlist[] = { "sample_set_sizes", "sample_sets", "windows", "mode", - "span_normalise", "polarised", NULL }; + static char *kwlist[] = { "sample_set_sizes", "sample_sets", "windows", + "time_windows", "mode", "span_normalise", "polarised", NULL }; PyObject *sample_set_sizes = NULL; PyObject *sample_sets = NULL; PyObject *windows = NULL; + PyObject *time_windows = NULL; char *mode = NULL; PyArrayObject *sample_set_sizes_array = NULL; PyArrayObject *sample_sets_array = NULL; PyArrayObject *windows_array = NULL; + PyArrayObject *time_windows_array = NULL; PyArrayObject *result_array = NULL; tsk_size_t *sizes; npy_intp *shape = NULL; - tsk_size_t k, num_windows, num_sample_sets; + tsk_size_t k, num_windows, num_time_windows, num_sample_sets; tsk_flags_t options = 0; int polarised = 0; int span_normalise = 1; int err; - if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|sii", kwlist, &sample_set_sizes, - &sample_sets, &windows, &mode, &span_normalise, &polarised)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOO|sii", kwlist, &sample_set_sizes, + &sample_sets, &windows, &time_windows, &mode, &span_normalise, &polarised)) { goto out; } if (parse_stats_mode(mode, &options) != 0) { @@ -9481,24 +9482,28 @@ TreeSequence_allele_frequency_spectrum( if (parse_windows(windows, &windows_array, &num_windows) != 0) { goto out; } - - shape = PyMem_Malloc((num_sample_sets + 1) * sizeof(*shape)); + if (parse_windows(time_windows, &time_windows_array, &num_time_windows) != 0) { + goto out; + } + shape = PyMem_Malloc((num_sample_sets + 1 + 1) * sizeof(*shape)); if (shape == NULL) { goto out; } sizes = PyArray_DATA(sample_set_sizes_array); shape[0] = num_windows; + shape[1] = num_time_windows; for (k = 0; k < num_sample_sets; k++) { - shape[k + 1] = 1 + sizes[k]; + shape[k + 1 + 1] = 1 + sizes[k]; } - result_array - = (PyArrayObject *) PyArray_SimpleNew(1 + num_sample_sets, shape, NPY_FLOAT64); + result_array = (PyArrayObject *) PyArray_SimpleNew( + 1 + 1 + num_sample_sets, shape, NPY_FLOAT64); if (result_array == NULL) { goto out; } err = tsk_treeseq_allele_frequency_spectrum(self->tree_sequence, num_sample_sets, PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), - num_windows, PyArray_DATA(windows_array), options, PyArray_DATA(result_array)); + num_windows, PyArray_DATA(windows_array), num_time_windows, + PyArray_DATA(time_windows_array), options, PyArray_DATA(result_array)); if (err != 0) { handle_library_error(err); goto out; @@ -9510,6 +9515,7 @@ TreeSequence_allele_frequency_spectrum( Py_XDECREF(sample_set_sizes_array); Py_XDECREF(sample_sets_array); Py_XDECREF(windows_array); + Py_XDECREF(time_windows_array); Py_XDECREF(result_array); return ret; } diff --git a/python/tests/ibd.py b/python/tests/ibd.py index 6db35201b7..bfeaf743b6 100644 --- a/python/tests/ibd.py +++ b/python/tests/ibd.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2024 Tskit Developers +# Copyright (c) 2020-2025 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/python/tests/simplify.py b/python/tests/simplify.py index 9e416c3e78..fe0771abb3 100644 --- a/python/tests/simplify.py +++ b/python/tests/simplify.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2024 Tskit Developers +# Copyright (c) 2019-2025 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/python/tests/test_avl_tree.py b/python/tests/test_avl_tree.py index 2afaa0615e..1b2ad95131 100644 --- a/python/tests/test_avl_tree.py +++ b/python/tests/test_avl_tree.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2021-2024 Tskit Developers +# Copyright (c) 2021-2025 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/python/tests/test_balance_metrics.py b/python/tests/test_balance_metrics.py index eed20477b2..cebe498222 100644 --- a/python/tests/test_balance_metrics.py +++ b/python/tests/test_balance_metrics.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2024 Tskit Developers +# Copyright (c) 2025 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/python/tests/test_cli.py b/python/tests/test_cli.py index 67b3890c2e..01cb980e1e 100644 --- a/python/tests/test_cli.py +++ b/python/tests/test_cli.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2024 Tskit Developers +# Copyright (c) 2018-2025 Tskit Developers # Copyright (c) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/python/tests/test_coalrate.py b/python/tests/test_coalrate.py index f21bcea627..e1b0ce4b87 100644 --- a/python/tests/test_coalrate.py +++ b/python/tests/test_coalrate.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2024 Tskit Developers +# Copyright (c) 2025 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/python/tests/test_distance_metrics.py b/python/tests/test_distance_metrics.py index c08351962f..006f648a2c 100644 --- a/python/tests/test_distance_metrics.py +++ b/python/tests/test_distance_metrics.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2024 Tskit Developers +# Copyright (c) 2025 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py index ac66f43f1e..a52f8b1ab9 100644 --- a/python/tests/test_divmat.py +++ b/python/tests/test_divmat.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2023-2024 Tskit Developers +# Copyright (c) 2023-2025 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py index 8da436dab3..19415af268 100644 --- a/python/tests/test_drawing.py +++ b/python/tests/test_drawing.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2024 Tskit Developers +# Copyright (c) 2018-2025 Tskit Developers # Copyright (C) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index 4c7d5fc18a..6b02381edf 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2024 Tskit Developers +# Copyright (c) 2019-2025 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 5929daa6ae..fd2f11d263 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2024 Tskit Developers +# Copyright (c) 2018-2025 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/python/tests/test_intervals.py b/python/tests/test_intervals.py index e23f221e62..cb4e484774 100644 --- a/python/tests/test_intervals.py +++ b/python/tests/test_intervals.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2023-2024 Tskit Developers +# Copyright (c) 2023-2025 Tskit Developers # Copyright (C) 2020-2022 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 3926767f8b..13437d1252 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2023-2024 Tskit Developers +# Copyright (c) 2023-2025 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index a26f9242c4..4a99d994b4 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2024 Tskit Developers +# Copyright (c) 2018-2025 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -2239,33 +2239,57 @@ def test_basic_example(self): ts = self.get_example_tree_sequence() n = ts.get_num_samples() result = ts.allele_frequency_spectrum( - [n], ts.get_samples(), [0, ts.get_sequence_length()] + [n], + ts.get_samples(), + [0, ts.get_sequence_length()], + mode="branch", + time_windows=[0, np.inf], ) - assert result.shape == (1, n + 1) + assert result.shape == (1, 1, n + 1) result = ts.allele_frequency_spectrum( - [n], ts.get_samples(), [0, ts.get_sequence_length()], polarised=True + [n], + ts.get_samples(), + [0, ts.get_sequence_length()], + mode="branch", + time_windows=[0, np.inf], + polarised=True, ) - assert result.shape == (1, n + 1) + assert result.shape == (1, 1, n + 1) def test_output_dims(self): ts = self.get_example_tree_sequence() samples = ts.get_samples() L = ts.get_sequence_length() n = len(samples) + time_windows = [0, np.inf] - for mode in ["site", "branch"]: + for mode in ["branch"]: for s in [[n], [n - 2, 2], [n - 4, 2, 2], [1] * n]: s = np.array(s, dtype=np.uint32) windows = [0, L] for windows in [[0, L], [0, L / 2, L], np.linspace(0, L, num=10)]: jafs = ts.allele_frequency_spectrum( - s, samples, windows, mode=mode, polarised=True + s, + samples, + windows, + mode=mode, + time_windows=time_windows, + polarised=True, + ) + assert jafs.shape == tuple( + [len(windows) - 1] + [len(time_windows) - 1] + list(s + 1) ) - assert jafs.shape == tuple([len(windows) - 1] + list(s + 1)) jafs = ts.allele_frequency_spectrum( - s, samples, windows, mode=mode, polarised=False + s, + samples, + windows, + mode=mode, + time_windows=time_windows, + polarised=False, + ) + assert jafs.shape == tuple( + [len(windows) - 1] + [len(time_windows) - 1] + list(s + 1) ) - assert jafs.shape == tuple([len(windows) - 1] + list(s + 1)) def test_node_mode_not_supported(self): ts = self.get_example_tree_sequence() @@ -2275,8 +2299,162 @@ def test_node_mode_not_supported(self): ts.get_samples(), [0, ts.get_sequence_length()], mode="node", + time_windows=[0, np.inf], ) + def test_polarised(self): + """ + Temporary duplicate from class OneWaySampleStatsMixin + used to provide the time_windows argument. + """ + # TODO move this to the top level. + ts, method = self.get_method() + samples = ts.get_samples() + n = len(samples) + windows = [0, ts.get_sequence_length()] + method( + [n], + samples, + windows, + time_windows=[0, np.inf], + mode="branch", + polarised=True, + ) + method( + [n], + samples, + windows, + time_windows=[0, np.inf], + mode="branch", + polarised=False, + ) + + def test_polarisation(self): + ts, f, params = self.get_example() + with pytest.raises(TypeError): + f(polarised="sdf", time_windows=[0, np.inf], mode="branch", **params) + x1 = f(polarised=False, time_windows=[0, np.inf], mode="branch", **params) + x2 = f(polarised=True, time_windows=[0, np.inf], mode="branch", **params) + # Basic check just to run both code paths + assert x1.shape == x2.shape + + def test_mode_errors(self): + _, f, params = self.get_example() + for bad_mode in ["", "not a mode", "SITE", "x" * 8192]: + with pytest.raises(ValueError): + f(mode=bad_mode, time_windows=[0, np.inf], **params) + + for bad_type in [123, {}, None, [[]]]: + with pytest.raises(TypeError): + f(mode=bad_type, time_windows=[0, np.inf], **params) + + def test_window_errors(self): + ts, f, params = self.get_example() + del params["windows"] + for bad_array in ["asdf", None, [[[[]], [[]]]], np.zeros((10, 3, 4))]: + with pytest.raises(ValueError): + f(windows=bad_array, time_windows=[0, np.inf], mode="branch", **params) + + for bad_windows in [[], [0]]: + with pytest.raises(ValueError): + f( + windows=bad_windows, + time_windows=[0, np.inf], + mode="branch", + **params, + ) + L = ts.get_sequence_length() + bad_windows = [ + [L, 0], + [0.1, L], + [-1, L], + [0, L + 0.1], + [0, 0.1, 0.1, L], + [0, -1, L], + [0, 0.1, 0.05, 0.2, L], + ] + for bad_window in bad_windows: + with pytest.raises(_tskit.LibraryError): + f(windows=bad_window, time_windows=[0, np.inf], mode="branch", **params) + + def test_time_window_errors(self): + ts, f, params = self.get_example() + + for bad_time_windows in [[], [0]]: + with pytest.raises(ValueError, match="must have at least 2"): + f( + time_windows=bad_time_windows, + mode="branch", + **params, + ) + bad_time_windows = [ + [-1, np.inf], + [0, 0, np.inf], + [0, 10, 5, np.inf], + [0, np.inf, np.inf], + ] + for bad_time_window in bad_time_windows: + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_BAD_TIME_WINDOWS"): + f(time_windows=bad_time_window, mode="branch", **params) + + def test_windows_output(self): + ts, f, params = self.get_example() + del params["windows"] + for num_windows in range(1, 10): + windows = np.linspace(0, ts.get_sequence_length(), num=num_windows + 1) + assert windows.shape[0] == num_windows + 1 + sigma = f( + windows=windows, time_windows=[0, np.inf], mode="branch", **params + ) + assert sigma.shape[0] == num_windows + + def test_bad_sample_sets(self): + ts, f, params = self.get_example() + del params["sample_set_sizes"] + del params["sample_sets"] + + with pytest.raises(_tskit.LibraryError): + f( + sample_sets=[], + sample_set_sizes=[], + time_windows=[0, np.inf], + mode="branch", + **params, + ) + + n = ts.get_num_samples() + samples = ts.get_samples() + for bad_set_sizes in [[], [1], [n - 1], [n + 1], [n - 3, 1, 1], [1, n - 2]]: + with pytest.raises(ValueError): + f( + sample_set_sizes=bad_set_sizes, + sample_sets=samples, + time_windows=[0, np.inf], + mode="branch", + **params, + ) + + N = ts.get_num_nodes() + for bad_node in [-1, N, N + 1, -N]: + with pytest.raises(_tskit.LibraryError): + f( + sample_set_sizes=[2], + sample_sets=[0, bad_node], + time_windows=[0, np.inf], + mode="branch", + **params, + ) + + for bad_sample in [n, n + 1, N - 1]: + with pytest.raises(_tskit.LibraryError): + f( + sample_set_sizes=[2], + sample_sets=[0, bad_sample], + time_windows=[0, np.inf], + mode="branch", + **params, + ) + class TwoWaySampleStatsMixin(SampleSetMixin): """ diff --git a/python/tests/test_metadata.py b/python/tests/test_metadata.py index 7387f73fad..86bd6d3ca3 100644 --- a/python/tests/test_metadata.py +++ b/python/tests/test_metadata.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2024 Tskit Developers +# Copyright (c) 2018-2025 Tskit Developers # Copyright (c) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/python/tests/test_provenance.py b/python/tests/test_provenance.py index 7b61308f8a..97fb193866 100644 --- a/python/tests/test_provenance.py +++ b/python/tests/test_provenance.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2024 Tskit Developers +# Copyright (c) 2018-2025 Tskit Developers # Copyright (C) 2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/python/tests/test_relatedness_vector.py b/python/tests/test_relatedness_vector.py index f765c75c9f..6241138261 100644 --- a/python/tests/test_relatedness_vector.py +++ b/python/tests/test_relatedness_vector.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2024 Tskit Developers +# Copyright (c) 2025 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/python/tests/test_table_transforms.py b/python/tests/test_table_transforms.py index 35285fc4dc..acb104c6d5 100644 --- a/python/tests/test_table_transforms.py +++ b/python/tests/test_table_transforms.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2022-2024 Tskit Developers +# Copyright (c) 2022-2025 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 3b1bd50bad..f25b446540 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2024 Tskit Developers +# Copyright (c) 2018-2025 Tskit Developers # Copyright (c) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/python/tests/test_text_formats.py b/python/tests/test_text_formats.py index 92874af6fd..4287266dc8 100644 --- a/python/tests/test_text_formats.py +++ b/python/tests/test_text_formats.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2021-2024 Tskit Developers +# Copyright (c) 2021-2025 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 47be14e4bd..6806103374 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2024 Tskit Developers +# Copyright (c) 2018-2025 Tskit Developers # Copyright (c) 2016-2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index de9cd1d2b8..5bf8f988ea 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2024 Tskit Developers +# Copyright (c) 2018-2025 Tskit Developers # Copyright (C) 2016 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -40,6 +40,7 @@ import tskit import tskit.exceptions as exceptions + np.random.seed(5) @@ -145,40 +146,59 @@ def windowed_tree_stat(ts, stat, windows, span_normalise=True): def naive_branch_general_stat( - ts, w, f, windows=None, polarised=False, span_normalise=True + ts, w, f, windows=None, time_windows=None, polarised=False, span_normalise=True ): # NOTE: does not behave correctly for unpolarised stats # with non-ancestral material. if windows is None: windows = [0.0, ts.sequence_length] + drop_time_windows = time_windows is None + if time_windows is None: + time_windows = [0.0, np.inf] + else: + if time_windows[0] != 0: + time_windows = [0] + time_windows n, k = w.shape + tw = len(time_windows) - 1 # hack to determine m m = len(f(w[0])) total = np.sum(w, axis=0) - sigma = np.zeros((ts.num_trees, m)) - for tree in ts.trees(): - x = np.zeros((ts.num_nodes, k)) - x[ts.samples()] = w - for u in tree.nodes(order="postorder"): - for v in tree.children(u): - x[u] += x[v] - if polarised: - s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes()) + sigma = np.zeros((ts.num_trees, tw, m)) + for j, upper_time in enumerate(time_windows[1:]): + if np.isfinite(upper_time): + decap_ts = ts.decapitate(upper_time) else: - s = sum( - tree.branch_length(u) * (f(x[u]) + f(total - x[u])) - for u in tree.nodes() - ) - sigma[tree.index] = s * tree.span + decap_ts = ts + assert np.all(list(ts.samples()) == list(decap_ts.samples())) + for tree in decap_ts.trees(): + x = np.zeros((decap_ts.num_nodes, k)) + x[decap_ts.samples()] = w + for u in tree.nodes(order="postorder"): + for v in tree.children(u): + x[u] += x[v] + if polarised: + s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes()) + else: + s = sum( + tree.branch_length(u) * (f(x[u]) + f(total - x[u])) + for u in tree.nodes() + ) + sigma[tree.index, j, :] = s * tree.span + for j in range(1, tw): + sigma[:, j, :] = sigma[:, j, :] - sigma[:, j - 1, :] if isinstance(windows, str) and windows == "trees": # need to average across the windows if span_normalise: for j, tree in enumerate(ts.trees()): sigma[j] /= tree.span - return sigma + out = sigma else: - return windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise) + out = windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise) + if drop_time_windows: + assert out.ndim == 3 + out = out[:, 0] + return out def branch_general_stat( @@ -261,7 +281,6 @@ def polarised_summary(u): # for the next tree break - # print("window_index:", window_index, windows.shape) assert window_index == windows.shape[0] - 1 if span_normalise: for j in range(num_windows): @@ -569,6 +588,59 @@ def cumsum_f(self, ts): def sum_f(self, ts, k=1): return lambda x: np.array([sum(x) * (sum(x) < 2 * ts.num_samples)] * k) + def four_taxa_test_case(self): + # + # 1.0 7 + # 0.7 / \ 6 + # / \ / \ + # 0.5 / 5 5 / 5 + # / / \ / \__ / / \ + # 0.4 / 8 \ 8 4 / 8 \ + # / / \ \ / \ / \ / / \ \ + # 0.0 0 1 3 2 1 3 0 2 0 1 3 2 + # (0.0, 0.2), (0.2, 0.8), (0.8, 2.5) + + nodes = io.StringIO( + """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 1 0 + 4 0 0.4 + 5 0 0.5 + 6 0 0.7 + 7 0 1.0 + 8 0 0.4 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0.0 2.5 8 1,3 + 0.2 0.8 4 0,2 + 0.0 0.2 5 8,2 + 0.2 0.8 5 8,4 + 0.8 2.5 5 8,2 + 0.8 2.5 6 0,5 + 0.0 0.2 7 0,5 + """ + ) + sites = io.StringIO( + """\ + id position ancestral_state + """ + ) + mutations = io.StringIO( + """\ + site node derived_state parent + """ + ) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, strict=False + ) + return ts + class TopologyExamplesMixin: """ @@ -3732,7 +3804,12 @@ def test_site_folded(self): def naive_site_allele_frequency_spectrum( - ts, sample_sets, windows=None, polarised=False, span_normalise=True + ts, + sample_sets, + windows=None, + time_windows=None, + polarised=False, + span_normalise=True, ): """ The joint allele frequency spectrum for sites. @@ -3787,47 +3864,84 @@ def naive_site_allele_frequency_spectrum( def naive_branch_allele_frequency_spectrum( - ts, sample_sets, windows=None, polarised=False, span_normalise=True + ts, + sample_sets, + windows=None, + time_windows=None, + polarised=False, + span_normalise=True, ): """ The joint allele frequency spectrum for branches. """ + drop_windows = windows is None + if windows is None: + windows = [0.0, ts.sequence_length] + else: + if windows[0] != 0: + windows = [0] + windows + drop_time_windows = time_windows is None + if time_windows is None: + time_windows = [0.0, np.inf] + else: + if time_windows[0] != 0: + time_windows = [0] + time_windows windows = ts.parse_windows(windows) num_windows = len(windows) - 1 + num_time_windows = len(time_windows) - 1 out_dim = [1 + len(sample_set) for sample_set in sample_sets] - out = np.zeros([num_windows] + out_dim) + out = np.zeros([num_windows] + [num_time_windows] + out_dim) for j in range(num_windows): begin = windows[j] end = windows[j + 1] - S = np.zeros(out_dim) - trees = [ - next(ts.trees(tracked_samples=sample_set)) for sample_set in sample_sets - ] - t = trees[0] - while True: - tr_len = min(end, t.interval.right) - max(begin, t.interval.left) - if tr_len > 0: - for node in t.nodes(): - if 0 < t.num_samples(node) < ts.num_samples: - x = [tree.num_tracked_samples(node) for tree in trees] - # Note x must be a tuple for indexing to work - if not polarised: - x = fold(x, out_dim) - S[tuple(x)] += t.branch_length(node) * tr_len - - # Advance the trees - more = [tree.next() for tree in trees] - assert len(set(more)) == 1 - if not more[0]: - break - if span_normalise: - S /= end - begin - out[j, :] = S + for k, upper_time in enumerate(time_windows[1:]): + S = np.zeros(out_dim) + if np.isfinite(upper_time): + decap_ts = ts.decapitate(upper_time) + else: + decap_ts = ts + assert np.all(list(ts.samples()) == list(decap_ts.samples())) + trees = [ + next(decap_ts.trees(tracked_samples=sample_set)) + for sample_set in sample_sets + ] + t = trees[0] + while True: + tr_len = min(end, t.interval.right) - max(begin, t.interval.left) + if tr_len > 0: + for node in t.nodes(): + if 0 < t.num_samples(node) < decap_ts.num_samples: + x = [tree.num_tracked_samples(node) for tree in trees] + if not polarised: + x = fold(x, out_dim) + # Note x must be a tuple for indexing to work + S[tuple(x)] += t.branch_length(node) * tr_len + + # Advance the trees + more = [tree.next() for tree in trees] + assert len(set(more)) == 1 + if not more[0]: + break + if span_normalise: + S /= end - begin + out[j, k, :] = S - sum(out[j, 0:k, :]) + if drop_time_windows: + assert out.ndim == 2 + len(out_dim) + out = out[:, 0] + elif drop_windows: + assert out.shape[0] == 1 + out = out[0] return out def naive_allele_frequency_spectrum( - ts, sample_sets, windows=None, polarised=False, mode="site", span_normalise=True + ts, + sample_sets, + windows=None, + time_windows=None, + polarised=False, + mode="site", + span_normalise=True, ): """ Naive definition of the generalised site frequency spectrum. @@ -3840,25 +3954,31 @@ def naive_allele_frequency_spectrum( ts, sample_sets, windows=windows, + time_windows=time_windows, polarised=polarised, span_normalise=span_normalise, ) def branch_allele_frequency_spectrum( - ts, sample_sets, windows, polarised=False, span_normalise=True + ts, sample_sets, windows, time_windows=None, polarised=False, span_normalise=True ): """ Efficient implementation of the algorithm used as the basis for the underlying C version. """ num_sample_sets = len(sample_sets) + drop_windows = windows is None windows = ts.parse_windows(windows) + drop_time_windows = time_windows is None + if time_windows is None: + time_windows = [0.0, np.inf] num_windows = windows.shape[0] - 1 + num_time_windows = len(time_windows) - 1 out_dim = [1 + len(sample_set) for sample_set in sample_sets] time = ts.tables.nodes.time - result = np.zeros([num_windows] + out_dim) + result = np.zeros([num_windows] + [num_time_windows] + out_dim) # Number of nodes in sample_set j ancestral to each node u. count = np.zeros((ts.num_nodes, num_sample_sets + 1), dtype=np.uint32) for j in range(num_sample_sets): @@ -3869,17 +3989,30 @@ def branch_allele_frequency_spectrum( last_update = np.zeros(ts.num_nodes) window_index = 0 parent = np.zeros(ts.num_nodes, dtype=np.int32) - 1 - branch_length = np.zeros(ts.num_nodes) + # branch_length = np.zeros(ts.num_nodes) tree_index = 0 def update_result(window_index, u, right): - if 0 < count[u, -1] < ts.num_samples: - x = (right - last_update[u]) * branch_length[u] - c = count[u, :num_sample_sets] - if not polarised: - c = fold(c, out_dim) - index = tuple([window_index] + list(c)) - result[index] += x + if parent[u] != -1: + t_v = time[parent[u]] + if 0 < count[u, -1] < ts.num_samples: + time_window_index = 0 + while ( + time_window_index < num_time_windows + and time_windows[time_window_index] < t_v + ): + assert parent[u] != -1 + tw_branch_length = abs( + min(time_windows[time_window_index + 1], t_v) + - max(time_windows[time_window_index], time[u]) + ) + x = (right - last_update[u]) * tw_branch_length + c = count[u, :num_sample_sets] + if not polarised: + c = fold(c, out_dim) + index = tuple([window_index] + [time_window_index] + list(c)) + result[index] += x + time_window_index += 1 last_update[u] = right for (t_left, t_right), edges_out, edges_in in ts.edge_diffs(): @@ -3892,17 +4025,17 @@ def update_result(window_index, u, right): count[v] -= count[u] v = parent[v] parent[u] = -1 - branch_length[u] = 0 + # branch_length[u] = 0 for edge in edges_in: u = edge.child v = edge.parent - parent[u] = v - branch_length[u] = time[v] - time[u] + # branch_length[u] = time[v] - time[u] while v != -1: update_result(window_index, v, t_left) count[v] += count[u] v = parent[v] + parent[u] = edge.parent # Update the windows while window_index < num_windows and windows[window_index + 1] <= t_right: @@ -3927,6 +4060,14 @@ def update_result(window_index, u, right): if span_normalise: for j in range(num_windows): result[j] /= windows[j + 1] - windows[j] + + if drop_time_windows: + assert result.ndim == 2 + len(out_dim) + assert result.shape[1] == 1 + result = result[:, 0] + elif drop_windows: + assert result.shape[0] == 1 + result = result[0] return result @@ -6147,6 +6288,7 @@ def test_case_four_taxa(self): # (0.0, 0.2), (0.2, 0.8), (0.8, 2.5) # f4(0, 1, 2, 3): (0 -> 1)(2 -> 3) + ts = self.four_taxa_test_case() branch_true_f4_0123 = (0.1 * 0.2 + (0.1 + 0.1) * 0.6 + 0.1 * 1.7) / 2.5 windows = [0.0, 0.4, 2.5] branch_true_f4_0123_windowed = np.array( @@ -6179,46 +6321,6 @@ def test_case_four_taxa(self): ] ) - nodes = io.StringIO( - """\ - id is_sample time - 0 1 0 - 1 1 0 - 2 1 0 - 3 1 0 - 4 0 0.4 - 5 0 0.5 - 6 0 0.7 - 7 0 1.0 - 8 0 0.4 - """ - ) - edges = io.StringIO( - """\ - left right parent child - 0.0 2.5 8 1,3 - 0.2 0.8 4 0,2 - 0.0 0.2 5 8,2 - 0.2 0.8 5 8,4 - 0.8 2.5 5 8,2 - 0.8 2.5 6 0,5 - 0.0 0.2 7 0,5 - """ - ) - sites = io.StringIO( - """\ - id position ancestral_state - """ - ) - mutations = io.StringIO( - """\ - site node derived_state parent - """ - ) - ts = tskit.load_text( - nodes=nodes, edges=edges, sites=sites, mutations=mutations, strict=False - ) - mode = "branch" A = [[0], [1], [2], [3]] self.assertAlmostEqual(branch_true_f4_0123, f4(ts, A, mode=mode)[0][0]) @@ -6562,20 +6664,20 @@ def test_multi_way_indexes_not_scalar_stat(self): ) assert x.shape == (1,) - def test_afs_default_windows(self): - ts = self.get_example_ts() - n = ts.num_samples - A = ts.samples()[:4] - B = ts.samples()[6:] - for mode in ["site", "branch"]: - x = ts.allele_frequency_spectrum(mode=mode) - # x is a 1D numpy array with n + 1 values - assert x.shape == (n + 1,) - self.assertArrayEqual( - x, ts.allele_frequency_spectrum([ts.samples()], mode=mode) - ) - x = ts.allele_frequency_spectrum([A, B], mode=mode) - assert x.shape == (len(A) + 1, len(B) + 1) + # def test_afs_default_windows(self): + # ts = self.get_example_ts() + # n = ts.num_samples + # A = ts.samples()[:4] + # B = ts.samples()[6:] + # for mode in ["site", "branch"]: + # x = ts.allele_frequency_spectrum(mode=mode) + # # x is a 1D numpy array with n + 1 values + # assert x.shape == (n + 1,) + # self.assertArrayEqual( + # x, ts.allele_frequency_spectrum([ts.samples()], mode=mode) + # ) + # x = ts.allele_frequency_spectrum([A, B], mode=mode) + # assert x.shape == (len(A) + 1, len(B) + 1) def test_afs_windows(self): ts = self.get_example_ts() @@ -6893,3 +6995,314 @@ def f_too_long(_): output_dim=1, strict=False, ) + + +class TestTimeWindows(TestBranchAlleleFrequencySpectrum): + def test_four_taxa_test_case(self): + # 1.00┊ 7 ┊ ┊ ┊ + # ┊ ┏━┻━┓ ┊ ┊ ┊ + # 0.70┊ ┃ ┃ ┊ ┊ 6 ┊ + # ┊ ┃ ┃ ┊ ┊ ┏━┻━┓ ┊ + # 0.50┊ ┃ 5 ┊ 5 ┊ ┃ 5 ┊ + # ┊ ┃ ┏┻━┓ ┊ ┏━┻━┓ ┊ ┃ ┏┻━┓ ┊ + # 0.40┊ ┃ 8 ┃ ┊ 4 8 ┊ ┃ 8 ┃ ┊ + # ┊ ┃ ┏┻┓ ┃ ┊ ┏┻┓ ┏┻┓ ┊ ┃ ┏┻┓ ┃ ┊ + # 0.00┊ 0 1 3 2 ┊ 0 2 1 3 ┊ 0 1 3 2 ┊ + # 0.00 0.20 0.80 2.50 + ts = self.four_taxa_test_case() + true_x = np.array( + [ + [ + [ + 0.2 * (1 + 0.5 + 0.4) + + (0.8 - 0.2) * (1 + 0.8) + + (2.5 - 0.8) * (1.0 + 0.5 + 0.4) + ], + [0.2 * 1.0 + 0 + (2.5 - 0.8) * 0.4], + ] + ] + ) + + n = ts.num_samples + + def f(x): + return (x > 0) * (1 - x / n) + + W = np.ones((ts.num_samples, 1)) + x = naive_branch_general_stat( + ts, W, f, time_windows=[0, 0.5, 2.0], span_normalise=False + ) + self.assertArrayAlmostEqual(x, true_x) + + def test_bad_time_windows(self): + time_windows = [-1] + ts = self.four_taxa_test_case() + # make a badly formatted time_windows array + assert ( + branch_allele_frequency_spectrum( + ts, + sample_sets=[[0, 1, 2, 3]], + time_windows=time_windows, + windows=None, + polarised=True, + span_normalise=False, + ).size + == 0 + ) + + def test_drop_dimension(self): + ts = self.four_taxa_test_case() + sample_set = [0, 1, 2, 3] + for tw in [None, [0, 0.5, 1, np.inf]]: + x = ts.allele_frequency_spectrum( + sample_sets=[sample_set], + time_windows=tw, + mode="branch", + ) + y = ts.allele_frequency_spectrum( + sample_sets=[sample_set], + mode="branch", + ) + assert x.shape[-1] == y.shape[-1] + assert x.shape[-1] == len(sample_set) + 1 + assert np.all(x[0] == y[:0]) + + def test_afs_branch(self): + """Tests for the Allele Frequency Spectrum stat + using time windows under branch mode. + """ + # + # 1.0 7 + # 0.7 / \ 6 + # / \ / \ + # 0.5 / 5 5 / 5 + # / / \ / \__ / / \ + # 0.4 / 8 \ 8 4 / 8 \ + # / / \ \ / \ / \ / / \ \ + # 0.0 0 1 3 2 1 3 0 2 0 1 3 2 + # (0.0, 0.2), (0.2, 0.8), (0.8, 2.5) + + ts = self.four_taxa_test_case() + + self.mode = "branch" + + sfs1 = naive_allele_frequency_spectrum( + ts, + sample_sets=[[0, 1, 2, 3]], + mode=self.mode, + polarised=True, + span_normalise=False, + ) + + sfs1_w = naive_allele_frequency_spectrum( + ts, + sample_sets=[[0, 1, 2, 3]], + windows=[0, 0.2, 0.80, 2.5], + mode=self.mode, + polarised=True, + span_normalise=False, + ) + + sfs1_w_opti = branch_allele_frequency_spectrum( + ts, + sample_sets=[[0, 1, 2, 3]], + windows=[0, 0.2, 0.8, 2.5], + polarised=True, + span_normalise=False, + ) + + sfs1_w_tw = naive_allele_frequency_spectrum( + ts, + sample_sets=[[0, 1, 2, 3]], + windows=[0, 0.2, 0.8, 2.5], + time_windows=[0, 0.5, 0.8, 1], + mode=self.mode, + polarised=True, + span_normalise=False, + ) + + sfs1_w_tw_opti = branch_allele_frequency_spectrum( + ts, + sample_sets=[[0, 1, 2, 3]], + time_windows=[0, 0.5, 0.8, 1], + windows=[0, 0.2, 0.8, 2.5], + polarised=True, + span_normalise=False, + ) + + sfs1_tws = naive_allele_frequency_spectrum( + ts, + sample_sets=[[0, 1, 2, 3]], + time_windows=[0, 0.5, 0.8, 1], + mode=self.mode, + polarised=True, + span_normalise=False, + ) + + sfs1_tws_opti = branch_allele_frequency_spectrum( + ts, + sample_sets=[[0, 1, 2, 3]], + time_windows=[0, 0.5, 0.8, 1], + windows=None, + polarised=True, + span_normalise=False, + ) + + sfs1_tw_05 = naive_allele_frequency_spectrum( + ts, + sample_sets=[[0, 1, 2, 3]], + time_windows=[0, 0.5], + mode=self.mode, + polarised=True, + span_normalise=False, + ) + ts_decap_05 = ts.decapitate(0.5) + sfs1_decap_05 = naive_allele_frequency_spectrum( + ts_decap_05, + sample_sets=[[0, 1, 2, 3]], + mode=self.mode, + polarised=True, + span_normalise=False, + ) + + # non-naive version + sfs1_opti_decap_05 = branch_allele_frequency_spectrum( + ts_decap_05, + sample_sets=[[0, 1, 2, 3]], + windows=None, + polarised=True, + span_normalise=False, + ) + + sfs1_opti_tw_05 = branch_allele_frequency_spectrum( + ts, + sample_sets=[[0, 1, 2, 3]], + time_windows=[0, 0.5], + windows=None, + polarised=True, + span_normalise=False, + ) + + sfs = [ + [ + # bin 0 + 0, + # bin 1 + (0.4 + 0.4 + 0.5 + 1) * 0.2 + + (0.4 + 0.4 + 0.4 + 0.4) * 0.6 + + (0.7 + 0.4 + 0.4 + 0.5) * 1.7, + # bin 2 + (0.1) * 0.2 + (0.1 + 0.1) * 0.6 + (0.1) * 1.7, + # bin 3 + (0.5) * 0.2 + (0.2) * 1.7, + # bin 4 + 0, + ] + ] + + sfs_w_02 = [ + [ + # bin 0 + 0, + # bin 1 + (0.4 + 0.4 + 0.5 + 1) * 0.2, + # bin 2 + (0.1) * 0.2, + # bin 3 + (0.5) * 0.2, + # bin 4 + 0, + ] + ] + + sfs_w_08 = [ + [ + # bin 0 + 0, + # bin 1 + (0.4 + 0.4 + 0.4 + 0.4) * 0.6, + # bin 2 + (0.1 + 0.1) * 0.6, + # bin 3 + 0, + # bin 4 + 0, + ] + ] + + sfs_w_25 = [ + [ + # bin 0 + 0, + # bin 1 + (0.7 + 0.4 + 0.4 + 0.5) * 1.7, + # bin 2 + (0.1) * 1.7, + # bin 3 + (0.2) * 1.7, + # bin 4 + 0, + ] + ] + + sfs_05 = [ + [ + # bin 0 + 0, + # bin 1 + (0.5 + 0.4 + 0.4 + 0.5) * 0.2 + + (0.4 + 0.4 + 0.4 + 0.4) * 0.6 + + (0.4 + 0.4 + 0.5 + 0.5) * 1.7, + # bin 2 + (0.1) * 0.2 + (0.1 + 0.1) * 0.6 + (0.1) * 1.7, + # bin 3, no nodes are present anymore at time 0.5 + 0, + # bin 4 + 0, + ] + ] + + # try on simple example computed "by hand" + # if the modified function still works as expected + self.assertArrayAlmostEqual(sfs, sfs1) + # check if windows work as expected with naive version + # Window 0.0 to 0.2 + self.assertArrayAlmostEqual(sfs1_w[0], sfs_w_02[0]) + # Window 0.2 to 0.8 + self.assertArrayAlmostEqual(sfs1_w[1], sfs_w_08[0]) + # Window 0.8 to 2.5 + self.assertArrayAlmostEqual(sfs1_w[2], sfs_w_25[0]) + # try if time_windows before t=0.5 is + # like sfs on a previously decapted ts before t=0.5 + self.assertArrayAlmostEqual(sfs1_tw_05, sfs1_decap_05) + # try if sfs by hand before t=0.5 is + # equivalent to time_windows before t=0.5 + self.assertArrayAlmostEqual(sfs_05, sfs1_tw_05) + # try if the SFSs of a ts decapited before t=0.5 are + # not altered by naive and non-naive afs without + # any time_windows parameter + self.assertArrayAlmostEqual(sfs1_decap_05, sfs1_opti_decap_05) + # non-naive version + # Window 0.0 to 0.2 + self.assertArrayAlmostEqual(sfs1_w_opti[0], sfs_w_02[0]) + # Window 0.2 to 0.8 + self.assertArrayAlmostEqual(sfs1_w_opti[1], sfs_w_08[0]) + # Window 0.8 to 2.5 + self.assertArrayAlmostEqual(sfs1_w_opti[2], sfs_w_25[0]) + # try if sfs by hand before t=0.5 is + # equivalent to time_windows before t=0.5 + self.assertArrayAlmostEqual(sfs_05, sfs1_opti_tw_05) + # try if time_windows before t=0.5 is + # like sfs on a previously decapted ts before t=0.5 + self.assertArrayAlmostEqual(sfs1_opti_decap_05, sfs1_opti_tw_05) + # test if time windows obtained with naive version and opti + # are equal + self.assertArrayAlmostEqual(sfs1_tws, sfs1_tws_opti) + # test if time windows and windows obtained with naive version + # and opti are equal + self.assertArrayAlmostEqual(sfs1_w_tw, sfs1_w_tw_opti) + # dimmensions Tests + assert sfs1_tws.ndim == sfs1_w.ndim + # dimensions are dim1: windows ; dim2: time_windows ; + # dim3-or-more: num_sample_sets + assert sfs1_w_tw.ndim == 3 diff --git a/python/tests/test_util.py b/python/tests/test_util.py index a10b292db4..d03d314b8f 100644 --- a/python/tests/test_util.py +++ b/python/tests/test_util.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2024 Tskit Developers +# Copyright (c) 2018-2025 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/python/tests/test_vcf.py b/python/tests/test_vcf.py index a1ce300702..a378abecd2 100644 --- a/python/tests/test_vcf.py +++ b/python/tests/test_vcf.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2024 Tskit Developers +# Copyright (c) 2018-2025 Tskit Developers # Copyright (c) 2016 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/python/tests/test_version.py b/python/tests/test_version.py index 65a8ab5b13..f90d0455a5 100644 --- a/python/tests/test_version.py +++ b/python/tests/test_version.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2024 Tskit Developers +# Copyright (c) 2020-2025 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 091e1d0df5..eadc47bef6 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2024 Tskit Developers +# Copyright (c) 2018-2025 Tskit Developers # Copyright (C) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/python/tskit/trees.py b/python/tskit/trees.py index a4f72393b7..6b6cb9cb3a 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7658,14 +7658,35 @@ def parse_windows(self, windows): ) return np.array(windows) + def parse_time_windows(self, time_windows): + """Time windows intitialization""" + if time_windows is None: + time_windows = [0.0, math.inf] + return np.array(time_windows) + def __run_windowed_stat(self, windows, method, *args, **kwargs): - strip_dim = windows is None + strip_win = windows is None windows = self.parse_windows(windows) stat = method(*args, **kwargs, windows=windows) - if strip_dim: + if strip_win: stat = stat[0] return stat + # only for temporary tw version + def __run_windowed_stat_tw(self, windows, time_windows, method, *args, **kwargs): + strip_win = windows is None + strip_timewin = time_windows is None + windows = self.parse_windows(windows) + time_windows = self.parse_time_windows(time_windows) + stat = method(*args, **kwargs, windows=windows, time_windows=time_windows) + if strip_win: + stat = stat[0, :, :] + elif strip_timewin: + stat = stat[:, 0, :] + elif strip_win and strip_timewin: + stat = stat[0, 0, :] + return stat + def __one_way_sample_set_stat( self, ll_method, @@ -7710,10 +7731,63 @@ def __one_way_sample_set_stat( ) if drop_dimension: stat = stat.reshape(stat.shape[:-1]) + # TODO: Write test for this if stat.shape == () and windows is None: stat = stat[()] return stat + # only for temporary tw version + def __one_way_sample_set_stat_tw( + self, + ll_method, + sample_sets, + windows=None, + time_windows=None, + mode=None, + span_normalise=True, + polarised=False, + ): + if sample_sets is None: + sample_sets = self.samples() + + # First try to convert to a 1D numpy array. If it is, then we strip off + # the corresponding dimension from the output. + drop_dimension = False + try: + sample_sets = np.array(sample_sets, dtype=np.uint64) + except ValueError: + pass + else: + # If we've successfully converted sample_sets to a 1D numpy array + # of integers then drop the dimension + if len(sample_sets.shape) == 1: + sample_sets = [sample_sets] + drop_dimension = True + + sample_set_sizes = np.array( + [len(sample_set) for sample_set in sample_sets], dtype=np.uint32 + ) + if np.any(sample_set_sizes == 0): + raise ValueError("Sample sets must contain at least one element") + + flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) + stat = self.__run_windowed_stat_tw( + windows, + time_windows, + ll_method, + sample_set_sizes, + flattened, + mode=mode, + span_normalise=span_normalise, + polarised=polarised, + ) + if drop_dimension: + stat = stat.reshape(stat.shape[:-1]) + # TODO: Write test for this + if stat.shape == () and windows is None and time_windows is None: + stat = stat[()] + return stat + def parse_sites(self, sites): row_sites, col_sites = None, None if sites is not None: @@ -8922,6 +8996,7 @@ def allele_frequency_spectrum( self, sample_sets=None, windows=None, + time_windows=None, mode="site", span_normalise=True, polarised=False, @@ -9016,10 +9091,11 @@ def allele_frequency_spectrum( """ if sample_sets is None: sample_sets = [self.samples()] - return self.__one_way_sample_set_stat( + return self.__one_way_sample_set_stat_tw( self._ll_tree_sequence.allele_frequency_spectrum, sample_sets, windows=windows, + time_windows=time_windows, mode=mode, span_normalise=span_normalise, polarised=polarised,