Skip to content
2 changes: 1 addition & 1 deletion silnlp/common/compare_usfm_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from machine.tokenization import WhitespaceTokenizer

from .usfm_preservation import CHARACTER_TYPE_EMBEDS, PARAGRAPH_TYPE_EMBEDS
from .usfm_utils import CHARACTER_TYPE_EMBEDS, PARAGRAPH_TYPE_EMBEDS

LOGGER = logging.getLogger(__package__ + ".compare_usfm_structure")

Expand Down
180 changes: 40 additions & 140 deletions silnlp/common/postprocess_draft.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,21 @@
import argparse
import logging
import re
from pathlib import Path
from typing import List, Tuple

import yaml
from machine.corpora import (
FileParatextProjectSettingsParser,
ScriptureRef,
UpdateUsfmMarkerBehavior,
UpdateUsfmParserHandler,
UpdateUsfmTextBehavior,
UsfmFileText,
UsfmStylesheet,
UsfmTextType,
parse_usfm,
)
from machine.scripture import book_id_to_number
from transformers.trainer_utils import get_last_checkpoint

from ..nmt.clearml_connection import SILClearML
from ..nmt.config import Config
from ..nmt.config_utils import create_config
from ..nmt.hugging_face_config import get_best_checkpoint
from .paratext import book_file_name_digits, get_book_path, get_project_dir
from .usfm_preservation import PARAGRAPH_TYPE_EMBEDS, construct_place_markers_handler
from ..nmt.config_utils import load_config
from ..nmt.postprocess import get_draft_paths_from_exp, postprocess_draft
from .paratext import get_project_dir
from .postprocesser import PostprocessConfig, PostprocessHandler
from .utils import get_mt_exp_dir

LOGGER = logging.getLogger(__package__ + ".postprocess_draft")


# NOTE: only using first book of first translate request for now
def get_paths_from_exp(config: Config) -> Tuple[Path, Path]:
# TODO: default to first draft in the infer folder
if not (config.exp_dir / "translate_config.yml").exists():
raise ValueError("Experiment translate_config.yml not found. Please use --source and --draft options instead.")

with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file:
translate_config = yaml.safe_load(file)["translate"][0]
src_project = translate_config.get("src_project", next(iter(config.src_projects)))
books = translate_config["books"]
book = books[0] if isinstance(books, list) else books.split(";")[0] # TODO: handle partial book translation
book_num = book_id_to_number(book)

ckpt = translate_config.get("checkpoint", "last")
if ckpt == "best":
step_str = get_best_checkpoint(config.model_dir).name[11:]
elif ckpt == "last":
step_str = Path(get_last_checkpoint(config.model_dir)).name[11:]
else:
step_str = str(ckpt)

return (
get_book_path(src_project, book),
config.exp_dir / "infer" / step_str / src_project / f"{book_file_name_digits(book_num)}{book}.SFM",
)


def insert_draft_remarks(usfm: str, remarks: List[str]) -> str:
lines = usfm.split("\n")
remark_lines = [f"\\rem {r}" for r in remarks]
return "\n".join(lines[:1] + remark_lines + lines[1:])


def get_sentences(
book_path: Path, stylesheet: UsfmStylesheet, encoding: str, book: str, chapters: List[int] = []
) -> Tuple[List[str], List[ScriptureRef], List[str]]:
sents = []
refs = []
draft_remarks = []
for sent in UsfmFileText(stylesheet, encoding, book, book_path, include_all_text=True):
marker = sent.ref.path[-1].name if len(sent.ref.path) > 0 else ""
if marker == "rem" and len(refs) == 0: # TODO: \ide and \usfm lines could potentially come before the remark(s)
draft_remarks.append(sent.text)
continue
if (
marker in PARAGRAPH_TYPE_EMBEDS
or stylesheet.get_tag(marker).text_type == UsfmTextType.NOTE_TEXT
# or len(sent.text.strip()) == 0
or (len(chapters) > 0 and sent.ref.chapter_num not in chapters)
):
continue

sents.append(re.sub(" +", " ", sent.text.strip()))
refs.append(sent.ref)

return sents, refs, draft_remarks


def main() -> None:
parser = argparse.ArgumentParser(
description="Applies draft postprocessing steps to a draft. Can be used with no postprocessing options to create a base draft."
)
parser = argparse.ArgumentParser(description="Applies draft postprocessing steps to a draft.")
parser.add_argument(
"--experiment",
default=None,
Expand Down Expand Up @@ -155,79 +79,55 @@ def main() -> None:

experiment = args.experiment.replace("\\", "/") if args.experiment else None
if experiment and get_mt_exp_dir(experiment).exists():
exp_dir = get_mt_exp_dir(experiment)
if args.clearml_queue is not None:
if "cpu" not in args.clearml_queue:
raise ValueError("Running this script on a GPU queue will not speed it up. Please only use CPU queues.")
clearml = SILClearML(experiment, args.clearml_queue)
config = clearml.config
else:
with (exp_dir / "config.yml").open("r", encoding="utf-8") as file:
config = yaml.safe_load(file)
config = create_config(exp_dir, config)
config = load_config(experiment)

src_path, draft_path = get_paths_from_exp(config)
if not (config.exp_dir / "translate_config.yml").exists():
raise ValueError(
"Experiment translate_config.yml not found. Please use --source and --draft options instead."
)
src_paths, draft_paths = get_draft_paths_from_exp(config)
elif args.clearml_queue is not None:
raise ValueError("Must use --experiment option to use ClearML.")
else:
src_path = Path(args.source.replace("\\", "/"))
draft_path = Path(args.draft.replace("\\", "/"))

if str(src_path).startswith(str(get_project_dir(""))):
settings = FileParatextProjectSettingsParser(src_path.parent).parse()
stylesheet = settings.stylesheet
encoding = settings.encoding
book = settings.get_book_id(src_path.name)
else:
stylesheet = UsfmStylesheet("usfm.sty")
encoding = "utf-8-sig"
book = args.book
if book is None:
src_paths = [Path(args.source.replace("\\", "/"))]
draft_paths = [Path(args.draft.replace("\\", "/"))]
if not str(src_paths[0]).startswith(str(get_project_dir(""))) and args.book is None:
raise ValueError(
"--book argument must be passed if the source file is not in a Paratext project directory."
)

src_sents, src_refs, _ = get_sentences(src_path, stylesheet, encoding, book)
draft_sents, draft_refs, draft_remarks = get_sentences(draft_path, stylesheet, encoding, book)

if len(src_refs) != len(draft_refs):
raise ValueError("Different number of verses/references between source and draft.")
for src_ref, draft_ref in zip(src_refs, draft_refs):
if src_ref.to_relaxed() != draft_ref.to_relaxed():
raise ValueError(
f"'source' and 'draft' must have the exact same USFM structure. Mismatched ref: {src_ref} {draft_ref}"
)

paragraph_behavior = (
UpdateUsfmMarkerBehavior.PRESERVE if args.include_paragraph_markers else UpdateUsfmMarkerBehavior.STRIP
)
style_behavior = UpdateUsfmMarkerBehavior.PRESERVE if args.include_style_markers else UpdateUsfmMarkerBehavior.STRIP
embed_behavior = UpdateUsfmMarkerBehavior.PRESERVE if args.include_embeds else UpdateUsfmMarkerBehavior.STRIP

update_block_handlers = []
if args.include_paragraph_markers or args.include_style_markers:
update_block_handlers.append(construct_place_markers_handler(src_refs, src_sents, draft_sents))

with src_path.open(encoding=encoding) as f:
usfm = f.read()
handler = UpdateUsfmParserHandler(
rows=[([ref], sent) for ref, sent in zip(src_refs, draft_sents)],
id_text=book,
text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING,
paragraph_behavior=paragraph_behavior,
embed_behavior=embed_behavior,
style_behavior=style_behavior,
update_block_handlers=update_block_handlers,
)
parse_usfm(usfm, handler)
usfm_out = handler.get_usfm()

usfm_out = insert_draft_remarks(usfm_out, draft_remarks)

out_dir = Path(args.output_folder.replace("\\", "/")) if args.output_folder else draft_path.parent
out_path = out_dir / f"{draft_path.stem}_postprocessed{draft_path.suffix}"
with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f:
f.write(usfm_out)
# If no postprocessing options are used, use any postprocessing requests in the experiment's translate config
if args.include_paragraph_markers or args.include_style_markers or args.include_embeds:
postprocess_configs = [
{
"include_paragraph_markers": args.include_paragraph_markers,
"include_style_markers": args.include_style_markers,
"include_embeds": args.include_embeds,
}
]
else:
if args.experiment:
LOGGER.info("No postprocessing options used. Applying postprocessing requests from translate config.")
with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file:
postprocess_configs = yaml.safe_load(file).get("postprocess", [])
if len(postprocess_configs) == 0:
LOGGER.info("No postprocessing requests found in translate config.")
exit()
else:
LOGGER.info("Please use at least one postprocessing option.")
exit()
postprocess_handler = PostprocessHandler([PostprocessConfig(pc) for pc in postprocess_configs], include_base=False)

if args.output_folder:
args.output_folder = Path(args.output_folder.replace("\\", "/"))
for src_path, draft_path in zip(src_paths, draft_paths):
postprocess_draft(src_path, draft_path, postprocess_handler, args.book, args.output_folder)


if __name__ == "__main__":
Expand Down
102 changes: 102 additions & 0 deletions silnlp/common/postprocesser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, List, Union

from machine.corpora import (
PlaceMarkersAlignmentInfo,
PlaceMarkersUsfmUpdateBlockHandler,
ScriptureRef,
UpdateUsfmMarkerBehavior,
UsfmUpdateBlockHandler,
)
from machine.tokenization import LatinWordTokenizer
from machine.translation import WordAlignmentMatrix

from ..alignment.eflomal import to_word_alignment_matrix
from ..alignment.utils import compute_alignment_scores
from .corpus import load_corpus, write_corpus
from .utils import merge_dict

POSTPROCESS_OPTIONS = {"include_paragraph_markers": False, "include_style_markers": False, "include_embeds": False}
POSTPROCESS_SUFFIX_CHARS = ["p", "s", "e"]


class PostprocessConfig:
def __init__(self, config: Dict[str, Union[bool, str]] = {}) -> None:
self._config = merge_dict(dict(POSTPROCESS_OPTIONS), config)
self.update_block_handlers: List[UsfmUpdateBlockHandler] = []

def _get_usfm_marker_behavior(self, preserve: bool) -> UpdateUsfmMarkerBehavior:
return UpdateUsfmMarkerBehavior.PRESERVE if preserve else UpdateUsfmMarkerBehavior.STRIP

def get_paragraph_behavior(self) -> UpdateUsfmMarkerBehavior:
return self._get_usfm_marker_behavior(self._config["include_paragraph_markers"])

def get_style_behavior(self) -> UpdateUsfmMarkerBehavior:
return self._get_usfm_marker_behavior(self._config["include_style_markers"])

def get_embed_behavior(self) -> UpdateUsfmMarkerBehavior:
return self._get_usfm_marker_behavior(self._config["include_embeds"])

def get_postprocess_suffix(self) -> str:
suffix = "_"
for (option, default), char in zip(POSTPROCESS_OPTIONS.items(), POSTPROCESS_SUFFIX_CHARS):
if self._config[option] != default:
suffix += char

return suffix if len(suffix) > 1 else ""

def get_postprocess_remark(self) -> str:
used = [option for (option, default) in POSTPROCESS_OPTIONS.items() if self._config[option] != default]
return f"Post-processing options used: {' '.join(used)}" if len(used) > 0 else ""

def __getitem__(self, key):
return self._config[key]


class PostprocessHandler:
def __init__(self, configs: List[PostprocessConfig] = [], include_base: bool = True) -> None:
self.configs = ([PostprocessConfig()] if include_base else []) + configs

# NOTE: Update block handlers may need to be created/recreated at different times
# For example, the marker placement handler needs to be recreated for each new draft because it uses text alignment,
# but other handlers may only need to be created once overall, or once per source project.
# This may change what part of the process we want this function to be called at
def create_update_block_handlers(self, refs: List[ScriptureRef], source: List[str], translation: List[str]) -> None:
if any(config["include_paragraph_markers"] or config["include_style_markers"] for config in self.configs):
place_markers_handler = self._construct_place_markers_handler(refs, source, translation)

for config in self.configs:
if config["include_paragraph_markers"] or config["include_style_markers"]:
if len(config.update_block_handlers) == 0:
config.update_block_handlers.append(place_markers_handler)
else: # NOTE: this assumes a set order of update block handlers
config.update_block_handlers[0] = place_markers_handler

def _construct_place_markers_handler(
self, refs: List[ScriptureRef], source: List[str], translation: List[str], aligner: str = "eflomal"
) -> PlaceMarkersUsfmUpdateBlockHandler:
align_info = []
tokenizer = LatinWordTokenizer()
alignments = self._get_alignment_matrices(source, translation, aligner)
for ref, s, t, alignment in zip(refs, source, translation, alignments):
align_info.append(
PlaceMarkersAlignmentInfo(
refs=[str(ref)],
source_tokens=list(tokenizer.tokenize(s)),
translation_tokens=list(tokenizer.tokenize(t)),
alignment=alignment,
)
)
return PlaceMarkersUsfmUpdateBlockHandler(align_info)

def _get_alignment_matrices(
self, src_sents: List[str], trg_sents: List[str], aligner: str = "eflomal"
) -> List[WordAlignmentMatrix]:
with TemporaryDirectory() as td:
align_path = Path(td, "sym-align.txt")
write_corpus(Path(td, "src_align.txt"), src_sents)
write_corpus(Path(td, "trg_align.txt"), trg_sents)
compute_alignment_scores(Path(td, "src_align.txt"), Path(td, "trg_align.txt"), aligner, align_path)

return [to_word_alignment_matrix(line) for line in load_corpus(align_path)]
Loading