diff --git a/pori_python/graphkb/match.py b/pori_python/graphkb/match.py index c646f49..f735408 100644 --- a/pori_python/graphkb/match.py +++ b/pori_python/graphkb/match.py @@ -31,7 +31,12 @@ looks_like_rid, stringifyVariant, ) -from .vocab import get_equivalent_terms, get_term_tree, get_terms_set +from .vocab import ( + get_equivalent_terms, + get_term_by_name, + get_term_tree, + get_terms_set, +) FEATURES_CACHE: Set[str] = set() @@ -275,7 +280,55 @@ def positions_overlap( return start is None or pos == start +def equivalent_types( + conn: GraphKBConnection, + type1: str, + type2: str, + strict: bool = False, +) -> bool: + """ + Compare 2 variant types to determine if they should match + + Args: + conn: the graphkb connection object + type1: type from the observed variant we want to match to the DB + type2: type from the DB variant + strict: wether or not only the specific-to-generic ones are considered. + By default (false), not only specific types can match more generic ones, + but generic types can also match more specific ones. + + Returns: + bool: True if the types can be matched + """ + + # Convert rid to displayName if needed + if looks_like_rid(type1): + type1 = conn.get_records_by_id([type1])[0]['displayName'] + if looks_like_rid(type2): + type2 = conn.get_records_by_id([type2])[0]['displayName'] + + # Get type terms from observed variant + terms1 = [] + if strict: + try: + terms1.append(get_term_by_name(conn, type1)['@rid']) + except: + pass + else: + terms1 = get_terms_set(conn, [type1]) + + # Get type terms from DB variant + terms2 = get_terms_set(conn, [type2]) + + # Check for intersect + if len(terms2.intersection(terms1)) == 0: + return False + + return True + + def compare_positional_variants( + conn: GraphKBConnection, variant: Union[PositionalVariant, ParsedVariant], reference_variant: Union[PositionalVariant, ParsedVariant], generic: bool = True, @@ -378,6 +431,11 @@ def compare_positional_variants( elif len(variant["refSeq"]) != len(reference_variant["refSeq"]): # type: ignore return False + # Equivalent types + if variant.get('type') and reference_variant.get('type'): + if not equivalent_types(conn, variant["type"], reference_variant["type"]): + return False + return True @@ -598,10 +656,14 @@ def match_positional_variant( ): # TODO: Check if variant and reference_variant should be interchanged if compare_positional_variants( - variant=parsed, reference_variant=cast(PositionalVariant, row), generic=True + conn, + variant=parsed, + reference_variant=cast(PositionalVariant, row), + generic=True, ): filtered_similarAndGeneric.append(row) if compare_positional_variants( + conn, variant=parsed, reference_variant=cast(PositionalVariant, row), generic=False, # Similar variants only diff --git a/tests/test_graphkb/test_match.py b/tests/test_graphkb/test_match.py index 9ab4515..42ff7f6 100644 --- a/tests/test_graphkb/test_match.py +++ b/tests/test_graphkb/test_match.py @@ -269,14 +269,14 @@ def test_known_increased_expression(self, conn): class TestComparePositionalVariants: def test_nonspecific_altseq(self): assert match.compare_positional_variants( - {"break1Start": {"pos": 1}}, {"break1Start": {"pos": 1}} + conn, {"break1Start": {"pos": 1}}, {"break1Start": {"pos": 1}} ) # null matches anything assert match.compare_positional_variants( - {"break1Start": {"pos": 1}, "untemplatedSeq": "T"}, {"break1Start": {"pos": 1}} + conn, {"break1Start": {"pos": 1}, "untemplatedSeq": "T"}, {"break1Start": {"pos": 1}} ) assert match.compare_positional_variants( - {"break1Start": {"pos": 1}}, {"break1Start": {"pos": 1}, "untemplatedSeq": "T"} + conn, {"break1Start": {"pos": 1}}, {"break1Start": {"pos": 1}, "untemplatedSeq": "T"} ) @pytest.mark.parametrize("seq1", ["T", "X", "?"]) @@ -284,16 +284,19 @@ def test_nonspecific_altseq(self): def test_ambiguous_altseq(self, seq1, seq2): # ambiguous AA matches anything the same length assert match.compare_positional_variants( + conn, {"break1Start": {"pos": 1}, "untemplatedSeq": seq1}, {"break1Start": {"pos": 1}, "untemplatedSeq": seq2}, ) def test_altseq_length_mismatch(self): assert not match.compare_positional_variants( + conn, {"break1Start": {"pos": 1}, "untemplatedSeq": "??"}, {"break1Start": {"pos": 1}, "untemplatedSeq": "T"}, ) assert not match.compare_positional_variants( + conn, {"break1Start": {"pos": 1}, "untemplatedSeq": "?"}, {"break1Start": {"pos": 1}, "untemplatedSeq": "TT"}, ) @@ -301,10 +304,10 @@ def test_altseq_length_mismatch(self): def test_nonspecific_refseq(self): # null matches anything assert match.compare_positional_variants( - {"break1Start": {"pos": 1}, "refSeq": "T"}, {"break1Start": {"pos": 1}} + conn, {"break1Start": {"pos": 1}, "refSeq": "T"}, {"break1Start": {"pos": 1}} ) assert match.compare_positional_variants( - {"break1Start": {"pos": 1}}, {"break1Start": {"pos": 1}, "refSeq": "T"} + conn, {"break1Start": {"pos": 1}}, {"break1Start": {"pos": 1}, "refSeq": "T"} ) @pytest.mark.parametrize("seq1", ["T", "X", "?"]) @@ -312,37 +315,49 @@ def test_nonspecific_refseq(self): def test_ambiguous_refseq(self, seq1, seq2): # ambiguous AA matches anything the same length assert match.compare_positional_variants( - {"break1Start": {"pos": 1}, "refSeq": seq1}, {"break1Start": {"pos": 1}, "refSeq": seq2} + conn, + {"break1Start": {"pos": 1}, "refSeq": seq1}, + {"break1Start": {"pos": 1}, "refSeq": seq2}, ) def test_refseq_length_mismatch(self): assert not match.compare_positional_variants( - {"break1Start": {"pos": 1}, "refSeq": "??"}, {"break1Start": {"pos": 1}, "refSeq": "T"} + conn, + {"break1Start": {"pos": 1}, "refSeq": "??"}, + {"break1Start": {"pos": 1}, "refSeq": "T"}, ) assert not match.compare_positional_variants( - {"break1Start": {"pos": 1}, "refSeq": "?"}, {"break1Start": {"pos": 1}, "refSeq": "TT"} + conn, + {"break1Start": {"pos": 1}, "refSeq": "?"}, + {"break1Start": {"pos": 1}, "refSeq": "TT"}, ) def test_diff_altseq(self): assert not match.compare_positional_variants( + conn, {"break1Start": {"pos": 1}, "untemplatedSeq": "M"}, {"break1Start": {"pos": 1}, "untemplatedSeq": "R"}, ) def test_same_altseq_matches(self): assert match.compare_positional_variants( + conn, {"break1Start": {"pos": 1}, "untemplatedSeq": "R"}, {"break1Start": {"pos": 1}, "untemplatedSeq": "R"}, ) def test_diff_refseq(self): assert not match.compare_positional_variants( - {"break1Start": {"pos": 1}, "refSeq": "M"}, {"break1Start": {"pos": 1}, "refSeq": "R"} + conn, + {"break1Start": {"pos": 1}, "refSeq": "M"}, + {"break1Start": {"pos": 1}, "refSeq": "R"}, ) def test_same_refseq_matches(self): assert match.compare_positional_variants( - {"break1Start": {"pos": 1}, "refSeq": "R"}, {"break1Start": {"pos": 1}, "refSeq": "R"} + conn, + {"break1Start": {"pos": 1}, "refSeq": "R"}, + {"break1Start": {"pos": 1}, "refSeq": "R"}, ) def test_range_vs_sub(self): @@ -364,8 +379,8 @@ def test_range_vs_sub(self): "refSeq": "G", "untemplatedSeq": "VV", } - assert not match.compare_positional_variants(sub, range_variant) - assert not match.compare_positional_variants(range_variant, sub) + assert not match.compare_positional_variants(conn, sub, range_variant) + assert not match.compare_positional_variants(conn, range_variant, sub) class TestMatchPositionalVariant: