Skip to content

Commit 77bb00d

Browse files
refactor: refactor base_evaluator
1 parent 978b76c commit 77bb00d

File tree

1 file changed

+4
-46
lines changed

1 file changed

+4
-46
lines changed

graphgen/bases/base_evaluator.py

Lines changed: 4 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,10 @@
1-
import asyncio
2-
3-
from tqdm.asyncio import tqdm as tqdm_async
4-
1+
from abc import ABC, abstractmethod
52
from graphgen.bases.datatypes import QAPair
6-
from graphgen.utils import create_event_loop
73

84

9-
class BaseEvaluator:
10-
def __init__(self, max_concurrent: int = 100):
11-
self.max_concurrent = max_concurrent
12-
self.results: list[float] = None
13-
14-
def evaluate(self, pairs: list[QAPair]) -> list[float]:
5+
class BaseEvaluator(ABC):
6+
@abstractmethod
7+
def evaluate(self, pair: QAPair) -> float:
158
"""
169
Evaluate the text and return a score.
1710
"""
18-
return create_event_loop().run_until_complete(self.async_evaluate(pairs))
19-
20-
async def async_evaluate(self, pairs: list[QAPair]) -> list[float]:
21-
semaphore = asyncio.Semaphore(self.max_concurrent)
22-
23-
async def evaluate_with_semaphore(pair):
24-
async with semaphore: # 获取Semaphore
25-
return await self.evaluate_single(pair)
26-
27-
results = []
28-
for result in tqdm_async(
29-
asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]),
30-
total=len(pairs),
31-
):
32-
results.append(await result)
33-
return results
34-
35-
async def evaluate_single(self, pair: QAPair) -> float:
36-
raise NotImplementedError()
37-
38-
def get_average_score(self, pairs: list[QAPair]) -> float:
39-
"""
40-
Get the average score of a batch of texts.
41-
"""
42-
results = self.evaluate(pairs)
43-
self.results = results
44-
return sum(self.results) / len(pairs)
45-
46-
def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]:
47-
"""
48-
Get the min and max score of a batch of texts.
49-
"""
50-
if self.results is None:
51-
self.get_average_score(pairs)
52-
return min(self.results), max(self.results)

0 commit comments

Comments
 (0)