diff --git a/src/error_align/path_to_alignment.py b/src/error_align/path_to_alignment.py index 97c6796..c53c8be 100644 --- a/src/error_align/path_to_alignment.py +++ b/src/error_align/path_to_alignment.py @@ -37,6 +37,8 @@ def get_insert_alignment( op_type=OpType.INSERT, hyp_slice=hyp_slice, hyp=subgraph_metadata.hyp_raw[hyp_slice], + left_compound=subgraph_metadata.hyp_idx_map[start_hyp_idx] >= 0, + right_compound=subgraph_metadata.hyp_idx_map[end_hyp_idx - 1] >= 0, ) diff --git a/tests/test_default.py b/tests/test_default.py index b728a20..19e8239 100644 --- a/tests/test_default.py +++ b/tests/test_default.py @@ -71,6 +71,23 @@ def test_error_align_identical() -> None: assert alignment.op_type == OpType.MATCH +def test_partial_substitution_and_insertion() -> None: + """Test error alignment for partial substitutions and insertions with compound markers.""" + + ref = "test" + hyp = "testpartial" + + alignments = error_align(ref, hyp) + + assert len(alignments) == 2 + assert alignments[0].op_type == OpType.SUBSTITUTE + assert alignments[0].left_compound is False + assert alignments[0].right_compound is True + assert alignments[1].op_type == OpType.INSERT + assert alignments[1].left_compound is True + assert alignments[1].right_compound is False + + def test_categorize_char() -> None: """Test character categorization."""