diff --git a/.coveragerc b/.coveragerc
index 69e15dd..3f6a94f 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -1,6 +1,6 @@
[run]
-source = src
-omit = ./tests/*
-concurrency = multiprocessing
-parallel = true
-sigterm = true
+branch = True
+source = error_align
+
+[report]
+omit = src/error_align/baselines/*
\ No newline at end of file
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 458d12a..19588ac 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -1,5 +1 @@
-* @corticph/machine-learning-r-d
-
-# No owners of dependencies to allow @github-actions[bot] to approve and auto-merge dependabot PRs.
-pyproject.toml
-poetry.lock
+* @borgholt
diff --git a/.github/assets/logo.svg b/.github/assets/logo.svg
new file mode 100644
index 0000000..381357e
--- /dev/null
+++ b/.github/assets/logo.svg
@@ -0,0 +1,57 @@
+
diff --git a/.github/dependabot.yml b/.github/dependabot.yml
deleted file mode 100644
index 5300f8d..0000000
--- a/.github/dependabot.yml
+++ /dev/null
@@ -1,35 +0,0 @@
-version: 2
-
-registries:
- corti-github:
- type: git
- url: https://github.com
- username: x-access-token
- password: ${{ secrets.DEPENDABOT_TOKEN }}
-
-updates:
- # Maintain dependencies for GitHub Actions
- - package-ecosystem: "github-actions"
- directory: "/"
- labels:
- - "dependabot"
- schedule:
- interval: "weekly"
-
- # Maintain dependencies for Python
- - package-ecosystem: "pip"
- open-pull-requests-limit: 1
- registries: "*" # Use all, i.e. the ecosystem default and corti-github, registries
- directory: "/"
- labels:
- - "dependabot"
- schedule:
- interval: "weekly"
- allow:
- - dependency-type: "all" # Allow both direct and indirect updates for all packages
- commit-message:
- prefix: "pip prod"
- prefix-development: "pip dev"
- include: "scope"
- # https://docs.github.com/en/code-security/supply-chain-security/keeping-your-dependencies-updated-automatically/configuration-options-for-dependency-updates#insecure-external-code-execution
- insecure-external-code-execution: allow
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index 319b529..6f14ea5 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -5,18 +5,6 @@ Do not leave this blank.
This PR [adds/removes/fixes/replaces] the [feature/bug/etc] because [reason] by doing [x].
-->
-## Related Issue(s)
-
-
-
-Related to RD-XXXX
-
-Closes RD-XXXX
-
10 to the power 60)
+ # tested with Python24 vegaseat 07dec2006
+ @staticmethod
+ def int2word(n):
+ """
+ convert an integer number n into a string of english words
+ """
+ # break the number into groups of 3 digits using slicing
+ # each group representing hundred, thousand, million, billion, ...
+ n3 = []
+
+ # create numeric string
+ ns = str(n)
+ for k in range(3, 33, 3):
+ r = ns[-k:]
+ q = len(ns) - k
+ # break if end of ns has been reached
+ if q < -2:
+ break
+ else:
+ if q >= 0:
+ n3.append(int(r[:3]))
+ elif q >= -1:
+ n3.append(int(r[:2]))
+ elif q >= -2:
+ n3.append(int(r[:1]))
+
+
+ #print n3 # test
+
+ # break each group of 3 digits into
+ # ones, tens/twenties, hundreds
+ # and form a string
+ nw = ""
+ for i, x in enumerate(n3):
+ b1 = x % 10
+ b2 = (x % 100)//10
+ b3 = (x % 1000)//100
+ #print b1, b2, b3 # test
+ if x == 0:
+ continue # skip
+ else:
+ t = TextToNumEng.thousands[i]
+ if b2 == 0:
+ nw = TextToNumEng.ones[b1] + t + nw
+ elif b2 == 1:
+ nw = TextToNumEng.tens[b1] + t + nw
+ elif b2 > 1:
+ nw = TextToNumEng.twenties[b2] + TextToNumEng.ones[b1] + t + nw
+ if b3 > 0:
+ nw = TextToNumEng.ones[b3] + "hundred " + nw
+ return nw
+
+class NumToTextEng(object):
+ units = ["", "one", "two", "three", "four", "five",
+ "six", "seven", "eight", "nine"]
+ teens = ["", "eleven", "twelve", "thirteen", "fourteen",
+ "fifteen", "sixteen", "seventeen", "eighteen", "nineteen"]
+ tens = ["", "ten", "twenty", "thirty", "forty",
+ "fifty", "sixty", "seventy", "eighty", "ninety"]
+ thousands = ["","thousand", "million", "billion", "trillion",
+ "quadrillion", "quintillion", "sextillion", "septillion", "octillion",
+ "nonillion", "decillion", "undecillion", "duodecillion", "tredecillion",
+ "quattuordecillion", "quindecillion", "sexdecillion", "septendecillion", "octodecillion",
+ "novemdecillion", "vigintillion"]
+
+ @staticmethod
+ def convertTryYear(num):
+ first = None
+ second = None
+
+ num_str = str(num)
+ if len(num_str) > 4 or len(num_str) == 0:
+ raise ValueError("Invalid date: {0}", num)
+ elif len(num_str) >= 3:
+ # Check if it could be a year
+ if len(num_str) == 3:
+ first = int(num_str[0])
+ second = int(num_str[1:len(num_str)])
+ elif len(num_str) == 4:
+ first = int(num_str[0:2])
+ second = int(num_str[1:len(num_str)])
+
+ if not (first and second):
+ raise ValueError("Invalid date: {0:d}", num)
+
+ return "{0} {1}".format(NumToTextEng.convert(first), NumToTextEng.convert(second))
+ else:
+ return NumToTextEng.convert(num)
+
+
+ @staticmethod
+ def convert(num):
+ if not isinstance(num, int):
+ raise TypeError
+
+ words = []
+ if num == 0:
+ words.append("zero")
+ else:
+ numStr = "%d" % num
+ numStrLen = len(numStr)
+ groups = (numStrLen + 2) // 3
+ numStr = numStr.zfill((groups) * 3)
+ for i in range(0, groups*3, 3):
+ h = int(numStr[i])
+ t = int(numStr[i+1])
+ u = int(numStr[i+2])
+ g = groups - (i // 3 + 1)
+
+ if h >= 1:
+ words.append(NumToTextEng.units[h])
+ words.append("hundred")
+
+ if t > 1:
+ words.append(NumToTextEng.tens[t])
+ if u >= 1:
+ words.append(NumToTextEng.units[u])
+ elif t == 1:
+ if u >= 1:
+ words.append(NumToTextEng.teens[u])
+ else:
+ words.append(NumToTextEng.tens[t])
+ else:
+ if u >= 1:
+ words.append(NumToTextEng.units[u])
+
+ if g >= 1 and (h + t + u) > 0:
+ words.append(NumToTextEng.thousands[g])
+ return ' '.join(words)
\ No newline at end of file
diff --git a/src/error_align/baselines/power/power/__init__.py b/src/error_align/baselines/power/power/__init__.py
new file mode 100644
index 0000000..4c4a18b
--- /dev/null
+++ b/src/error_align/baselines/power/power/__init__.py
@@ -0,0 +1,3 @@
+# from align_labeler import AlignLabeler
+from error_align.baselines.power.power.levenshtein import ExpandedAlignment, Levenshtein
+from error_align.baselines.power.power.phonemes import Phonemes
diff --git a/src/error_align/baselines/power/power/aligner.py b/src/error_align/baselines/power/power/aligner.py
new file mode 100644
index 0000000..22b239b
--- /dev/null
+++ b/src/error_align/baselines/power/power/aligner.py
@@ -0,0 +1,467 @@
+from __future__ import division
+
+from error_align.baselines.power.power.levenshtein import AlignLabels, ExpandedAlignment, Levenshtein
+from error_align.baselines.power.power.phonemes import Phonemes
+from error_align.baselines.power.power.pronounce import PronouncerBase, PronouncerLex, PronouncerType
+
+
+class TokType:
+ WordBoundary = 1
+ SyllableBoundary = 2
+ Phoneme = 3
+ Empty = 0
+
+ @staticmethod
+ def checkAnnotation(tok):
+ tt = TokType.Phoneme
+ if not tok:
+ tt = TokType.Empty
+ elif tok == '|':
+ tt = TokType.WordBoundary
+ elif tok == '#':
+ tt = TokType.SyllableBoundary
+ return tt
+
+class CharToWordAligner:
+ def __init__(self, ref, hyp, lowercase=False):
+ self.refwords = ref.strip()
+ self.hypwords = hyp.strip()
+ self.ref = ref
+ self.hyp = hyp
+ self.lowercase = lowercase
+ self.char_align = None
+ self.word_align = None
+
+ def charAlign(self):
+ ref_chars = [x for x in self.ref] + [' ']
+ hyp_chars = [x for x in self.hyp] + [' ']
+
+ lev = Levenshtein.align(ref_chars, hyp_chars, lowercase=self.lowercase, reserve_list=set([' ']))
+ lev.editops()
+ self.char_align = lev.expandAlign()
+ return self.char_align
+
+ def charAlignToWordAlign(self):
+ if not self.char_align:
+ raise Exception("char_align is None")
+
+ ref_word_align = []
+ hyp_word_align = []
+ align_word = []
+
+ tmp_ref_word = []
+ tmp_hyp_word = []
+
+ for i in range(len(self.char_align.align)):
+ ref_char = self.char_align.s1[i]
+ hyp_char = self.char_align.s2[i]
+ align_char = self.char_align.align[i]
+
+ # check if both words are completed
+ # There are a few of ways this could happen:
+ if ((align_char == AlignLabels.correct and ref_char == ' ') or
+ (align_char == AlignLabels.deletion and ref_char == ' ') or
+ (align_char == AlignLabels.insertion and hyp_char == ' ')):
+
+ ref_word = ''.join(tmp_ref_word)
+ hyp_word = ''.join(tmp_hyp_word)
+
+ if ref_word or hyp_word:
+ ref_word_align.append(ref_word)
+ hyp_word_align.append(hyp_word)
+ tmp_ref_word = []
+ tmp_hyp_word = []
+
+ # Check align type
+ if ref_word and hyp_word:
+ if ref_word == hyp_word:
+ align_word.append(AlignLabels.correct)
+ else:
+ align_word.append(AlignLabels.substitution)
+ elif ref_word:
+ align_word.append(AlignLabels.deletion)
+ else:
+ align_word.append(AlignLabels.insertion)
+ continue
+
+ # Read current chars and check if one of the words is complete
+ if ref_char == ' ':
+ if len(tmp_ref_word) > 1:
+ # Probably a D
+ ref_word = ''.join(tmp_ref_word)
+ ref_word_align.append(ref_word)
+ hyp_word_align.append('')
+ tmp_ref_word = []
+ align_word.append(AlignLabels.deletion)
+ else:
+ tmp_ref_word.append(ref_char)
+
+ if hyp_char == ' ':
+ if len(tmp_hyp_word) > 1:
+ # Probably an I
+ hyp_word = ''.join(tmp_hyp_word)
+ ref_word_align.append('')
+ hyp_word_align.append(hyp_word)
+ tmp_hyp_word = []
+ align_word.append(AlignLabels.insertion)
+ else:
+ tmp_hyp_word.append(hyp_char)
+
+ self.word_align = ExpandedAlignment(ref_word_align, hyp_word_align, align_word, lowercase=self.lowercase)
+ return self.word_align
+
+class PowerAligner:
+ # Exclusive tokens that can only align to themselves; not other members in this set.
+ reserve_list = set(['|', '#'])
+
+ # R-sounds
+ r_set = set.union(set('r'), Phonemes.r_vowels)
+ exclusive_sets = [Phonemes.vowels, Phonemes.consonants, r_set]
+
+ phoneDistPenalty = 0.25
+ phoneDistPenalt16ySet = set(['|'])
+
+ def __init__(self, ref, hyp, lowercase=False, verbose=False,
+ pronounce_type=PronouncerType.Lexicon,
+ lexicon=None,
+ word_align_weights=Levenshtein.wordAlignWeights):
+ if not ref:
+ raise Exception("No reference file.\nref: {0}\nhyp: {1}".format(ref, hyp))
+
+ if pronounce_type == PronouncerType.Lexicon:
+ self.pronouncer = PronouncerLex(lexicon)
+ else:
+ self.pronouncer = PronouncerBase()
+
+ self.ref = [x for x in ref.strip().split() if x]
+ self.hyp = [x for x in hyp.strip().split() if x]
+ self.refwords = ' '.join(self.ref)
+ self.hypwords = ' '.join(self.hyp)
+
+ self.lowercase = lowercase
+ self.verbose = verbose
+
+ # Perform word alignment
+ lev = Levenshtein.align(self.ref, self.hyp, lowercase=self.lowercase, weights=word_align_weights)
+ lev.editops()
+ self.wer_alignment = lev.expandAlign()
+ self.wer, self.wer_components = self.wer_alignment.error_rate()
+
+ # Used for POWER alignment
+ self.power_alignment = None
+ self.power = None
+ self.power_components = None
+
+ # Used to find potential error regions
+ self.split_regions = None
+ self.error_indexes = None
+ self.phonetic_alignments = None
+ self.phonetic_lev = None
+
+ def align(self):
+ # Find the error regions that may need to be realigned
+ self.split_regions, self.error_indexes = self.wer_alignment.split_error_regions()
+ self.phonetic_alignments = [None] * len(self.split_regions)
+
+ for error_index in self.error_indexes:
+ seg = self.split_regions[error_index]
+ ref_words = seg.s1_tokens()
+ hyp_words = seg.s2_tokens()
+ ref_phones = self.pronouncer.pronounce(ref_words)
+ hyp_phones = self.pronouncer.pronounce(hyp_words)
+
+ power_seg_alignment, self.phonetic_alignments[error_index] = PowerAligner.phoneAlignToWordAlign(ref_words, hyp_words,
+ ref_phones, hyp_phones)
+
+ # Replace the error region at the current index.
+ self.split_regions[error_index] = power_seg_alignment
+
+ # Merge the alignment segments back together.
+ self.power_alignment = ExpandedAlignment(self.split_regions[0].s1, self.split_regions[0].s2,
+ self.split_regions[0].align,
+ self.split_regions[0].s1_map, self.split_regions[0].s2_map, lowercase=self.lowercase)
+ for i in range(1, len(self.split_regions)):
+ self.power_alignment.append_alignment(self.split_regions[i])
+
+ # Get the alignment score
+ self.power, self.power_components = self.power_alignment.error_rate()
+
+ assert self.hypwords == self.power_alignment.s2_string(), "hyp mismatch:\n{0}\n{1}".format(self.hypwords, self.power_alignment.s2_string())
+ assert self.refwords == self.power_alignment.s1_string(), "ref mismatch:\n{0}\n{1}".format(self.refwords, self.power_alignment.s1_string())
+
+ # TODO: Make this simpler (and maybe recursive)
+ @classmethod
+ def phoneAlignToWordAlign(cls, ref_words, hyp_words, ref_phones, hyp_phones, break_on_syllables=True):
+ ref_word_span = (0, len(ref_words))
+ hyp_word_span = (0, len(hyp_words))
+
+ # Perform Levenshtein Alignment
+ lev = Levenshtein.align(ref=ref_phones,
+ hyp=hyp_phones,
+ reserve_list=PowerAligner.reserve_list,
+ exclusive_sets=PowerAligner.exclusive_sets,
+ weights=Levenshtein.wordAlignWeights) #,
+ #dist_penalty=PowerAligner.phoneDistPenalty, dist_penalty_set=Levenshtein.wordAlignWeights)
+ phone_align = lev.expandAlignCompact()
+
+ worklist = list()
+ worklist.append((ref_word_span, hyp_word_span, phone_align))
+
+ full_reference = list()
+ full_hypothesis = list()
+ full_alignment = list()
+ full_phone_align = list()
+
+ while worklist:
+ # Take the next set of sequence boundaries off the worklist
+ ref_word_span, hyp_word_span, phone_align = worklist.pop()
+ ref_word_index, ref_word_limit = ref_word_span
+ hyp_word_index, hyp_word_limit = hyp_word_span
+
+ # TODO: Currently only checking in the forward direction
+ ref_word_builder = [] # Temp storage of words in alignment span
+ hyp_word_builder = []
+
+ ref_word_iter = enumerate(ref_words[ref_word_span[0]:ref_word_span[1]]) # Iterates through the surface words
+ hyp_word_iter = enumerate(hyp_words[hyp_word_span[0]:hyp_word_span[1]])
+
+ ref_aligned = [] # Finalized alignments
+ hyp_aligned = []
+ alignment = [] # Finalized alignment labels
+
+ ref_extra_syllable_word_index = None # Used for marking words mapping to extra syllables in alignment.
+ hyp_extra_syllable_word_index = None
+ ref_syllable_count = 0
+ hyp_syllable_count = 0
+
+ ref_word_started = False # Indicates whether a word is already accounted for in the alignment when a phoneme is reached.
+ hyp_word_started = False
+
+ advance_worklist = False
+ commit_alignment = False
+
+ for i in range(len(phone_align.align)):
+ ref_type = TokType.checkAnnotation(phone_align.s1[i])
+ hyp_type = TokType.checkAnnotation(phone_align.s2[i])
+
+ # Check if word boundaries are reached, both on ref an hyp -- or the case where no more symbols can be read.
+ if (i == len(phone_align.align) - 1) or (ref_type == TokType.WordBoundary and ref_type == hyp_type):
+ align_tok = None
+ # Only write outputs if either the ref or the hyp has scanned some words.
+ if ref_word_builder:
+ if hyp_word_builder:
+ align_tok = AlignLabels.substitution if ref_word_builder != hyp_word_builder else AlignLabels.correct
+ else:
+ align_tok = AlignLabels.deletion
+ elif hyp_word_builder:
+ align_tok = AlignLabels.insertion
+
+ if align_tok:
+ # Add the remainder to the worklist
+ ref_word_span_next = (ref_word_index + len(ref_word_builder), ref_word_limit)
+ hyp_word_span_next = (hyp_word_index + len(hyp_word_builder), hyp_word_limit)
+ phone_align_next = phone_align.subsequence(i, phone_align.length(), preserve_index=False)
+ worklist.append((ref_word_span_next, hyp_word_span_next, phone_align_next))
+
+ # "Commit" the current alignment
+ if align_tok in (AlignLabels.correct, AlignLabels.substitution):
+ alignment.append(align_tok)
+
+ # Check for syllable conflicts
+ if not break_on_syllables or not ref_extra_syllable_word_index:
+ ref_aligned.append(' '.join(ref_word_builder))
+ ref_syllable_count = 0
+ hyp_syllable_count = 0
+ else:
+ ref_aligned.append(' '.join(ref_word_builder[0:ref_extra_syllable_word_index]))
+ # The remaining words are deletions
+ for word in ref_word_builder[ref_extra_syllable_word_index:]:
+ alignment.append(AlignLabels.deletion)
+ ref_aligned.append(word)
+ hyp_aligned.append('')
+ ref_syllable_count = 0
+
+ if not break_on_syllables or not hyp_extra_syllable_word_index:
+ hyp_aligned.append(' '.join(hyp_word_builder))
+ ref_syllable_count = 0
+ hyp_syllable_count = 0
+ else:
+ hyp_aligned.append(' '.join(hyp_word_builder[0:hyp_extra_syllable_word_index]))
+ # The remaining words are insertions
+ for word in hyp_word_builder[hyp_extra_syllable_word_index:]:
+ alignment.append(AlignLabels.insertion)
+ ref_aligned.append('')
+ hyp_aligned.append(word)
+ hyp_syllable_count = 0
+
+ if align_tok == AlignLabels.substitution:
+ # Check if you need to rework this alignment.
+ if len(ref_word_builder) != len(hyp_word_builder):
+ # Word count mismatch in the alignment span. Is there a possibility that we need to re-align this segment?
+ ref_word_span_curr = (ref_word_index, ref_word_index + len(ref_word_builder))
+ hyp_word_span_curr = (hyp_word_index, hyp_word_index + len(hyp_word_builder))
+ phone_align_curr = phone_align.subsequence(0, i+1, preserve_index=False)
+
+ lev = Levenshtein.align(
+ ref=phone_align_curr.s1_tokens(),
+ hyp=phone_align_curr.s2_tokens(),
+ reserve_list=PowerAligner.reserve_list,
+ exclusive_sets=PowerAligner.exclusive_sets,
+ weights=Levenshtein.wordAlignWeights) #,
+ #dist_penalty=PowerAligner.phoneDistPenalty, dist_penalty_set=Levenshtein.wordAlignWeights)
+
+ phone_align_adjusted = lev.expandAlignCompact()
+
+ if phone_align_curr.align != phone_align_adjusted.align:
+ # Looks like we need to redo the phone-to-word alignment.
+ worklist.append((ref_word_span_curr, hyp_word_span_curr, phone_align_adjusted))
+ else:
+ commit_alignment = True
+ else:
+ commit_alignment = True
+
+ elif align_tok == AlignLabels.deletion:
+ for word in ref_word_builder:
+ alignment.append(align_tok)
+ ref_aligned.append(word)
+ hyp_aligned.append('')
+
+ commit_alignment = True
+ ref_syllable_count = 0
+
+ elif align_tok == AlignLabels.insertion:
+ for word in hyp_word_builder:
+ alignment.append(align_tok)
+ ref_aligned.append('')
+ hyp_aligned.append(word)
+
+ commit_alignment = True
+ hyp_syllable_count = 0
+
+ if commit_alignment:
+ # Commit the alignment.
+ full_reference.extend(ref_aligned)
+ full_hypothesis.extend(hyp_aligned)
+ full_alignment.extend(alignment)
+ full_phone_align.append(phone_align.subsequence(0, i, preserve_index=False))
+ ref_aligned = []
+ hyp_aligned = []
+ alignment = []
+ break
+
+ # Add words if word boundaries are reached.
+ else:
+ if ref_type == TokType.WordBoundary:
+ ref_word_started = False
+ if hyp_type != TokType.WordBoundary and ref_word_builder and not hyp_word_builder:
+ # DELETION
+ # Ref word ended, but no hyp words have been added. Mark the current ref word(s) in the span as deletion errors.
+ # TODO: Dedupe this logic
+ for word in ref_word_builder:
+ alignment.append(AlignLabels.deletion)
+ ref_aligned.append(word)
+ hyp_aligned.append('')
+ ref_syllable_count = 0
+
+ # Commit the alignment.
+ full_reference.extend(ref_aligned)
+ full_hypothesis.extend(hyp_aligned)
+ full_alignment.extend(alignment)
+ full_phone_align.append(phone_align.subsequence(0, i, preserve_index=False))
+
+ # Add the remainder to the worklist
+ ref_word_span_next = (ref_word_index + len(ref_word_builder), ref_word_limit)
+ hyp_word_span_next = (hyp_word_index + len(hyp_word_builder), hyp_word_limit)
+ lev = Levenshtein.align(
+ ref=[x for x in phone_align.s1[i:] if x],
+ hyp=[x for x in phone_align.s2 if x],
+ reserve_list=PowerAligner.reserve_list,
+ exclusive_sets=PowerAligner.exclusive_sets,
+ weights=Levenshtein.wordAlignWeights) #,
+ #dist_penalty=PowerAligner.phoneDistPenalty, dist_penalty_set=Levenshtein.wordAlignWeights)
+ phone_align_next = lev.expandAlignCompact()
+
+ worklist.append((ref_word_span_next, hyp_word_span_next, phone_align_next))
+ break
+ elif ref_type == TokType.Phoneme and not ref_word_started:
+ ref_word_started = True
+ try:
+ ref_word_item = ref_word_iter.__next__()
+ ref_word_builder.append(ref_word_item[1])
+ except StopIteration:
+ pass
+
+ if hyp_type == TokType.WordBoundary:
+ hyp_word_started = False
+ if ref_type != TokType.WordBoundary and hyp_word_builder and not ref_word_builder:
+ # INSERTION
+ # Hyp word ended, but no ref words have been added. Mark the current hyp word(s) in the span as insertion errors.
+ # TODO: Dedupe this logic
+ for word in hyp_word_builder:
+ alignment.append(AlignLabels.insertion)
+ ref_aligned.append('')
+ hyp_aligned.append(word)
+ hyp_syllable_count = 0
+
+ # Commit the alignment.
+ full_reference.extend(ref_aligned)
+ full_hypothesis.extend(hyp_aligned)
+ full_alignment.extend(alignment)
+ full_phone_align.append(phone_align.subsequence(0, i, preserve_index=False))
+
+ # Add the remainder to the worklist
+ ref_word_span_next = (ref_word_index + len(ref_word_builder), ref_word_limit)
+ hyp_word_span_next = (hyp_word_index + len(hyp_word_builder), hyp_word_limit)
+ lev = Levenshtein.align(
+ ref=[x for x in phone_align.s1 if x],
+ hyp=[x for x in phone_align.s2[i:] if x],
+ reserve_list=PowerAligner.reserve_list,
+ exclusive_sets=PowerAligner.exclusive_sets,
+ weights=Levenshtein.wordAlignWeights) #,
+ #dist_penalty=PowerAligner.phoneDistPenalty, dist_penalty_set=Levenshtein.wordAlignWeights)
+ phone_align_next = lev.expandAlignCompact()
+
+ worklist.append((ref_word_span_next, hyp_word_span_next, phone_align_next))
+ break
+ elif hyp_type == TokType.Phoneme and not hyp_word_started:
+ hyp_word_started = True
+ try:
+ hyp_word_item = hyp_word_iter.__next__()
+ hyp_word_builder.append(hyp_word_item[1])
+ except StopIteration:
+ pass
+
+ # Check for syllable mismatches
+ if ref_type == TokType.SyllableBoundary:
+ ref_syllable_count += 1
+ if hyp_type == TokType.SyllableBoundary:
+ hyp_syllable_count += 1
+
+ if (ref_type == TokType.SyllableBoundary == hyp_type or ref_syllable_count == hyp_syllable_count):
+ # No syllable conflicts here!
+ ref_extra_syllable_word_index = None
+ hyp_extra_syllable_word_index = None
+ elif (ref_type == TokType.SyllableBoundary and
+ not ref_extra_syllable_word_index and
+ TokType.checkAnnotation(phone_align.s2[i - 1]) == TokType.WordBoundary):
+ # Extra syllable in hypothesis. We only care if the syllable immediately follows a word boundary.
+ # This is because this indicates that a new word is being formed, which may likely be an insertion in hyp.
+ ref_extra_syllable_word_index = len(ref_word_builder) - 1
+ # print ref_word_builder
+ # print 'Syllable/word mismatch at', i
+ # print 'Extra hyp word:', ref_word_builder[ref_extra_syllable_word_index]
+ elif (hyp_type == TokType.SyllableBoundary and
+ not hyp_extra_syllable_word_index and
+ TokType.checkAnnotation(phone_align.s2[i - 1]) == TokType.WordBoundary):
+ # This time there's an extra syllable in the ref, corresponding to a new ref word.
+ hyp_extra_syllable_word_index = len(hyp_word_builder) - 1
+ # print hyp_word_builder
+ # print 'Syllable/word mismatch at', i
+ # print 'Extra ref word:', hyp_word_builder[hyp_extra_syllable_word_index]
+ # Concatenate all phoneme alignments
+ fp_align = full_phone_align[0]
+ for expand_align in full_phone_align[1:]:
+ fp_align.append_alignment(expand_align)
+
+ return ExpandedAlignment(full_reference, full_hypothesis, full_alignment), fp_align
+
diff --git a/src/error_align/baselines/power/power/levenshtein.py b/src/error_align/baselines/power/power/levenshtein.py
new file mode 100644
index 0000000..a648ec4
--- /dev/null
+++ b/src/error_align/baselines/power/power/levenshtein.py
@@ -0,0 +1,576 @@
+from __future__ import division
+
+import itertools
+import re
+from collections import Counter, defaultdict, deque
+
+
+class AlignLabels:
+ correct = 'C'
+ substitution = 'S'
+ insertion = 'I'
+ deletion = 'D'
+ validOptions = set([correct, substitution, insertion, deletion])
+
+
+class ExpandedAlignment:
+ '''Levenshtein-aligned reference and hypothesis, not just edit distance score.'''
+
+ def __init__(self, s1, s2, align, s1_map=None, s2_map=None, lowercase=False):
+ if not (len(s1) == len(s2) == len(align)):
+ raise Exception("Length mismatch: align:{0:d}, s1:{1:d}, s2:{2:d}".format(
+ len(align), len(s1), len(s2)))
+ if len(align) == 0:
+ raise Exception("No alignment: strings are empty")
+
+ self.s1 = s1
+ self.s2 = s2
+ self.align = align
+ self.s1_map = s1_map
+ self.s2_map = s2_map
+ self.lowercase = lowercase
+
+ if not s1_map or s2_map:
+ self.recompute_alignment_maps()
+
+ def __str__(self):
+ widths = [max(len(self.s1[i]), len(self.s2[i]))
+ for i in range(len(self.s1))]
+ s1_args = zip(widths, self.s1)
+ s2_args = zip(widths, self.s2)
+ align_args = zip(widths, self.align)
+
+ value = 'REF: %s\n' % ' '.join(['%-*s' % x for x in s1_args])
+ value += 'HYP: %s\n' % ' '.join(['%-*s' % x for x in s2_args])
+ value += 'Eval: %s' % ' '.join(['%-*s' % x for x in align_args])
+
+ return value
+
+ def s1_string(self):
+ return ' '.join(self.s1_tokens())
+
+ def s2_string(self):
+ return ' '.join(self.s2_tokens())
+
+ def s1_tokens(self):
+ return [x for x in self.s1 if x != '']
+
+ def s2_tokens(self):
+ return [x for x in self.s2 if x != '']
+
+ def s1_align_tokens(self, i):
+ return self.s1[i].split()
+
+ def s2_align_tokens(self, i):
+ return self.s2[i].split()
+
+ def ref(self):
+ return self.s1_tokens()
+
+ def hyp(self):
+ return self.s2_tokens()
+
+ def length(self):
+ return len(self.align)
+
+ def pos(self, word):
+ s1_idx = [i for i in range(len(self.s1)) if self.s1[i] == word]
+ s2_idx = [i for i in range(len(self.s2)) if self.s2[i] == word]
+
+ return s1_idx, s2_idx
+
+ def subsequence(self, i, j, preserve_index=False):
+ # TODO: Right now we're losing any s1_map and s2_map components for compatibility reasons. Refactoring necessary.
+ scale = 0 if preserve_index else i
+ s1_map = [self.s1_map[k] -
+ scale for k in range(len(self.s1_map)) if i <= self.s1_map[k] < j]
+ s2_map = [self.s2_map[k] -
+ scale for k in range(len(self.s2_map)) if i <= self.s2_map[k] < j]
+ return ExpandedAlignment(self.s1[i:j], self.s2[i:j], self.align[i:j], s1_map, s2_map, lowercase=self.lowercase)
+
+ def split_error_regions(self, error_pattern='[SDI]*S[SDI]+|[SDI]+S[SDI]*'):
+ '''
+ Splits the object into a list of multiple segments.
+ Some segments are defined as error regions, containing at least one substitution error.
+ These error regions may candidates for realignment to alignment precision for downstream tasks.
+ (i.e. phoneme alignment or character alignment)
+ '''
+ split_regions = []
+ error_indexes = []
+
+ p = re.compile(error_pattern)
+ # Find candidate error regions in the current utterance
+ # Matches go from left to right
+ prev_index = 0
+
+ err_str = ''.join(self.align)
+ for match in p.finditer(err_str):
+ i, j = match.span()
+ if prev_index < i:
+ # Previous items are rolled up into a 'correct' segment
+ split_regions.append(self.subsequence(prev_index, i))
+ # Get the error region
+ error_indexes.append(len(split_regions))
+ split_regions.append(self.subsequence(i, j))
+ prev_index = j
+
+ # Add the trailing segment.
+ if prev_index < len(self.align):
+ split_regions.append(self.subsequence(prev_index, len(self.align)))
+ return split_regions, error_indexes
+
+ def append_alignment(self, expanded_alignment):
+ '''
+ Concatenates a string alignment to the current object.
+ '''
+ map_offset = self.length()
+
+ self.s1 += expanded_alignment.s1
+ self.s2 += expanded_alignment.s2
+ self.align += expanded_alignment.align
+ if self.s1_map and expanded_alignment.s1_map:
+ self.s1_map += [align_pos +
+ map_offset for align_pos in expanded_alignment.s1_map]
+ if self.s2_map and expanded_alignment.s2_map:
+ self.s2_map += [align_pos +
+ map_offset for align_pos in expanded_alignment.s2_map]
+
+ def recompute_alignment_maps(self):
+ '''
+ Regenerates s1_map and s2_map based on the alignment info.
+ '''
+ self.s1_map = []
+ self.s2_map = []
+
+ for i in range(self.length()):
+ if self.align[i] in (AlignLabels.correct, AlignLabels.substitution, AlignLabels.deletion):
+ self.s1_map.extend([i] * len(self.s1[i].split()))
+ if self.align[i] in (AlignLabels.correct, AlignLabels.substitution, AlignLabels.insertion):
+ self.s2_map.extend([i] * len(self.s2[i].split()))
+
+ def error_rate(self, cluster_on_ref=False):
+ '''
+ Computes WER or POWER.
+ self.s1 is considered to be the reference and self.s2 is the hypothesis.
+ '''
+ score_components = {AlignLabels.correct: 0, AlignLabels.substitution: 0,
+ AlignLabels.deletion: 0, AlignLabels.insertion: 0, 'L': 0}
+
+ for i in range(self.length()):
+ alignment = self.align[i]
+
+ magnitude = 1
+ if alignment != AlignLabels.insertion:
+ ref_seg_length = len(self.s1[i].split())
+ hyp_seg_length = len(self.s2[i].split())
+
+ if cluster_on_ref:
+ # NOTE: This relaxes the penalty of errors like
+ # anatomy -> "and that to me"
+ # And sharply penalizes
+ # "and that to me" -> anatomy
+ # but it might be in better sync with classic WER.
+ magnitude = ref_seg_length
+ else:
+ # NOTE: This causes a mismatch between reference length and the number of substitutions! Is this a problem?
+ magnitude = max(ref_seg_length, hyp_seg_length)
+ score_components['L'] += ref_seg_length
+
+ score_components[alignment] += magnitude
+
+ if not self.s1:
+ # No reference. Error is 100%
+ error_rate = 1.0
+ else:
+ error_rate = (score_components[AlignLabels.substitution] + score_components[AlignLabels.deletion] +
+ score_components[AlignLabels.insertion]) / score_components['L']
+ return error_rate, score_components
+
+ def confusion_pairs(self):
+ d = defaultdict(Counter)
+ for i in range(len(self.align)):
+ if self.align[i] == AlignLabels.substitution:
+ s1 = self.s1[i]
+ s2 = self.s2[i]
+ if self.lowercase:
+ s1 = s1.lower()
+ s2 = s2.lower()
+ d[s1] += Counter({s2: 1})
+ return d
+
+ def alignment_capacity(self):
+ '''
+ Returns the number of word slots occupied by each alignment point.
+ '''
+ return [(len(self.s1_align_tokens(i)), len(self.s2_align_tokens(i))) for i in range(self.length())]
+
+ def hyp_oriented_alignment(self, hyp_only=True):
+ '''
+ Returns all alignment tokens.
+ If an S slot is an multiword alignment, duplicates AlignLabels.substitution by the capacity.
+ TODO: Move to subclass.
+ '''
+ alignment = []
+ ref_align_len, hyp_align_len = zip(*self.alignment_capacity())
+
+ for i in range(self.length()):
+ if hyp_only:
+ # Treat each hyp token as a substitution error
+ alignment.extend(self.align[i] * max(1, hyp_align_len[i]))
+ else:
+ len_diff = ref_align_len[i] - hyp_align_len[i]
+ if len_diff == 0:
+ alignment.extend(self.align[i] * hyp_align_len[i])
+ elif len_diff < 0:
+ len_diff = -len_diff
+ alignment.extend(self.align[i] * ref_align_len[i])
+ alignment.extend([AlignLabels.insertion] * len_diff)
+ else:
+ alignment.extend(self.align[i] * hyp_align_len[i])
+ alignment.extend([AlignLabels.deletion] * len_diff)
+ return alignment
+
+
+class Levenshtein:
+ def __init__(self, lowercase=False, tokenMap=None):
+ self.backMatrix = None
+ self.distMatrix = None
+ self.dist = -1
+ self.s1 = None
+ self.s2 = None
+ self.edits = None
+ self.lowercase = lowercase
+
+ uniformWeights = {AlignLabels.correct: 0, AlignLabels.substitution: 1,
+ AlignLabels.deletion: 1, AlignLabels.insertion: 1}
+ wordAlignWeights = {AlignLabels.correct: 0, AlignLabels.substitution: 4,
+ AlignLabels.deletion: 3, AlignLabels.insertion: 3}
+
+ @staticmethod
+ def align(ref, hyp, reserve_list=None, exclusive_sets=None, lowercase=False, weights=None, dist_penalty=0.5, dist_penalty_set=None):
+ '''
+ Creates an alignment with hyp x ref matrix.
+ reserve_list defines tokens that may never have 'S' alignments.
+ exclusive_sets defines families of tokens that can have 'S' alignments. Anything outside of exclusive_sets can be aligned to any other nonmember.
+ '''
+ if not weights:
+ weights = Levenshtein.uniformWeights
+ lev = Levenshtein(lowercase=lowercase)
+ lev.s1 = ref
+ lev.s2 = hyp
+
+ # If using distance penaties:
+ distPenaltyRef = 0
+ distPenaltyHyp = 0
+
+ if lowercase:
+ ref = [x.lower() for x in ref]
+ hyp = [x.lower() for x in hyp]
+
+ #pp = pprint.PrettyPrinter(width=300)
+ lev.backMatrix = BackTrackMatrix(len(ref), len(hyp), weights)
+
+ # Starts with 1st word in hyp
+ for index2, char2 in enumerate(hyp):
+
+ if dist_penalty_set and char2 not in dist_penalty_set:
+ distPenaltyRef += 1
+ else:
+ distPenaltyRef = 0
+
+ # Loop through columns, corresponding to characters in hyp
+ for index1, char1 in enumerate(ref):
+ if dist_penalty_set and char1 not in dist_penalty_set:
+ distPenaltyHyp += 1
+ else:
+ distPenaltyHyp = 0
+
+ match_char = AlignLabels.substitution
+ # Add insert/delete options
+ insPenalty = lev.backMatrix.getWeight(
+ index2, index1+1) + weights[AlignLabels.insertion]
+ delPenalty = lev.backMatrix.getWeight(
+ index2+1, index1) + weights[AlignLabels.deletion]
+
+ if dist_penalty_set:
+ insPenalty += (distPenaltyHyp * dist_penalty) * \
+ weights[AlignLabels.insertion]
+ delPenalty += (distPenaltyRef * dist_penalty) * \
+ weights[AlignLabels.deletion]
+
+ opts = [insPenalty, # I
+ delPenalty, # D
+ ]
+
+ if char1 == char2:
+ opts.append(lev.backMatrix.getWeight(
+ index2, index1) + weights[AlignLabels.correct]) # C
+ match_char = AlignLabels.correct
+ elif not reserve_list or not (char1 in reserve_list or char2 in reserve_list):
+ if exclusive_sets:
+ # Check if char1 and char2 belong to the same exclusive set
+ no_membership_set = set([-1])
+
+ # Change this to handle multiple set membership
+ char1_sets = set(
+ [k for k in range(len(exclusive_sets)) if char1 in exclusive_sets[k]])
+ if not char1_sets:
+ char1_sets = no_membership_set
+ char2_sets = set(
+ [k for k in range(len(exclusive_sets)) if char2 in exclusive_sets[k]])
+ if not char2_sets:
+ char2_sets = no_membership_set
+
+ if set.intersection(char1_sets, char2_sets):
+ opts.append(lev.backMatrix.getWeight(
+ index2, index1) + weights[AlignLabels.substitution]) # S
+ else:
+ opts.append(lev.backMatrix.getWeight(
+ index2, index1) + weights[AlignLabels.substitution]) # S
+
+ # Get the locally minimum score and edit operation
+ minDist = min(opts)
+ minIndices = [i for i in reversed(
+ range(len(opts))) if opts[i] == minDist]
+
+ # Build the backtrack
+ alignLabels = []
+ for minIndex in minIndices:
+ if minIndex == 2:
+ alignLabels.append(match_char)
+ elif minIndex == 0:
+ alignLabels.append(AlignLabels.insertion)
+ elif minIndex == 1:
+ alignLabels.append(AlignLabels.deletion)
+
+ lev.backMatrix.addBackTrack(
+ index2+1, index1+1, alignLabels, minDist) # S/C
+ pass
+
+ lev.dist = lev.backMatrix.getWeight(
+ lev.backMatrix.hyplen, lev.backMatrix.reflen)
+ return lev
+
+ def matchPositions(self, token, token2=None, min_i=None, min_j=None, max_i=None, max_j=None):
+ if not min_i:
+ min_i = 0
+ if not max_i:
+ max_i = len(self.s2)
+
+ if not min_j:
+ min_j = 0
+ if not max_j:
+ max_j = len(self.s1)
+
+ if not token2:
+ token2 = token
+
+ # Find match positions of token in hyp
+ hypIdx = [i+1 for i in range(min_i, max_i) if self.s2[i] == token]
+ refIdx = [j+1 for j in range(min_j, max_j) if self.s1[j] == token]
+
+ return list(itertools.product(*[hypIdx, refIdx]))
+
+ def bestPathsGraph(self, minPos=None, maxPos=None):
+ """
+ Takes all of the best Levenshtein alignment backtrack paths and puts them in a graph.
+ The graph is weighted by the distance between minPos and maxPos for all paths.
+ """
+ import networkx as nx
+ if not minPos:
+ minPos = (0, 0)
+ if not maxPos:
+ maxPos = (self.backMatrix.hyplen, self.backMatrix.reflen)
+
+ chart = deque()
+ chart.appendleft(maxPos)
+
+ G = nx.Graph()
+
+ from time import time
+ start = time()
+ while chart:
+ (i, j) = chart.pop()
+
+ for alignLabel in self.backMatrix.matrix[i][j].backTrackOptions:
+ child = self.backMatrix.matrix[i][j].getBackTrackOffset(
+ alignLabel)
+ prev_i = i + child[1][0]
+ prev_j = j + child[1][1]
+
+ align = child[0]
+ rlabel = self.s1[j-1] if prev_j < j else ''
+ hlabel = self.s2[i-1] if prev_i < i else ''
+
+ # Weight applied based on whether (i & prev_i) or (j & prev_j) are along the hull of minPos or maxPos
+ weight = 1
+ if (i == prev_i and i in (minPos[0], maxPos[0])) or (j == prev_j and j in (minPos[1], maxPos[1])):
+ weight = 0
+
+ G.add_edge((i, j), (prev_i, prev_j), weight=weight,
+ labels=(rlabel, hlabel, align))
+ chart.appendleft((prev_i, prev_j))
+
+ if time() - start > 120:
+ print("\nWarning: Long computation time\n")
+ raise AssertionError("Computation took too long")
+ return G
+
+ def editops(self):
+ '''
+ Records edit distance operations in a compact format, changing s1 to s2.
+ '''
+ i = self.backMatrix.hyplen
+ j = self.backMatrix.reflen
+ back = []
+
+ while i > 0 or j > 0:
+ op = self.backMatrix.matrix[i][j].getBackTrackOffset()
+ off_i, off_j = op[1]
+ i += off_i
+ j += off_j
+ back.append((op[0], (i, j)))
+
+ back.reverse()
+ self.edits = back
+ return back
+
+ def expandAlign(self):
+ '''
+ Expands the edit operations to actually align the strings.
+ Also contains maps to track the positions of each character in the strings
+ to its aligned position.
+ '''
+ if not self.edits:
+ return None
+
+ s1 = []
+ s2 = []
+ align = []
+ s1_map = []
+ s2_map = []
+
+ for op in self.edits:
+ a = op[0]
+ i, j = op[1]
+
+ # Bugfix for empty hypotheses or reference (reference shouldn't happen)
+ c1 = None
+ c2 = None
+ if -1 < j < len(self.s1):
+ c1 = self.s1[j]
+ if -1 < i < len(self.s2):
+ c2 = self.s2[i]
+
+ if a == AlignLabels.deletion:
+ s1.append(c1)
+ s2.append('')
+ s1_map.append(len(align))
+ elif a == AlignLabels.insertion:
+ s1.append('')
+ s2.append(c2)
+ s2_map.append(len(align))
+ else:
+ s1.append(c1)
+ s2.append(c2)
+ s1_map.append(len(align))
+ s2_map.append(len(align))
+
+ align.append(a)
+
+ return ExpandedAlignment(s1, s2, align, s1_map, s2_map, lowercase=self.lowercase)
+
+ def expandAlignCompact(self, minPos=None, maxPos=None):
+ """
+ Using the backtracking matrix, finds all of the paths with the minimum Levenshtein distance score and stores them in a graph.
+ Then, it returns the expanded alignment of the shortest path in the graph (which still has the same minimum Lev distance score.
+ """
+ import networkx as nx
+ minPos = (0, 0)
+ maxPos = (self.backMatrix.hyplen, self.backMatrix.reflen)
+ G = self.bestPathsGraph(minPos, maxPos)
+ path = nx.shortest_path(
+ G, source=minPos, target=maxPos, weight='weight')
+
+ # Expand the best path into the Levenshtein alignment.
+ s1_align, s2_align, align = [list(a) for a in zip(
+ *(G[u][v]['labels'] for (u, v) in zip(path[0:], path[1:])))]
+ return ExpandedAlignment(s1_align, s2_align, align, lowercase=self.lowercase)
+
+ @staticmethod
+ def errorRate(s, d, i, reflength):
+ return (s + d + i) / reflength
+
+
+class BackTrackMatrix:
+ def __init__(self, reflen, hyplen, weights=Levenshtein.uniformWeights):
+ self.reflen = reflen
+ self.hyplen = hyplen
+ self.weights = weights
+
+ self.matrix = [[None] * (reflen+1) for i in range(hyplen+1)]
+ self.row_count = hyplen + 1
+ self.col_count = reflen + 1
+
+ # Initialize the top corner.
+ self.matrix[0][0] = BackTrackSlot(0)
+
+ # Initialize first column.
+ for i in range(1, self.row_count):
+ self.addBackTrack(i, 0, AlignLabels.insertion,
+ i*weights[AlignLabels.insertion])
+
+ # Initialize columns.
+ for j in range(1, self.col_count):
+ self.addBackTrack(0, j, AlignLabels.deletion,
+ j*weights[AlignLabels.deletion])
+
+ def addBackTrack(self, i, j, alignLabels, weight=1.0):
+ self.matrix[i][j] = BackTrackSlot(weight)
+ self.matrix[i][j].addOptions(alignLabels)
+
+ def backTrackOptions(self, i, j):
+ return self.matrix[i][j]
+
+ def getWeight(self, i, j):
+ return self.matrix[i][j].weight
+
+
+class BackTrackSlot:
+ def __init__(self, weight):
+ self.weight = weight
+ self.backTrackOptions = list()
+
+ def __str__(self):
+ return "({0}, {1})".format(self.weight, ','.join(list(self.backTrackOptions)))
+
+ def iterOptions(self):
+ return iter(self.backTrackOptions)
+
+ def addOption(self, alignLabel):
+ if alignLabel not in self.backTrackOptions:
+ self.backTrackOptions.append(alignLabel)
+
+ def addOptions(self, alignLabels):
+ self.backTrackOptions.extend(
+ [x for x in alignLabels if x not in self.backTrackOptions])
+
+ def getBackTrackOffset(self, alignLabel=None):
+ if alignLabel:
+ # Make sure it exists
+ if alignLabel not in AlignLabels.validOptions:
+ raise Exception("Invalid backtrack option: %s" % alignLabel)
+ if alignLabel not in self.backTrackOptions:
+ raise Exception("Illegal backtrack option: %s" % alignLabel)
+ else:
+ # Just arbitrarily grab the first item
+ alignLabel = self.backTrackOptions[0]
+
+ offset = None
+ if alignLabel in (AlignLabels.correct, AlignLabels.substitution):
+ offset = (-1, -1)
+ elif alignLabel == AlignLabels.deletion:
+ offset = (0, -1)
+ else:
+ offset = (-1, 0)
+ return (alignLabel, offset)
diff --git a/src/error_align/baselines/power/power/phonemes.py b/src/error_align/baselines/power/power/phonemes.py
new file mode 100644
index 0000000..5a187ea
--- /dev/null
+++ b/src/error_align/baselines/power/power/phonemes.py
@@ -0,0 +1,16 @@
+class Phonemes(object):
+ # Vowels
+ monophthongs = set(['ao', 'aa', 'iy', 'uw', 'eh', 'ih', 'uh', 'ah', 'ax', 'ae'])
+ diphthongs = set(['ey', 'ay', 'ow', 'aw', 'oy'])
+ # Add r-colored monophthongs?
+ r_vowels = set(['er', 'axr']) # The remaining are split into two tokens.
+ vowels = set.union(monophthongs, diphthongs, r_vowels)
+
+ # Consonants
+ c_stops = set(['p','b','t','d','k','g'])
+ c_afficates = set(['ch','jh'])
+ c_fricatives = set(['f','v','th','dh','s','z','sh','zh','hh'])
+ c_nasals = set(['m','em','n','en','ng','eng'])
+ c_liquids = set(['l','el','r','dx','nx'])
+ c_semivowels = set(['y','w','q'])
+ consonants = set.union(c_stops, c_afficates, c_fricatives, c_nasals, c_liquids, c_semivowels)
\ No newline at end of file
diff --git a/src/error_align/baselines/power/power/pronounce.py b/src/error_align/baselines/power/power/pronounce.py
new file mode 100644
index 0000000..3e47ae4
--- /dev/null
+++ b/src/error_align/baselines/power/power/pronounce.py
@@ -0,0 +1,85 @@
+'''
+Created on Mar 30, 2015
+
+@author: Nick Ruiz
+Assuming you have a file of word sequences, this processes them through Festival to generate pronunciations.
+
+'''
+import json
+from itertools import groupby
+
+import pyphen
+
+from error_align.baselines.power.normalize import NumToTextEng, splitHyphens
+
+
+class PronouncerType:
+ Base = "base"
+ Lexicon = "lexicon"
+
+class PronouncerBase(object):
+ def __init__(self):
+ pass
+ def pronounce(self, words):
+ '''G2P conversion'''
+ raise NotImplementedError
+
+class PronouncerLex(PronouncerBase):
+ '''Lexicon-based pronunciation generator. Looks up words in the lexicon and if they aren't found, uses a hacky alternative.
+ NOTE: English-only
+ '''
+ def __init__(self, lexicon):
+ with open(lexicon, 'r') as f:
+ self.lexicon = json.load(f)
+ self.fallbackDict = pyphen.Pyphen(lang='en_US')
+
+ def pronounce(self, words):
+ '''G2P using Pyphen + heuristics'''
+ wordsLower = (w.lower() for w in words)
+ prons = [self.lexicon[w] if w in self.lexicon else self.alt_pronounce(
+ w) for w in wordsLower]
+ return "| {0} |".format(' | '.join(prons)).split()
+
+ def alt_pronounce(self, word):
+ '''Alternative ways to pronounce the word. Adds simple digit to word conversion'''
+ prons = []
+ # Split words along hyphens
+ wordsSplit = splitHyphens(' '.join(word))
+ # Instead of one word, we may have many
+ for myword in wordsSplit:
+ if myword.isdigit():
+ words = NumToTextEng.convert(int(myword)).split()
+ pron = ' # '.join((self.pyphen_pronounce(w) for w in words))
+ else:
+ pron = self.pyphen_pronounce(myword)
+ prons.append(pron)
+ # Although the hyphenated word is now pronounced as multiple "words",
+ # we treat again as a single word with multiple syllables
+ return ' # '.join(prons)
+
+
+ def pyphen_pronounce(self, word):
+ ''' Uses pyphen as a back-off to generate a pseudo-hyphenated-pronounciation for words not in the lexicon. '''
+ syllables = self.fallbackDict.inserted(word).split('-')
+ pron = []
+ for syl in syllables:
+ m = 0
+ n = len(syl)
+ sylpron = []
+ while m < n:
+ pronidx = [(n-i, self.lexicon[syl[m:n-i]])
+ for i in range(n) if syl[m:n-i] in self.lexicon]
+ if not pronidx:
+ break
+ j, p = pronidx[0]
+ sylpron.append(p)
+ m = j
+ if not sylpron:
+ # TODO: No dictionary entries found for this word. Error?
+ # Break syl into characters
+ sylpron = syl
+
+ # Remove duplicate adjacent phonemes
+ sylpron = list(i for i, x in groupby(sylpron))
+ pron.append(' '.join(sylpron))
+ return ' # '.join(pron)
diff --git a/src/error_align/baselines/power/power/punct.py b/src/error_align/baselines/power/power/punct.py
new file mode 100644
index 0000000..1ee38c1
--- /dev/null
+++ b/src/error_align/baselines/power/power/punct.py
@@ -0,0 +1,166 @@
+import copy
+import re
+
+from error_align.baselines.power.power.aligner import CharToWordAligner
+from error_align.baselines.power.power.levenshtein import AlignLabels
+
+
+class PunctInsertOracle(object):
+ '''
+ Insert punctuation on hypothesis based on alignment with reference
+ '''
+
+ def __init__(self, params):
+ '''
+ Constructor
+ '''
+ pass
+
+ @staticmethod
+ def insertPunct(error_alignment, ref_punct):
+ if not (error_alignment.s1_map and error_alignment.s2_map):
+ # Don't try to add punct if either the hyp or ref are empty.
+ return error_alignment
+
+ # Align reference with punct against reference without punct
+ # Is this needed? Yes, in case there's punctuation that makes an extra token
+ ref_punct_tokens = ref_punct.split()
+ ref_nopunct_tokens = error_alignment.s1_string().split() # [x for x in error_alignment.s1 if x]
+ hyp_nopunct_tokens = error_alignment.s2_string().split() # [x for x in error_alignment.s2 if x]
+
+ ref_punct_string = ' '.join(ref_punct_tokens)
+ ref_nopunct_string = ' '.join(ref_nopunct_tokens)
+ c2w_aligner = CharToWordAligner(ref_punct_string, ref_nopunct_string, lowercase=True)
+ c2w_aligner.charAlign()
+ ref_punct_align = c2w_aligner.charAlignToWordAlign()
+
+# lev = Levenshtein.align(ref_punct_tokens, ref_nopunct_tokens, lowercase=True)
+# lev.editops()
+# ref_punct_align = lev.expandAlign()
+
+ error_alignment_punct = copy.deepcopy(error_alignment)
+
+ try:
+ err_align_index = None
+ err_align_label = None
+ ref_punct_index = -1
+ ref_nopunct_index = -1
+ for ref_align_index in range(len(ref_punct_align.align)):
+ # Check the error type.
+ ref_align_error_type = ref_punct_align.align[ref_align_index]
+ ref_word_punct = ref_punct_align.s1[ref_align_index]
+ ref_word_nopunct = ref_punct_align.s2[ref_align_index]
+
+ punct_lhs = ""
+ punct_rhs = ""
+
+ if ref_align_error_type in (AlignLabels.correct, AlignLabels.substitution):
+ # There's some punct attached to the word. Split it off!
+ # Idea #1, Use regex to find punct around a word sequence. However, this will fail if there's punctuation in between words
+ pattern = '^(.*?){0}(.*?)$'.format(ref_word_nopunct)
+ match = re.search(pattern, ref_word_punct)
+
+ if not match:
+ # Whoops! something went wrong
+ raise Exception("Regex match not found\nPattern: {0}\nPunct: {1}\nNo Punct: {2}"
+ .format(pattern, ref_word_nopunct, ref_word_nopunct))
+ punct_lhs = match.group(1)
+ punct_rhs = match.group(2)
+
+ if punct_lhs:
+ # Apply to front of hyp
+ punct_lhs = punct_lhs.strip()
+ else:
+ punct_lhs = ""
+ if punct_rhs:
+ # Apply to back of hyp
+ punct_rhs = punct_rhs.strip()
+ else:
+ punct_rhs = ""
+
+ ref_punct_index += 1
+ ref_nopunct_index += 1
+
+ elif ref_align_error_type == AlignLabels.deletion:
+ # TODO: The extra token on the reference side is probably punct. Check if token is punct
+ # Determine if punct should attach as punct_lhs or punct_rhs. We'll make that decision later.
+ punct_lhs = '{0} '.format(ref_word_punct)
+ punct_rhs = ' {0}'.format(ref_word_punct)
+
+ ref_punct_index += 1
+ elif ref_align_error_type == 'I':
+ # TODO: uh-oh. This shouldn't happen!
+ ref_nopunct_index += 1
+ raise TypeError("Insertion errors shouldn't happen between punct and nopunct")
+ else:
+ # Now this is impossible!
+ raise TypeError("Invalid error type: {0}".format(ref_align_error_type))
+
+ # Advance error_alignment to the word alignment position containing the nopunct reference word
+ if ref_nopunct_index >= 0: # TODO: Check if the words have changed.
+ try:
+ err_align_index = error_alignment.s1_map[ref_nopunct_index]
+ err_align_label = error_alignment_punct.align[err_align_index]
+ except Exception:
+ print(error_alignment)
+ print("ref_nopunct_index:", ref_nopunct_index)
+ print("ref_align_index: ", ref_align_index)
+ raise
+
+ if punct_lhs or punct_rhs:
+# print "Punct: |_{0}_| |_{1}_|".format(punct_lhs, punct_rhs)
+ if err_align_label == AlignLabels.deletion:
+ # If there's no token here, apply punct to previous or next
+ # TODO: This won't work for long strings of D's. You'll need to find the next closest one.
+
+ # Scan backward for the previous non-empty hyp word. Try to apply the punct there.
+ err_align_index_shift = err_align_index - 1
+
+ # ... except if the punct was at the end of the last ref word
+ if ref_nopunct_index == len(ref_nopunct_tokens) - 1:
+ err_align_index_shift = error_alignment.s2_map[len(hyp_nopunct_tokens) - 1]
+
+ while err_align_index_shift > 0 and not error_alignment.s2[err_align_index_shift]:
+ err_align_index_shift -= 1
+ if err_align_index_shift >= 0:
+
+ # Apply punct to previous position
+ if ref_align_error_type == AlignLabels.deletion:
+ punct_lhs = ""
+
+ error_alignment_punct.s2[err_align_index_shift] = "{1}{0}{2}".format(error_alignment_punct.s2[err_align_index_shift], punct_lhs, punct_rhs)
+ else:
+ # Scan forward for the next non-empty hyp word. Try to apply the punct there.
+ err_align_index_shift = err_align_index + 1
+ while err_align_index_shift < len(hyp_nopunct_tokens) and not error_alignment_punct.s2[err_align_index_shift]:
+ err_align_index_shift += 1
+ if err_align_index_shift < len(hyp_nopunct_tokens):
+
+ # Apply punct to next position
+ if ref_align_error_type == "D":
+ punct_rhs = ""
+
+ error_alignment_punct.s2[err_align_index_shift] = "{1}{0}{2}".format(error_alignment_punct.s2[err_align_index_shift], punct_lhs, punct_rhs)
+ else:
+ # Otherwise ignore punct altogether
+ print("Discarding punct")
+ else:
+ # Apply punct on the end of the current word.
+ if ref_align_error_type == "D":
+ punct_lhs = "" # TODO: Is this assumption correct? Apply punct to rhs?
+
+ # Force punct at the end of the last reference word to appear at the end of the hyp word
+ if ref_nopunct_index == len(ref_nopunct_tokens) - 1:
+ err_align_index_shift = error_alignment.s2_map[len(hyp_nopunct_tokens) - 1]
+ error_alignment_punct.s2[err_align_index_shift] = "{0}{1}".format(error_alignment_punct.s2[err_align_index_shift], punct_rhs)
+ punct_rhs = ""
+
+ error_alignment_punct.s2[err_align_index] = "{1}{0}{2}".format(error_alignment_punct.s2[err_align_index], punct_lhs, punct_rhs)
+
+ except Exception:
+ # TODO: Add exception handling to output the sentences as well as the offending segment.
+ raise
+
+ return error_alignment_punct
+
+
\ No newline at end of file
diff --git a/src/error_align/baselines/power/power/readers.py b/src/error_align/baselines/power/power/readers.py
new file mode 100644
index 0000000..512aeb6
--- /dev/null
+++ b/src/error_align/baselines/power/power/readers.py
@@ -0,0 +1,48 @@
+import json
+import sys
+
+from error_align.baselines.power.power.levenshtein import ExpandedAlignment
+
+
+class AlignmentReaderJson(object):
+ def __init__(self, filepath):
+ self.filepath = filepath
+
+ def read_alignments(self):
+ with open(self.filepath, 'r') as f:
+ for line in f:
+ yield AlignmentReaderJson.read_json(line)
+
+ @staticmethod
+ def read_json(jstr):
+ in_dict = json.loads(jstr)
+ if not in_dict:
+ return None
+
+ ref = []
+ hyp = []
+ align = []
+ ref_map = []
+ hyp_map = []
+
+ for i in range(len(in_dict['alignments'])):
+ alignment = in_dict['alignments'][i]
+ ref.append(alignment['ref'])
+ hyp.append(alignment['hyp'])
+ align.append(alignment['align'])
+ if ref[-1]:
+ ref_map.extend([i for r in ref[-1].split()])
+ if hyp[-1]:
+ hyp_map.extend([i for h in hyp[-1].split()])
+ return ExpandedAlignment(ref, hyp, align, ref_map, hyp_map)
+
+
+def main(args):
+ filename = args[0]
+ reader = AlignmentReaderJson(filename)
+ for alignment in reader.read_alignments():
+ print(alignment)
+ print('---')
+
+if __name__ == '__main__':
+ main(sys.argv[1:])
\ No newline at end of file
diff --git a/src/error_align/baselines/power/power/writers.py b/src/error_align/baselines/power/power/writers.py
new file mode 100644
index 0000000..62e0267
--- /dev/null
+++ b/src/error_align/baselines/power/power/writers.py
@@ -0,0 +1,259 @@
+from __future__ import division
+
+import abc
+import json
+
+from error_align.baselines.power.power.levenshtein import Levenshtein
+
+
+def CreateWriter(output_type, filepath, hypfile, reffile):
+ writer = None
+ if output_type == "sgml":
+ writer = SgmlWriter(filepath, hypfile, reffile)
+ elif output_type == "snt":
+ writer = SntWriter(filepath, hypfile, reffile)
+ elif output_type == "json":
+ writer = JsonWriter(filepath, hypfile, reffile)
+ elif output_type == "align":
+ writer = AlignWriter(filepath, hypfile, reffile)
+ else:
+ raise NotImplementedError("Writer not implemented for %s" % format)
+ return writer
+
+
+class WerWriter(object):
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, filepath, hypfile, reffile):
+ self.out_file = open(filepath, "w")
+ self.filepath = filepath
+
+ @abc.abstractmethod
+ def write(self, segid, score_components, expanded_alignment, phonetic_alignments=None):
+ return
+
+ def write_blank(self):
+ return
+
+ def finalize(self):
+ if self.out_file:
+ self.out_file.close()
+ print("File written to {}".format(self.filepath))
+ self.out_file = None
+
+
+class AlignWriter(WerWriter):
+ def __init__(self, filepath, hypfile, reffile):
+ WerWriter.__init__(self, filepath, hypfile, reffile)
+
+ def write(self, segid, score_components, expanded_alignment, phonetic_alignments=None):
+ self.out_file.write("{0}\n".format(" ".join(expanded_alignment.alignment_expanded())))
+
+ def finalize(self):
+ WerWriter.finalize(self)
+
+
+class SgmlWriter(WerWriter):
+ def __init__(self, filepath, hypfile, reffile):
+ WerWriter.__init__(self, filepath, hypfile, reffile)
+ self.out_file.write(
+ '\n' % (hypfile, reffile, hypfile)
+ )
+ self.out_file.write('\n')
+
+ def write(self, segid, score_components, expanded_alignment, phonetic_alignments=None):
+ self.out_file.write('\n' % (segid, len(expanded_alignment.align)))
+ self.out_file.write(
+ "%s\n"
+ % ":".join(
+ [
+ '%s,"%s","%s"' % (expanded_alignment.align[i], expanded_alignment.s1[i], expanded_alignment.s2[i])
+ for i in range(len(expanded_alignment.align))
+ ]
+ )
+ )
+ self.out_file.write("\n")
+
+ def finalize(self):
+ self.out_file.write("\n")
+ self.out_file.write("\n")
+ WerWriter.finalize(self)
+
+
+class JsonWriter(WerWriter):
+ def __init__(self, filepath, hypfile, reffile):
+ WerWriter.__init__(self, filepath, hypfile, reffile)
+
+ def write(self, segid, score_components, expanded_alignment, phonetic_alignments=None):
+ out_dict = dict()
+ out_dict["id"] = segid
+ out_dict["errorTypes"] = score_components.copy()
+ out_dict["errorTypes"]["refLength"] = out_dict["errorTypes"].pop("L")
+ out_dict["errRate"] = Levenshtein.errorRate(
+ score_components["S"], score_components["D"], score_components["I"], score_components["L"]
+ )
+ out_dict["alignments"] = []
+ for i in range(len(expanded_alignment.align)):
+ out_dict["alignments"].append(
+ {"align": expanded_alignment.align[i], "ref": expanded_alignment.s1[i], "hyp": expanded_alignment.s2[i]}
+ )
+ self.out_file.write("%s\n" % json.dumps(out_dict))
+
+ def write_blank(self):
+ self.out_file.write("%s\n" % json.dumps({}))
+
+ def finalize(self):
+ WerWriter.finalize(self)
+
+
+class SntWriter(WerWriter):
+ def __init__(self, filepath, hypfile, reffile):
+ WerWriter.__init__(self, filepath, hypfile, reffile)
+ self.out_file.write("===============================================================================\n")
+ self.out_file.write(" SENTENCE LEVEL REPORT FOR THE SYSTEM:\n")
+ self.out_file.write("\tName: %s\n" % hypfile)
+ self.out_file.write("===============================================================================\n")
+ self.out_file.write("\n\n")
+
+ def write(self, segid, score_components, expanded_alignment, phonetic_alignments=None):
+ self.out_file.write("id: (%d)\n" % segid)
+ labels = ["C", "S", "D", "I"]
+ counts = [score_components[label] for label in labels]
+
+ # Print score components
+ self.out_file.write(
+ "Scores (%s) %s\n" % (" ".join(["#%s" % label for label in labels]), " ".join([str(x) for x in counts]))
+ )
+
+ # Print word alignment
+ self.out_file.write("%s\n" % expanded_alignment)
+ self.out_file.write("\n")
+
+ # Print phonetic alignments
+ if phonetic_alignments:
+ for palign in [p for p in phonetic_alignments if p]:
+ self.out_file.write("%s\n" % palign)
+ self.out_file.write("\n")
+
+ # Print statistics
+ self.out_file.write(
+ "Correct = {0:4.1%} {1} ({2})\n".format(
+ score_components["C"] / score_components["L"], score_components["C"], score_components["L"]
+ )
+ )
+ self.out_file.write(
+ "Substitutions = {0:4.1%} {1} ({2})\n".format(
+ score_components["S"] / score_components["L"], score_components["S"], score_components["L"]
+ )
+ )
+ self.out_file.write(
+ "Deletions = {0:4.1%} {1} ({2})\n".format(
+ score_components["D"] / score_components["L"], score_components["D"], score_components["L"]
+ )
+ )
+ self.out_file.write(
+ "Insertions = {0:4.1%} {1} ({2})\n".format(
+ score_components["I"] / score_components["L"], score_components["I"], score_components["L"]
+ )
+ )
+ self.out_file.write("\n")
+ self.out_file.write(
+ "Errors = {0:4.1%} {1} ({2})\n".format(
+ Levenshtein.errorRate(
+ score_components["S"], score_components["D"], score_components["I"], score_components["L"]
+ ),
+ score_components["S"] + score_components["D"] + score_components["I"],
+ score_components["L"],
+ )
+ )
+ self.out_file.write("\n")
+ self.out_file.write(
+ "Ref. words = {0} ({1})\n".format(score_components["L"], score_components["L"])
+ )
+ self.out_file.write(
+ "Hyp. words = {0} ({1})\n".format(
+ len(expanded_alignment.s2_string().split()), score_components["L"]
+ )
+ )
+ self.out_file.write(
+ "Aligned words = {0} ({1})\n".format(
+ score_components["C"] + score_components["S"], score_components["L"]
+ )
+ )
+ self.out_file.write("\n")
+ self.out_file.write("-------------------------------------------------------------------------------\n")
+ self.out_file.write("\n")
+
+ def finalize(self):
+ WerWriter.finalize(self)
+
+
+class ConfusionPairWriter(WerWriter):
+ @staticmethod
+ def write(filepath, hypfile, reffile, conf_dict):
+ with open(filepath, "w") as out_file:
+ out_file.write("System name: %s\n" % hypfile)
+ out_file.write("Ref file : %s\n" % reffile)
+
+ for key in sorted(conf_dict.keys()):
+ for item in sorted(conf_dict[key].keys()):
+ out_file.write("%s\t==>\t%s\t%d\n" % (key, item, conf_dict[key][item]))
+
+ @staticmethod
+ def write_json(filepath, hypfile, reffile, conf_dict):
+ with open(filepath, "w") as out_file:
+ out_file.write("%s\n" % json.dumps(conf_dict))
+
+
+class CompareWriter:
+ @staticmethod
+ def write_comparison(
+ filepath,
+ hypfile,
+ reffile,
+ linecount,
+ final_power,
+ final_wer,
+ power_score_components,
+ wer_score_components,
+ diff_score,
+ diff_components,
+ ):
+ with open(filepath, "w") as out_file:
+ out_file.write("System name: %s\n" % hypfile)
+ out_file.write("Ref file : %s\n" % reffile)
+ out_file.write("Hyp file : %s\n" % hypfile)
+ out_file.write(
+ """
+,---------------------------------------------------------.
+|{0:^57}|
+|---------------------------------------------------------|
+| Metric | # Snt # Wrd | Corr Sub Del Ins Err |
+|--------+-------------+----------------------------------|
+| POWER | {1:5d} {8:5d} | {9:5d} {10:5d} {11:5d} {12:5d} {13:3.1%} |
+| WER | {1:5d} {2:5d} | {3:5d} {4:5d} {5:5d} {6:5d} {7:3.1%} |
+|=========================================================|
+| Diff | {1:5d} {2:5d} | {14:-5d} {15:-5d} {16:-5d} {17:-5d} {18:-3.1%} |
+`---------------------------------------------------------'
+""".format(
+ hypfile,
+ linecount,
+ wer_score_components["L"],
+ wer_score_components["C"],
+ wer_score_components["S"],
+ wer_score_components["D"],
+ wer_score_components["I"],
+ final_wer,
+ power_score_components["L"],
+ power_score_components["C"],
+ power_score_components["S"],
+ power_score_components["D"],
+ power_score_components["I"],
+ final_power,
+ diff_components["C"],
+ diff_components["S"],
+ diff_components["D"],
+ diff_components["I"],
+ diff_score,
+ )
+ )
diff --git a/src/error_align/baselines/power_alignment.py b/src/error_align/baselines/power_alignment.py
new file mode 100644
index 0000000..cbbc816
--- /dev/null
+++ b/src/error_align/baselines/power_alignment.py
@@ -0,0 +1,61 @@
+from error_align.baselines.power.power.aligner import PowerAligner as _PowerAligner
+from error_align.utils import Alignment, OpType
+
+
+class PowerAlign:
+ """Phonetically-oriented word error alignment."""
+
+ def __init__(
+ self,
+ ref: str,
+ hyp: str,
+ ):
+ """Initialize the phonetically-oriented word error alignment with reference and hypothesis texts.
+
+ Args:
+ ref (str): The reference sequence/transcript.
+ hyp (str): The hypothesis sequence/transcript.
+ """
+ self.aligner = _PowerAligner(
+ ref=ref,
+ hyp=hyp,
+ lowercase=True,
+ verbose=True,
+ lexicon="/home/lb/repos/power-asr/lex/cmudict.rep.json",
+ )
+
+ def align(self):
+ """Run the two-pass Power alignment algorithm.
+
+ Returns:
+ list[Alignment]: A list of Alignment objects.
+ """
+ self.aligner.align()
+ widths = [
+ max(len(self.aligner.power_alignment.s1[i]), len(self.aligner.power_alignment.s2[i]))
+ for i in range(len(self.aligner.power_alignment.s1))
+ ]
+ s1_args = list(zip(widths, self.aligner.power_alignment.s1))
+ s2_args = list(zip(widths, self.aligner.power_alignment.s2))
+ align_args = list(zip(widths, self.aligner.power_alignment.align))
+
+ alignments = []
+ for (_, ref_token), (_, hyp_token), (_, align_token) in zip(s1_args, s2_args, align_args):
+
+ if align_token == "C":
+ op_type = OpType.MATCH
+ if align_token == "S":
+ op_type = OpType.SUBSTITUTE
+ if align_token == "I":
+ op_type = OpType.INSERT
+ if align_token == "D":
+ op_type = OpType.DELETE
+
+ alignment = Alignment(
+ op_type=op_type,
+ ref=ref_token,
+ hyp=hyp_token,
+ )
+ alignments.append(alignment)
+
+ return alignments
diff --git a/src/error_align/baselines/rapidfuzz_word_alignment.py b/src/error_align/baselines/rapidfuzz_word_alignment.py
new file mode 100644
index 0000000..4bd11d1
--- /dev/null
+++ b/src/error_align/baselines/rapidfuzz_word_alignment.py
@@ -0,0 +1,100 @@
+from rapidfuzz.distance import Levenshtein
+
+from error_align.utils import (
+ Alignment,
+ OpType,
+ basic_normalizer,
+ basic_tokenizer,
+)
+
+OPS_MAP = {
+ "match": OpType.MATCH,
+ "replace": OpType.SUBSTITUTE,
+ "insert": OpType.INSERT,
+ "delete": OpType.DELETE,
+}
+
+
+class RapidFuzzWordAlign:
+ """Levenshtein-based word-level alignment."""
+
+ def __init__(
+ self,
+ ref: str,
+ hyp: str,
+ tokenizer: callable = basic_tokenizer,
+ normalizer: callable = basic_normalizer,
+ ):
+ """Initialize the Levenshtein-based word-level alignment with reference and hypothesis texts.
+
+ Args:
+ ref (str): The reference sequence/transcript.
+ hyp (str): The hypothesis sequence/transcript.
+ tokenizer (callable): A function to tokenize the sequences. Must be regex-based and return Match objects.
+ normalizer (callable): A function to normalize the tokens. Defaults to basic_normalizer.
+ """
+ if not isinstance(ref, str):
+ raise TypeError("Reference sequence must be a string.")
+ if not isinstance(hyp, str):
+ raise TypeError("Hypothesis sequence must be a string.")
+
+ self.ref = ref
+ self.hyp = hyp
+ self._ref_token_matches = tokenizer(ref)
+ self._hyp_token_matches = tokenizer(hyp)
+ self._ref = [normalizer(r.group()) for r in self._ref_token_matches]
+ self._hyp = [normalizer(h.group()) for h in self._hyp_token_matches]
+ self._ref_max_idx = len(self._ref) - 1
+ self._hyp_max_idx = len(self._hyp) - 1
+ self.end_index = (self._hyp_max_idx, self._ref_max_idx)
+
+ def align(self) -> list[Alignment]:
+ """Extract an arbitrary path from the backtrace graph.
+
+ Returns:
+ list[Alignment]: A list of Alignment objects.
+ """
+ edit_ops = Levenshtein.editops(self._ref, self._hyp).as_list()
+
+ # Add match segments to editops
+ ref_edit_idxs = set([op[1] for op in edit_ops if op[0] != "insert"])
+ hyp_edit_idxs = set([op[2] for op in edit_ops if op[0] != "delete"])
+ ref_match_idxs = [i for i in range(len(self._ref)) if i not in ref_edit_idxs]
+ hyp_match_idxs = [i for i in range(len(self._hyp)) if i not in hyp_edit_idxs]
+ assert len(ref_match_idxs) == len(hyp_match_idxs)
+ for ref_idx, hyp_idx in zip(ref_match_idxs, hyp_match_idxs):
+ edit_ops.append(("match", ref_idx, hyp_idx))
+ edit_ops = sorted(edit_ops, key=lambda x: (x[1], x[2]))
+
+ # Convert to Alignment objects
+ alignments = []
+ for op_type, ref_idx, hyp_idx in edit_ops:
+ if op_type == "match" or op_type == "replace":
+ ref_match = self._ref_token_matches[ref_idx]
+ hyp_match = self._hyp_token_matches[hyp_idx]
+ alignment = Alignment(
+ op_type=OPS_MAP[op_type],
+ ref_slice=slice(*ref_match.span()),
+ hyp_slice=slice(*hyp_match.span()),
+ ref=ref_match.group(),
+ hyp=hyp_match.group(),
+ )
+ elif op_type == "delete":
+ ref_match = self._ref_token_matches[ref_idx]
+ alignment = Alignment(
+ op_type=OPS_MAP[op_type],
+ ref_slice=slice(*ref_match.span()),
+ ref=ref_match.group(),
+ )
+ elif op_type == "insert":
+ hyp_match = self._hyp_token_matches[hyp_idx]
+ alignment = Alignment(
+ op_type=OPS_MAP[op_type],
+ hyp_slice=slice(*hyp_match.span()),
+ hyp=hyp_match.group(),
+ )
+ else:
+ raise ValueError(f"Unknown operation type: {op_type}")
+ alignments.append(alignment)
+
+ return alignments
diff --git a/src/error_align/baselines/utils.py b/src/error_align/baselines/utils.py
new file mode 100644
index 0000000..4818e1e
--- /dev/null
+++ b/src/error_align/baselines/utils.py
@@ -0,0 +1,93 @@
+import unicodedata
+
+import regex as re
+from num2words import num2words
+
+from error_align.utils import NUMERIC_TOKEN, basic_normalizer, basic_tokenizer
+
+
+def strip_accents(text: str) -> str:
+ """Strip accents from the text."""
+ normalized = unicodedata.normalize("NFD", text)
+ return "".join(c for c in normalized if unicodedata.category(c) != "Mn")
+
+
+def normalize_evaluation_segment(segment: str) -> str:
+ """Normalize a segment by removing accents and converting to lowercase.
+
+ Args:
+ segment (str): The segment to normalize.
+
+ Returns:
+ str: The normalized segment.
+ """
+ return re.sub(r"[^a-z0-9]", "", strip_accents(segment.lower()))
+
+
+def convert_numbers_to_words(text: str, lang: str = "en") -> str:
+ """Convert numeric tokens in the text to their word representation.
+
+ Args:
+ text (str): The input text containing numeric tokens.
+ lang (str): The language to use for conversion (default is "en").
+
+ Returns:
+ str: The text with numeric tokens converted to words.
+ """
+ if bool(re.match(NUMERIC_TOKEN, text)):
+ if lang == "en":
+ # NOTE: num2words doesn't support thousands separators
+ text_ = text.replace(",", "")
+ else:
+
+ text_ = text.replace(".", "")
+ text_ = text_.replace(",", ".")
+ try:
+ return num2words(text_, lang=lang)
+ except Exception:
+ pass
+
+ return text
+
+
+def clean_text(text: str, lang: str = "en") -> dict:
+ """
+ Cleans the text by removing examples with empty transcriptions.
+
+ Args:
+ example: The example to clean.
+
+ Returns:
+ A cleaned version of the example with empty transcriptions removed.
+ """
+ # Remove all tags, e.g., .
+ text = re.sub(r"<[^>]+>", "", text)
+
+ # Re-contract apostrophes.
+ text = re.sub(r"(\w) '(\w)", r"\1'\2", text)
+
+ # Get normalized tokens.
+ normalized_tokens = [basic_normalizer(token.group()) for token in basic_tokenizer(text)]
+
+ # Convert numbers to words.
+ normalized_tokens = [convert_numbers_to_words(token, lang=lang) for token in normalized_tokens]
+
+ return " ".join(normalized_tokens)
+
+
+def clean_example(example: dict, lang="en") -> dict:
+ """
+ Cleans the example by removing examples with empty transcriptions.
+
+ Args:
+ example: The example to clean.
+ lang: The language to use for cleaning.
+
+ Returns:
+ A cleaned version of the example with empty transcriptions removed.
+ """
+ if "ref" in example:
+ example["ref"] = clean_text(example["ref"], lang=lang)
+ if "hyp" in example:
+ example["hyp"] = clean_text(example["hyp"], lang=lang)
+ return example
diff --git a/src/error_align/edit_distance.py b/src/error_align/edit_distance.py
new file mode 100644
index 0000000..2254737
--- /dev/null
+++ b/src/error_align/edit_distance.py
@@ -0,0 +1,144 @@
+from error_align.utils import DELIMITERS, OP_TYPE_COMBO_MAP_INV, OpType
+
+
+def _get_levenshtein_values(ref_token: str, hyp_token: str):
+ """Compute the Levenshtein values for deletion, insertion, and diagonal (substitution or match).
+
+ Args:
+ ref_token (str): The reference token.
+ hyp_token (str): The hypothesis token.
+
+ Returns:
+ tuple: A tuple containing the deletion cost, insertion cost, and diagonal cost.
+
+ """
+ if hyp_token == ref_token:
+ diag_cost = 0
+ else:
+ diag_cost = 1
+
+ return 1, 1, diag_cost
+
+
+def _get_error_align_values(ref_token: str, hyp_token: str):
+ """Compute the error alignment values for deletion, insertion, and diagonal (substitution or match).
+
+ Args:
+ ref_token (str): The reference token.
+ hyp_token (str): The hypothesis token.
+
+ Returns:
+ tuple: A tuple containing the deletion cost, insertion cost, and diagonal cost.
+
+ """
+ if hyp_token == ref_token:
+ diag_cost = 0
+ elif hyp_token in DELIMITERS or ref_token in DELIMITERS:
+ diag_cost = 3 # NOTE: Will never be chosen as insert + delete (= 2) is equivalent and cheaper.
+ else:
+ diag_cost = 2
+
+ return 1, 1, diag_cost
+
+
+def compute_distance_matrix(
+ ref: str | list[str],
+ hyp: str | list[str],
+ score_func: callable,
+ backtrace: bool = False,
+ dtype: type = int,
+):
+ """Compute the edit distance score matrix between two sequences x (hyp) and y (ref)
+ using only pure Python lists.
+
+ Args:
+ ref (str or list[str]): The reference sequence/transcript.
+ hyp (str or list[str]): The hypothesis sequence/transcript.
+ score_func (callable): A function that takes two tokens (ref_token, hyp_token) and returns
+ a tuple of (deletion_cost, insertion_cost, diagonal_cost).
+ backtrace (bool): Whether to compute the backtrace matrix.
+ dtype (type): The type to store scores (int, float, etc.).
+
+ Returns:
+ list[list]: The score matrix.
+ list[list]: The backtrace matrix, if backtrace=True.
+
+ """
+ hyp_dim, ref_dim = len(hyp) + 1, len(ref) + 1
+
+ # Create empty score matrix of zeros and initialize first row and column.
+ score_matrix = [[dtype(0) for _ in range(ref_dim)] for _ in range(hyp_dim)]
+ for j in range(ref_dim):
+ score_matrix[0][j] = dtype(j)
+ for i in range(hyp_dim):
+ score_matrix[i][0] = dtype(i)
+
+ # Create backtrace matrix and operation combination map and initialize first row and column.
+ # Each operation combination is dynamically assigned a unique integer.
+ if backtrace:
+ backtrace_matrix = [[0 for _ in range(ref_dim)] for _ in range(hyp_dim)]
+ backtrace_matrix[0][0] = OP_TYPE_COMBO_MAP_INV[(OpType.MATCH,)]
+ for j in range(1, ref_dim):
+ backtrace_matrix[0][j] = OP_TYPE_COMBO_MAP_INV[(OpType.DELETE,)]
+ for i in range(1, hyp_dim):
+ backtrace_matrix[i][0] = OP_TYPE_COMBO_MAP_INV[(OpType.INSERT,)]
+
+ # Fill in the score and backtrace matrix.
+ for j in range(1, ref_dim):
+ for i in range(1, hyp_dim):
+ ins_cost, del_cost, diag_cost = score_func(ref[j - 1], hyp[i - 1])
+
+ ins_val = score_matrix[i - 1][j] + ins_cost
+ del_val = score_matrix[i][j - 1] + del_cost
+ diag_val = score_matrix[i - 1][j - 1] + diag_cost
+ new_val = min(ins_val, del_val, diag_val)
+ score_matrix[i][j] = dtype(new_val)
+
+ # Track possible operations (note that the order of operations matters).
+ if backtrace:
+ pos_ops = tuple()
+ if diag_val == new_val and diag_cost == 0:
+ pos_ops += (OpType.MATCH,)
+ if ins_val == new_val:
+ pos_ops += (OpType.INSERT,)
+ if del_val == new_val:
+ pos_ops += (OpType.DELETE,)
+ if diag_val == new_val and diag_cost > 0:
+ pos_ops += (OpType.SUBSTITUTE,)
+ backtrace_matrix[i][j] = OP_TYPE_COMBO_MAP_INV[pos_ops]
+
+ if backtrace:
+ return score_matrix, backtrace_matrix
+ return score_matrix
+
+
+def compute_levenshtein_distance_matrix(ref: str | list[str], hyp: str | list[str], backtrace: bool = False):
+ """Compute the Levenshtein distance matrix between two sequences.
+
+ Args:
+ ref (str): The reference sequence/transcript.
+ hyp (str): The hypothesis sequence/transcript.
+ backtrace (bool): Whether to compute the backtrace matrix.
+
+ Returns:
+ np.ndarray: The score matrix.
+ np.ndarray: The backtrace matrix, if backtrace=True.
+
+ """
+ return compute_distance_matrix(ref, hyp, _get_levenshtein_values, backtrace)
+
+
+def compute_error_align_distance_matrix(ref: str | list[str], hyp: str | list[str], backtrace: bool = False):
+ """Compute the error alignment distance matrix between two sequences.
+
+ Args:
+ ref (str): The reference sequence/transcript.
+ hyp (str): The hypothesis sequence/transcript.
+ backtrace (bool): Whether to compute the backtrace matrix.
+
+ Returns:
+ np.ndarray: The score matrix.
+ np.ndarray: The backtrace matrix, if backtrace=True.
+
+ """
+ return compute_distance_matrix(ref, hyp, _get_error_align_values, backtrace)
diff --git a/src/error_align/error_align.py b/src/error_align/error_align.py
new file mode 100644
index 0000000..56a3bd3
--- /dev/null
+++ b/src/error_align/error_align.py
@@ -0,0 +1,495 @@
+from collections import defaultdict
+from typing import Union
+
+import regex as re
+from tqdm import tqdm
+
+from error_align.backtrace_graph import BacktraceGraph
+from error_align.edit_distance import compute_error_align_distance_matrix
+from error_align.utils import (
+ END_DELIMITER,
+ START_DELIMITER,
+ Alignment,
+ OpType,
+ basic_normalizer,
+ basic_tokenizer,
+ categorize_char,
+ ensure_length_preservation,
+ get_manhattan_distance,
+)
+
+
+class ErrorAlign:
+ """Error alignment class that performs a two-pass alignment process."""
+
+ def __init__(
+ self,
+ ref: str,
+ hyp: str,
+ tokenizer: callable = basic_tokenizer,
+ normalizer: callable = basic_normalizer,
+ ):
+ """Initialize the error alignment with reference and hypothesis texts.
+
+ The first pass (backtrace graph extraction) is performed during initialization.
+
+ The second pass (beam search) is performed in the `align` method.
+
+ Args:
+ ref (str): The reference sequence/transcript.
+ hyp (str): The hypothesis sequence/transcript.
+ tokenizer (callable): A function to tokenize the sequences. Must be regex-based and return Match objects.
+ normalizer (callable): A function to normalize the tokens. Defaults to basic_normalizer.
+
+ """
+ if not isinstance(ref, str):
+ raise TypeError("Reference sequence must be a string.")
+ if not isinstance(hyp, str):
+ raise TypeError("Hypothesis sequence must be a string.")
+
+ self.ref = ref
+ self.hyp = hyp
+
+ # Inclusive tokenization: Track the token position in the original text.
+ self._ref_token_matches = tokenizer(ref)
+ self._hyp_token_matches = tokenizer(hyp)
+
+ # Length-preserving normalization: Ensure that the normalizer preserves token length.
+ normalizer = ensure_length_preservation(normalizer)
+ self._ref = "".join([f"<{normalizer(r.group())}>" for r in self._ref_token_matches])
+ self._hyp = "".join([f"<{normalizer(h.group())}>" for h in self._hyp_token_matches])
+
+ # Categorize characters.
+ self._ref_char_types = list(map(categorize_char, self._ref))
+ self._hyp_char_types = list(map(categorize_char, self._hyp))
+
+ # Initialize graph attributes.
+ self._identical_inputs = self._ref == self._hyp
+ self._ref_max_idx = len(self._ref) - 1
+ self._hyp_max_idx = len(self._hyp) - 1
+ self.end_index = (self._hyp_max_idx, self._ref_max_idx)
+
+ # Create index maps for reference and hypothesis sequences.
+ self._ref_index_map = self._create_index_map(self._ref_token_matches)
+ self._hyp_index_map = self._create_index_map(self._hyp_token_matches)
+
+ # First pass: Extract backtrace graph.
+ if not self._identical_inputs:
+ _, backtrace_matrix = compute_error_align_distance_matrix(self._ref, self._hyp, backtrace=True)
+ self._backtrace_graph = BacktraceGraph(backtrace_matrix)
+ self._backtrace_node_set = self._backtrace_graph.get_node_set()
+ self._unambiguous_matches = self._backtrace_graph.get_unambiguous_matches(self._ref)
+ else:
+ self._backtrace_graph = None
+ self._backtrace_node_set = None
+ self._unambiguous_matches = None
+
+ def __repr__(self):
+ ref_preview = self.ref if len(self.ref) < 20 else self.ref[:17] + "..."
+ hyp_preview = self.hyp if len(self.hyp) < 20 else self.hyp[:17] + "..."
+ return f'ErrorAlign(ref="{ref_preview}", hyp="{hyp_preview}")'
+
+ def align(
+ self,
+ beam_size: int = 100,
+ pbar: bool = False,
+ return_path: bool = False,
+ ) -> Union[list[Alignment], "Path"]:
+ """Perform beam search to align reference and hypothesis texts.
+
+ Args:
+ beam_size (int): The size of the beam for beam search. Defaults to 100.
+ pbar (bool): Whether to display a progress bar. Defaults to False.
+ return_path (bool): Whether to return the path object or just the alignments. Defaults to False.
+
+ Returns:
+ list[Alignment]: A list of Alignment objects.
+
+ """
+ # Skip beam search if inputs are identical.
+ if self._identical_inputs:
+ return self._identical_input_alignments()
+
+ # Initialize the beam with a single path starting at the root node.
+ start_path = Path(self)
+ beam = {start_path.pid: start_path}
+ prune_map = defaultdict(lambda: float("inf"))
+ ended = []
+
+ # Setup progress bar, if enabled.
+ if pbar:
+ total_mdist = self._ref_max_idx + self._hyp_max_idx + 2
+ progress_bar = tqdm(total=total_mdist, desc="Aligning transcripts")
+
+ # Expand candidate paths until all have reached the terminal node.
+ while len(beam) > 0:
+ new_beam = {}
+
+ # Expand each path in the current beam.
+ for path in beam.values():
+ if path.at_end:
+ ended.append(path)
+ continue
+
+ # Transition to all child nodes.
+ for new_path in path.expand():
+ if new_path.pid in prune_map:
+ if new_path.cost > prune_map[new_path.pid]:
+ continue
+ prune_map[new_path.pid] = new_path.cost
+
+ if new_path.pid not in new_beam or new_path.cost < new_beam[new_path.pid].cost:
+ new_beam[new_path.pid] = new_path
+
+ # Update the beam with the newly expanded paths.
+ new_beam = list(new_beam.values())
+ new_beam.sort(key=lambda p: p.norm_cost)
+ beam = new_beam[:beam_size]
+
+ # Keep only the best path if, it matches the segment.
+ if len(beam) > 0 and beam[0]._at_unambiguous_match_node:
+ beam = beam[:1]
+ prune_map = defaultdict(lambda: float("inf"))
+ beam = {p.pid: p for p in beam} # Convert to dict for diversity check.
+
+ # Update progress bar, if enabled.
+ try:
+ worst_path = next(reversed(beam.values()))
+ mdist = get_manhattan_distance(worst_path.index, self.end_index)
+ if pbar:
+ progress_bar.n = total_mdist - mdist
+ progress_bar.refresh()
+ except StopIteration:
+ if pbar:
+ progress_bar.n = total_mdist
+ progress_bar.refresh()
+
+ # Return the best path or its alignments.
+ ended.sort(key=lambda p: p.cost)
+ if return_path:
+ return ended[0] if len(ended) > 0 else None
+ return ended[0].alignments if len(ended) > 0 else []
+
+ def _create_index_map(self, text_tokens: list[re.Match]) -> list[int]:
+ """Create an index map for the given tokens.
+
+ The 'index_map' is used to map each aligned character back to its original position in the input text.
+
+ NOTE: -1 is used for delimiter (<>) and indicates no match in the source sequence.
+ """
+ index_map = []
+ for match in text_tokens:
+ index_map.extend([-1]) # Start delimiter
+ index_map.extend(list(range(*match.span())))
+ index_map.extend([-1]) # End delimiter
+ return index_map
+
+ def _identical_input_alignments(self) -> list[Alignment]:
+ """Return alignments for identical reference and hypothesis pairs."""
+ assert self._identical_inputs, "Inputs are not identical."
+
+ alignments = []
+ for ref_match, hyp_match in zip(self._ref_token_matches, self._hyp_token_matches, strict=False):
+ ref_slice = slice(*ref_match.span())
+ hyp_slice = slice(*hyp_match.span())
+ ref_token = self.ref[ref_slice]
+ hyp_token = self.hyp[hyp_slice]
+ alignment = Alignment(
+ op_type=OpType.MATCH,
+ ref_slice=ref_slice,
+ hyp_slice=hyp_slice,
+ ref=ref_token,
+ hyp=hyp_token,
+ )
+ alignments.append(alignment)
+ return alignments
+
+
+class Path:
+ """Class to represent a graph path."""
+
+ def __init__(self, src: ErrorAlign):
+ """Initialize the Path class with a given path."""
+ self.src = src
+ self.ref_idx = -1
+ self.hyp_idx = -1
+ self._closed_cost = 0
+ self._open_cost = 0
+ self._at_unambiguous_match_node = False
+ self._last_end_index = (-1, -1)
+ self._end_indices = tuple()
+ self._alignments = None
+ self._alignments_index = None
+
+ def __repr__(self):
+ return f"Path(({self.ref_idx}, {self.hyp_idx}), score={self.cost})"
+
+ @property
+ def alignments(self) -> list[Alignment]:
+ """Get the alignments of the path."""
+ # Return cached alignments if available and the path has not changed.
+ if self._alignments is not None and self._alignments_index == self.index:
+ return self._alignments
+
+ self._alignments_index = self.index
+ alignments = []
+ start_hyp, start_ref = (0, 0)
+ for (end_hyp, end_ref), score in self._end_indices:
+ end_hyp, end_ref = end_hyp + 1, end_ref + 1
+
+ # Construct DELETE alignment.
+ if start_hyp == end_hyp:
+ assert start_ref < end_ref
+ ref_slice = slice(start_ref, end_ref)
+ ref_slice = self._translate_slice(ref_slice, self.src._ref_index_map)
+ assert ref_slice is not None
+ alignment = Alignment(
+ op_type=OpType.DELETE,
+ ref_slice=ref_slice,
+ ref=self.src.ref[ref_slice],
+ )
+ alignments.append(alignment)
+
+ # Construct INSERT alignment.
+ elif start_ref == end_ref:
+ assert start_hyp < end_hyp
+ hyp_slice = slice(start_hyp, end_hyp)
+ hyp_slice = self._translate_slice(hyp_slice, self.src._hyp_index_map)
+ assert hyp_slice is not None
+ alignment = Alignment(
+ op_type=OpType.INSERT,
+ hyp_slice=hyp_slice,
+ hyp=self.src.hyp[hyp_slice],
+ left_compound=self.src._hyp_index_map[start_hyp] >= 0,
+ right_compound=self.src._hyp_index_map[end_hyp - 1] >= 0,
+ )
+ alignments.append(alignment)
+
+ # Construct SUBSTITUTE or MATCH alignment.
+ else:
+ assert start_hyp < end_hyp and start_ref < end_ref
+ hyp_slice = slice(start_hyp, end_hyp)
+ ref_slice = slice(start_ref, end_ref)
+ hyp_slice = self._translate_slice(hyp_slice, self.src._hyp_index_map)
+ ref_slice = self._translate_slice(ref_slice, self.src._ref_index_map)
+ assert hyp_slice is not None and ref_slice is not None
+ is_match_segment = score == 0
+ op_type = OpType.MATCH if is_match_segment else OpType.SUBSTITUTE
+ alignment = Alignment(
+ op_type=op_type,
+ ref_slice=ref_slice,
+ hyp_slice=hyp_slice,
+ ref=self.src.ref[ref_slice],
+ hyp=self.src.hyp[hyp_slice],
+ left_compound=self.src._hyp_index_map[start_hyp] >= 0,
+ right_compound=self.src._hyp_index_map[end_hyp - 1] >= 0,
+ )
+ alignments.append(alignment)
+
+ start_hyp, start_ref = end_hyp, end_ref
+
+ # Cache the computed alignments.
+ self._alignments = alignments
+
+ return alignments
+
+ @property
+ def pid(self):
+ """Get the ID of the path used for pruning."""
+ return hash((self.index, self._last_end_index))
+
+ @property
+ def cost(self):
+ """Get the cost of the path."""
+ return self._closed_cost + self._open_cost + self._substitution_penalty()
+
+ @property
+ def norm_cost(self):
+ """Get the normalized cost of the path."""
+ if self.cost == 0:
+ return 0
+ return self.cost / (self.ref_idx + self.hyp_idx + 3) # NOTE: +3 to avoid zero division. Root = (-1,-1).
+
+ @property
+ def index(self):
+ """Get the current node index of the path."""
+ return (self.hyp_idx, self.ref_idx)
+
+ @property
+ def at_end(self):
+ """Check if the path has reached the terminal node."""
+ return self.index == self.src.end_index
+
+ def expand(self):
+ """Expand the path by transitioning to child nodes.
+
+ Yields:
+ Path: The expanded child paths.
+
+ """
+ # Add delete operation.
+ delete_path = self._add_delete()
+ if delete_path is not None:
+ yield delete_path
+
+ # Add insert operation.
+ insert_path = self._add_insert()
+ if insert_path is not None:
+ yield insert_path
+
+ # Add substitution or match operation.
+ sub_or_match_path = self._add_substitution_or_match()
+ if sub_or_match_path is not None:
+ yield sub_or_match_path
+
+ def _transition_and_shallow_copy(self, ref_step: int, hyp_step: int):
+ """Create a shallow copy of the path."""
+ new_path = Path(self.src)
+ new_path.ref_idx = self.ref_idx + ref_step
+ new_path.hyp_idx = self.hyp_idx + hyp_step
+ new_path._closed_cost = self._closed_cost
+ new_path._open_cost = self._open_cost
+ new_path._at_unambiguous_match_node = False
+ new_path._last_end_index = self._last_end_index
+ new_path._end_indices = self._end_indices
+
+ return new_path
+
+ def _reset_segment_variables(self, index: tuple[int, int]) -> None:
+ """Apply updates when segment end is detected."""
+ self._closed_cost += self._open_cost
+ self._closed_cost += self._substitution_penalty(index)
+ self._last_end_index = index
+ self._open_cost = 0
+
+ def _end_insertion_segment(self, index: tuple[int, int]) -> None:
+ """End the current segment, if criteria for an insertion are met."""
+ hyp_slice = slice(self._last_end_index[0] + 1, index[0] + 1)
+ hyp_slice = self._translate_slice(hyp_slice, self.src._hyp_index_map)
+ ref_is_empty = index[1] == self._last_end_index[1]
+ if hyp_slice is not None and ref_is_empty:
+ self._end_indices += ((index, self._open_cost),)
+ self._reset_segment_variables(index)
+
+ def _end_segment(self) -> Union[None, "Path"]:
+ """End the current segment, if criteria for an insertion, a substitution, or a match are met."""
+ hyp_slice = slice(self._last_end_index[0] + 1, self.index[0] + 1)
+ hyp_slice = self._translate_slice(hyp_slice, self.src._hyp_index_map)
+ ref_slice = slice(self._last_end_index[1] + 1, self.index[1] + 1)
+ ref_slice = self._translate_slice(ref_slice, self.src._ref_index_map)
+
+ assert ref_slice is not None
+
+ hyp_is_empty = self.index[0] == self._last_end_index[0]
+ if hyp_is_empty:
+ self._end_indices += ((self.index, self._open_cost),)
+ else:
+ # TODO: Handle edge case where hyp has only covered delimiters.
+ if hyp_slice is None:
+ return None
+
+ is_match_segment = self._open_cost == 0
+ self._at_unambiguous_match_node = is_match_segment and self.index in self.src._unambiguous_matches
+ self._end_indices += ((self.index, self._open_cost),)
+
+ # Update the path score and reset segments attributes.
+ self._reset_segment_variables(self.index)
+ return self
+
+ def _in_backtrace_node_set(self, index) -> bool:
+ """Check if the given operation is an optimal transition at the current index."""
+ return index in self.src._backtrace_node_set
+
+ def _add_delete(self) -> Union[None, "Path"]:
+ """Expand the path by adding a delete operation."""
+ # Ensure we are not at the end of the hypothesis sequence.
+ if self.hyp_idx >= self.src._hyp_max_idx:
+ return None
+
+ # Transition and update costs.
+ new_path = self._transition_and_shallow_copy(ref_step=0, hyp_step=1)
+ is_backtrace = self._in_backtrace_node_set(self.index)
+ is_delimiter = self.src._hyp_char_types[new_path.hyp_idx] == 0 # NOTE: 0 indicates delimiter.
+ new_path._open_cost += 1 if is_delimiter else 2
+ new_path._open_cost += 0 if is_backtrace or is_delimiter else 1
+
+ # Check for end-of-segment criteria.
+ if self.src._hyp[new_path.hyp_idx] == END_DELIMITER:
+ new_path._end_insertion_segment(new_path.index)
+
+ return new_path
+
+ def _add_insert(self) -> Union[None, "Path"]:
+ """Expand the path by adding an insert operation."""
+ # Ensure we are not at the end of the reference sequence.
+ if self.ref_idx >= self.src._ref_max_idx:
+ return None
+
+ # Transition and check for end-of-segment criteria.
+ new_path = self._transition_and_shallow_copy(ref_step=1, hyp_step=0)
+ if self.src._ref[new_path.ref_idx] == START_DELIMITER:
+ new_path._end_insertion_segment(self.index)
+
+ # Update costs.
+ is_backtrace = self._in_backtrace_node_set(self.index)
+ is_delimiter = self.src._ref_char_types[new_path.ref_idx] == 0 # NOTE: 0 indicates delimiter.
+ new_path._open_cost += 1 if is_delimiter else 2
+ new_path._open_cost += 0 if is_backtrace or is_delimiter else 1
+
+ # Check for end-of-segment criteria.
+ if self.src._ref[new_path.ref_idx] == END_DELIMITER:
+ new_path = new_path._end_segment()
+
+ return new_path
+
+ def _add_substitution_or_match(self) -> Union[None, "Path"]:
+ """Expand the given path by adding a substitution or match operation."""
+ # Ensure we are not at the end of either sequence.
+ if self.ref_idx >= self.src._ref_max_idx or self.hyp_idx >= self.src._hyp_max_idx:
+ return None
+
+ # Transition and ensure that the transition is allowed.
+ new_path = self._transition_and_shallow_copy(ref_step=1, hyp_step=1)
+ is_match = self.src._ref[new_path.ref_idx] == self.src._hyp[new_path.hyp_idx]
+ if not is_match:
+ ref_is_delimiter = self.src._ref_char_types[new_path.ref_idx] == 0 # NOTE: 0 indicates delimiter
+ hyp_is_delimiter = self.src._hyp_char_types[new_path.hyp_idx] == 0 # NOTE: 0 indicates delimiter
+ if ref_is_delimiter or hyp_is_delimiter:
+ return None
+
+ # Check for end-of-segment criteria.
+ if self.src._ref[new_path.ref_idx] == START_DELIMITER:
+ new_path._end_insertion_segment(self.index)
+
+ # Update costs, if not a match.
+ if not is_match:
+ is_backtrace = self._in_backtrace_node_set(self.index)
+ is_letter_type_match = (
+ self.src._ref_char_types[new_path.ref_idx] == self.src._hyp_char_types[new_path.hyp_idx]
+ )
+ new_path._open_cost += 2 if is_letter_type_match else 3
+ new_path._open_cost += 0 if is_backtrace else 1
+
+ # Check for end-of-segment criteria.
+ if self.src._ref[new_path.ref_idx] == END_DELIMITER:
+ new_path = new_path._end_segment()
+
+ return new_path
+
+ def _translate_slice(self, segment_slice: slice, index_map: list[int]) -> None | slice:
+ """Translate a slice from the alignment sequence back to the original sequence."""
+ slice_indices = index_map[segment_slice]
+ slice_indices = list(filter(lambda x: x >= 0, slice_indices))
+ if len(slice_indices) == 0:
+ return None
+ start, end = int(slice_indices[0]), int(slice_indices[-1] + 1)
+ return slice(start, end)
+
+ def _substitution_penalty(self, index: tuple[int, int] | None = None) -> int:
+ """Get the substitution penalty given an index."""
+ index = index or self.index
+ ref_is_not_empty = index[1] > self._last_end_index[1]
+ hyp_is_not_empty = index[0] > self._last_end_index[0]
+ if ref_is_not_empty and hyp_is_not_empty:
+ return self._open_cost
+ return 0
diff --git a/src/error_align/func.py b/src/error_align/func.py
new file mode 100644
index 0000000..6ac8dfc
--- /dev/null
+++ b/src/error_align/func.py
@@ -0,0 +1,38 @@
+from error_align.error_align import ErrorAlign, Path
+from error_align.utils import Alignment, basic_normalizer, basic_tokenizer
+
+
+def error_align(
+ ref: str,
+ hyp: str,
+ tokenizer: callable = basic_tokenizer,
+ normalizer: callable = basic_normalizer,
+ beam_size: int = 100,
+ pbar: bool = False,
+ return_path: bool = False,
+) -> list[Alignment] | Path:
+ """Perform error alignment between two sequences.
+
+ Args:
+ ref (str): The reference sequence/transcript.
+ hyp (str): The hypothesis sequence/transcript.
+ tokenizer (callable): A function to tokenize the sequences. Must be regex-based and return Match objects.
+ normalizer (callable): A function to normalize the tokens. Defaults to basic_normalizer.
+ pbar (bool): Whether to display a progress bar. Defaults to False.
+ return_path (bool): Whether to return the path object or just the alignments. Defaults to False.
+
+ Returns:
+ list[tuple[str, str, OpType]]: A list of tuples containing aligned reference token,
+ hypothesis token, and the operation type.
+
+ """
+ return ErrorAlign(
+ ref,
+ hyp,
+ tokenizer=tokenizer,
+ normalizer=normalizer,
+ ).align(
+ beam_size=beam_size,
+ pbar=pbar,
+ return_path=return_path,
+ )
diff --git a/src/error_align/utils.py b/src/error_align/utils.py
new file mode 100644
index 0000000..45830b1
--- /dev/null
+++ b/src/error_align/utils.py
@@ -0,0 +1,166 @@
+from dataclasses import dataclass
+from enum import IntEnum
+from itertools import chain, combinations
+
+import regex as re
+from unidecode import unidecode
+
+
+class OpType(IntEnum):
+ MATCH = 0
+ INSERT = 1
+ DELETE = 2
+ SUBSTITUTE = 3
+
+
+@dataclass
+class Alignment:
+ """Class representing an operation with its type and cost."""
+
+ op_type: OpType
+ ref_slice: slice | None = None
+ hyp_slice: slice | None = None
+ ref: str | None = None
+ hyp: str | None = None
+ left_compound: bool = False
+ right_compound: bool = False
+
+ @property
+ def hyp_with_compound_markers(self) -> str:
+ """Return the hypothesis with compound markers if applicable."""
+ if self.hyp is None:
+ return None
+ return f"{'-' if self.left_compound else ''}{self.hyp}{'-' if self.right_compound else ''}"
+
+ def __repr__(self) -> str:
+ lc = "-" if self.left_compound else ""
+ rc = "-" if self.right_compound else ""
+ if self.op_type == OpType.DELETE:
+ return f'Alignment({self.op_type.name}: "{self.ref}")'
+ if self.op_type == OpType.INSERT:
+ return f'Alignment({self.op_type.name}: "{self.hyp_with_compound_markers}")'
+ if self.op_type == OpType.SUBSTITUTE:
+ return f'Alignment({self.op_type.name}: "{self.ref}" -> {lc}"{self.hyp}"{rc})'
+ return f'Alignment({self.op_type.name}: "{self.ref}" == {lc}"{self.hyp}"{rc})'
+
+
+def op_type_powerset() -> chain:
+ """Generate all possible combinations of operation types, except the empty set.
+
+ Returns:
+ Generator: All possible combinations of operation types.
+
+ """
+ op_types = list(OpType)
+ op_combinations = [combinations(op_types, r) for r in range(1, len(op_types) + 1)]
+ return chain.from_iterable(op_combinations)
+
+
+START_DELIMITER = "<"
+END_DELIMITER = ">"
+DELIMITERS = {START_DELIMITER, END_DELIMITER}
+
+OP_TYPE_MAP = {op_type.value: op_type for op_type in OpType}
+OP_TYPE_COMBO_MAP = {i: op_types for i, op_types in enumerate(op_type_powerset())}
+OP_TYPE_COMBO_MAP_INV = {v: k for k, v in OP_TYPE_COMBO_MAP.items()}
+
+NUMERIC_TOKEN = r"\p{N}+([,.]\p{N}+)*(?=\s|$)"
+STANDARD_TOKEN = r"[\p{L}\p{N}]+(['][\p{L}\p{N}]+)*'?"
+
+
+def is_vowel(c: str) -> bool:
+ """Check if the normalized character is a vowel.
+
+ Args:
+ c (str): The character to check.
+
+ Returns:
+ bool: True if the character is a vowel, False otherwise.
+
+ """
+ assert len(c) == 1, "Input must be a single character."
+ return unidecode(c)[0] in "aeiouy"
+
+
+def is_consonant(c: str) -> bool:
+ """Check if the normalized character is a consonant.
+
+ Args:
+ c (str): The character to check.
+
+ Returns:
+ bool: True if the character is a consonant, False otherwise.
+
+ """
+ assert len(c) == 1, "Input must be a single character."
+ return unidecode(c)[0] in "bcdfghjklmnpqrstvwxyz"
+
+
+def categorize_char(c: str) -> int:
+ """Categorize a character as 'vowel', 'consonant', or 'unvoiced'.
+
+ Args:
+ c (str): The character to categorize.
+
+ Returns:
+ str: The category of the character.
+
+ """
+ if c in DELIMITERS:
+ return 0
+ if is_consonant(c):
+ return 1
+ if is_vowel(c):
+ return 2
+ return 3 # NOTE: Unvoiced characters (only apostrophes are expected by default).
+
+
+def get_manhattan_distance(a: tuple[int, int], b: tuple[int, int]) -> int:
+ """Calculate the Manhattan distance between two points a and b."""
+ return abs(a[0] - b[0]) + abs(a[1] - b[1])
+
+
+def basic_tokenizer(text: str) -> list:
+ """Default tokenizer that splits text into words based on whitespace.
+
+ Args:
+ text (str): The input text to tokenize.
+
+ Returns:
+ list: A list of tokens (words).
+
+ """
+ return list(re.finditer(rf"({NUMERIC_TOKEN})|({STANDARD_TOKEN})", text, re.UNICODE | re.VERBOSE))
+
+
+def basic_normalizer(text: str) -> str:
+ """Default normalizer that only converts text to lowercase.
+
+ Args:
+ text (str): The input text to normalize.
+
+ Returns:
+ str: The normalized text.
+
+ """
+ return text.lower()
+
+
+def ensure_length_preservation(normalizer: callable) -> callable:
+ """Decorator to ensure that the normalizer preserves the length of the input text.
+
+ Args:
+ normalizer (callable): The normalizer function to wrap.
+
+ Returns:
+ callable: The wrapped normalizer function that preserves length.
+
+ """
+
+ def wrapper(text: str, *args: list, **kwargs: dict) -> str:
+ normalized = normalizer(text, *args, **kwargs)
+ if len(normalized) != len(text):
+ raise ValueError("Normalizer must preserve length.")
+ return normalized
+
+ return wrapper
diff --git a/src/python_package_template/__init__.py b/src/python_package_template/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/tests/test_default.py b/tests/test_default.py
index 2b57b30..b0728d2 100644
--- a/tests/test_default.py
+++ b/tests/test_default.py
@@ -1,4 +1,129 @@
-def test_stub() -> None:
- """Lorem Ipsum."""
+from typeguard import suppress_type_checks
- assert True
+from error_align import ErrorAlign, error_align
+from error_align.edit_distance import compute_levenshtein_distance_matrix
+from error_align.utils import OpType, categorize_char
+
+
+def test_error_align() -> None:
+ """Test error alignment for an example including all substitution types."""
+
+ ref = "This is a substitution test deleted."
+ hyp = "Inserted this is a contribution test."
+
+ alignments = error_align(ref, hyp, pbar=True)
+ expected_ops = [
+ OpType.INSERT, # Inserted
+ OpType.MATCH, # This
+ OpType.MATCH, # is
+ OpType.MATCH, # a
+ OpType.SUBSTITUTE, # contribution -> substitution
+ OpType.MATCH, # test
+ OpType.DELETE, # deleted
+ ]
+
+ for op, alignment in zip(expected_ops, alignments, strict=True):
+ assert alignment.op_type == op
+
+
+def test_error_align_full_match() -> None:
+ """Test error alignment for full match."""
+
+ ref = "This is a test."
+ hyp = "This is a test."
+
+ alignments = error_align(ref, hyp)
+
+ for alignment in alignments:
+ assert alignment.op_type == OpType.MATCH
+
+
+def test_categorize_char() -> None:
+ """Test character categorization."""
+
+ assert categorize_char("<") == 0 # Delimiters
+ assert categorize_char("b") == 1 # Consonants
+ assert categorize_char("a") == 2 # Vowels
+ assert categorize_char("'") == 3 # Unvoiced characters
+
+
+def test_representations() -> None:
+ """Test the string representation of Alignment objects."""
+
+ # Test DELETE operation
+ delete_alignment = error_align("deleted", "")[0]
+ assert repr(delete_alignment) == 'Alignment(DELETE: "deleted")'
+
+ # Test INSERT operation with compound markers
+ insert_alignment = error_align("", "inserted")[0]
+ assert repr(insert_alignment) == 'Alignment(INSERT: "inserted")'
+
+ # Test SUBSTITUTE operation with compound markers
+ substitute_alignment = error_align("substitution", "substitutiontesting")[0]
+ assert substitute_alignment.left_compound is False
+ assert substitute_alignment.right_compound is True
+ assert repr(substitute_alignment) == 'Alignment(SUBSTITUTE: "substitution" -> "substitution"-)'
+
+ # Test MATCH operation without compound markers
+ match_alignment = error_align("test", "test")[0]
+ assert repr(match_alignment) == 'Alignment(MATCH: "test" == "test")'
+
+ # Test ErrorAlign class representation
+ ea = ErrorAlign(ref="test", hyp="pest")
+ assert repr(ea) == 'ErrorAlign(ref="test", hyp="pest")'
+
+ # Test Path class representation
+ path = ea.align(beam_size=10, return_path=True)
+ assert repr(path) == f"Path(({path.ref_idx}, {path.hyp_idx}), score={path.cost})"
+
+
+@suppress_type_checks
+def test_input_type_checks() -> None:
+ """Test input type checks for ErrorAlign class."""
+
+ try:
+ _ = ErrorAlign(ref=123, hyp="valid") # type: ignore
+ except TypeError as e:
+ assert str(e) == "Reference sequence must be a string."
+
+ try:
+ _ = ErrorAlign(ref="valid", hyp=456) # type: ignore
+ except TypeError as e:
+ assert str(e) == "Hypothesis sequence must be a string."
+
+
+def test_backtrace_graph() -> None:
+ """Test backtrace graph generation."""
+
+ ref = "This is a test."
+ hyp = "This is a pest."
+
+ # Create ErrorAlign instance and generate backtrace graph.
+ ea = ErrorAlign(ref, hyp)
+ ea.align(beam_size=10)
+ graph = ea._backtrace_graph
+
+ # Check basic properties of the graph.
+ assert isinstance(graph.get_path(), list)
+ assert isinstance(graph.get_path(sample=True), list)
+ assert graph.number_of_paths == 3
+ for index in graph._iter_topological_order():
+ assert isinstance(index, tuple)
+
+ # Check specific node properties.
+ node = graph.get_node(2, 2)
+ assert node.number_of_ingoing_paths_via(OpType.MATCH) == 3
+ assert node.number_of_outgoing_paths_via(OpType.MATCH) == 3
+ assert node.number_of_ingoing_paths_via(OpType.INSERT) == 0
+ assert node.number_of_outgoing_paths_via(OpType.INSERT) == 0
+
+
+def test_levenshtein_distance_matrix() -> None:
+ """Test Levenshtein distance matrix computation."""
+
+ ref = "kitten"
+ hyp = "sitting"
+
+ distance_matrix = compute_levenshtein_distance_matrix(ref, hyp)
+
+ assert distance_matrix[-1][-1] == 3 # The Levenshtein distance between "kitten" and "sitting" is 3