Skip to content

Commit 028b043

Browse files
refactor: refactor MTLDEvaluator
1 parent 19510d9 commit 028b043

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

graphgen/models/evaluator/qa/mtld_evaluator.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,33 @@
22

33
from graphgen.bases.base_evaluator import BaseEvaluator
44
from graphgen.bases.datatypes import QAPair
5-
from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language
6-
7-
nltk_helper = NLTKHelper()
5+
from graphgen.utils import NLTKHelper, detect_main_language
86

97

108
class MTLDEvaluator(BaseEvaluator):
119
"""
12-
衡量文本词汇多样性的指标
10+
Metrics for measuring the lexical diversity of text.
1311
"""
1412

15-
def __init__(self, max_concurrent: int = 100):
16-
super().__init__(max_concurrent)
17-
self.stopwords_en: Set[str] = set(nltk_helper.get_stopwords("english"))
18-
self.stopwords_zh: Set[str] = set(nltk_helper.get_stopwords("chinese"))
19-
20-
async def evaluate_single(self, pair: QAPair) -> float:
21-
loop = create_event_loop()
22-
return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer)
13+
def __init__(self, threshold: float = 0.72):
14+
self.nltk_helper = NLTKHelper()
15+
self.stopwords_en: Set[str] = set(self.nltk_helper.get_stopwords("english"))
16+
self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("chinese"))
17+
self.threshold = threshold
2318

24-
def _calculate_mtld_score(self, text: str, threshold=0.72) -> float:
19+
def evaluate(self, pair: QAPair) -> float:
2520
"""
26-
计算MTLD (向前和向后的平均值)
21+
Calculate the MTLD (Mean Token Length Diversity) score for a given text.
2722
2823
min is 1.0
2924
higher is better
3025
"""
26+
text = pair.answer
3127
if not text or not text.strip():
3228
return 0.0
3329

3430
lang = detect_main_language(text)
35-
tokens = nltk_helper.word_tokenize(text, lang)
31+
tokens = self.nltk_helper.word_tokenize(text, lang)
3632

3733
stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en
3834
filtered_tokens = [word for word in tokens if word not in stopwords]
@@ -41,13 +37,13 @@ def _calculate_mtld_score(self, text: str, threshold=0.72) -> float:
4137
if not filtered_tokens:
4238
return 0
4339

44-
# 计算向前的MTLD
45-
forward_factors = self._compute_factors(filtered_tokens, threshold)
40+
# Compute forward factors
41+
forward_factors = self._compute_factors(filtered_tokens, self.threshold)
4642

47-
# 计算向后的MTLD
48-
backward_factors = self._compute_factors(filtered_tokens[::-1], threshold)
43+
# Compute backward factors
44+
backward_factors = self._compute_factors(filtered_tokens[::-1], self.threshold)
4945

50-
# 取平均值
46+
# Compute average factors
5147
return (forward_factors + backward_factors) / 2
5248

5349
@staticmethod
@@ -66,7 +62,7 @@ def _compute_factors(tokens: list, threshold: float) -> float:
6662
current_segment = []
6763
unique_words = set()
6864

69-
# 处理最后一个不完整片段
65+
# handle last segment
7066
if current_segment:
7167
ttr = len(unique_words) / len(current_segment)
7268
if ttr <= threshold:

0 commit comments

Comments
 (0)