1
1
import numpy as np
2
+ import numpy .testing as nt
2
3
import pytest
3
4
import tsinfer
4
5
import tskit
8
9
import util
9
10
10
11
12
+ class TestSolveNumMismatches :
13
+
14
+ @pytest .mark .parametrize (
15
+ ["k" , "expected_rho" ],
16
+ [(2 , 0.02295918 ), (3 , 0.00327988 ), (4 , 0.00046855 ), (1000 , 0 )],
17
+ )
18
+ def test_examples (self , k , expected_rho ):
19
+ rho , mu = sc2ts .solve_num_mismatches (k , num_sites = 2 )
20
+ assert mu [0 ] == 0.125
21
+ nt .assert_almost_equal (rho [0 ], expected_rho )
22
+
23
+
11
24
class TestInitialTs :
12
25
def test_reference_sequence (self ):
13
26
ts = sc2ts .initial_ts ()
@@ -612,13 +625,13 @@ def test_node_mutation_counts(self, fx_ts_map, date):
612
625
"2020-02-03" : {"nodes" : 36 , "mutations" : 42 },
613
626
"2020-02-04" : {"nodes" : 41 , "mutations" : 48 },
614
627
"2020-02-05" : {"nodes" : 42 , "mutations" : 48 },
615
- "2020-02-06" : {"nodes" : 49 , "mutations" : 51 },
616
- "2020-02-07" : {"nodes" : 51 , "mutations" : 57 },
617
- "2020-02-08" : {"nodes" : 57 , "mutations" : 58 },
618
- "2020-02-09" : {"nodes" : 59 , "mutations" : 61 },
619
- "2020-02-10" : {"nodes" : 60 , "mutations" : 65 },
620
- "2020-02-11" : {"nodes" : 62 , "mutations" : 66 },
621
- "2020-02-13" : {"nodes" : 66 , "mutations" : 68 },
628
+ "2020-02-06" : {"nodes" : 48 , "mutations" : 51 },
629
+ "2020-02-07" : {"nodes" : 50 , "mutations" : 57 },
630
+ "2020-02-08" : {"nodes" : 56 , "mutations" : 58 },
631
+ "2020-02-09" : {"nodes" : 58 , "mutations" : 61 },
632
+ "2020-02-10" : {"nodes" : 59 , "mutations" : 65 },
633
+ "2020-02-11" : {"nodes" : 61 , "mutations" : 66 },
634
+ "2020-02-13" : {"nodes" : 65 , "mutations" : 68 },
622
635
}
623
636
assert ts .num_nodes == expected [date ]["nodes" ]
624
637
assert ts .num_mutations == expected [date ]["mutations" ]
@@ -631,9 +644,9 @@ def test_node_mutation_counts(self, fx_ts_map, date):
631
644
(13 , "SRR11597132" , 10 ),
632
645
(16 , "SRR11597177" , 10 ),
633
646
(41 , "SRR11597156" , 10 ),
634
- (57 , "SRR11597216" , 1 ),
635
- (60 , "SRR11597207" , 40 ),
636
- (62 , "ERR4205570" , 58 ),
647
+ (56 , "SRR11597216" , 1 ),
648
+ (59 , "SRR11597207" , 40 ),
649
+ (61 , "ERR4205570" , 57 ),
637
650
],
638
651
)
639
652
def test_exact_matches (self , fx_ts_map , node , strain , parent ):
@@ -693,10 +706,9 @@ class TestMatchingDetails:
693
706
# assert s.path[0].parent == 37
694
707
695
708
@pytest .mark .parametrize (
696
- ("strain" , "parent" ), [("SRR11597207" , 40 ), ("ERR4205570" , 58 )]
709
+ ("strain" , "parent" ), [("SRR11597207" , 40 ), ("ERR4205570" , 57 )]
697
710
)
698
711
@pytest .mark .parametrize ("num_mismatches" , [2 , 3 , 4 ])
699
- @pytest .mark .parametrize ("precision" , [0 , 1 , 2 , 12 ])
700
712
def test_exact_matches (
701
713
self ,
702
714
fx_ts_map ,
@@ -705,17 +717,18 @@ def test_exact_matches(
705
717
strain ,
706
718
parent ,
707
719
num_mismatches ,
708
- precision ,
709
720
):
710
721
ts = fx_ts_map ["2020-02-10" ]
711
722
samples = sc2ts .preprocess (
712
723
[fx_metadata_db [strain ]], ts , "2020-02-20" , fx_alignment_store
713
724
)
725
+ # FIXME
726
+ mu = 0.125
714
727
sc2ts .match_tsinfer (
715
728
samples = samples ,
716
729
ts = ts ,
717
730
num_mismatches = num_mismatches ,
718
- precision = precision ,
731
+ likelihood_threshold = mu ** num_mismatches - 1e-12 ,
719
732
num_threads = 0 ,
720
733
)
721
734
s = samples [0 ]
@@ -725,10 +738,10 @@ def test_exact_matches(
725
738
726
739
@pytest .mark .parametrize (
727
740
("strain" , "parent" , "position" , "derived_state" ),
728
- [("SRR11597218" , 10 , 289 , "T" ), ("ERR4206593" , 58 , 26994 , "T" )],
741
+ [("SRR11597218" , 10 , 289 , "T" ), ("ERR4206593" , 57 , 26994 , "T" )],
729
742
)
730
743
@pytest .mark .parametrize ("num_mismatches" , [2 , 3 , 4 ])
731
- @pytest .mark .parametrize ("precision" , [0 , 1 , 2 , 12 ])
744
+ # @pytest.mark.parametrize("precision", [0, 1, 2, 12])
732
745
def test_one_mismatch (
733
746
self ,
734
747
fx_ts_map ,
@@ -739,7 +752,6 @@ def test_one_mismatch(
739
752
position ,
740
753
derived_state ,
741
754
num_mismatches ,
742
- precision ,
743
755
):
744
756
ts = fx_ts_map ["2020-02-10" ]
745
757
samples = sc2ts .preprocess (
@@ -749,7 +761,8 @@ def test_one_mismatch(
749
761
samples = samples ,
750
762
ts = ts ,
751
763
num_mismatches = num_mismatches ,
752
- precision = precision ,
764
+ # FIXME
765
+ likelihood_threshold = 0.12499999 ,
753
766
num_threads = 0 ,
754
767
)
755
768
s = samples [0 ]
@@ -760,30 +773,27 @@ def test_one_mismatch(
760
773
assert s .path [0 ].parent == parent
761
774
762
775
@pytest .mark .parametrize ("num_mismatches" , [2 , 3 , 4 ])
763
- @pytest .mark .parametrize ("precision" , [0 , 1 , 2 , 12 ])
764
776
def test_two_mismatches (
765
777
self ,
766
778
fx_ts_map ,
767
779
fx_alignment_store ,
768
780
fx_metadata_db ,
769
781
num_mismatches ,
770
- precision ,
771
782
):
772
783
strain = "ERR4204459"
773
784
ts = fx_ts_map ["2020-02-10" ]
774
785
samples = sc2ts .preprocess (
775
786
[fx_metadata_db [strain ]], ts , "2020-02-20" , fx_alignment_store
776
787
)
788
+ mu = 0.125
777
789
sc2ts .match_tsinfer (
778
790
samples = samples ,
779
791
ts = ts ,
780
792
num_mismatches = num_mismatches ,
781
- precision = precision ,
793
+ likelihood_threshold = mu ** 2 - 1e-12 ,
782
794
num_threads = 0 ,
783
795
)
784
796
s = samples [0 ]
785
797
assert len (s .path ) == 1
786
798
assert s .path [0 ].parent == 5
787
799
assert len (s .mutations ) == 2
788
- # assert s.mutations[0].site_position == position
789
- # assert s.mutations[0].derived_state == derived_state
0 commit comments