Skip to content

Commit 7a3b625

Browse files
Experiment with alternative HMM implementation
Requires tskit-dev/tsinfer#959
1 parent dffa62c commit 7a3b625

File tree

3 files changed

+56
-55
lines changed

3 files changed

+56
-55
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ authors = [
88
]
99
requires-python = ">=3.9"
1010
dependencies = [
11-
"tsinfer==0.3.3", # https://github.com/jeromekelleher/sc2ts/issues/201
11+
# "tsinfer==0.3.3", # https://github.com/jeromekelleher/sc2ts/issues/201
12+
# FIXME
13+
"tsinfer @ git+https://github.com/jeromekelleher/tsinfer.git@experimental-hmm",
1214
"pyfaidx",
1315
"tskit>=0.5.3",
1416
"tszip",

sc2ts/inference.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -402,46 +402,45 @@ def match_samples(
402402
show_progress=False,
403403
num_threads=None,
404404
):
405-
# First pass, compute the matches at precision=0.
406405
run_batch = samples
407406

408-
# Values based on https://github.com/jeromekelleher/sc2ts/issues/242,
409-
# but somewhat arbitrary.
410-
for precision, cost_threshold in [(0, 1), (1, 2), (2, 3)]:
411-
logger.info(f"Running batch of {len(run_batch)} at p={precision}")
407+
mu = 0.125 ## FIXME
408+
for k in range(num_mismatches):
409+
# To catch k mismatches we need a likelihood threshold of mu**k
410+
likelihood_threshold = mu**k - 1e-15
411+
logger.info(f"Running match={k} batch of {len(run_batch)} at threshold={likelihood_threshold}")
412412
match_tsinfer(
413413
samples=run_batch,
414414
ts=base_ts,
415415
num_mismatches=num_mismatches,
416-
precision=precision,
416+
likelihood_threshold=likelihood_threshold,
417417
num_threads=num_threads,
418418
show_progress=show_progress,
419419
)
420420

421421
exceeding_threshold = []
422422
for sample in run_batch:
423423
cost = sample.get_hmm_cost(num_mismatches)
424-
logger.debug(f"HMM@p={precision}: hmm_cost={cost} {sample.summary()}")
425-
if cost > cost_threshold:
424+
logger.debug(f"HMM@k={k}: hmm_cost={cost} {sample.summary()}")
425+
if cost > k + 1:
426426
sample.path.clear()
427427
sample.mutations.clear()
428428
exceeding_threshold.append(sample)
429429

430430
num_matches_found = len(run_batch) - len(exceeding_threshold)
431431
logger.info(
432-
f"{num_matches_found} final matches for found p={precision}; "
432+
f"{num_matches_found} final matches found at k={k}; "
433433
f"{len(exceeding_threshold)} remain"
434434
)
435435
run_batch = exceeding_threshold
436436

437-
precision = 6
438-
logger.info(f"Running final batch of {len(run_batch)} at p={precision}")
437+
logger.info(f"Running final batch of {len(run_batch)} at full precision")
439438
match_tsinfer(
440439
samples=run_batch,
441440
ts=base_ts,
442441
num_mismatches=num_mismatches,
443-
precision=precision,
444442
num_threads=num_threads,
443+
likelihood_threshold=1e-200,
445444
show_progress=show_progress,
446445
)
447446
for sample in run_batch:
@@ -798,36 +797,26 @@ def add_matching_results(
798797
return ts # , excluded_samples, added_samples
799798

800799

801-
def solve_num_mismatches(ts, k):
800+
def solve_num_mismatches(k, num_sites, mu=0.125):
802801
"""
803802
Return the low-level LS parameters corresponding to accepting
804803
k mismatches in favour of a single recombination.
805804
806805
NOTE! This is NOT taking into account the spatial distance along
807806
the genome, and so is not a very good model in some ways.
808807
"""
809-
# We can match against any node in tsinfer
810-
m = ts.num_sites
811-
n = ts.num_nodes
812808
# values of k <= 1 are not relevant for SC2 and lead to awkward corner cases
813809
assert k > 1
814810

815-
# NOTE: the magnitude of mu matters because it puts a limit
816-
# on how low we can push the HMM precision. We should be able to solve
817-
# for the optimal value of this parameter such that the magnitude of the
818-
# values within the HMM are as large as possible (so that we can truncate
819-
# usefully).
820-
# mu = 1e-2
821-
mu = 0.125
822-
denom = (1 - mu) ** k + (n - 1) * mu**k
823-
r = n * mu**k / denom
811+
denom = (1 - mu) ** k
812+
r = mu**k / denom
824813

825814
# Add a little bit of extra mass for recombination so that we deterministically
826815
# chose to recombine over k mutations
827816
# NOTE: the magnitude of this value will depend also on mu, see above.
828-
r += r * 0.01
829-
ls_recomb = np.full(m - 1, r)
830-
ls_mismatch = np.full(m, mu)
817+
r += r * 0.125
818+
ls_recomb = np.full(num_sites - 1, r)
819+
ls_mismatch = np.full(num_sites, mu)
831820
return ls_recomb, ls_mismatch
832821

833822

@@ -1268,7 +1257,7 @@ def match_tsinfer(
12681257
ts,
12691258
*,
12701259
num_mismatches,
1271-
precision=None,
1260+
likelihood_threshold=None,
12721261
num_threads=0,
12731262
show_progress=False,
12741263
mirror_coordinates=False,
@@ -1284,7 +1273,7 @@ def match_tsinfer(
12841273
sd = convert_tsinfer_sample_data(ts, genotypes)
12851274

12861275
L = int(ts.sequence_length)
1287-
ls_recomb, ls_mismatch = solve_num_mismatches(ts, num_mismatches)
1276+
ls_recomb, ls_mismatch = solve_num_mismatches(num_mismatches, ts.num_sites)
12881277
pm = tsinfer.inference._get_progress_monitor(
12891278
show_progress,
12901279
generate_ancestors=False,
@@ -1309,7 +1298,7 @@ def match_tsinfer(
13091298
mismatch=ls_mismatch,
13101299
progress_monitor=pm,
13111300
num_threads=num_threads,
1312-
precision=precision,
1301+
likelihood_threshold=likelihood_threshold
13131302
)
13141303
results = manager.run_match(np.arange(sd.num_samples))
13151304

tests/test_inference.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import numpy.testing as nt
23
import pytest
34
import tsinfer
45
import tskit
@@ -8,6 +9,18 @@
89
import util
910

1011

12+
class TestSolveNumMismatches:
13+
14+
@pytest.mark.parametrize(
15+
["k", "expected_rho"],
16+
[(2, 0.02295918), (3, 0.00327988), (4, 0.00046855), (1000, 0)],
17+
)
18+
def test_examples(self, k, expected_rho):
19+
rho, mu = sc2ts.solve_num_mismatches(k, num_sites=2)
20+
assert mu[0] == 0.125
21+
nt.assert_almost_equal(rho[0], expected_rho)
22+
23+
1124
class TestInitialTs:
1225
def test_reference_sequence(self):
1326
ts = sc2ts.initial_ts()
@@ -612,13 +625,13 @@ def test_node_mutation_counts(self, fx_ts_map, date):
612625
"2020-02-03": {"nodes": 36, "mutations": 42},
613626
"2020-02-04": {"nodes": 41, "mutations": 48},
614627
"2020-02-05": {"nodes": 42, "mutations": 48},
615-
"2020-02-06": {"nodes": 49, "mutations": 51},
616-
"2020-02-07": {"nodes": 51, "mutations": 57},
617-
"2020-02-08": {"nodes": 57, "mutations": 58},
618-
"2020-02-09": {"nodes": 59, "mutations": 61},
619-
"2020-02-10": {"nodes": 60, "mutations": 65},
620-
"2020-02-11": {"nodes": 62, "mutations": 66},
621-
"2020-02-13": {"nodes": 66, "mutations": 68},
628+
"2020-02-06": {"nodes": 48, "mutations": 51},
629+
"2020-02-07": {"nodes": 50, "mutations": 57},
630+
"2020-02-08": {"nodes": 56, "mutations": 58},
631+
"2020-02-09": {"nodes": 58, "mutations": 61},
632+
"2020-02-10": {"nodes": 59, "mutations": 65},
633+
"2020-02-11": {"nodes": 61, "mutations": 66},
634+
"2020-02-13": {"nodes": 65, "mutations": 68},
622635
}
623636
assert ts.num_nodes == expected[date]["nodes"]
624637
assert ts.num_mutations == expected[date]["mutations"]
@@ -631,9 +644,9 @@ def test_node_mutation_counts(self, fx_ts_map, date):
631644
(13, "SRR11597132", 10),
632645
(16, "SRR11597177", 10),
633646
(41, "SRR11597156", 10),
634-
(57, "SRR11597216", 1),
635-
(60, "SRR11597207", 40),
636-
(62, "ERR4205570", 58),
647+
(56, "SRR11597216", 1),
648+
(59, "SRR11597207", 40),
649+
(61, "ERR4205570", 57),
637650
],
638651
)
639652
def test_exact_matches(self, fx_ts_map, node, strain, parent):
@@ -693,10 +706,9 @@ class TestMatchingDetails:
693706
# assert s.path[0].parent == 37
694707

695708
@pytest.mark.parametrize(
696-
("strain", "parent"), [("SRR11597207", 40), ("ERR4205570", 58)]
709+
("strain", "parent"), [("SRR11597207", 40), ("ERR4205570", 57)]
697710
)
698711
@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
699-
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
700712
def test_exact_matches(
701713
self,
702714
fx_ts_map,
@@ -705,17 +717,18 @@ def test_exact_matches(
705717
strain,
706718
parent,
707719
num_mismatches,
708-
precision,
709720
):
710721
ts = fx_ts_map["2020-02-10"]
711722
samples = sc2ts.preprocess(
712723
[fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store
713724
)
725+
# FIXME
726+
mu = 0.125
714727
sc2ts.match_tsinfer(
715728
samples=samples,
716729
ts=ts,
717730
num_mismatches=num_mismatches,
718-
precision=precision,
731+
likelihood_threshold = mu**num_mismatches - 1e-12,
719732
num_threads=0,
720733
)
721734
s = samples[0]
@@ -725,10 +738,10 @@ def test_exact_matches(
725738

726739
@pytest.mark.parametrize(
727740
("strain", "parent", "position", "derived_state"),
728-
[("SRR11597218", 10, 289, "T"), ("ERR4206593", 58, 26994, "T")],
741+
[("SRR11597218", 10, 289, "T"), ("ERR4206593", 57, 26994, "T")],
729742
)
730743
@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
731-
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
744+
# @pytest.mark.parametrize("precision", [0, 1, 2, 12])
732745
def test_one_mismatch(
733746
self,
734747
fx_ts_map,
@@ -739,7 +752,6 @@ def test_one_mismatch(
739752
position,
740753
derived_state,
741754
num_mismatches,
742-
precision,
743755
):
744756
ts = fx_ts_map["2020-02-10"]
745757
samples = sc2ts.preprocess(
@@ -749,7 +761,8 @@ def test_one_mismatch(
749761
samples=samples,
750762
ts=ts,
751763
num_mismatches=num_mismatches,
752-
precision=precision,
764+
# FIXME
765+
likelihood_threshold=0.12499999,
753766
num_threads=0,
754767
)
755768
s = samples[0]
@@ -760,30 +773,27 @@ def test_one_mismatch(
760773
assert s.path[0].parent == parent
761774

762775
@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
763-
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
764776
def test_two_mismatches(
765777
self,
766778
fx_ts_map,
767779
fx_alignment_store,
768780
fx_metadata_db,
769781
num_mismatches,
770-
precision,
771782
):
772783
strain = "ERR4204459"
773784
ts = fx_ts_map["2020-02-10"]
774785
samples = sc2ts.preprocess(
775786
[fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store
776787
)
788+
mu = 0.125
777789
sc2ts.match_tsinfer(
778790
samples=samples,
779791
ts=ts,
780792
num_mismatches=num_mismatches,
781-
precision=precision,
793+
likelihood_threshold=mu**2 - 1e-12,
782794
num_threads=0,
783795
)
784796
s = samples[0]
785797
assert len(s.path) == 1
786798
assert s.path[0].parent == 5
787799
assert len(s.mutations) == 2
788-
# assert s.mutations[0].site_position == position
789-
# assert s.mutations[0].derived_state == derived_state

0 commit comments

Comments
 (0)