diff --git a/html/app.py b/html/app.py index 5b2db23..9f924f6 100644 --- a/html/app.py +++ b/html/app.py @@ -3,7 +3,7 @@ """ import re import time -from typing import List +from typing import Dict, List from collections import OrderedDict from fastapi import FastAPI, HTTPException @@ -22,7 +22,7 @@ model.enable_pipe('senter') # Rankers. -english_ranker = Ranker(model_name='ms-marco-TinyBERT-L-2-v2', cache_dir='/opt/models') +english_ranker = Ranker()#model_name='ms-marco-TinyBERT-L-2-v2', cache_dir='/opt/models') multilingual_reranker = Ranker(model_name='ms-marco-MultiBERT-L-12', cache_dir='/opt/models') rankers = { 'en': english_ranker, @@ -75,7 +75,7 @@ def split_text_into_sentences(text: str, language: str) -> List[str]: return sentences # Rank texts against a query. -def rank_texts(query: str, texts: List[str], language: str, limit: int) -> List[str]: +def rank_texts(query: str, texts: List[str], language: str, limit: int) -> Dict[str, float]: """ Split a text into sentences. @@ -86,13 +86,13 @@ def rank_texts(query: str, texts: List[str], language: str, limit: int) -> List[ limit (int): the maximum number of most relevant texts to the query. Returns: - List[str]: most relevant texts. + Dict[str, float]: most relevant texts with their score. """ passages = [{'text': text} for text in texts] rerank_request = RerankRequest(query=query, passages=passages) results = rankers[language].rerank(rerank_request) - ranked_texts = [result.get('text') for result in results[:limit]] + ranked_texts = {result.get('text'): result.get('score') for result in results[:limit]} return ranked_texts @@ -291,8 +291,7 @@ def text_correlate_keywords(request: TextCorrelateKeywordsRequest) -> TextCorrel # Find the most relevant terms for the text and add them to the # keywords extracted from the text. - keywords = [keyword.strip().lower() for keyword in request.keywords] - keywords += rank_texts(request.text, keywords, request.language, request.limit) + keywords += rank_texts(request.text, request.keywords, request.language, request.limit).keys() # Ensure uniqueness. keywords = list(OrderedDict.fromkeys(keywords)) @@ -321,7 +320,7 @@ class TextCorrelateTextsResponse(Response): Attributes: texts (List[str]): the most relevant tests. """ - texts: List[str] + texts: Dict[str, float] # Endpoint to correlate texts to another text. @app.post('/text/correlate/texts')