diff --git a/html/app.py b/html/app.py
index 5b2db23..9f924f6 100644
--- a/html/app.py
+++ b/html/app.py
@@ -3,7 +3,7 @@
"""
import re
import time
-from typing import List
+from typing import Dict, List
from collections import OrderedDict
from fastapi import FastAPI, HTTPException
@@ -22,7 +22,7 @@
model.enable_pipe('senter')
# Rankers.
-english_ranker = Ranker(model_name='ms-marco-TinyBERT-L-2-v2', cache_dir='/opt/models')
+english_ranker = Ranker()#model_name='ms-marco-TinyBERT-L-2-v2', cache_dir='/opt/models')
multilingual_reranker = Ranker(model_name='ms-marco-MultiBERT-L-12', cache_dir='/opt/models')
rankers = {
'en': english_ranker,
@@ -75,7 +75,7 @@ def split_text_into_sentences(text: str, language: str) -> List[str]:
return sentences
# Rank texts against a query.
-def rank_texts(query: str, texts: List[str], language: str, limit: int) -> List[str]:
+def rank_texts(query: str, texts: List[str], language: str, limit: int) -> Dict[str, float]:
"""
Split a text into sentences.
@@ -86,13 +86,13 @@ def rank_texts(query: str, texts: List[str], language: str, limit: int) -> List[
limit (int): the maximum number of most relevant texts to the query.
Returns:
- List[str]: most relevant texts.
+ Dict[str, float]: most relevant texts with their score.
"""
passages = [{'text': text} for text in texts]
rerank_request = RerankRequest(query=query, passages=passages)
results = rankers[language].rerank(rerank_request)
- ranked_texts = [result.get('text') for result in results[:limit]]
+ ranked_texts = {result.get('text'): result.get('score') for result in results[:limit]}
return ranked_texts
@@ -291,8 +291,7 @@ def text_correlate_keywords(request: TextCorrelateKeywordsRequest) -> TextCorrel
# Find the most relevant terms for the text and add them to the
# keywords extracted from the text.
- keywords = [keyword.strip().lower() for keyword in request.keywords]
- keywords += rank_texts(request.text, keywords, request.language, request.limit)
+ keywords += rank_texts(request.text, request.keywords, request.language, request.limit).keys()
# Ensure uniqueness.
keywords = list(OrderedDict.fromkeys(keywords))
@@ -321,7 +320,7 @@ class TextCorrelateTextsResponse(Response):
Attributes:
texts (List[str]): the most relevant tests.
"""
- texts: List[str]
+ texts: Dict[str, float]
# Endpoint to correlate texts to another text.
@app.post('/text/correlate/texts')