Skip to content

Commit c689d52

Browse files
committed
WIP
1 parent 589e2d9 commit c689d52

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

python/tests/test_beagle.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -398,29 +398,49 @@ def compute_state_probability_matrix_equation1(fm, bm, ref_h, query_h, rho, mu):
398398
return sm
399399

400400

401-
def interpolate_allele_probabilities_equation1(sm, ref_h, genotyped_pos, imputed_pos):
401+
def interpolate_allele_probabilities_equation1(
402+
sm, ref_h, query_h, genotyped_pos, imputed_pos
403+
):
402404
"""
403405
Compute the interpolated allele probabilities following Equation 1 of BB2016.
404406
407+
This function takes the output of `compute_state_probability_matrix_equation1`.
408+
405409
Assume all biallelic sites.
406410
407411
:param numpy.ndarray sm: HMM state probability matrix.
408412
:param numpy.ndarray ref_h: Reference haplotypes.
413+
:param numpy.ndarray query_h: One query haplotype.
409414
:param numpy.ndarray genotyped_pos: Site positions at genotyped markers.
410415
:param numpy.ndarray imputed_pos: Site positions at imputed markers.
411416
:return: Interpolated allele probabilities.
412417
:rtype: numpy.ndarray
413418
"""
414419
h = ref_h.shape[1]
415-
# m = len(genotyped_pos)
420+
m = len(genotyped_pos)
416421
x = len(imputed_pos)
417-
p = np.array((x, 2), dtype=np.float64)
422+
assert sm.shape == (m, h)
423+
assert ref_h.shape == (m + x, h)
424+
assert len(query_h) == m + x
418425
weights = get_weights(genotyped_pos, imputed_pos)
419-
# Compute probabilities of allele a at imputed markers
420-
a = 0
421-
for i in np.arange(x):
422-
for j in np.arange(h):
423-
p[i, a] = weights[i] * sm[i, j] + (1 - weights[i]) * sm[i + 1, j]
426+
assert len(weights) == x
427+
p = np.zeros((x, 2), dtype=np.float64)
428+
429+
def _compute_allele_probabilities(a):
430+
"""Helper function to compute probability of allele a at imputed markers."""
431+
k = 0 # Keep track of imputed marker index
432+
# l = 0 # Keep track of genotyped marker index
433+
for i in np.arange(m + x):
434+
if query_h[i] != -1:
435+
continue
436+
for j in np.arange(h):
437+
if ref_h[i, j] == a:
438+
p[k, a] += weights[i] * sm[i, j]
439+
p[k, a] += (1 - weights[i]) * sm[i + 1, j]
440+
k += 1
441+
442+
_compute_allele_probabilities(a=0)
443+
_compute_allele_probabilities(a=1)
424444
return p
425445

426446

0 commit comments

Comments
 (0)