diff --git a/.coveragerc b/.coveragerc index 69e15dd..3f6a94f 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,6 +1,6 @@ [run] -source = src -omit = ./tests/* -concurrency = multiprocessing -parallel = true -sigterm = true +branch = True +source = error_align + +[report] +omit = src/error_align/baselines/* \ No newline at end of file diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 458d12a..19588ac 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,5 +1 @@ -* @corticph/machine-learning-r-d - -# No owners of dependencies to allow @github-actions[bot] to approve and auto-merge dependabot PRs. -pyproject.toml -poetry.lock +* @borgholt diff --git a/.github/assets/logo.svg b/.github/assets/logo.svg new file mode 100644 index 0000000..381357e --- /dev/null +++ b/.github/assets/logo.svg @@ -0,0 +1,57 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ErrorAlign + + + diff --git a/.github/dependabot.yml b/.github/dependabot.yml deleted file mode 100644 index 5300f8d..0000000 --- a/.github/dependabot.yml +++ /dev/null @@ -1,35 +0,0 @@ -version: 2 - -registries: - corti-github: - type: git - url: https://github.com - username: x-access-token - password: ${{ secrets.DEPENDABOT_TOKEN }} - -updates: - # Maintain dependencies for GitHub Actions - - package-ecosystem: "github-actions" - directory: "/" - labels: - - "dependabot" - schedule: - interval: "weekly" - - # Maintain dependencies for Python - - package-ecosystem: "pip" - open-pull-requests-limit: 1 - registries: "*" # Use all, i.e. the ecosystem default and corti-github, registries - directory: "/" - labels: - - "dependabot" - schedule: - interval: "weekly" - allow: - - dependency-type: "all" # Allow both direct and indirect updates for all packages - commit-message: - prefix: "pip prod" - prefix-development: "pip dev" - include: "scope" - # https://docs.github.com/en/code-security/supply-chain-security/keeping-your-dependencies-updated-automatically/configuration-options-for-dependency-updates#insecure-external-code-execution - insecure-external-code-execution: allow diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 319b529..6f14ea5 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -5,18 +5,6 @@ Do not leave this blank. This PR [adds/removes/fixes/replaces] the [feature/bug/etc] because [reason] by doing [x]. --> -## Related Issue(s) - - - -Related to RD-XXXX - -Closes RD-XXXX - 10 to the power 60) + # tested with Python24 vegaseat 07dec2006 + @staticmethod + def int2word(n): + """ + convert an integer number n into a string of english words + """ + # break the number into groups of 3 digits using slicing + # each group representing hundred, thousand, million, billion, ... + n3 = [] + + # create numeric string + ns = str(n) + for k in range(3, 33, 3): + r = ns[-k:] + q = len(ns) - k + # break if end of ns has been reached + if q < -2: + break + else: + if q >= 0: + n3.append(int(r[:3])) + elif q >= -1: + n3.append(int(r[:2])) + elif q >= -2: + n3.append(int(r[:1])) + + + #print n3 # test + + # break each group of 3 digits into + # ones, tens/twenties, hundreds + # and form a string + nw = "" + for i, x in enumerate(n3): + b1 = x % 10 + b2 = (x % 100)//10 + b3 = (x % 1000)//100 + #print b1, b2, b3 # test + if x == 0: + continue # skip + else: + t = TextToNumEng.thousands[i] + if b2 == 0: + nw = TextToNumEng.ones[b1] + t + nw + elif b2 == 1: + nw = TextToNumEng.tens[b1] + t + nw + elif b2 > 1: + nw = TextToNumEng.twenties[b2] + TextToNumEng.ones[b1] + t + nw + if b3 > 0: + nw = TextToNumEng.ones[b3] + "hundred " + nw + return nw + +class NumToTextEng(object): + units = ["", "one", "two", "three", "four", "five", + "six", "seven", "eight", "nine"] + teens = ["", "eleven", "twelve", "thirteen", "fourteen", + "fifteen", "sixteen", "seventeen", "eighteen", "nineteen"] + tens = ["", "ten", "twenty", "thirty", "forty", + "fifty", "sixty", "seventy", "eighty", "ninety"] + thousands = ["","thousand", "million", "billion", "trillion", + "quadrillion", "quintillion", "sextillion", "septillion", "octillion", + "nonillion", "decillion", "undecillion", "duodecillion", "tredecillion", + "quattuordecillion", "quindecillion", "sexdecillion", "septendecillion", "octodecillion", + "novemdecillion", "vigintillion"] + + @staticmethod + def convertTryYear(num): + first = None + second = None + + num_str = str(num) + if len(num_str) > 4 or len(num_str) == 0: + raise ValueError("Invalid date: {0}", num) + elif len(num_str) >= 3: + # Check if it could be a year + if len(num_str) == 3: + first = int(num_str[0]) + second = int(num_str[1:len(num_str)]) + elif len(num_str) == 4: + first = int(num_str[0:2]) + second = int(num_str[1:len(num_str)]) + + if not (first and second): + raise ValueError("Invalid date: {0:d}", num) + + return "{0} {1}".format(NumToTextEng.convert(first), NumToTextEng.convert(second)) + else: + return NumToTextEng.convert(num) + + + @staticmethod + def convert(num): + if not isinstance(num, int): + raise TypeError + + words = [] + if num == 0: + words.append("zero") + else: + numStr = "%d" % num + numStrLen = len(numStr) + groups = (numStrLen + 2) // 3 + numStr = numStr.zfill((groups) * 3) + for i in range(0, groups*3, 3): + h = int(numStr[i]) + t = int(numStr[i+1]) + u = int(numStr[i+2]) + g = groups - (i // 3 + 1) + + if h >= 1: + words.append(NumToTextEng.units[h]) + words.append("hundred") + + if t > 1: + words.append(NumToTextEng.tens[t]) + if u >= 1: + words.append(NumToTextEng.units[u]) + elif t == 1: + if u >= 1: + words.append(NumToTextEng.teens[u]) + else: + words.append(NumToTextEng.tens[t]) + else: + if u >= 1: + words.append(NumToTextEng.units[u]) + + if g >= 1 and (h + t + u) > 0: + words.append(NumToTextEng.thousands[g]) + return ' '.join(words) \ No newline at end of file diff --git a/src/error_align/baselines/power/power/__init__.py b/src/error_align/baselines/power/power/__init__.py new file mode 100644 index 0000000..4c4a18b --- /dev/null +++ b/src/error_align/baselines/power/power/__init__.py @@ -0,0 +1,3 @@ +# from align_labeler import AlignLabeler +from error_align.baselines.power.power.levenshtein import ExpandedAlignment, Levenshtein +from error_align.baselines.power.power.phonemes import Phonemes diff --git a/src/error_align/baselines/power/power/aligner.py b/src/error_align/baselines/power/power/aligner.py new file mode 100644 index 0000000..22b239b --- /dev/null +++ b/src/error_align/baselines/power/power/aligner.py @@ -0,0 +1,467 @@ +from __future__ import division + +from error_align.baselines.power.power.levenshtein import AlignLabels, ExpandedAlignment, Levenshtein +from error_align.baselines.power.power.phonemes import Phonemes +from error_align.baselines.power.power.pronounce import PronouncerBase, PronouncerLex, PronouncerType + + +class TokType: + WordBoundary = 1 + SyllableBoundary = 2 + Phoneme = 3 + Empty = 0 + + @staticmethod + def checkAnnotation(tok): + tt = TokType.Phoneme + if not tok: + tt = TokType.Empty + elif tok == '|': + tt = TokType.WordBoundary + elif tok == '#': + tt = TokType.SyllableBoundary + return tt + +class CharToWordAligner: + def __init__(self, ref, hyp, lowercase=False): + self.refwords = ref.strip() + self.hypwords = hyp.strip() + self.ref = ref + self.hyp = hyp + self.lowercase = lowercase + self.char_align = None + self.word_align = None + + def charAlign(self): + ref_chars = [x for x in self.ref] + [' '] + hyp_chars = [x for x in self.hyp] + [' '] + + lev = Levenshtein.align(ref_chars, hyp_chars, lowercase=self.lowercase, reserve_list=set([' '])) + lev.editops() + self.char_align = lev.expandAlign() + return self.char_align + + def charAlignToWordAlign(self): + if not self.char_align: + raise Exception("char_align is None") + + ref_word_align = [] + hyp_word_align = [] + align_word = [] + + tmp_ref_word = [] + tmp_hyp_word = [] + + for i in range(len(self.char_align.align)): + ref_char = self.char_align.s1[i] + hyp_char = self.char_align.s2[i] + align_char = self.char_align.align[i] + + # check if both words are completed + # There are a few of ways this could happen: + if ((align_char == AlignLabels.correct and ref_char == ' ') or + (align_char == AlignLabels.deletion and ref_char == ' ') or + (align_char == AlignLabels.insertion and hyp_char == ' ')): + + ref_word = ''.join(tmp_ref_word) + hyp_word = ''.join(tmp_hyp_word) + + if ref_word or hyp_word: + ref_word_align.append(ref_word) + hyp_word_align.append(hyp_word) + tmp_ref_word = [] + tmp_hyp_word = [] + + # Check align type + if ref_word and hyp_word: + if ref_word == hyp_word: + align_word.append(AlignLabels.correct) + else: + align_word.append(AlignLabels.substitution) + elif ref_word: + align_word.append(AlignLabels.deletion) + else: + align_word.append(AlignLabels.insertion) + continue + + # Read current chars and check if one of the words is complete + if ref_char == ' ': + if len(tmp_ref_word) > 1: + # Probably a D + ref_word = ''.join(tmp_ref_word) + ref_word_align.append(ref_word) + hyp_word_align.append('') + tmp_ref_word = [] + align_word.append(AlignLabels.deletion) + else: + tmp_ref_word.append(ref_char) + + if hyp_char == ' ': + if len(tmp_hyp_word) > 1: + # Probably an I + hyp_word = ''.join(tmp_hyp_word) + ref_word_align.append('') + hyp_word_align.append(hyp_word) + tmp_hyp_word = [] + align_word.append(AlignLabels.insertion) + else: + tmp_hyp_word.append(hyp_char) + + self.word_align = ExpandedAlignment(ref_word_align, hyp_word_align, align_word, lowercase=self.lowercase) + return self.word_align + +class PowerAligner: + # Exclusive tokens that can only align to themselves; not other members in this set. + reserve_list = set(['|', '#']) + + # R-sounds + r_set = set.union(set('r'), Phonemes.r_vowels) + exclusive_sets = [Phonemes.vowels, Phonemes.consonants, r_set] + + phoneDistPenalty = 0.25 + phoneDistPenalt16ySet = set(['|']) + + def __init__(self, ref, hyp, lowercase=False, verbose=False, + pronounce_type=PronouncerType.Lexicon, + lexicon=None, + word_align_weights=Levenshtein.wordAlignWeights): + if not ref: + raise Exception("No reference file.\nref: {0}\nhyp: {1}".format(ref, hyp)) + + if pronounce_type == PronouncerType.Lexicon: + self.pronouncer = PronouncerLex(lexicon) + else: + self.pronouncer = PronouncerBase() + + self.ref = [x for x in ref.strip().split() if x] + self.hyp = [x for x in hyp.strip().split() if x] + self.refwords = ' '.join(self.ref) + self.hypwords = ' '.join(self.hyp) + + self.lowercase = lowercase + self.verbose = verbose + + # Perform word alignment + lev = Levenshtein.align(self.ref, self.hyp, lowercase=self.lowercase, weights=word_align_weights) + lev.editops() + self.wer_alignment = lev.expandAlign() + self.wer, self.wer_components = self.wer_alignment.error_rate() + + # Used for POWER alignment + self.power_alignment = None + self.power = None + self.power_components = None + + # Used to find potential error regions + self.split_regions = None + self.error_indexes = None + self.phonetic_alignments = None + self.phonetic_lev = None + + def align(self): + # Find the error regions that may need to be realigned + self.split_regions, self.error_indexes = self.wer_alignment.split_error_regions() + self.phonetic_alignments = [None] * len(self.split_regions) + + for error_index in self.error_indexes: + seg = self.split_regions[error_index] + ref_words = seg.s1_tokens() + hyp_words = seg.s2_tokens() + ref_phones = self.pronouncer.pronounce(ref_words) + hyp_phones = self.pronouncer.pronounce(hyp_words) + + power_seg_alignment, self.phonetic_alignments[error_index] = PowerAligner.phoneAlignToWordAlign(ref_words, hyp_words, + ref_phones, hyp_phones) + + # Replace the error region at the current index. + self.split_regions[error_index] = power_seg_alignment + + # Merge the alignment segments back together. + self.power_alignment = ExpandedAlignment(self.split_regions[0].s1, self.split_regions[0].s2, + self.split_regions[0].align, + self.split_regions[0].s1_map, self.split_regions[0].s2_map, lowercase=self.lowercase) + for i in range(1, len(self.split_regions)): + self.power_alignment.append_alignment(self.split_regions[i]) + + # Get the alignment score + self.power, self.power_components = self.power_alignment.error_rate() + + assert self.hypwords == self.power_alignment.s2_string(), "hyp mismatch:\n{0}\n{1}".format(self.hypwords, self.power_alignment.s2_string()) + assert self.refwords == self.power_alignment.s1_string(), "ref mismatch:\n{0}\n{1}".format(self.refwords, self.power_alignment.s1_string()) + + # TODO: Make this simpler (and maybe recursive) + @classmethod + def phoneAlignToWordAlign(cls, ref_words, hyp_words, ref_phones, hyp_phones, break_on_syllables=True): + ref_word_span = (0, len(ref_words)) + hyp_word_span = (0, len(hyp_words)) + + # Perform Levenshtein Alignment + lev = Levenshtein.align(ref=ref_phones, + hyp=hyp_phones, + reserve_list=PowerAligner.reserve_list, + exclusive_sets=PowerAligner.exclusive_sets, + weights=Levenshtein.wordAlignWeights) #, + #dist_penalty=PowerAligner.phoneDistPenalty, dist_penalty_set=Levenshtein.wordAlignWeights) + phone_align = lev.expandAlignCompact() + + worklist = list() + worklist.append((ref_word_span, hyp_word_span, phone_align)) + + full_reference = list() + full_hypothesis = list() + full_alignment = list() + full_phone_align = list() + + while worklist: + # Take the next set of sequence boundaries off the worklist + ref_word_span, hyp_word_span, phone_align = worklist.pop() + ref_word_index, ref_word_limit = ref_word_span + hyp_word_index, hyp_word_limit = hyp_word_span + + # TODO: Currently only checking in the forward direction + ref_word_builder = [] # Temp storage of words in alignment span + hyp_word_builder = [] + + ref_word_iter = enumerate(ref_words[ref_word_span[0]:ref_word_span[1]]) # Iterates through the surface words + hyp_word_iter = enumerate(hyp_words[hyp_word_span[0]:hyp_word_span[1]]) + + ref_aligned = [] # Finalized alignments + hyp_aligned = [] + alignment = [] # Finalized alignment labels + + ref_extra_syllable_word_index = None # Used for marking words mapping to extra syllables in alignment. + hyp_extra_syllable_word_index = None + ref_syllable_count = 0 + hyp_syllable_count = 0 + + ref_word_started = False # Indicates whether a word is already accounted for in the alignment when a phoneme is reached. + hyp_word_started = False + + advance_worklist = False + commit_alignment = False + + for i in range(len(phone_align.align)): + ref_type = TokType.checkAnnotation(phone_align.s1[i]) + hyp_type = TokType.checkAnnotation(phone_align.s2[i]) + + # Check if word boundaries are reached, both on ref an hyp -- or the case where no more symbols can be read. + if (i == len(phone_align.align) - 1) or (ref_type == TokType.WordBoundary and ref_type == hyp_type): + align_tok = None + # Only write outputs if either the ref or the hyp has scanned some words. + if ref_word_builder: + if hyp_word_builder: + align_tok = AlignLabels.substitution if ref_word_builder != hyp_word_builder else AlignLabels.correct + else: + align_tok = AlignLabels.deletion + elif hyp_word_builder: + align_tok = AlignLabels.insertion + + if align_tok: + # Add the remainder to the worklist + ref_word_span_next = (ref_word_index + len(ref_word_builder), ref_word_limit) + hyp_word_span_next = (hyp_word_index + len(hyp_word_builder), hyp_word_limit) + phone_align_next = phone_align.subsequence(i, phone_align.length(), preserve_index=False) + worklist.append((ref_word_span_next, hyp_word_span_next, phone_align_next)) + + # "Commit" the current alignment + if align_tok in (AlignLabels.correct, AlignLabels.substitution): + alignment.append(align_tok) + + # Check for syllable conflicts + if not break_on_syllables or not ref_extra_syllable_word_index: + ref_aligned.append(' '.join(ref_word_builder)) + ref_syllable_count = 0 + hyp_syllable_count = 0 + else: + ref_aligned.append(' '.join(ref_word_builder[0:ref_extra_syllable_word_index])) + # The remaining words are deletions + for word in ref_word_builder[ref_extra_syllable_word_index:]: + alignment.append(AlignLabels.deletion) + ref_aligned.append(word) + hyp_aligned.append('') + ref_syllable_count = 0 + + if not break_on_syllables or not hyp_extra_syllable_word_index: + hyp_aligned.append(' '.join(hyp_word_builder)) + ref_syllable_count = 0 + hyp_syllable_count = 0 + else: + hyp_aligned.append(' '.join(hyp_word_builder[0:hyp_extra_syllable_word_index])) + # The remaining words are insertions + for word in hyp_word_builder[hyp_extra_syllable_word_index:]: + alignment.append(AlignLabels.insertion) + ref_aligned.append('') + hyp_aligned.append(word) + hyp_syllable_count = 0 + + if align_tok == AlignLabels.substitution: + # Check if you need to rework this alignment. + if len(ref_word_builder) != len(hyp_word_builder): + # Word count mismatch in the alignment span. Is there a possibility that we need to re-align this segment? + ref_word_span_curr = (ref_word_index, ref_word_index + len(ref_word_builder)) + hyp_word_span_curr = (hyp_word_index, hyp_word_index + len(hyp_word_builder)) + phone_align_curr = phone_align.subsequence(0, i+1, preserve_index=False) + + lev = Levenshtein.align( + ref=phone_align_curr.s1_tokens(), + hyp=phone_align_curr.s2_tokens(), + reserve_list=PowerAligner.reserve_list, + exclusive_sets=PowerAligner.exclusive_sets, + weights=Levenshtein.wordAlignWeights) #, + #dist_penalty=PowerAligner.phoneDistPenalty, dist_penalty_set=Levenshtein.wordAlignWeights) + + phone_align_adjusted = lev.expandAlignCompact() + + if phone_align_curr.align != phone_align_adjusted.align: + # Looks like we need to redo the phone-to-word alignment. + worklist.append((ref_word_span_curr, hyp_word_span_curr, phone_align_adjusted)) + else: + commit_alignment = True + else: + commit_alignment = True + + elif align_tok == AlignLabels.deletion: + for word in ref_word_builder: + alignment.append(align_tok) + ref_aligned.append(word) + hyp_aligned.append('') + + commit_alignment = True + ref_syllable_count = 0 + + elif align_tok == AlignLabels.insertion: + for word in hyp_word_builder: + alignment.append(align_tok) + ref_aligned.append('') + hyp_aligned.append(word) + + commit_alignment = True + hyp_syllable_count = 0 + + if commit_alignment: + # Commit the alignment. + full_reference.extend(ref_aligned) + full_hypothesis.extend(hyp_aligned) + full_alignment.extend(alignment) + full_phone_align.append(phone_align.subsequence(0, i, preserve_index=False)) + ref_aligned = [] + hyp_aligned = [] + alignment = [] + break + + # Add words if word boundaries are reached. + else: + if ref_type == TokType.WordBoundary: + ref_word_started = False + if hyp_type != TokType.WordBoundary and ref_word_builder and not hyp_word_builder: + # DELETION + # Ref word ended, but no hyp words have been added. Mark the current ref word(s) in the span as deletion errors. + # TODO: Dedupe this logic + for word in ref_word_builder: + alignment.append(AlignLabels.deletion) + ref_aligned.append(word) + hyp_aligned.append('') + ref_syllable_count = 0 + + # Commit the alignment. + full_reference.extend(ref_aligned) + full_hypothesis.extend(hyp_aligned) + full_alignment.extend(alignment) + full_phone_align.append(phone_align.subsequence(0, i, preserve_index=False)) + + # Add the remainder to the worklist + ref_word_span_next = (ref_word_index + len(ref_word_builder), ref_word_limit) + hyp_word_span_next = (hyp_word_index + len(hyp_word_builder), hyp_word_limit) + lev = Levenshtein.align( + ref=[x for x in phone_align.s1[i:] if x], + hyp=[x for x in phone_align.s2 if x], + reserve_list=PowerAligner.reserve_list, + exclusive_sets=PowerAligner.exclusive_sets, + weights=Levenshtein.wordAlignWeights) #, + #dist_penalty=PowerAligner.phoneDistPenalty, dist_penalty_set=Levenshtein.wordAlignWeights) + phone_align_next = lev.expandAlignCompact() + + worklist.append((ref_word_span_next, hyp_word_span_next, phone_align_next)) + break + elif ref_type == TokType.Phoneme and not ref_word_started: + ref_word_started = True + try: + ref_word_item = ref_word_iter.__next__() + ref_word_builder.append(ref_word_item[1]) + except StopIteration: + pass + + if hyp_type == TokType.WordBoundary: + hyp_word_started = False + if ref_type != TokType.WordBoundary and hyp_word_builder and not ref_word_builder: + # INSERTION + # Hyp word ended, but no ref words have been added. Mark the current hyp word(s) in the span as insertion errors. + # TODO: Dedupe this logic + for word in hyp_word_builder: + alignment.append(AlignLabels.insertion) + ref_aligned.append('') + hyp_aligned.append(word) + hyp_syllable_count = 0 + + # Commit the alignment. + full_reference.extend(ref_aligned) + full_hypothesis.extend(hyp_aligned) + full_alignment.extend(alignment) + full_phone_align.append(phone_align.subsequence(0, i, preserve_index=False)) + + # Add the remainder to the worklist + ref_word_span_next = (ref_word_index + len(ref_word_builder), ref_word_limit) + hyp_word_span_next = (hyp_word_index + len(hyp_word_builder), hyp_word_limit) + lev = Levenshtein.align( + ref=[x for x in phone_align.s1 if x], + hyp=[x for x in phone_align.s2[i:] if x], + reserve_list=PowerAligner.reserve_list, + exclusive_sets=PowerAligner.exclusive_sets, + weights=Levenshtein.wordAlignWeights) #, + #dist_penalty=PowerAligner.phoneDistPenalty, dist_penalty_set=Levenshtein.wordAlignWeights) + phone_align_next = lev.expandAlignCompact() + + worklist.append((ref_word_span_next, hyp_word_span_next, phone_align_next)) + break + elif hyp_type == TokType.Phoneme and not hyp_word_started: + hyp_word_started = True + try: + hyp_word_item = hyp_word_iter.__next__() + hyp_word_builder.append(hyp_word_item[1]) + except StopIteration: + pass + + # Check for syllable mismatches + if ref_type == TokType.SyllableBoundary: + ref_syllable_count += 1 + if hyp_type == TokType.SyllableBoundary: + hyp_syllable_count += 1 + + if (ref_type == TokType.SyllableBoundary == hyp_type or ref_syllable_count == hyp_syllable_count): + # No syllable conflicts here! + ref_extra_syllable_word_index = None + hyp_extra_syllable_word_index = None + elif (ref_type == TokType.SyllableBoundary and + not ref_extra_syllable_word_index and + TokType.checkAnnotation(phone_align.s2[i - 1]) == TokType.WordBoundary): + # Extra syllable in hypothesis. We only care if the syllable immediately follows a word boundary. + # This is because this indicates that a new word is being formed, which may likely be an insertion in hyp. + ref_extra_syllable_word_index = len(ref_word_builder) - 1 + # print ref_word_builder + # print 'Syllable/word mismatch at', i + # print 'Extra hyp word:', ref_word_builder[ref_extra_syllable_word_index] + elif (hyp_type == TokType.SyllableBoundary and + not hyp_extra_syllable_word_index and + TokType.checkAnnotation(phone_align.s2[i - 1]) == TokType.WordBoundary): + # This time there's an extra syllable in the ref, corresponding to a new ref word. + hyp_extra_syllable_word_index = len(hyp_word_builder) - 1 + # print hyp_word_builder + # print 'Syllable/word mismatch at', i + # print 'Extra ref word:', hyp_word_builder[hyp_extra_syllable_word_index] + # Concatenate all phoneme alignments + fp_align = full_phone_align[0] + for expand_align in full_phone_align[1:]: + fp_align.append_alignment(expand_align) + + return ExpandedAlignment(full_reference, full_hypothesis, full_alignment), fp_align + diff --git a/src/error_align/baselines/power/power/levenshtein.py b/src/error_align/baselines/power/power/levenshtein.py new file mode 100644 index 0000000..a648ec4 --- /dev/null +++ b/src/error_align/baselines/power/power/levenshtein.py @@ -0,0 +1,576 @@ +from __future__ import division + +import itertools +import re +from collections import Counter, defaultdict, deque + + +class AlignLabels: + correct = 'C' + substitution = 'S' + insertion = 'I' + deletion = 'D' + validOptions = set([correct, substitution, insertion, deletion]) + + +class ExpandedAlignment: + '''Levenshtein-aligned reference and hypothesis, not just edit distance score.''' + + def __init__(self, s1, s2, align, s1_map=None, s2_map=None, lowercase=False): + if not (len(s1) == len(s2) == len(align)): + raise Exception("Length mismatch: align:{0:d}, s1:{1:d}, s2:{2:d}".format( + len(align), len(s1), len(s2))) + if len(align) == 0: + raise Exception("No alignment: strings are empty") + + self.s1 = s1 + self.s2 = s2 + self.align = align + self.s1_map = s1_map + self.s2_map = s2_map + self.lowercase = lowercase + + if not s1_map or s2_map: + self.recompute_alignment_maps() + + def __str__(self): + widths = [max(len(self.s1[i]), len(self.s2[i])) + for i in range(len(self.s1))] + s1_args = zip(widths, self.s1) + s2_args = zip(widths, self.s2) + align_args = zip(widths, self.align) + + value = 'REF: %s\n' % ' '.join(['%-*s' % x for x in s1_args]) + value += 'HYP: %s\n' % ' '.join(['%-*s' % x for x in s2_args]) + value += 'Eval: %s' % ' '.join(['%-*s' % x for x in align_args]) + + return value + + def s1_string(self): + return ' '.join(self.s1_tokens()) + + def s2_string(self): + return ' '.join(self.s2_tokens()) + + def s1_tokens(self): + return [x for x in self.s1 if x != ''] + + def s2_tokens(self): + return [x for x in self.s2 if x != ''] + + def s1_align_tokens(self, i): + return self.s1[i].split() + + def s2_align_tokens(self, i): + return self.s2[i].split() + + def ref(self): + return self.s1_tokens() + + def hyp(self): + return self.s2_tokens() + + def length(self): + return len(self.align) + + def pos(self, word): + s1_idx = [i for i in range(len(self.s1)) if self.s1[i] == word] + s2_idx = [i for i in range(len(self.s2)) if self.s2[i] == word] + + return s1_idx, s2_idx + + def subsequence(self, i, j, preserve_index=False): + # TODO: Right now we're losing any s1_map and s2_map components for compatibility reasons. Refactoring necessary. + scale = 0 if preserve_index else i + s1_map = [self.s1_map[k] - + scale for k in range(len(self.s1_map)) if i <= self.s1_map[k] < j] + s2_map = [self.s2_map[k] - + scale for k in range(len(self.s2_map)) if i <= self.s2_map[k] < j] + return ExpandedAlignment(self.s1[i:j], self.s2[i:j], self.align[i:j], s1_map, s2_map, lowercase=self.lowercase) + + def split_error_regions(self, error_pattern='[SDI]*S[SDI]+|[SDI]+S[SDI]*'): + ''' + Splits the object into a list of multiple segments. + Some segments are defined as error regions, containing at least one substitution error. + These error regions may candidates for realignment to alignment precision for downstream tasks. + (i.e. phoneme alignment or character alignment) + ''' + split_regions = [] + error_indexes = [] + + p = re.compile(error_pattern) + # Find candidate error regions in the current utterance + # Matches go from left to right + prev_index = 0 + + err_str = ''.join(self.align) + for match in p.finditer(err_str): + i, j = match.span() + if prev_index < i: + # Previous items are rolled up into a 'correct' segment + split_regions.append(self.subsequence(prev_index, i)) + # Get the error region + error_indexes.append(len(split_regions)) + split_regions.append(self.subsequence(i, j)) + prev_index = j + + # Add the trailing segment. + if prev_index < len(self.align): + split_regions.append(self.subsequence(prev_index, len(self.align))) + return split_regions, error_indexes + + def append_alignment(self, expanded_alignment): + ''' + Concatenates a string alignment to the current object. + ''' + map_offset = self.length() + + self.s1 += expanded_alignment.s1 + self.s2 += expanded_alignment.s2 + self.align += expanded_alignment.align + if self.s1_map and expanded_alignment.s1_map: + self.s1_map += [align_pos + + map_offset for align_pos in expanded_alignment.s1_map] + if self.s2_map and expanded_alignment.s2_map: + self.s2_map += [align_pos + + map_offset for align_pos in expanded_alignment.s2_map] + + def recompute_alignment_maps(self): + ''' + Regenerates s1_map and s2_map based on the alignment info. + ''' + self.s1_map = [] + self.s2_map = [] + + for i in range(self.length()): + if self.align[i] in (AlignLabels.correct, AlignLabels.substitution, AlignLabels.deletion): + self.s1_map.extend([i] * len(self.s1[i].split())) + if self.align[i] in (AlignLabels.correct, AlignLabels.substitution, AlignLabels.insertion): + self.s2_map.extend([i] * len(self.s2[i].split())) + + def error_rate(self, cluster_on_ref=False): + ''' + Computes WER or POWER. + self.s1 is considered to be the reference and self.s2 is the hypothesis. + ''' + score_components = {AlignLabels.correct: 0, AlignLabels.substitution: 0, + AlignLabels.deletion: 0, AlignLabels.insertion: 0, 'L': 0} + + for i in range(self.length()): + alignment = self.align[i] + + magnitude = 1 + if alignment != AlignLabels.insertion: + ref_seg_length = len(self.s1[i].split()) + hyp_seg_length = len(self.s2[i].split()) + + if cluster_on_ref: + # NOTE: This relaxes the penalty of errors like + # anatomy -> "and that to me" + # And sharply penalizes + # "and that to me" -> anatomy + # but it might be in better sync with classic WER. + magnitude = ref_seg_length + else: + # NOTE: This causes a mismatch between reference length and the number of substitutions! Is this a problem? + magnitude = max(ref_seg_length, hyp_seg_length) + score_components['L'] += ref_seg_length + + score_components[alignment] += magnitude + + if not self.s1: + # No reference. Error is 100% + error_rate = 1.0 + else: + error_rate = (score_components[AlignLabels.substitution] + score_components[AlignLabels.deletion] + + score_components[AlignLabels.insertion]) / score_components['L'] + return error_rate, score_components + + def confusion_pairs(self): + d = defaultdict(Counter) + for i in range(len(self.align)): + if self.align[i] == AlignLabels.substitution: + s1 = self.s1[i] + s2 = self.s2[i] + if self.lowercase: + s1 = s1.lower() + s2 = s2.lower() + d[s1] += Counter({s2: 1}) + return d + + def alignment_capacity(self): + ''' + Returns the number of word slots occupied by each alignment point. + ''' + return [(len(self.s1_align_tokens(i)), len(self.s2_align_tokens(i))) for i in range(self.length())] + + def hyp_oriented_alignment(self, hyp_only=True): + ''' + Returns all alignment tokens. + If an S slot is an multiword alignment, duplicates AlignLabels.substitution by the capacity. + TODO: Move to subclass. + ''' + alignment = [] + ref_align_len, hyp_align_len = zip(*self.alignment_capacity()) + + for i in range(self.length()): + if hyp_only: + # Treat each hyp token as a substitution error + alignment.extend(self.align[i] * max(1, hyp_align_len[i])) + else: + len_diff = ref_align_len[i] - hyp_align_len[i] + if len_diff == 0: + alignment.extend(self.align[i] * hyp_align_len[i]) + elif len_diff < 0: + len_diff = -len_diff + alignment.extend(self.align[i] * ref_align_len[i]) + alignment.extend([AlignLabels.insertion] * len_diff) + else: + alignment.extend(self.align[i] * hyp_align_len[i]) + alignment.extend([AlignLabels.deletion] * len_diff) + return alignment + + +class Levenshtein: + def __init__(self, lowercase=False, tokenMap=None): + self.backMatrix = None + self.distMatrix = None + self.dist = -1 + self.s1 = None + self.s2 = None + self.edits = None + self.lowercase = lowercase + + uniformWeights = {AlignLabels.correct: 0, AlignLabels.substitution: 1, + AlignLabels.deletion: 1, AlignLabels.insertion: 1} + wordAlignWeights = {AlignLabels.correct: 0, AlignLabels.substitution: 4, + AlignLabels.deletion: 3, AlignLabels.insertion: 3} + + @staticmethod + def align(ref, hyp, reserve_list=None, exclusive_sets=None, lowercase=False, weights=None, dist_penalty=0.5, dist_penalty_set=None): + ''' + Creates an alignment with hyp x ref matrix. + reserve_list defines tokens that may never have 'S' alignments. + exclusive_sets defines families of tokens that can have 'S' alignments. Anything outside of exclusive_sets can be aligned to any other nonmember. + ''' + if not weights: + weights = Levenshtein.uniformWeights + lev = Levenshtein(lowercase=lowercase) + lev.s1 = ref + lev.s2 = hyp + + # If using distance penaties: + distPenaltyRef = 0 + distPenaltyHyp = 0 + + if lowercase: + ref = [x.lower() for x in ref] + hyp = [x.lower() for x in hyp] + + #pp = pprint.PrettyPrinter(width=300) + lev.backMatrix = BackTrackMatrix(len(ref), len(hyp), weights) + + # Starts with 1st word in hyp + for index2, char2 in enumerate(hyp): + + if dist_penalty_set and char2 not in dist_penalty_set: + distPenaltyRef += 1 + else: + distPenaltyRef = 0 + + # Loop through columns, corresponding to characters in hyp + for index1, char1 in enumerate(ref): + if dist_penalty_set and char1 not in dist_penalty_set: + distPenaltyHyp += 1 + else: + distPenaltyHyp = 0 + + match_char = AlignLabels.substitution + # Add insert/delete options + insPenalty = lev.backMatrix.getWeight( + index2, index1+1) + weights[AlignLabels.insertion] + delPenalty = lev.backMatrix.getWeight( + index2+1, index1) + weights[AlignLabels.deletion] + + if dist_penalty_set: + insPenalty += (distPenaltyHyp * dist_penalty) * \ + weights[AlignLabels.insertion] + delPenalty += (distPenaltyRef * dist_penalty) * \ + weights[AlignLabels.deletion] + + opts = [insPenalty, # I + delPenalty, # D + ] + + if char1 == char2: + opts.append(lev.backMatrix.getWeight( + index2, index1) + weights[AlignLabels.correct]) # C + match_char = AlignLabels.correct + elif not reserve_list or not (char1 in reserve_list or char2 in reserve_list): + if exclusive_sets: + # Check if char1 and char2 belong to the same exclusive set + no_membership_set = set([-1]) + + # Change this to handle multiple set membership + char1_sets = set( + [k for k in range(len(exclusive_sets)) if char1 in exclusive_sets[k]]) + if not char1_sets: + char1_sets = no_membership_set + char2_sets = set( + [k for k in range(len(exclusive_sets)) if char2 in exclusive_sets[k]]) + if not char2_sets: + char2_sets = no_membership_set + + if set.intersection(char1_sets, char2_sets): + opts.append(lev.backMatrix.getWeight( + index2, index1) + weights[AlignLabels.substitution]) # S + else: + opts.append(lev.backMatrix.getWeight( + index2, index1) + weights[AlignLabels.substitution]) # S + + # Get the locally minimum score and edit operation + minDist = min(opts) + minIndices = [i for i in reversed( + range(len(opts))) if opts[i] == minDist] + + # Build the backtrack + alignLabels = [] + for minIndex in minIndices: + if minIndex == 2: + alignLabels.append(match_char) + elif minIndex == 0: + alignLabels.append(AlignLabels.insertion) + elif minIndex == 1: + alignLabels.append(AlignLabels.deletion) + + lev.backMatrix.addBackTrack( + index2+1, index1+1, alignLabels, minDist) # S/C + pass + + lev.dist = lev.backMatrix.getWeight( + lev.backMatrix.hyplen, lev.backMatrix.reflen) + return lev + + def matchPositions(self, token, token2=None, min_i=None, min_j=None, max_i=None, max_j=None): + if not min_i: + min_i = 0 + if not max_i: + max_i = len(self.s2) + + if not min_j: + min_j = 0 + if not max_j: + max_j = len(self.s1) + + if not token2: + token2 = token + + # Find match positions of token in hyp + hypIdx = [i+1 for i in range(min_i, max_i) if self.s2[i] == token] + refIdx = [j+1 for j in range(min_j, max_j) if self.s1[j] == token] + + return list(itertools.product(*[hypIdx, refIdx])) + + def bestPathsGraph(self, minPos=None, maxPos=None): + """ + Takes all of the best Levenshtein alignment backtrack paths and puts them in a graph. + The graph is weighted by the distance between minPos and maxPos for all paths. + """ + import networkx as nx + if not minPos: + minPos = (0, 0) + if not maxPos: + maxPos = (self.backMatrix.hyplen, self.backMatrix.reflen) + + chart = deque() + chart.appendleft(maxPos) + + G = nx.Graph() + + from time import time + start = time() + while chart: + (i, j) = chart.pop() + + for alignLabel in self.backMatrix.matrix[i][j].backTrackOptions: + child = self.backMatrix.matrix[i][j].getBackTrackOffset( + alignLabel) + prev_i = i + child[1][0] + prev_j = j + child[1][1] + + align = child[0] + rlabel = self.s1[j-1] if prev_j < j else '' + hlabel = self.s2[i-1] if prev_i < i else '' + + # Weight applied based on whether (i & prev_i) or (j & prev_j) are along the hull of minPos or maxPos + weight = 1 + if (i == prev_i and i in (minPos[0], maxPos[0])) or (j == prev_j and j in (minPos[1], maxPos[1])): + weight = 0 + + G.add_edge((i, j), (prev_i, prev_j), weight=weight, + labels=(rlabel, hlabel, align)) + chart.appendleft((prev_i, prev_j)) + + if time() - start > 120: + print("\nWarning: Long computation time\n") + raise AssertionError("Computation took too long") + return G + + def editops(self): + ''' + Records edit distance operations in a compact format, changing s1 to s2. + ''' + i = self.backMatrix.hyplen + j = self.backMatrix.reflen + back = [] + + while i > 0 or j > 0: + op = self.backMatrix.matrix[i][j].getBackTrackOffset() + off_i, off_j = op[1] + i += off_i + j += off_j + back.append((op[0], (i, j))) + + back.reverse() + self.edits = back + return back + + def expandAlign(self): + ''' + Expands the edit operations to actually align the strings. + Also contains maps to track the positions of each character in the strings + to its aligned position. + ''' + if not self.edits: + return None + + s1 = [] + s2 = [] + align = [] + s1_map = [] + s2_map = [] + + for op in self.edits: + a = op[0] + i, j = op[1] + + # Bugfix for empty hypotheses or reference (reference shouldn't happen) + c1 = None + c2 = None + if -1 < j < len(self.s1): + c1 = self.s1[j] + if -1 < i < len(self.s2): + c2 = self.s2[i] + + if a == AlignLabels.deletion: + s1.append(c1) + s2.append('') + s1_map.append(len(align)) + elif a == AlignLabels.insertion: + s1.append('') + s2.append(c2) + s2_map.append(len(align)) + else: + s1.append(c1) + s2.append(c2) + s1_map.append(len(align)) + s2_map.append(len(align)) + + align.append(a) + + return ExpandedAlignment(s1, s2, align, s1_map, s2_map, lowercase=self.lowercase) + + def expandAlignCompact(self, minPos=None, maxPos=None): + """ + Using the backtracking matrix, finds all of the paths with the minimum Levenshtein distance score and stores them in a graph. + Then, it returns the expanded alignment of the shortest path in the graph (which still has the same minimum Lev distance score. + """ + import networkx as nx + minPos = (0, 0) + maxPos = (self.backMatrix.hyplen, self.backMatrix.reflen) + G = self.bestPathsGraph(minPos, maxPos) + path = nx.shortest_path( + G, source=minPos, target=maxPos, weight='weight') + + # Expand the best path into the Levenshtein alignment. + s1_align, s2_align, align = [list(a) for a in zip( + *(G[u][v]['labels'] for (u, v) in zip(path[0:], path[1:])))] + return ExpandedAlignment(s1_align, s2_align, align, lowercase=self.lowercase) + + @staticmethod + def errorRate(s, d, i, reflength): + return (s + d + i) / reflength + + +class BackTrackMatrix: + def __init__(self, reflen, hyplen, weights=Levenshtein.uniformWeights): + self.reflen = reflen + self.hyplen = hyplen + self.weights = weights + + self.matrix = [[None] * (reflen+1) for i in range(hyplen+1)] + self.row_count = hyplen + 1 + self.col_count = reflen + 1 + + # Initialize the top corner. + self.matrix[0][0] = BackTrackSlot(0) + + # Initialize first column. + for i in range(1, self.row_count): + self.addBackTrack(i, 0, AlignLabels.insertion, + i*weights[AlignLabels.insertion]) + + # Initialize columns. + for j in range(1, self.col_count): + self.addBackTrack(0, j, AlignLabels.deletion, + j*weights[AlignLabels.deletion]) + + def addBackTrack(self, i, j, alignLabels, weight=1.0): + self.matrix[i][j] = BackTrackSlot(weight) + self.matrix[i][j].addOptions(alignLabels) + + def backTrackOptions(self, i, j): + return self.matrix[i][j] + + def getWeight(self, i, j): + return self.matrix[i][j].weight + + +class BackTrackSlot: + def __init__(self, weight): + self.weight = weight + self.backTrackOptions = list() + + def __str__(self): + return "({0}, {1})".format(self.weight, ','.join(list(self.backTrackOptions))) + + def iterOptions(self): + return iter(self.backTrackOptions) + + def addOption(self, alignLabel): + if alignLabel not in self.backTrackOptions: + self.backTrackOptions.append(alignLabel) + + def addOptions(self, alignLabels): + self.backTrackOptions.extend( + [x for x in alignLabels if x not in self.backTrackOptions]) + + def getBackTrackOffset(self, alignLabel=None): + if alignLabel: + # Make sure it exists + if alignLabel not in AlignLabels.validOptions: + raise Exception("Invalid backtrack option: %s" % alignLabel) + if alignLabel not in self.backTrackOptions: + raise Exception("Illegal backtrack option: %s" % alignLabel) + else: + # Just arbitrarily grab the first item + alignLabel = self.backTrackOptions[0] + + offset = None + if alignLabel in (AlignLabels.correct, AlignLabels.substitution): + offset = (-1, -1) + elif alignLabel == AlignLabels.deletion: + offset = (0, -1) + else: + offset = (-1, 0) + return (alignLabel, offset) diff --git a/src/error_align/baselines/power/power/phonemes.py b/src/error_align/baselines/power/power/phonemes.py new file mode 100644 index 0000000..5a187ea --- /dev/null +++ b/src/error_align/baselines/power/power/phonemes.py @@ -0,0 +1,16 @@ +class Phonemes(object): + # Vowels + monophthongs = set(['ao', 'aa', 'iy', 'uw', 'eh', 'ih', 'uh', 'ah', 'ax', 'ae']) + diphthongs = set(['ey', 'ay', 'ow', 'aw', 'oy']) + # Add r-colored monophthongs? + r_vowels = set(['er', 'axr']) # The remaining are split into two tokens. + vowels = set.union(monophthongs, diphthongs, r_vowels) + + # Consonants + c_stops = set(['p','b','t','d','k','g']) + c_afficates = set(['ch','jh']) + c_fricatives = set(['f','v','th','dh','s','z','sh','zh','hh']) + c_nasals = set(['m','em','n','en','ng','eng']) + c_liquids = set(['l','el','r','dx','nx']) + c_semivowels = set(['y','w','q']) + consonants = set.union(c_stops, c_afficates, c_fricatives, c_nasals, c_liquids, c_semivowels) \ No newline at end of file diff --git a/src/error_align/baselines/power/power/pronounce.py b/src/error_align/baselines/power/power/pronounce.py new file mode 100644 index 0000000..3e47ae4 --- /dev/null +++ b/src/error_align/baselines/power/power/pronounce.py @@ -0,0 +1,85 @@ +''' +Created on Mar 30, 2015 + +@author: Nick Ruiz +Assuming you have a file of word sequences, this processes them through Festival to generate pronunciations. + +''' +import json +from itertools import groupby + +import pyphen + +from error_align.baselines.power.normalize import NumToTextEng, splitHyphens + + +class PronouncerType: + Base = "base" + Lexicon = "lexicon" + +class PronouncerBase(object): + def __init__(self): + pass + def pronounce(self, words): + '''G2P conversion''' + raise NotImplementedError + +class PronouncerLex(PronouncerBase): + '''Lexicon-based pronunciation generator. Looks up words in the lexicon and if they aren't found, uses a hacky alternative. + NOTE: English-only + ''' + def __init__(self, lexicon): + with open(lexicon, 'r') as f: + self.lexicon = json.load(f) + self.fallbackDict = pyphen.Pyphen(lang='en_US') + + def pronounce(self, words): + '''G2P using Pyphen + heuristics''' + wordsLower = (w.lower() for w in words) + prons = [self.lexicon[w] if w in self.lexicon else self.alt_pronounce( + w) for w in wordsLower] + return "| {0} |".format(' | '.join(prons)).split() + + def alt_pronounce(self, word): + '''Alternative ways to pronounce the word. Adds simple digit to word conversion''' + prons = [] + # Split words along hyphens + wordsSplit = splitHyphens(' '.join(word)) + # Instead of one word, we may have many + for myword in wordsSplit: + if myword.isdigit(): + words = NumToTextEng.convert(int(myword)).split() + pron = ' # '.join((self.pyphen_pronounce(w) for w in words)) + else: + pron = self.pyphen_pronounce(myword) + prons.append(pron) + # Although the hyphenated word is now pronounced as multiple "words", + # we treat again as a single word with multiple syllables + return ' # '.join(prons) + + + def pyphen_pronounce(self, word): + ''' Uses pyphen as a back-off to generate a pseudo-hyphenated-pronounciation for words not in the lexicon. ''' + syllables = self.fallbackDict.inserted(word).split('-') + pron = [] + for syl in syllables: + m = 0 + n = len(syl) + sylpron = [] + while m < n: + pronidx = [(n-i, self.lexicon[syl[m:n-i]]) + for i in range(n) if syl[m:n-i] in self.lexicon] + if not pronidx: + break + j, p = pronidx[0] + sylpron.append(p) + m = j + if not sylpron: + # TODO: No dictionary entries found for this word. Error? + # Break syl into characters + sylpron = syl + + # Remove duplicate adjacent phonemes + sylpron = list(i for i, x in groupby(sylpron)) + pron.append(' '.join(sylpron)) + return ' # '.join(pron) diff --git a/src/error_align/baselines/power/power/punct.py b/src/error_align/baselines/power/power/punct.py new file mode 100644 index 0000000..1ee38c1 --- /dev/null +++ b/src/error_align/baselines/power/power/punct.py @@ -0,0 +1,166 @@ +import copy +import re + +from error_align.baselines.power.power.aligner import CharToWordAligner +from error_align.baselines.power.power.levenshtein import AlignLabels + + +class PunctInsertOracle(object): + ''' + Insert punctuation on hypothesis based on alignment with reference + ''' + + def __init__(self, params): + ''' + Constructor + ''' + pass + + @staticmethod + def insertPunct(error_alignment, ref_punct): + if not (error_alignment.s1_map and error_alignment.s2_map): + # Don't try to add punct if either the hyp or ref are empty. + return error_alignment + + # Align reference with punct against reference without punct + # Is this needed? Yes, in case there's punctuation that makes an extra token + ref_punct_tokens = ref_punct.split() + ref_nopunct_tokens = error_alignment.s1_string().split() # [x for x in error_alignment.s1 if x] + hyp_nopunct_tokens = error_alignment.s2_string().split() # [x for x in error_alignment.s2 if x] + + ref_punct_string = ' '.join(ref_punct_tokens) + ref_nopunct_string = ' '.join(ref_nopunct_tokens) + c2w_aligner = CharToWordAligner(ref_punct_string, ref_nopunct_string, lowercase=True) + c2w_aligner.charAlign() + ref_punct_align = c2w_aligner.charAlignToWordAlign() + +# lev = Levenshtein.align(ref_punct_tokens, ref_nopunct_tokens, lowercase=True) +# lev.editops() +# ref_punct_align = lev.expandAlign() + + error_alignment_punct = copy.deepcopy(error_alignment) + + try: + err_align_index = None + err_align_label = None + ref_punct_index = -1 + ref_nopunct_index = -1 + for ref_align_index in range(len(ref_punct_align.align)): + # Check the error type. + ref_align_error_type = ref_punct_align.align[ref_align_index] + ref_word_punct = ref_punct_align.s1[ref_align_index] + ref_word_nopunct = ref_punct_align.s2[ref_align_index] + + punct_lhs = "" + punct_rhs = "" + + if ref_align_error_type in (AlignLabels.correct, AlignLabels.substitution): + # There's some punct attached to the word. Split it off! + # Idea #1, Use regex to find punct around a word sequence. However, this will fail if there's punctuation in between words + pattern = '^(.*?){0}(.*?)$'.format(ref_word_nopunct) + match = re.search(pattern, ref_word_punct) + + if not match: + # Whoops! something went wrong + raise Exception("Regex match not found\nPattern: {0}\nPunct: {1}\nNo Punct: {2}" + .format(pattern, ref_word_nopunct, ref_word_nopunct)) + punct_lhs = match.group(1) + punct_rhs = match.group(2) + + if punct_lhs: + # Apply to front of hyp + punct_lhs = punct_lhs.strip() + else: + punct_lhs = "" + if punct_rhs: + # Apply to back of hyp + punct_rhs = punct_rhs.strip() + else: + punct_rhs = "" + + ref_punct_index += 1 + ref_nopunct_index += 1 + + elif ref_align_error_type == AlignLabels.deletion: + # TODO: The extra token on the reference side is probably punct. Check if token is punct + # Determine if punct should attach as punct_lhs or punct_rhs. We'll make that decision later. + punct_lhs = '{0} '.format(ref_word_punct) + punct_rhs = ' {0}'.format(ref_word_punct) + + ref_punct_index += 1 + elif ref_align_error_type == 'I': + # TODO: uh-oh. This shouldn't happen! + ref_nopunct_index += 1 + raise TypeError("Insertion errors shouldn't happen between punct and nopunct") + else: + # Now this is impossible! + raise TypeError("Invalid error type: {0}".format(ref_align_error_type)) + + # Advance error_alignment to the word alignment position containing the nopunct reference word + if ref_nopunct_index >= 0: # TODO: Check if the words have changed. + try: + err_align_index = error_alignment.s1_map[ref_nopunct_index] + err_align_label = error_alignment_punct.align[err_align_index] + except Exception: + print(error_alignment) + print("ref_nopunct_index:", ref_nopunct_index) + print("ref_align_index: ", ref_align_index) + raise + + if punct_lhs or punct_rhs: +# print "Punct: |_{0}_| |_{1}_|".format(punct_lhs, punct_rhs) + if err_align_label == AlignLabels.deletion: + # If there's no token here, apply punct to previous or next + # TODO: This won't work for long strings of D's. You'll need to find the next closest one. + + # Scan backward for the previous non-empty hyp word. Try to apply the punct there. + err_align_index_shift = err_align_index - 1 + + # ... except if the punct was at the end of the last ref word + if ref_nopunct_index == len(ref_nopunct_tokens) - 1: + err_align_index_shift = error_alignment.s2_map[len(hyp_nopunct_tokens) - 1] + + while err_align_index_shift > 0 and not error_alignment.s2[err_align_index_shift]: + err_align_index_shift -= 1 + if err_align_index_shift >= 0: + + # Apply punct to previous position + if ref_align_error_type == AlignLabels.deletion: + punct_lhs = "" + + error_alignment_punct.s2[err_align_index_shift] = "{1}{0}{2}".format(error_alignment_punct.s2[err_align_index_shift], punct_lhs, punct_rhs) + else: + # Scan forward for the next non-empty hyp word. Try to apply the punct there. + err_align_index_shift = err_align_index + 1 + while err_align_index_shift < len(hyp_nopunct_tokens) and not error_alignment_punct.s2[err_align_index_shift]: + err_align_index_shift += 1 + if err_align_index_shift < len(hyp_nopunct_tokens): + + # Apply punct to next position + if ref_align_error_type == "D": + punct_rhs = "" + + error_alignment_punct.s2[err_align_index_shift] = "{1}{0}{2}".format(error_alignment_punct.s2[err_align_index_shift], punct_lhs, punct_rhs) + else: + # Otherwise ignore punct altogether + print("Discarding punct") + else: + # Apply punct on the end of the current word. + if ref_align_error_type == "D": + punct_lhs = "" # TODO: Is this assumption correct? Apply punct to rhs? + + # Force punct at the end of the last reference word to appear at the end of the hyp word + if ref_nopunct_index == len(ref_nopunct_tokens) - 1: + err_align_index_shift = error_alignment.s2_map[len(hyp_nopunct_tokens) - 1] + error_alignment_punct.s2[err_align_index_shift] = "{0}{1}".format(error_alignment_punct.s2[err_align_index_shift], punct_rhs) + punct_rhs = "" + + error_alignment_punct.s2[err_align_index] = "{1}{0}{2}".format(error_alignment_punct.s2[err_align_index], punct_lhs, punct_rhs) + + except Exception: + # TODO: Add exception handling to output the sentences as well as the offending segment. + raise + + return error_alignment_punct + + \ No newline at end of file diff --git a/src/error_align/baselines/power/power/readers.py b/src/error_align/baselines/power/power/readers.py new file mode 100644 index 0000000..512aeb6 --- /dev/null +++ b/src/error_align/baselines/power/power/readers.py @@ -0,0 +1,48 @@ +import json +import sys + +from error_align.baselines.power.power.levenshtein import ExpandedAlignment + + +class AlignmentReaderJson(object): + def __init__(self, filepath): + self.filepath = filepath + + def read_alignments(self): + with open(self.filepath, 'r') as f: + for line in f: + yield AlignmentReaderJson.read_json(line) + + @staticmethod + def read_json(jstr): + in_dict = json.loads(jstr) + if not in_dict: + return None + + ref = [] + hyp = [] + align = [] + ref_map = [] + hyp_map = [] + + for i in range(len(in_dict['alignments'])): + alignment = in_dict['alignments'][i] + ref.append(alignment['ref']) + hyp.append(alignment['hyp']) + align.append(alignment['align']) + if ref[-1]: + ref_map.extend([i for r in ref[-1].split()]) + if hyp[-1]: + hyp_map.extend([i for h in hyp[-1].split()]) + return ExpandedAlignment(ref, hyp, align, ref_map, hyp_map) + + +def main(args): + filename = args[0] + reader = AlignmentReaderJson(filename) + for alignment in reader.read_alignments(): + print(alignment) + print('---') + +if __name__ == '__main__': + main(sys.argv[1:]) \ No newline at end of file diff --git a/src/error_align/baselines/power/power/writers.py b/src/error_align/baselines/power/power/writers.py new file mode 100644 index 0000000..62e0267 --- /dev/null +++ b/src/error_align/baselines/power/power/writers.py @@ -0,0 +1,259 @@ +from __future__ import division + +import abc +import json + +from error_align.baselines.power.power.levenshtein import Levenshtein + + +def CreateWriter(output_type, filepath, hypfile, reffile): + writer = None + if output_type == "sgml": + writer = SgmlWriter(filepath, hypfile, reffile) + elif output_type == "snt": + writer = SntWriter(filepath, hypfile, reffile) + elif output_type == "json": + writer = JsonWriter(filepath, hypfile, reffile) + elif output_type == "align": + writer = AlignWriter(filepath, hypfile, reffile) + else: + raise NotImplementedError("Writer not implemented for %s" % format) + return writer + + +class WerWriter(object): + __metaclass__ = abc.ABCMeta + + def __init__(self, filepath, hypfile, reffile): + self.out_file = open(filepath, "w") + self.filepath = filepath + + @abc.abstractmethod + def write(self, segid, score_components, expanded_alignment, phonetic_alignments=None): + return + + def write_blank(self): + return + + def finalize(self): + if self.out_file: + self.out_file.close() + print("File written to {}".format(self.filepath)) + self.out_file = None + + +class AlignWriter(WerWriter): + def __init__(self, filepath, hypfile, reffile): + WerWriter.__init__(self, filepath, hypfile, reffile) + + def write(self, segid, score_components, expanded_alignment, phonetic_alignments=None): + self.out_file.write("{0}\n".format(" ".join(expanded_alignment.alignment_expanded()))) + + def finalize(self): + WerWriter.finalize(self) + + +class SgmlWriter(WerWriter): + def __init__(self, filepath, hypfile, reffile): + WerWriter.__init__(self, filepath, hypfile, reffile) + self.out_file.write( + '\n' % (hypfile, reffile, hypfile) + ) + self.out_file.write('\n') + + def write(self, segid, score_components, expanded_alignment, phonetic_alignments=None): + self.out_file.write('\n' % (segid, len(expanded_alignment.align))) + self.out_file.write( + "%s\n" + % ":".join( + [ + '%s,"%s","%s"' % (expanded_alignment.align[i], expanded_alignment.s1[i], expanded_alignment.s2[i]) + for i in range(len(expanded_alignment.align)) + ] + ) + ) + self.out_file.write("\n") + + def finalize(self): + self.out_file.write("\n") + self.out_file.write("\n") + WerWriter.finalize(self) + + +class JsonWriter(WerWriter): + def __init__(self, filepath, hypfile, reffile): + WerWriter.__init__(self, filepath, hypfile, reffile) + + def write(self, segid, score_components, expanded_alignment, phonetic_alignments=None): + out_dict = dict() + out_dict["id"] = segid + out_dict["errorTypes"] = score_components.copy() + out_dict["errorTypes"]["refLength"] = out_dict["errorTypes"].pop("L") + out_dict["errRate"] = Levenshtein.errorRate( + score_components["S"], score_components["D"], score_components["I"], score_components["L"] + ) + out_dict["alignments"] = [] + for i in range(len(expanded_alignment.align)): + out_dict["alignments"].append( + {"align": expanded_alignment.align[i], "ref": expanded_alignment.s1[i], "hyp": expanded_alignment.s2[i]} + ) + self.out_file.write("%s\n" % json.dumps(out_dict)) + + def write_blank(self): + self.out_file.write("%s\n" % json.dumps({})) + + def finalize(self): + WerWriter.finalize(self) + + +class SntWriter(WerWriter): + def __init__(self, filepath, hypfile, reffile): + WerWriter.__init__(self, filepath, hypfile, reffile) + self.out_file.write("===============================================================================\n") + self.out_file.write(" SENTENCE LEVEL REPORT FOR THE SYSTEM:\n") + self.out_file.write("\tName: %s\n" % hypfile) + self.out_file.write("===============================================================================\n") + self.out_file.write("\n\n") + + def write(self, segid, score_components, expanded_alignment, phonetic_alignments=None): + self.out_file.write("id: (%d)\n" % segid) + labels = ["C", "S", "D", "I"] + counts = [score_components[label] for label in labels] + + # Print score components + self.out_file.write( + "Scores (%s) %s\n" % (" ".join(["#%s" % label for label in labels]), " ".join([str(x) for x in counts])) + ) + + # Print word alignment + self.out_file.write("%s\n" % expanded_alignment) + self.out_file.write("\n") + + # Print phonetic alignments + if phonetic_alignments: + for palign in [p for p in phonetic_alignments if p]: + self.out_file.write("%s\n" % palign) + self.out_file.write("\n") + + # Print statistics + self.out_file.write( + "Correct = {0:4.1%} {1} ({2})\n".format( + score_components["C"] / score_components["L"], score_components["C"], score_components["L"] + ) + ) + self.out_file.write( + "Substitutions = {0:4.1%} {1} ({2})\n".format( + score_components["S"] / score_components["L"], score_components["S"], score_components["L"] + ) + ) + self.out_file.write( + "Deletions = {0:4.1%} {1} ({2})\n".format( + score_components["D"] / score_components["L"], score_components["D"], score_components["L"] + ) + ) + self.out_file.write( + "Insertions = {0:4.1%} {1} ({2})\n".format( + score_components["I"] / score_components["L"], score_components["I"], score_components["L"] + ) + ) + self.out_file.write("\n") + self.out_file.write( + "Errors = {0:4.1%} {1} ({2})\n".format( + Levenshtein.errorRate( + score_components["S"], score_components["D"], score_components["I"], score_components["L"] + ), + score_components["S"] + score_components["D"] + score_components["I"], + score_components["L"], + ) + ) + self.out_file.write("\n") + self.out_file.write( + "Ref. words = {0} ({1})\n".format(score_components["L"], score_components["L"]) + ) + self.out_file.write( + "Hyp. words = {0} ({1})\n".format( + len(expanded_alignment.s2_string().split()), score_components["L"] + ) + ) + self.out_file.write( + "Aligned words = {0} ({1})\n".format( + score_components["C"] + score_components["S"], score_components["L"] + ) + ) + self.out_file.write("\n") + self.out_file.write("-------------------------------------------------------------------------------\n") + self.out_file.write("\n") + + def finalize(self): + WerWriter.finalize(self) + + +class ConfusionPairWriter(WerWriter): + @staticmethod + def write(filepath, hypfile, reffile, conf_dict): + with open(filepath, "w") as out_file: + out_file.write("System name: %s\n" % hypfile) + out_file.write("Ref file : %s\n" % reffile) + + for key in sorted(conf_dict.keys()): + for item in sorted(conf_dict[key].keys()): + out_file.write("%s\t==>\t%s\t%d\n" % (key, item, conf_dict[key][item])) + + @staticmethod + def write_json(filepath, hypfile, reffile, conf_dict): + with open(filepath, "w") as out_file: + out_file.write("%s\n" % json.dumps(conf_dict)) + + +class CompareWriter: + @staticmethod + def write_comparison( + filepath, + hypfile, + reffile, + linecount, + final_power, + final_wer, + power_score_components, + wer_score_components, + diff_score, + diff_components, + ): + with open(filepath, "w") as out_file: + out_file.write("System name: %s\n" % hypfile) + out_file.write("Ref file : %s\n" % reffile) + out_file.write("Hyp file : %s\n" % hypfile) + out_file.write( + """ +,---------------------------------------------------------. +|{0:^57}| +|---------------------------------------------------------| +| Metric | # Snt # Wrd | Corr Sub Del Ins Err | +|--------+-------------+----------------------------------| +| POWER | {1:5d} {8:5d} | {9:5d} {10:5d} {11:5d} {12:5d} {13:3.1%} | +| WER | {1:5d} {2:5d} | {3:5d} {4:5d} {5:5d} {6:5d} {7:3.1%} | +|=========================================================| +| Diff | {1:5d} {2:5d} | {14:-5d} {15:-5d} {16:-5d} {17:-5d} {18:-3.1%} | +`---------------------------------------------------------' +""".format( + hypfile, + linecount, + wer_score_components["L"], + wer_score_components["C"], + wer_score_components["S"], + wer_score_components["D"], + wer_score_components["I"], + final_wer, + power_score_components["L"], + power_score_components["C"], + power_score_components["S"], + power_score_components["D"], + power_score_components["I"], + final_power, + diff_components["C"], + diff_components["S"], + diff_components["D"], + diff_components["I"], + diff_score, + ) + ) diff --git a/src/error_align/baselines/power_alignment.py b/src/error_align/baselines/power_alignment.py new file mode 100644 index 0000000..cbbc816 --- /dev/null +++ b/src/error_align/baselines/power_alignment.py @@ -0,0 +1,61 @@ +from error_align.baselines.power.power.aligner import PowerAligner as _PowerAligner +from error_align.utils import Alignment, OpType + + +class PowerAlign: + """Phonetically-oriented word error alignment.""" + + def __init__( + self, + ref: str, + hyp: str, + ): + """Initialize the phonetically-oriented word error alignment with reference and hypothesis texts. + + Args: + ref (str): The reference sequence/transcript. + hyp (str): The hypothesis sequence/transcript. + """ + self.aligner = _PowerAligner( + ref=ref, + hyp=hyp, + lowercase=True, + verbose=True, + lexicon="/home/lb/repos/power-asr/lex/cmudict.rep.json", + ) + + def align(self): + """Run the two-pass Power alignment algorithm. + + Returns: + list[Alignment]: A list of Alignment objects. + """ + self.aligner.align() + widths = [ + max(len(self.aligner.power_alignment.s1[i]), len(self.aligner.power_alignment.s2[i])) + for i in range(len(self.aligner.power_alignment.s1)) + ] + s1_args = list(zip(widths, self.aligner.power_alignment.s1)) + s2_args = list(zip(widths, self.aligner.power_alignment.s2)) + align_args = list(zip(widths, self.aligner.power_alignment.align)) + + alignments = [] + for (_, ref_token), (_, hyp_token), (_, align_token) in zip(s1_args, s2_args, align_args): + + if align_token == "C": + op_type = OpType.MATCH + if align_token == "S": + op_type = OpType.SUBSTITUTE + if align_token == "I": + op_type = OpType.INSERT + if align_token == "D": + op_type = OpType.DELETE + + alignment = Alignment( + op_type=op_type, + ref=ref_token, + hyp=hyp_token, + ) + alignments.append(alignment) + + return alignments diff --git a/src/error_align/baselines/rapidfuzz_word_alignment.py b/src/error_align/baselines/rapidfuzz_word_alignment.py new file mode 100644 index 0000000..4bd11d1 --- /dev/null +++ b/src/error_align/baselines/rapidfuzz_word_alignment.py @@ -0,0 +1,100 @@ +from rapidfuzz.distance import Levenshtein + +from error_align.utils import ( + Alignment, + OpType, + basic_normalizer, + basic_tokenizer, +) + +OPS_MAP = { + "match": OpType.MATCH, + "replace": OpType.SUBSTITUTE, + "insert": OpType.INSERT, + "delete": OpType.DELETE, +} + + +class RapidFuzzWordAlign: + """Levenshtein-based word-level alignment.""" + + def __init__( + self, + ref: str, + hyp: str, + tokenizer: callable = basic_tokenizer, + normalizer: callable = basic_normalizer, + ): + """Initialize the Levenshtein-based word-level alignment with reference and hypothesis texts. + + Args: + ref (str): The reference sequence/transcript. + hyp (str): The hypothesis sequence/transcript. + tokenizer (callable): A function to tokenize the sequences. Must be regex-based and return Match objects. + normalizer (callable): A function to normalize the tokens. Defaults to basic_normalizer. + """ + if not isinstance(ref, str): + raise TypeError("Reference sequence must be a string.") + if not isinstance(hyp, str): + raise TypeError("Hypothesis sequence must be a string.") + + self.ref = ref + self.hyp = hyp + self._ref_token_matches = tokenizer(ref) + self._hyp_token_matches = tokenizer(hyp) + self._ref = [normalizer(r.group()) for r in self._ref_token_matches] + self._hyp = [normalizer(h.group()) for h in self._hyp_token_matches] + self._ref_max_idx = len(self._ref) - 1 + self._hyp_max_idx = len(self._hyp) - 1 + self.end_index = (self._hyp_max_idx, self._ref_max_idx) + + def align(self) -> list[Alignment]: + """Extract an arbitrary path from the backtrace graph. + + Returns: + list[Alignment]: A list of Alignment objects. + """ + edit_ops = Levenshtein.editops(self._ref, self._hyp).as_list() + + # Add match segments to editops + ref_edit_idxs = set([op[1] for op in edit_ops if op[0] != "insert"]) + hyp_edit_idxs = set([op[2] for op in edit_ops if op[0] != "delete"]) + ref_match_idxs = [i for i in range(len(self._ref)) if i not in ref_edit_idxs] + hyp_match_idxs = [i for i in range(len(self._hyp)) if i not in hyp_edit_idxs] + assert len(ref_match_idxs) == len(hyp_match_idxs) + for ref_idx, hyp_idx in zip(ref_match_idxs, hyp_match_idxs): + edit_ops.append(("match", ref_idx, hyp_idx)) + edit_ops = sorted(edit_ops, key=lambda x: (x[1], x[2])) + + # Convert to Alignment objects + alignments = [] + for op_type, ref_idx, hyp_idx in edit_ops: + if op_type == "match" or op_type == "replace": + ref_match = self._ref_token_matches[ref_idx] + hyp_match = self._hyp_token_matches[hyp_idx] + alignment = Alignment( + op_type=OPS_MAP[op_type], + ref_slice=slice(*ref_match.span()), + hyp_slice=slice(*hyp_match.span()), + ref=ref_match.group(), + hyp=hyp_match.group(), + ) + elif op_type == "delete": + ref_match = self._ref_token_matches[ref_idx] + alignment = Alignment( + op_type=OPS_MAP[op_type], + ref_slice=slice(*ref_match.span()), + ref=ref_match.group(), + ) + elif op_type == "insert": + hyp_match = self._hyp_token_matches[hyp_idx] + alignment = Alignment( + op_type=OPS_MAP[op_type], + hyp_slice=slice(*hyp_match.span()), + hyp=hyp_match.group(), + ) + else: + raise ValueError(f"Unknown operation type: {op_type}") + alignments.append(alignment) + + return alignments diff --git a/src/error_align/baselines/utils.py b/src/error_align/baselines/utils.py new file mode 100644 index 0000000..4818e1e --- /dev/null +++ b/src/error_align/baselines/utils.py @@ -0,0 +1,93 @@ +import unicodedata + +import regex as re +from num2words import num2words + +from error_align.utils import NUMERIC_TOKEN, basic_normalizer, basic_tokenizer + + +def strip_accents(text: str) -> str: + """Strip accents from the text.""" + normalized = unicodedata.normalize("NFD", text) + return "".join(c for c in normalized if unicodedata.category(c) != "Mn") + + +def normalize_evaluation_segment(segment: str) -> str: + """Normalize a segment by removing accents and converting to lowercase. + + Args: + segment (str): The segment to normalize. + + Returns: + str: The normalized segment. + """ + return re.sub(r"[^a-z0-9]", "", strip_accents(segment.lower())) + + +def convert_numbers_to_words(text: str, lang: str = "en") -> str: + """Convert numeric tokens in the text to their word representation. + + Args: + text (str): The input text containing numeric tokens. + lang (str): The language to use for conversion (default is "en"). + + Returns: + str: The text with numeric tokens converted to words. + """ + if bool(re.match(NUMERIC_TOKEN, text)): + if lang == "en": + # NOTE: num2words doesn't support thousands separators + text_ = text.replace(",", "") + else: + + text_ = text.replace(".", "") + text_ = text_.replace(",", ".") + try: + return num2words(text_, lang=lang) + except Exception: + pass + + return text + + +def clean_text(text: str, lang: str = "en") -> dict: + """ + Cleans the text by removing examples with empty transcriptions. + + Args: + example: The example to clean. + + Returns: + A cleaned version of the example with empty transcriptions removed. + """ + # Remove all tags, e.g., . + text = re.sub(r"<[^>]+>", "", text) + + # Re-contract apostrophes. + text = re.sub(r"(\w) '(\w)", r"\1'\2", text) + + # Get normalized tokens. + normalized_tokens = [basic_normalizer(token.group()) for token in basic_tokenizer(text)] + + # Convert numbers to words. + normalized_tokens = [convert_numbers_to_words(token, lang=lang) for token in normalized_tokens] + + return " ".join(normalized_tokens) + + +def clean_example(example: dict, lang="en") -> dict: + """ + Cleans the example by removing examples with empty transcriptions. + + Args: + example: The example to clean. + lang: The language to use for cleaning. + + Returns: + A cleaned version of the example with empty transcriptions removed. + """ + if "ref" in example: + example["ref"] = clean_text(example["ref"], lang=lang) + if "hyp" in example: + example["hyp"] = clean_text(example["hyp"], lang=lang) + return example diff --git a/src/error_align/edit_distance.py b/src/error_align/edit_distance.py new file mode 100644 index 0000000..2254737 --- /dev/null +++ b/src/error_align/edit_distance.py @@ -0,0 +1,144 @@ +from error_align.utils import DELIMITERS, OP_TYPE_COMBO_MAP_INV, OpType + + +def _get_levenshtein_values(ref_token: str, hyp_token: str): + """Compute the Levenshtein values for deletion, insertion, and diagonal (substitution or match). + + Args: + ref_token (str): The reference token. + hyp_token (str): The hypothesis token. + + Returns: + tuple: A tuple containing the deletion cost, insertion cost, and diagonal cost. + + """ + if hyp_token == ref_token: + diag_cost = 0 + else: + diag_cost = 1 + + return 1, 1, diag_cost + + +def _get_error_align_values(ref_token: str, hyp_token: str): + """Compute the error alignment values for deletion, insertion, and diagonal (substitution or match). + + Args: + ref_token (str): The reference token. + hyp_token (str): The hypothesis token. + + Returns: + tuple: A tuple containing the deletion cost, insertion cost, and diagonal cost. + + """ + if hyp_token == ref_token: + diag_cost = 0 + elif hyp_token in DELIMITERS or ref_token in DELIMITERS: + diag_cost = 3 # NOTE: Will never be chosen as insert + delete (= 2) is equivalent and cheaper. + else: + diag_cost = 2 + + return 1, 1, diag_cost + + +def compute_distance_matrix( + ref: str | list[str], + hyp: str | list[str], + score_func: callable, + backtrace: bool = False, + dtype: type = int, +): + """Compute the edit distance score matrix between two sequences x (hyp) and y (ref) + using only pure Python lists. + + Args: + ref (str or list[str]): The reference sequence/transcript. + hyp (str or list[str]): The hypothesis sequence/transcript. + score_func (callable): A function that takes two tokens (ref_token, hyp_token) and returns + a tuple of (deletion_cost, insertion_cost, diagonal_cost). + backtrace (bool): Whether to compute the backtrace matrix. + dtype (type): The type to store scores (int, float, etc.). + + Returns: + list[list]: The score matrix. + list[list]: The backtrace matrix, if backtrace=True. + + """ + hyp_dim, ref_dim = len(hyp) + 1, len(ref) + 1 + + # Create empty score matrix of zeros and initialize first row and column. + score_matrix = [[dtype(0) for _ in range(ref_dim)] for _ in range(hyp_dim)] + for j in range(ref_dim): + score_matrix[0][j] = dtype(j) + for i in range(hyp_dim): + score_matrix[i][0] = dtype(i) + + # Create backtrace matrix and operation combination map and initialize first row and column. + # Each operation combination is dynamically assigned a unique integer. + if backtrace: + backtrace_matrix = [[0 for _ in range(ref_dim)] for _ in range(hyp_dim)] + backtrace_matrix[0][0] = OP_TYPE_COMBO_MAP_INV[(OpType.MATCH,)] + for j in range(1, ref_dim): + backtrace_matrix[0][j] = OP_TYPE_COMBO_MAP_INV[(OpType.DELETE,)] + for i in range(1, hyp_dim): + backtrace_matrix[i][0] = OP_TYPE_COMBO_MAP_INV[(OpType.INSERT,)] + + # Fill in the score and backtrace matrix. + for j in range(1, ref_dim): + for i in range(1, hyp_dim): + ins_cost, del_cost, diag_cost = score_func(ref[j - 1], hyp[i - 1]) + + ins_val = score_matrix[i - 1][j] + ins_cost + del_val = score_matrix[i][j - 1] + del_cost + diag_val = score_matrix[i - 1][j - 1] + diag_cost + new_val = min(ins_val, del_val, diag_val) + score_matrix[i][j] = dtype(new_val) + + # Track possible operations (note that the order of operations matters). + if backtrace: + pos_ops = tuple() + if diag_val == new_val and diag_cost == 0: + pos_ops += (OpType.MATCH,) + if ins_val == new_val: + pos_ops += (OpType.INSERT,) + if del_val == new_val: + pos_ops += (OpType.DELETE,) + if diag_val == new_val and diag_cost > 0: + pos_ops += (OpType.SUBSTITUTE,) + backtrace_matrix[i][j] = OP_TYPE_COMBO_MAP_INV[pos_ops] + + if backtrace: + return score_matrix, backtrace_matrix + return score_matrix + + +def compute_levenshtein_distance_matrix(ref: str | list[str], hyp: str | list[str], backtrace: bool = False): + """Compute the Levenshtein distance matrix between two sequences. + + Args: + ref (str): The reference sequence/transcript. + hyp (str): The hypothesis sequence/transcript. + backtrace (bool): Whether to compute the backtrace matrix. + + Returns: + np.ndarray: The score matrix. + np.ndarray: The backtrace matrix, if backtrace=True. + + """ + return compute_distance_matrix(ref, hyp, _get_levenshtein_values, backtrace) + + +def compute_error_align_distance_matrix(ref: str | list[str], hyp: str | list[str], backtrace: bool = False): + """Compute the error alignment distance matrix between two sequences. + + Args: + ref (str): The reference sequence/transcript. + hyp (str): The hypothesis sequence/transcript. + backtrace (bool): Whether to compute the backtrace matrix. + + Returns: + np.ndarray: The score matrix. + np.ndarray: The backtrace matrix, if backtrace=True. + + """ + return compute_distance_matrix(ref, hyp, _get_error_align_values, backtrace) diff --git a/src/error_align/error_align.py b/src/error_align/error_align.py new file mode 100644 index 0000000..56a3bd3 --- /dev/null +++ b/src/error_align/error_align.py @@ -0,0 +1,495 @@ +from collections import defaultdict +from typing import Union + +import regex as re +from tqdm import tqdm + +from error_align.backtrace_graph import BacktraceGraph +from error_align.edit_distance import compute_error_align_distance_matrix +from error_align.utils import ( + END_DELIMITER, + START_DELIMITER, + Alignment, + OpType, + basic_normalizer, + basic_tokenizer, + categorize_char, + ensure_length_preservation, + get_manhattan_distance, +) + + +class ErrorAlign: + """Error alignment class that performs a two-pass alignment process.""" + + def __init__( + self, + ref: str, + hyp: str, + tokenizer: callable = basic_tokenizer, + normalizer: callable = basic_normalizer, + ): + """Initialize the error alignment with reference and hypothesis texts. + + The first pass (backtrace graph extraction) is performed during initialization. + + The second pass (beam search) is performed in the `align` method. + + Args: + ref (str): The reference sequence/transcript. + hyp (str): The hypothesis sequence/transcript. + tokenizer (callable): A function to tokenize the sequences. Must be regex-based and return Match objects. + normalizer (callable): A function to normalize the tokens. Defaults to basic_normalizer. + + """ + if not isinstance(ref, str): + raise TypeError("Reference sequence must be a string.") + if not isinstance(hyp, str): + raise TypeError("Hypothesis sequence must be a string.") + + self.ref = ref + self.hyp = hyp + + # Inclusive tokenization: Track the token position in the original text. + self._ref_token_matches = tokenizer(ref) + self._hyp_token_matches = tokenizer(hyp) + + # Length-preserving normalization: Ensure that the normalizer preserves token length. + normalizer = ensure_length_preservation(normalizer) + self._ref = "".join([f"<{normalizer(r.group())}>" for r in self._ref_token_matches]) + self._hyp = "".join([f"<{normalizer(h.group())}>" for h in self._hyp_token_matches]) + + # Categorize characters. + self._ref_char_types = list(map(categorize_char, self._ref)) + self._hyp_char_types = list(map(categorize_char, self._hyp)) + + # Initialize graph attributes. + self._identical_inputs = self._ref == self._hyp + self._ref_max_idx = len(self._ref) - 1 + self._hyp_max_idx = len(self._hyp) - 1 + self.end_index = (self._hyp_max_idx, self._ref_max_idx) + + # Create index maps for reference and hypothesis sequences. + self._ref_index_map = self._create_index_map(self._ref_token_matches) + self._hyp_index_map = self._create_index_map(self._hyp_token_matches) + + # First pass: Extract backtrace graph. + if not self._identical_inputs: + _, backtrace_matrix = compute_error_align_distance_matrix(self._ref, self._hyp, backtrace=True) + self._backtrace_graph = BacktraceGraph(backtrace_matrix) + self._backtrace_node_set = self._backtrace_graph.get_node_set() + self._unambiguous_matches = self._backtrace_graph.get_unambiguous_matches(self._ref) + else: + self._backtrace_graph = None + self._backtrace_node_set = None + self._unambiguous_matches = None + + def __repr__(self): + ref_preview = self.ref if len(self.ref) < 20 else self.ref[:17] + "..." + hyp_preview = self.hyp if len(self.hyp) < 20 else self.hyp[:17] + "..." + return f'ErrorAlign(ref="{ref_preview}", hyp="{hyp_preview}")' + + def align( + self, + beam_size: int = 100, + pbar: bool = False, + return_path: bool = False, + ) -> Union[list[Alignment], "Path"]: + """Perform beam search to align reference and hypothesis texts. + + Args: + beam_size (int): The size of the beam for beam search. Defaults to 100. + pbar (bool): Whether to display a progress bar. Defaults to False. + return_path (bool): Whether to return the path object or just the alignments. Defaults to False. + + Returns: + list[Alignment]: A list of Alignment objects. + + """ + # Skip beam search if inputs are identical. + if self._identical_inputs: + return self._identical_input_alignments() + + # Initialize the beam with a single path starting at the root node. + start_path = Path(self) + beam = {start_path.pid: start_path} + prune_map = defaultdict(lambda: float("inf")) + ended = [] + + # Setup progress bar, if enabled. + if pbar: + total_mdist = self._ref_max_idx + self._hyp_max_idx + 2 + progress_bar = tqdm(total=total_mdist, desc="Aligning transcripts") + + # Expand candidate paths until all have reached the terminal node. + while len(beam) > 0: + new_beam = {} + + # Expand each path in the current beam. + for path in beam.values(): + if path.at_end: + ended.append(path) + continue + + # Transition to all child nodes. + for new_path in path.expand(): + if new_path.pid in prune_map: + if new_path.cost > prune_map[new_path.pid]: + continue + prune_map[new_path.pid] = new_path.cost + + if new_path.pid not in new_beam or new_path.cost < new_beam[new_path.pid].cost: + new_beam[new_path.pid] = new_path + + # Update the beam with the newly expanded paths. + new_beam = list(new_beam.values()) + new_beam.sort(key=lambda p: p.norm_cost) + beam = new_beam[:beam_size] + + # Keep only the best path if, it matches the segment. + if len(beam) > 0 and beam[0]._at_unambiguous_match_node: + beam = beam[:1] + prune_map = defaultdict(lambda: float("inf")) + beam = {p.pid: p for p in beam} # Convert to dict for diversity check. + + # Update progress bar, if enabled. + try: + worst_path = next(reversed(beam.values())) + mdist = get_manhattan_distance(worst_path.index, self.end_index) + if pbar: + progress_bar.n = total_mdist - mdist + progress_bar.refresh() + except StopIteration: + if pbar: + progress_bar.n = total_mdist + progress_bar.refresh() + + # Return the best path or its alignments. + ended.sort(key=lambda p: p.cost) + if return_path: + return ended[0] if len(ended) > 0 else None + return ended[0].alignments if len(ended) > 0 else [] + + def _create_index_map(self, text_tokens: list[re.Match]) -> list[int]: + """Create an index map for the given tokens. + + The 'index_map' is used to map each aligned character back to its original position in the input text. + + NOTE: -1 is used for delimiter (<>) and indicates no match in the source sequence. + """ + index_map = [] + for match in text_tokens: + index_map.extend([-1]) # Start delimiter + index_map.extend(list(range(*match.span()))) + index_map.extend([-1]) # End delimiter + return index_map + + def _identical_input_alignments(self) -> list[Alignment]: + """Return alignments for identical reference and hypothesis pairs.""" + assert self._identical_inputs, "Inputs are not identical." + + alignments = [] + for ref_match, hyp_match in zip(self._ref_token_matches, self._hyp_token_matches, strict=False): + ref_slice = slice(*ref_match.span()) + hyp_slice = slice(*hyp_match.span()) + ref_token = self.ref[ref_slice] + hyp_token = self.hyp[hyp_slice] + alignment = Alignment( + op_type=OpType.MATCH, + ref_slice=ref_slice, + hyp_slice=hyp_slice, + ref=ref_token, + hyp=hyp_token, + ) + alignments.append(alignment) + return alignments + + +class Path: + """Class to represent a graph path.""" + + def __init__(self, src: ErrorAlign): + """Initialize the Path class with a given path.""" + self.src = src + self.ref_idx = -1 + self.hyp_idx = -1 + self._closed_cost = 0 + self._open_cost = 0 + self._at_unambiguous_match_node = False + self._last_end_index = (-1, -1) + self._end_indices = tuple() + self._alignments = None + self._alignments_index = None + + def __repr__(self): + return f"Path(({self.ref_idx}, {self.hyp_idx}), score={self.cost})" + + @property + def alignments(self) -> list[Alignment]: + """Get the alignments of the path.""" + # Return cached alignments if available and the path has not changed. + if self._alignments is not None and self._alignments_index == self.index: + return self._alignments + + self._alignments_index = self.index + alignments = [] + start_hyp, start_ref = (0, 0) + for (end_hyp, end_ref), score in self._end_indices: + end_hyp, end_ref = end_hyp + 1, end_ref + 1 + + # Construct DELETE alignment. + if start_hyp == end_hyp: + assert start_ref < end_ref + ref_slice = slice(start_ref, end_ref) + ref_slice = self._translate_slice(ref_slice, self.src._ref_index_map) + assert ref_slice is not None + alignment = Alignment( + op_type=OpType.DELETE, + ref_slice=ref_slice, + ref=self.src.ref[ref_slice], + ) + alignments.append(alignment) + + # Construct INSERT alignment. + elif start_ref == end_ref: + assert start_hyp < end_hyp + hyp_slice = slice(start_hyp, end_hyp) + hyp_slice = self._translate_slice(hyp_slice, self.src._hyp_index_map) + assert hyp_slice is not None + alignment = Alignment( + op_type=OpType.INSERT, + hyp_slice=hyp_slice, + hyp=self.src.hyp[hyp_slice], + left_compound=self.src._hyp_index_map[start_hyp] >= 0, + right_compound=self.src._hyp_index_map[end_hyp - 1] >= 0, + ) + alignments.append(alignment) + + # Construct SUBSTITUTE or MATCH alignment. + else: + assert start_hyp < end_hyp and start_ref < end_ref + hyp_slice = slice(start_hyp, end_hyp) + ref_slice = slice(start_ref, end_ref) + hyp_slice = self._translate_slice(hyp_slice, self.src._hyp_index_map) + ref_slice = self._translate_slice(ref_slice, self.src._ref_index_map) + assert hyp_slice is not None and ref_slice is not None + is_match_segment = score == 0 + op_type = OpType.MATCH if is_match_segment else OpType.SUBSTITUTE + alignment = Alignment( + op_type=op_type, + ref_slice=ref_slice, + hyp_slice=hyp_slice, + ref=self.src.ref[ref_slice], + hyp=self.src.hyp[hyp_slice], + left_compound=self.src._hyp_index_map[start_hyp] >= 0, + right_compound=self.src._hyp_index_map[end_hyp - 1] >= 0, + ) + alignments.append(alignment) + + start_hyp, start_ref = end_hyp, end_ref + + # Cache the computed alignments. + self._alignments = alignments + + return alignments + + @property + def pid(self): + """Get the ID of the path used for pruning.""" + return hash((self.index, self._last_end_index)) + + @property + def cost(self): + """Get the cost of the path.""" + return self._closed_cost + self._open_cost + self._substitution_penalty() + + @property + def norm_cost(self): + """Get the normalized cost of the path.""" + if self.cost == 0: + return 0 + return self.cost / (self.ref_idx + self.hyp_idx + 3) # NOTE: +3 to avoid zero division. Root = (-1,-1). + + @property + def index(self): + """Get the current node index of the path.""" + return (self.hyp_idx, self.ref_idx) + + @property + def at_end(self): + """Check if the path has reached the terminal node.""" + return self.index == self.src.end_index + + def expand(self): + """Expand the path by transitioning to child nodes. + + Yields: + Path: The expanded child paths. + + """ + # Add delete operation. + delete_path = self._add_delete() + if delete_path is not None: + yield delete_path + + # Add insert operation. + insert_path = self._add_insert() + if insert_path is not None: + yield insert_path + + # Add substitution or match operation. + sub_or_match_path = self._add_substitution_or_match() + if sub_or_match_path is not None: + yield sub_or_match_path + + def _transition_and_shallow_copy(self, ref_step: int, hyp_step: int): + """Create a shallow copy of the path.""" + new_path = Path(self.src) + new_path.ref_idx = self.ref_idx + ref_step + new_path.hyp_idx = self.hyp_idx + hyp_step + new_path._closed_cost = self._closed_cost + new_path._open_cost = self._open_cost + new_path._at_unambiguous_match_node = False + new_path._last_end_index = self._last_end_index + new_path._end_indices = self._end_indices + + return new_path + + def _reset_segment_variables(self, index: tuple[int, int]) -> None: + """Apply updates when segment end is detected.""" + self._closed_cost += self._open_cost + self._closed_cost += self._substitution_penalty(index) + self._last_end_index = index + self._open_cost = 0 + + def _end_insertion_segment(self, index: tuple[int, int]) -> None: + """End the current segment, if criteria for an insertion are met.""" + hyp_slice = slice(self._last_end_index[0] + 1, index[0] + 1) + hyp_slice = self._translate_slice(hyp_slice, self.src._hyp_index_map) + ref_is_empty = index[1] == self._last_end_index[1] + if hyp_slice is not None and ref_is_empty: + self._end_indices += ((index, self._open_cost),) + self._reset_segment_variables(index) + + def _end_segment(self) -> Union[None, "Path"]: + """End the current segment, if criteria for an insertion, a substitution, or a match are met.""" + hyp_slice = slice(self._last_end_index[0] + 1, self.index[0] + 1) + hyp_slice = self._translate_slice(hyp_slice, self.src._hyp_index_map) + ref_slice = slice(self._last_end_index[1] + 1, self.index[1] + 1) + ref_slice = self._translate_slice(ref_slice, self.src._ref_index_map) + + assert ref_slice is not None + + hyp_is_empty = self.index[0] == self._last_end_index[0] + if hyp_is_empty: + self._end_indices += ((self.index, self._open_cost),) + else: + # TODO: Handle edge case where hyp has only covered delimiters. + if hyp_slice is None: + return None + + is_match_segment = self._open_cost == 0 + self._at_unambiguous_match_node = is_match_segment and self.index in self.src._unambiguous_matches + self._end_indices += ((self.index, self._open_cost),) + + # Update the path score and reset segments attributes. + self._reset_segment_variables(self.index) + return self + + def _in_backtrace_node_set(self, index) -> bool: + """Check if the given operation is an optimal transition at the current index.""" + return index in self.src._backtrace_node_set + + def _add_delete(self) -> Union[None, "Path"]: + """Expand the path by adding a delete operation.""" + # Ensure we are not at the end of the hypothesis sequence. + if self.hyp_idx >= self.src._hyp_max_idx: + return None + + # Transition and update costs. + new_path = self._transition_and_shallow_copy(ref_step=0, hyp_step=1) + is_backtrace = self._in_backtrace_node_set(self.index) + is_delimiter = self.src._hyp_char_types[new_path.hyp_idx] == 0 # NOTE: 0 indicates delimiter. + new_path._open_cost += 1 if is_delimiter else 2 + new_path._open_cost += 0 if is_backtrace or is_delimiter else 1 + + # Check for end-of-segment criteria. + if self.src._hyp[new_path.hyp_idx] == END_DELIMITER: + new_path._end_insertion_segment(new_path.index) + + return new_path + + def _add_insert(self) -> Union[None, "Path"]: + """Expand the path by adding an insert operation.""" + # Ensure we are not at the end of the reference sequence. + if self.ref_idx >= self.src._ref_max_idx: + return None + + # Transition and check for end-of-segment criteria. + new_path = self._transition_and_shallow_copy(ref_step=1, hyp_step=0) + if self.src._ref[new_path.ref_idx] == START_DELIMITER: + new_path._end_insertion_segment(self.index) + + # Update costs. + is_backtrace = self._in_backtrace_node_set(self.index) + is_delimiter = self.src._ref_char_types[new_path.ref_idx] == 0 # NOTE: 0 indicates delimiter. + new_path._open_cost += 1 if is_delimiter else 2 + new_path._open_cost += 0 if is_backtrace or is_delimiter else 1 + + # Check for end-of-segment criteria. + if self.src._ref[new_path.ref_idx] == END_DELIMITER: + new_path = new_path._end_segment() + + return new_path + + def _add_substitution_or_match(self) -> Union[None, "Path"]: + """Expand the given path by adding a substitution or match operation.""" + # Ensure we are not at the end of either sequence. + if self.ref_idx >= self.src._ref_max_idx or self.hyp_idx >= self.src._hyp_max_idx: + return None + + # Transition and ensure that the transition is allowed. + new_path = self._transition_and_shallow_copy(ref_step=1, hyp_step=1) + is_match = self.src._ref[new_path.ref_idx] == self.src._hyp[new_path.hyp_idx] + if not is_match: + ref_is_delimiter = self.src._ref_char_types[new_path.ref_idx] == 0 # NOTE: 0 indicates delimiter + hyp_is_delimiter = self.src._hyp_char_types[new_path.hyp_idx] == 0 # NOTE: 0 indicates delimiter + if ref_is_delimiter or hyp_is_delimiter: + return None + + # Check for end-of-segment criteria. + if self.src._ref[new_path.ref_idx] == START_DELIMITER: + new_path._end_insertion_segment(self.index) + + # Update costs, if not a match. + if not is_match: + is_backtrace = self._in_backtrace_node_set(self.index) + is_letter_type_match = ( + self.src._ref_char_types[new_path.ref_idx] == self.src._hyp_char_types[new_path.hyp_idx] + ) + new_path._open_cost += 2 if is_letter_type_match else 3 + new_path._open_cost += 0 if is_backtrace else 1 + + # Check for end-of-segment criteria. + if self.src._ref[new_path.ref_idx] == END_DELIMITER: + new_path = new_path._end_segment() + + return new_path + + def _translate_slice(self, segment_slice: slice, index_map: list[int]) -> None | slice: + """Translate a slice from the alignment sequence back to the original sequence.""" + slice_indices = index_map[segment_slice] + slice_indices = list(filter(lambda x: x >= 0, slice_indices)) + if len(slice_indices) == 0: + return None + start, end = int(slice_indices[0]), int(slice_indices[-1] + 1) + return slice(start, end) + + def _substitution_penalty(self, index: tuple[int, int] | None = None) -> int: + """Get the substitution penalty given an index.""" + index = index or self.index + ref_is_not_empty = index[1] > self._last_end_index[1] + hyp_is_not_empty = index[0] > self._last_end_index[0] + if ref_is_not_empty and hyp_is_not_empty: + return self._open_cost + return 0 diff --git a/src/error_align/func.py b/src/error_align/func.py new file mode 100644 index 0000000..6ac8dfc --- /dev/null +++ b/src/error_align/func.py @@ -0,0 +1,38 @@ +from error_align.error_align import ErrorAlign, Path +from error_align.utils import Alignment, basic_normalizer, basic_tokenizer + + +def error_align( + ref: str, + hyp: str, + tokenizer: callable = basic_tokenizer, + normalizer: callable = basic_normalizer, + beam_size: int = 100, + pbar: bool = False, + return_path: bool = False, +) -> list[Alignment] | Path: + """Perform error alignment between two sequences. + + Args: + ref (str): The reference sequence/transcript. + hyp (str): The hypothesis sequence/transcript. + tokenizer (callable): A function to tokenize the sequences. Must be regex-based and return Match objects. + normalizer (callable): A function to normalize the tokens. Defaults to basic_normalizer. + pbar (bool): Whether to display a progress bar. Defaults to False. + return_path (bool): Whether to return the path object or just the alignments. Defaults to False. + + Returns: + list[tuple[str, str, OpType]]: A list of tuples containing aligned reference token, + hypothesis token, and the operation type. + + """ + return ErrorAlign( + ref, + hyp, + tokenizer=tokenizer, + normalizer=normalizer, + ).align( + beam_size=beam_size, + pbar=pbar, + return_path=return_path, + ) diff --git a/src/error_align/utils.py b/src/error_align/utils.py new file mode 100644 index 0000000..45830b1 --- /dev/null +++ b/src/error_align/utils.py @@ -0,0 +1,166 @@ +from dataclasses import dataclass +from enum import IntEnum +from itertools import chain, combinations + +import regex as re +from unidecode import unidecode + + +class OpType(IntEnum): + MATCH = 0 + INSERT = 1 + DELETE = 2 + SUBSTITUTE = 3 + + +@dataclass +class Alignment: + """Class representing an operation with its type and cost.""" + + op_type: OpType + ref_slice: slice | None = None + hyp_slice: slice | None = None + ref: str | None = None + hyp: str | None = None + left_compound: bool = False + right_compound: bool = False + + @property + def hyp_with_compound_markers(self) -> str: + """Return the hypothesis with compound markers if applicable.""" + if self.hyp is None: + return None + return f"{'-' if self.left_compound else ''}{self.hyp}{'-' if self.right_compound else ''}" + + def __repr__(self) -> str: + lc = "-" if self.left_compound else "" + rc = "-" if self.right_compound else "" + if self.op_type == OpType.DELETE: + return f'Alignment({self.op_type.name}: "{self.ref}")' + if self.op_type == OpType.INSERT: + return f'Alignment({self.op_type.name}: "{self.hyp_with_compound_markers}")' + if self.op_type == OpType.SUBSTITUTE: + return f'Alignment({self.op_type.name}: "{self.ref}" -> {lc}"{self.hyp}"{rc})' + return f'Alignment({self.op_type.name}: "{self.ref}" == {lc}"{self.hyp}"{rc})' + + +def op_type_powerset() -> chain: + """Generate all possible combinations of operation types, except the empty set. + + Returns: + Generator: All possible combinations of operation types. + + """ + op_types = list(OpType) + op_combinations = [combinations(op_types, r) for r in range(1, len(op_types) + 1)] + return chain.from_iterable(op_combinations) + + +START_DELIMITER = "<" +END_DELIMITER = ">" +DELIMITERS = {START_DELIMITER, END_DELIMITER} + +OP_TYPE_MAP = {op_type.value: op_type for op_type in OpType} +OP_TYPE_COMBO_MAP = {i: op_types for i, op_types in enumerate(op_type_powerset())} +OP_TYPE_COMBO_MAP_INV = {v: k for k, v in OP_TYPE_COMBO_MAP.items()} + +NUMERIC_TOKEN = r"\p{N}+([,.]\p{N}+)*(?=\s|$)" +STANDARD_TOKEN = r"[\p{L}\p{N}]+(['][\p{L}\p{N}]+)*'?" + + +def is_vowel(c: str) -> bool: + """Check if the normalized character is a vowel. + + Args: + c (str): The character to check. + + Returns: + bool: True if the character is a vowel, False otherwise. + + """ + assert len(c) == 1, "Input must be a single character." + return unidecode(c)[0] in "aeiouy" + + +def is_consonant(c: str) -> bool: + """Check if the normalized character is a consonant. + + Args: + c (str): The character to check. + + Returns: + bool: True if the character is a consonant, False otherwise. + + """ + assert len(c) == 1, "Input must be a single character." + return unidecode(c)[0] in "bcdfghjklmnpqrstvwxyz" + + +def categorize_char(c: str) -> int: + """Categorize a character as 'vowel', 'consonant', or 'unvoiced'. + + Args: + c (str): The character to categorize. + + Returns: + str: The category of the character. + + """ + if c in DELIMITERS: + return 0 + if is_consonant(c): + return 1 + if is_vowel(c): + return 2 + return 3 # NOTE: Unvoiced characters (only apostrophes are expected by default). + + +def get_manhattan_distance(a: tuple[int, int], b: tuple[int, int]) -> int: + """Calculate the Manhattan distance between two points a and b.""" + return abs(a[0] - b[0]) + abs(a[1] - b[1]) + + +def basic_tokenizer(text: str) -> list: + """Default tokenizer that splits text into words based on whitespace. + + Args: + text (str): The input text to tokenize. + + Returns: + list: A list of tokens (words). + + """ + return list(re.finditer(rf"({NUMERIC_TOKEN})|({STANDARD_TOKEN})", text, re.UNICODE | re.VERBOSE)) + + +def basic_normalizer(text: str) -> str: + """Default normalizer that only converts text to lowercase. + + Args: + text (str): The input text to normalize. + + Returns: + str: The normalized text. + + """ + return text.lower() + + +def ensure_length_preservation(normalizer: callable) -> callable: + """Decorator to ensure that the normalizer preserves the length of the input text. + + Args: + normalizer (callable): The normalizer function to wrap. + + Returns: + callable: The wrapped normalizer function that preserves length. + + """ + + def wrapper(text: str, *args: list, **kwargs: dict) -> str: + normalized = normalizer(text, *args, **kwargs) + if len(normalized) != len(text): + raise ValueError("Normalizer must preserve length.") + return normalized + + return wrapper diff --git a/src/python_package_template/__init__.py b/src/python_package_template/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_default.py b/tests/test_default.py index 2b57b30..b0728d2 100644 --- a/tests/test_default.py +++ b/tests/test_default.py @@ -1,4 +1,129 @@ -def test_stub() -> None: - """Lorem Ipsum.""" +from typeguard import suppress_type_checks - assert True +from error_align import ErrorAlign, error_align +from error_align.edit_distance import compute_levenshtein_distance_matrix +from error_align.utils import OpType, categorize_char + + +def test_error_align() -> None: + """Test error alignment for an example including all substitution types.""" + + ref = "This is a substitution test deleted." + hyp = "Inserted this is a contribution test." + + alignments = error_align(ref, hyp, pbar=True) + expected_ops = [ + OpType.INSERT, # Inserted + OpType.MATCH, # This + OpType.MATCH, # is + OpType.MATCH, # a + OpType.SUBSTITUTE, # contribution -> substitution + OpType.MATCH, # test + OpType.DELETE, # deleted + ] + + for op, alignment in zip(expected_ops, alignments, strict=True): + assert alignment.op_type == op + + +def test_error_align_full_match() -> None: + """Test error alignment for full match.""" + + ref = "This is a test." + hyp = "This is a test." + + alignments = error_align(ref, hyp) + + for alignment in alignments: + assert alignment.op_type == OpType.MATCH + + +def test_categorize_char() -> None: + """Test character categorization.""" + + assert categorize_char("<") == 0 # Delimiters + assert categorize_char("b") == 1 # Consonants + assert categorize_char("a") == 2 # Vowels + assert categorize_char("'") == 3 # Unvoiced characters + + +def test_representations() -> None: + """Test the string representation of Alignment objects.""" + + # Test DELETE operation + delete_alignment = error_align("deleted", "")[0] + assert repr(delete_alignment) == 'Alignment(DELETE: "deleted")' + + # Test INSERT operation with compound markers + insert_alignment = error_align("", "inserted")[0] + assert repr(insert_alignment) == 'Alignment(INSERT: "inserted")' + + # Test SUBSTITUTE operation with compound markers + substitute_alignment = error_align("substitution", "substitutiontesting")[0] + assert substitute_alignment.left_compound is False + assert substitute_alignment.right_compound is True + assert repr(substitute_alignment) == 'Alignment(SUBSTITUTE: "substitution" -> "substitution"-)' + + # Test MATCH operation without compound markers + match_alignment = error_align("test", "test")[0] + assert repr(match_alignment) == 'Alignment(MATCH: "test" == "test")' + + # Test ErrorAlign class representation + ea = ErrorAlign(ref="test", hyp="pest") + assert repr(ea) == 'ErrorAlign(ref="test", hyp="pest")' + + # Test Path class representation + path = ea.align(beam_size=10, return_path=True) + assert repr(path) == f"Path(({path.ref_idx}, {path.hyp_idx}), score={path.cost})" + + +@suppress_type_checks +def test_input_type_checks() -> None: + """Test input type checks for ErrorAlign class.""" + + try: + _ = ErrorAlign(ref=123, hyp="valid") # type: ignore + except TypeError as e: + assert str(e) == "Reference sequence must be a string." + + try: + _ = ErrorAlign(ref="valid", hyp=456) # type: ignore + except TypeError as e: + assert str(e) == "Hypothesis sequence must be a string." + + +def test_backtrace_graph() -> None: + """Test backtrace graph generation.""" + + ref = "This is a test." + hyp = "This is a pest." + + # Create ErrorAlign instance and generate backtrace graph. + ea = ErrorAlign(ref, hyp) + ea.align(beam_size=10) + graph = ea._backtrace_graph + + # Check basic properties of the graph. + assert isinstance(graph.get_path(), list) + assert isinstance(graph.get_path(sample=True), list) + assert graph.number_of_paths == 3 + for index in graph._iter_topological_order(): + assert isinstance(index, tuple) + + # Check specific node properties. + node = graph.get_node(2, 2) + assert node.number_of_ingoing_paths_via(OpType.MATCH) == 3 + assert node.number_of_outgoing_paths_via(OpType.MATCH) == 3 + assert node.number_of_ingoing_paths_via(OpType.INSERT) == 0 + assert node.number_of_outgoing_paths_via(OpType.INSERT) == 0 + + +def test_levenshtein_distance_matrix() -> None: + """Test Levenshtein distance matrix computation.""" + + ref = "kitten" + hyp = "sitting" + + distance_matrix = compute_levenshtein_distance_matrix(ref, hyp) + + assert distance_matrix[-1][-1] == 3 # The Levenshtein distance between "kitten" and "sitting" is 3