Skip to content

Commit c914b8b

Browse files
benjefferyjeromekelleher
authored andcommitted
Add sample_nodes_by_ploidy function
1 parent f255bf5 commit c914b8b

File tree

3 files changed

+113
-0
lines changed

3 files changed

+113
-0
lines changed

python/CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
**Features**
66

7+
- Add ``TreeSequence.sample_nodes_by_ploidy`` method to return the sample nodes
8+
in a tree sequence, grouped by a ploidy value.
9+
(:user:`benjeffery`, :pr:`3157`)
10+
711
- Add ``TreeSequence.individuals_nodes`` attribute to return the nodes
812
associated with each individual as a numpy array.
913
(:user:`benjeffery`, :pr:`3153`)

python/tests/test_highlevel.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5584,3 +5584,87 @@ def test_mixed_sample_status(self):
55845584
expected = np.array([[0, 1]])
55855585
assert result.shape == (1, 2)
55865586
assert_array_equal(result, expected)
5587+
5588+
5589+
class TestSampleNodesByPloidy:
5590+
@pytest.mark.parametrize(
5591+
"n_samples,ploidy,expected",
5592+
[
5593+
(6, 2, np.array([[0, 1], [2, 3], [4, 5]])), # Basic diploid
5594+
(9, 3, np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])), # Triploid
5595+
(5, 1, np.array([[0], [1], [2], [3], [4]])), # Ploidy of 1
5596+
(4, 4, np.array([[0, 1, 2, 3]])), # Ploidy equals number of samples
5597+
],
5598+
)
5599+
def test_various_ploidy_scenarios(self, n_samples, ploidy, expected):
5600+
tables = tskit.TableCollection(sequence_length=100)
5601+
for _ in range(n_samples):
5602+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
5603+
ts = tables.tree_sequence()
5604+
5605+
result = ts.sample_nodes_by_ploidy(ploidy)
5606+
expected_shape = (n_samples // ploidy, ploidy)
5607+
assert result.shape == expected_shape
5608+
assert_array_equal(result, expected)
5609+
5610+
def test_mixed_sample_status(self):
5611+
tables = tskit.TableCollection(sequence_length=100)
5612+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
5613+
tables.nodes.add_row(flags=0, time=0)
5614+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
5615+
tables.nodes.add_row(flags=0, time=0)
5616+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
5617+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
5618+
ts = tables.tree_sequence()
5619+
5620+
result = ts.sample_nodes_by_ploidy(2)
5621+
assert result.shape == (2, 2)
5622+
expected = np.array([[0, 2], [4, 5]])
5623+
assert_array_equal(result, expected)
5624+
5625+
def test_no_sample_nodes(self):
5626+
tables = tskit.TableCollection(sequence_length=100)
5627+
tables.nodes.add_row(flags=0, time=0)
5628+
tables.nodes.add_row(flags=0, time=0)
5629+
ts = tables.tree_sequence()
5630+
5631+
with pytest.raises(ValueError, match="No sample nodes in tree sequence"):
5632+
ts.sample_nodes_by_ploidy(2)
5633+
5634+
def test_not_multiple_of_ploidy(self):
5635+
tables = tskit.TableCollection(sequence_length=100)
5636+
for _ in range(5):
5637+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
5638+
ts = tables.tree_sequence()
5639+
5640+
with pytest.raises(ValueError, match="not a multiple of ploidy"):
5641+
ts.sample_nodes_by_ploidy(2)
5642+
5643+
def test_with_existing_individuals(self):
5644+
tables = tskit.TableCollection(sequence_length=100)
5645+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
5646+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
5647+
# Add nodes with individual references but in a different order
5648+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
5649+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
5650+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
5651+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
5652+
5653+
ts = tables.tree_sequence()
5654+
result = ts.sample_nodes_by_ploidy(2)
5655+
expected = np.array([[0, 1], [2, 3]])
5656+
assert_array_equal(result, expected)
5657+
ind_nodes = ts.individuals_nodes
5658+
assert not np.array_equal(result, ind_nodes)
5659+
5660+
def test_different_node_flags(self):
5661+
tables = tskit.TableCollection(sequence_length=100)
5662+
OTHER_FLAG1 = 1 << 1
5663+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
5664+
tables.nodes.add_row(flags=OTHER_FLAG1, time=0)
5665+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE | OTHER_FLAG1, time=0)
5666+
tables.nodes.add_row()
5667+
ts = tables.tree_sequence()
5668+
result = ts.sample_nodes_by_ploidy(2)
5669+
assert result.shape == (1, 2)
5670+
assert_array_equal(result, np.array([[0, 2]]))

python/tskit/trees.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10495,6 +10495,31 @@ def ld_matrix(
1049510495
mode=mode,
1049610496
)
1049710497

10498+
def sample_nodes_by_ploidy(self, ploidy):
10499+
"""
10500+
Returns an 2D array of node IDs, where each row has length `ploidy`.
10501+
This is useful when individuals are not defined in the tree sequence
10502+
so `TreeSequence.individuals_nodes` cannot be used. The samples are
10503+
placed in the array in the order which they are found in the node
10504+
table. The number of sample nodes must be a multiple of ploidy.
10505+
10506+
:param int ploidy: The number of samples per individual.
10507+
:return: A 2D array of node IDs, where each row has length `ploidy`.
10508+
:rtype: numpy.ndarray
10509+
"""
10510+
sample_node_ids = np.flatnonzero(self.nodes_flags & tskit.NODE_IS_SAMPLE)
10511+
num_samples = len(sample_node_ids)
10512+
if num_samples == 0:
10513+
raise ValueError("No sample nodes in tree sequence")
10514+
if num_samples % ploidy != 0:
10515+
raise ValueError(
10516+
f"Number of sample nodes {num_samples} is not a multiple "
10517+
f"of ploidy {ploidy}"
10518+
)
10519+
num_samples_per_individual = num_samples // ploidy
10520+
sample_node_ids = sample_node_ids.reshape((num_samples_per_individual, ploidy))
10521+
return sample_node_ids
10522+
1049810523
############################################
1049910524
#
1050010525
# Deprecated APIs. These are either already unsupported, or will be unsupported in a

0 commit comments

Comments
 (0)