@@ -1093,29 +1093,25 @@ def translate_test_files(
1093
1093
self ,
1094
1094
input_paths : List [Path ],
1095
1095
translation_paths : List [Path ],
1096
- src_trg_isos : List [Tuple [str , str ]],
1097
1096
produce_multiple_translations : bool = False ,
1098
1097
vref_paths : Optional [List [Path ]] = None ,
1099
1098
ckpt : Union [CheckpointType , str , int ] = CheckpointType .LAST ,
1100
1099
) -> None :
1101
- lang_codes : Dict [str , str ] = self ._config .data ["lang_codes" ]
1102
1100
tokenizer = self ._config .get_tokenizer ()
1103
1101
model = self ._create_inference_model (ckpt , tokenizer )
1104
- model .to (0 )
1105
- model = torch .compile (model )
1106
- for input_path , translation_path , src_trg_iso , vref_path in zip (
1102
+ pipeline = PretokenizedTranslationPipeline (
1103
+ model = model ,
1104
+ tokenizer = tokenizer ,
1105
+ src_lang = self ._config .test_src_lang ,
1106
+ tgt_lang = self ._config .test_trg_lang ,
1107
+ device = 0 ,
1108
+ )
1109
+ pipeline .model = torch .compile (pipeline .model )
1110
+ for input_path , translation_path , vref_path in zip (
1107
1111
input_paths ,
1108
1112
translation_paths ,
1109
- src_trg_isos ,
1110
1113
cast (Iterable [Optional [Path ]], repeat (None ) if vref_paths is None else vref_paths ),
1111
1114
):
1112
- pipeline = PretokenizedTranslationPipeline (
1113
- model = model ,
1114
- tokenizer = tokenizer ,
1115
- src_lang = lang_codes .get (src_trg_iso [0 ]),
1116
- tgt_lang = lang_codes .get (src_trg_iso [1 ]),
1117
- device = 0 ,
1118
- )
1119
1115
length = count_lines (input_path )
1120
1116
with ExitStack () as stack :
1121
1117
src_file = stack .enter_context (input_path .open ("r" , encoding = "utf-8-sig" ))
0 commit comments