Skip to content

Commit 2669e2f

Browse files
Revert "Multi Target Support (#651)" (#673)
This reverts commit a063c55.
1 parent ee51962 commit 2669e2f

File tree

3 files changed

+9
-18
lines changed

3 files changed

+9
-18
lines changed

silnlp/nmt/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,6 @@ def translate_test_files(
320320
self,
321321
input_paths: List[Path],
322322
translation_paths: List[Path],
323-
src_trg_isos: List[Tuple[str, str]],
324323
produce_multiple_translations: bool = False,
325324
vref_paths: Optional[List[Path]] = None,
326325
ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST,

silnlp/nmt/hugging_face_config.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,29 +1093,25 @@ def translate_test_files(
10931093
self,
10941094
input_paths: List[Path],
10951095
translation_paths: List[Path],
1096-
src_trg_isos: List[Tuple[str, str]],
10971096
produce_multiple_translations: bool = False,
10981097
vref_paths: Optional[List[Path]] = None,
10991098
ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST,
11001099
) -> None:
1101-
lang_codes: Dict[str, str] = self._config.data["lang_codes"]
11021100
tokenizer = self._config.get_tokenizer()
11031101
model = self._create_inference_model(ckpt, tokenizer)
1104-
model.to(0)
1105-
model = torch.compile(model)
1106-
for input_path, translation_path, src_trg_iso, vref_path in zip(
1102+
pipeline = PretokenizedTranslationPipeline(
1103+
model=model,
1104+
tokenizer=tokenizer,
1105+
src_lang=self._config.test_src_lang,
1106+
tgt_lang=self._config.test_trg_lang,
1107+
device=0,
1108+
)
1109+
pipeline.model = torch.compile(pipeline.model)
1110+
for input_path, translation_path, vref_path in zip(
11071111
input_paths,
11081112
translation_paths,
1109-
src_trg_isos,
11101113
cast(Iterable[Optional[Path]], repeat(None) if vref_paths is None else vref_paths),
11111114
):
1112-
pipeline = PretokenizedTranslationPipeline(
1113-
model=model,
1114-
tokenizer=tokenizer,
1115-
src_lang=lang_codes.get(src_trg_iso[0]),
1116-
tgt_lang=lang_codes.get(src_trg_iso[1]),
1117-
device=0,
1118-
)
11191115
length = count_lines(input_path)
11201116
with ExitStack() as stack:
11211117
src_file = stack.enter_context(input_path.open("r", encoding="utf-8-sig"))

silnlp/nmt/test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,6 @@ def test_checkpoint(
369369
translation_file_names: List[str] = []
370370
refs_patterns: List[str] = []
371371
translation_detok_file_names: List[str] = []
372-
src_trg_isos: List[Tuple[str, str]] = []
373372
suffix_str = "_".join(map(lambda n: book_number_to_id(n), sorted(books.keys())))
374373
if len(suffix_str) > 0:
375374
suffix_str += "-"
@@ -383,7 +382,6 @@ def test_checkpoint(
383382
translation_file_names.append(f"test.trg-predictions.txt.{suffix_str}")
384383
refs_patterns.append("test.trg.detok*.txt")
385384
translation_detok_file_names.append(f"test.trg-predictions.detok.txt.{suffix_str}")
386-
src_trg_isos.append((config.default_test_src_iso, config.default_test_trg_iso))
387385
else:
388386
# test data is split into separate files
389387
for src_iso in sorted(config.test_src_isos):
@@ -398,7 +396,6 @@ def test_checkpoint(
398396
translation_file_names.append(f"{prefix}.trg-predictions.txt.{suffix_str}")
399397
refs_patterns.append(f"{prefix}.trg.detok*.txt")
400398
translation_detok_file_names.append(f"{prefix}.trg-predictions.detok.txt.{suffix_str}")
401-
src_trg_isos.append((src_iso, trg_iso))
402399

403400
checkpoint_name = "averaged checkpoint" if step == -1 else f"checkpoint {step}"
404401

@@ -417,7 +414,6 @@ def test_checkpoint(
417414
model.translate_test_files(
418415
source_paths,
419416
translation_paths,
420-
src_trg_isos,
421417
produce_multiple_translations,
422418
vref_paths,
423419
step if checkpoint_type is CheckpointType.OTHER else checkpoint_type,

0 commit comments

Comments
 (0)