Skip to content

Commit

Permalink
Merge pull request #5 from UN-OCHA/RW-1057
Browse files Browse the repository at this point in the history
[RW-1057] Fix keywords correlation and change text correlation to return the scores as well
  • Loading branch information
orakili committed Aug 30, 2024
2 parents a7214ea + bc788b5 commit 9f09346
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions html/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
import re
import time
from typing import List
from typing import Dict, List

from collections import OrderedDict
from fastapi import FastAPI, HTTPException
Expand All @@ -22,7 +22,7 @@
model.enable_pipe('senter')

# Rankers.
english_ranker = Ranker(model_name='ms-marco-TinyBERT-L-2-v2', cache_dir='/opt/models')
english_ranker = Ranker()#model_name='ms-marco-TinyBERT-L-2-v2', cache_dir='/opt/models')
multilingual_reranker = Ranker(model_name='ms-marco-MultiBERT-L-12', cache_dir='/opt/models')
rankers = {
'en': english_ranker,
Expand Down Expand Up @@ -75,7 +75,7 @@ def split_text_into_sentences(text: str, language: str) -> List[str]:
return sentences

# Rank texts against a query.
def rank_texts(query: str, texts: List[str], language: str, limit: int) -> List[str]:
def rank_texts(query: str, texts: List[str], language: str, limit: int) -> Dict[str, float]:
"""
Split a text into sentences.
Expand All @@ -86,13 +86,13 @@ def rank_texts(query: str, texts: List[str], language: str, limit: int) -> List[
limit (int): the maximum number of most relevant texts to the query.
Returns:
List[str]: most relevant texts.
Dict[str, float]: most relevant texts with their score.
"""
passages = [{'text': text} for text in texts]

rerank_request = RerankRequest(query=query, passages=passages)
results = rankers[language].rerank(rerank_request)
ranked_texts = [result.get('text') for result in results[:limit]]
ranked_texts = {result.get('text'): result.get('score') for result in results[:limit]}

return ranked_texts

Expand Down Expand Up @@ -291,8 +291,7 @@ def text_correlate_keywords(request: TextCorrelateKeywordsRequest) -> TextCorrel

# Find the most relevant terms for the text and add them to the
# keywords extracted from the text.
keywords = [keyword.strip().lower() for keyword in request.keywords]
keywords += rank_texts(request.text, keywords, request.language, request.limit)
keywords += rank_texts(request.text, request.keywords, request.language, request.limit).keys()

# Ensure uniqueness.
keywords = list(OrderedDict.fromkeys(keywords))
Expand Down Expand Up @@ -321,7 +320,7 @@ class TextCorrelateTextsResponse(Response):
Attributes:
texts (List[str]): the most relevant tests.
"""
texts: List[str]
texts: Dict[str, float]

# Endpoint to correlate texts to another text.
@app.post('/text/correlate/texts')
Expand Down

0 comments on commit 9f09346

Please sign in to comment.