|
1 | | -import asyncio |
2 | | - |
3 | | -from tqdm.asyncio import tqdm as tqdm_async |
4 | | - |
| 1 | +from abc import ABC, abstractmethod |
5 | 2 | from graphgen.bases.datatypes import QAPair |
6 | | -from graphgen.utils import create_event_loop |
7 | 3 |
|
8 | 4 |
|
9 | | -class BaseEvaluator: |
10 | | - def __init__(self, max_concurrent: int = 100): |
11 | | - self.max_concurrent = max_concurrent |
12 | | - self.results: list[float] = None |
13 | | - |
14 | | - def evaluate(self, pairs: list[QAPair]) -> list[float]: |
| 5 | +class BaseEvaluator(ABC): |
| 6 | + @abstractmethod |
| 7 | + def evaluate(self, pair: QAPair) -> float: |
15 | 8 | """ |
16 | 9 | Evaluate the text and return a score. |
17 | 10 | """ |
18 | | - return create_event_loop().run_until_complete(self.async_evaluate(pairs)) |
19 | | - |
20 | | - async def async_evaluate(self, pairs: list[QAPair]) -> list[float]: |
21 | | - semaphore = asyncio.Semaphore(self.max_concurrent) |
22 | | - |
23 | | - async def evaluate_with_semaphore(pair): |
24 | | - async with semaphore: # 获取Semaphore |
25 | | - return await self.evaluate_single(pair) |
26 | | - |
27 | | - results = [] |
28 | | - for result in tqdm_async( |
29 | | - asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]), |
30 | | - total=len(pairs), |
31 | | - ): |
32 | | - results.append(await result) |
33 | | - return results |
34 | | - |
35 | | - async def evaluate_single(self, pair: QAPair) -> float: |
36 | | - raise NotImplementedError() |
37 | | - |
38 | | - def get_average_score(self, pairs: list[QAPair]) -> float: |
39 | | - """ |
40 | | - Get the average score of a batch of texts. |
41 | | - """ |
42 | | - results = self.evaluate(pairs) |
43 | | - self.results = results |
44 | | - return sum(self.results) / len(pairs) |
45 | | - |
46 | | - def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]: |
47 | | - """ |
48 | | - Get the min and max score of a batch of texts. |
49 | | - """ |
50 | | - if self.results is None: |
51 | | - self.get_average_score(pairs) |
52 | | - return min(self.results), max(self.results) |
0 commit comments