@@ -5584,3 +5584,87 @@ def test_mixed_sample_status(self):
5584
5584
expected = np .array ([[0 , 1 ]])
5585
5585
assert result .shape == (1 , 2 )
5586
5586
assert_array_equal (result , expected )
5587
+
5588
+
5589
+ class TestSampleNodesByPloidy :
5590
+ @pytest .mark .parametrize (
5591
+ "n_samples,ploidy,expected" ,
5592
+ [
5593
+ (6 , 2 , np .array ([[0 , 1 ], [2 , 3 ], [4 , 5 ]])), # Basic diploid
5594
+ (9 , 3 , np .array ([[0 , 1 , 2 ], [3 , 4 , 5 ], [6 , 7 , 8 ]])), # Triploid
5595
+ (5 , 1 , np .array ([[0 ], [1 ], [2 ], [3 ], [4 ]])), # Ploidy of 1
5596
+ (4 , 4 , np .array ([[0 , 1 , 2 , 3 ]])), # Ploidy equals number of samples
5597
+ ],
5598
+ )
5599
+ def test_various_ploidy_scenarios (self , n_samples , ploidy , expected ):
5600
+ tables = tskit .TableCollection (sequence_length = 100 )
5601
+ for _ in range (n_samples ):
5602
+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5603
+ ts = tables .tree_sequence ()
5604
+
5605
+ result = ts .sample_nodes_by_ploidy (ploidy )
5606
+ expected_shape = (n_samples // ploidy , ploidy )
5607
+ assert result .shape == expected_shape
5608
+ assert_array_equal (result , expected )
5609
+
5610
+ def test_mixed_sample_status (self ):
5611
+ tables = tskit .TableCollection (sequence_length = 100 )
5612
+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5613
+ tables .nodes .add_row (flags = 0 , time = 0 )
5614
+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5615
+ tables .nodes .add_row (flags = 0 , time = 0 )
5616
+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5617
+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5618
+ ts = tables .tree_sequence ()
5619
+
5620
+ result = ts .sample_nodes_by_ploidy (2 )
5621
+ assert result .shape == (2 , 2 )
5622
+ expected = np .array ([[0 , 2 ], [4 , 5 ]])
5623
+ assert_array_equal (result , expected )
5624
+
5625
+ def test_no_sample_nodes (self ):
5626
+ tables = tskit .TableCollection (sequence_length = 100 )
5627
+ tables .nodes .add_row (flags = 0 , time = 0 )
5628
+ tables .nodes .add_row (flags = 0 , time = 0 )
5629
+ ts = tables .tree_sequence ()
5630
+
5631
+ with pytest .raises (ValueError , match = "No sample nodes in tree sequence" ):
5632
+ ts .sample_nodes_by_ploidy (2 )
5633
+
5634
+ def test_not_multiple_of_ploidy (self ):
5635
+ tables = tskit .TableCollection (sequence_length = 100 )
5636
+ for _ in range (5 ):
5637
+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5638
+ ts = tables .tree_sequence ()
5639
+
5640
+ with pytest .raises (ValueError , match = "not a multiple of ploidy" ):
5641
+ ts .sample_nodes_by_ploidy (2 )
5642
+
5643
+ def test_with_existing_individuals (self ):
5644
+ tables = tskit .TableCollection (sequence_length = 100 )
5645
+ tables .individuals .add_row (flags = 0 , location = (0 , 0 ), metadata = b"" )
5646
+ tables .individuals .add_row (flags = 0 , location = (0 , 0 ), metadata = b"" )
5647
+ # Add nodes with individual references but in a different order
5648
+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 , individual = 1 )
5649
+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 , individual = 0 )
5650
+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 , individual = 1 )
5651
+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 , individual = 0 )
5652
+
5653
+ ts = tables .tree_sequence ()
5654
+ result = ts .sample_nodes_by_ploidy (2 )
5655
+ expected = np .array ([[0 , 1 ], [2 , 3 ]])
5656
+ assert_array_equal (result , expected )
5657
+ ind_nodes = ts .individuals_nodes
5658
+ assert not np .array_equal (result , ind_nodes )
5659
+
5660
+ def test_different_node_flags (self ):
5661
+ tables = tskit .TableCollection (sequence_length = 100 )
5662
+ OTHER_FLAG1 = 1 << 1
5663
+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5664
+ tables .nodes .add_row (flags = OTHER_FLAG1 , time = 0 )
5665
+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE | OTHER_FLAG1 , time = 0 )
5666
+ tables .nodes .add_row ()
5667
+ ts = tables .tree_sequence ()
5668
+ result = ts .sample_nodes_by_ploidy (2 )
5669
+ assert result .shape == (1 , 2 )
5670
+ assert_array_equal (result , np .array ([[0 , 2 ]]))
0 commit comments