-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbleu_transformer.py
67 lines (57 loc) · 2.71 KB
/
bleu_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from custom import *
import tensorflow as tf
import numpy as np
from sacrebleu import BLEU
from utils import load_text_vectorization
bleu = BLEU()
import os
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'
model_name = 'Transformer_CWIKI_2023_09_27_VS20000_SL20_H4_Wed_Oct_18_031904_2023.keras'
split_char = '⫯'
max_decoded_sentence_length = 20
tv_name = 'CWIKI_2023_09_27'
source_vectorization = load_text_vectorization(f"models/{tv_name}_source_vectorization.pkl")
target_vectorization = load_text_vectorization(f"models/{tv_name}_target_vectorization.pkl")
zh_vocab = target_vectorization.get_vocabulary()
zh_index_lookup = dict(zip(range(len(zh_vocab)), zh_vocab))
model = tf.keras.models.load_model(f'models/{model_name}', custom_objects={"TransformerEncoder": TransformerEncoder, 'TransformerDecoder': TransformerDecoder, 'PositionalEmbedding': PositionalEmbedding})
datasets_name = 'PTT_2023_08_06'
with open(f'datasets/{datasets_name}_engTyping_inserted_lines.txt', encoding='utf8') as file:
engTyping_inserted_lines = file.read().split('\n')
with open(f'datasets/{datasets_name}_zh_lines.txt', encoding='utf8') as file:
zh_lines = file.read().split('\n')
lines_len = len(engTyping_inserted_lines)
assert lines_len == len(zh_lines)
engTyping_inserted_lines = engTyping_inserted_lines[int(lines_len * 0.85):]
zh_lines = zh_lines[int(lines_len * 0.85):]
for i in range(len(zh_lines)):
zh_lines[i] = ' '.join(zh_lines[i].split(split_char)[1:-1])
pred_sentences = []
loop_times = 300
eng_len = len(engTyping_inserted_lines)
split_point = eng_len // loop_times
for k in range(loop_times + 1):
lines = engTyping_inserted_lines[k * split_point:(k + 1) * split_point]
len_lines = len(lines)
tokenized_input_sentence = source_vectorization(lines)
decoded_sentences = ["[start]"] * len_lines
for i in range(max_decoded_sentence_length):
tokenized_target_sentence = target_vectorization(
decoded_sentences)[:, :-1]
predictions = model(
[tokenized_input_sentence, tokenized_target_sentence])
for j in range(len_lines):
if decoded_sentences[j].endswith('[end]'): continue
decoded_sentences[j] += split_char + zh_index_lookup[np.argmax(predictions[j, i, :])]
pred_sentences.extend(decoded_sentences)
print(f'\r{k}/{loop_times}', end='')
count = 0
for i in range(len(pred_sentences)):
pred_sentences[i] = pred_sentences[i][7:]
if pred_sentences[i].endswith('[end]'): pred_sentences[i] = pred_sentences[i][:-5]
pred_sentences[i] = pred_sentences[i].replace(split_char, ' ')
count += 1
print(f'{count}/{eng_len}', end='\r')
result = bleu.corpus_score(pred_sentences, [zh_lines])
print(result)
assert len(pred_sentences) == len(zh_lines)