22
33from graphgen .bases .base_evaluator import BaseEvaluator
44from graphgen .bases .datatypes import QAPair
5- from graphgen .utils import NLTKHelper , create_event_loop , detect_main_language
6-
7- nltk_helper = NLTKHelper ()
5+ from graphgen .utils import NLTKHelper , detect_main_language
86
97
108class MTLDEvaluator (BaseEvaluator ):
119 """
12- 衡量文本词汇多样性的指标
10+ Metrics for measuring the lexical diversity of text.
1311 """
1412
15- def __init__ (self , max_concurrent : int = 100 ):
16- super ().__init__ (max_concurrent )
17- self .stopwords_en : Set [str ] = set (nltk_helper .get_stopwords ("english" ))
18- self .stopwords_zh : Set [str ] = set (nltk_helper .get_stopwords ("chinese" ))
19-
20- async def evaluate_single (self , pair : QAPair ) -> float :
21- loop = create_event_loop ()
22- return await loop .run_in_executor (None , self ._calculate_mtld_score , pair .answer )
13+ def __init__ (self , threshold : float = 0.72 ):
14+ self .nltk_helper = NLTKHelper ()
15+ self .stopwords_en : Set [str ] = set (self .nltk_helper .get_stopwords ("english" ))
16+ self .stopwords_zh : Set [str ] = set (self .nltk_helper .get_stopwords ("chinese" ))
17+ self .threshold = threshold
2318
24- def _calculate_mtld_score (self , text : str , threshold = 0.72 ) -> float :
19+ def evaluate (self , pair : QAPair ) -> float :
2520 """
26- 计算MTLD (向前和向后的平均值)
21+ Calculate the MTLD (Mean Token Length Diversity) score for a given text.
2722
2823 min is 1.0
2924 higher is better
3025 """
26+ text = pair .answer
3127 if not text or not text .strip ():
3228 return 0.0
3329
3430 lang = detect_main_language (text )
35- tokens = nltk_helper .word_tokenize (text , lang )
31+ tokens = self . nltk_helper .word_tokenize (text , lang )
3632
3733 stopwords = self .stopwords_zh if lang == "zh" else self .stopwords_en
3834 filtered_tokens = [word for word in tokens if word not in stopwords ]
@@ -41,13 +37,13 @@ def _calculate_mtld_score(self, text: str, threshold=0.72) -> float:
4137 if not filtered_tokens :
4238 return 0
4339
44- # 计算向前的MTLD
45- forward_factors = self ._compute_factors (filtered_tokens , threshold )
40+ # Compute forward factors
41+ forward_factors = self ._compute_factors (filtered_tokens , self . threshold )
4642
47- # 计算向后的MTLD
48- backward_factors = self ._compute_factors (filtered_tokens [::- 1 ], threshold )
43+ # Compute backward factors
44+ backward_factors = self ._compute_factors (filtered_tokens [::- 1 ], self . threshold )
4945
50- # 取平均值
46+ # Compute average factors
5147 return (forward_factors + backward_factors ) / 2
5248
5349 @staticmethod
@@ -66,7 +62,7 @@ def _compute_factors(tokens: list, threshold: float) -> float:
6662 current_segment = []
6763 unique_words = set ()
6864
69- # 处理最后一个不完整片段
65+ # handle last segment
7066 if current_segment :
7167 ttr = len (unique_words ) / len (current_segment )
7268 if ttr <= threshold :
0 commit comments