Skip to content

Commit

Permalink
Merge pull request #896 from khoj-ai/features/add-support-for-custom-…
Browse files Browse the repository at this point in the history
…confidence

Add support for custom search model-specific thresholds
  • Loading branch information
sabaimran committed Aug 25, 2024
2 parents fa4d808 + 4b77325 commit af4e998
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Generated by Django 5.0.7 on 2024-08-24 18:19

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("database", "0058_alter_chatmodeloptions_chat_model"),
]

operations = [
migrations.AddField(
model_name="searchmodelconfig",
name="bi_encoder_confidence_threshold",
field=models.FloatField(default=0.18),
),
]
2 changes: 2 additions & 0 deletions src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ class ModelType(models.TextChoices):
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
# The confidence threshold of the bi_encoder model to consider the embeddings as relevant
bi_encoder_confidence_threshold = models.FloatField(default=0.18)


class TextToImageModelConfig(BaseModel):
Expand Down
3 changes: 1 addition & 2 deletions src/khoj/routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def search(
n=n,
t=t,
r=r,
max_distance=max_distance,
max_distance=max_distance or math.inf,
dedupe=dedupe,
)

Expand Down Expand Up @@ -117,7 +117,6 @@ async def execute_search(
# initialize variables
user_query = q.strip()
results_count = n or 5
max_distance = max_distance or math.inf
search_futures: List[concurrent.futures.Future] = []

# return cached results, if available
Expand Down
4 changes: 2 additions & 2 deletions src/khoj/routers/api_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ async def chat(
common: CommonQueryParams,
q: str,
n: int = 7,
d: float = 0.18,
d: float = None,
stream: Optional[bool] = False,
title: Optional[str] = None,
conversation_id: Optional[int] = None,
Expand Down Expand Up @@ -764,7 +764,7 @@ def collect_telemetry():
meta_log,
q,
(n or 7),
(d or 0.18),
d,
conversation_id,
conversation_commands,
location,
Expand Down
9 changes: 7 additions & 2 deletions src/khoj/search_type/text_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,23 @@ async def query(
raw_query: str,
type: SearchType = SearchType.All,
question_embedding: Union[torch.Tensor, None] = None,
max_distance: float = math.inf,
max_distance: float = None,
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"

file_type = search_type_to_embeddings_type[type.value]

query = raw_query
search_model = await sync_to_async(get_user_search_model_or_default)(user)
if not max_distance:
if search_model.bi_encoder_confidence_threshold:
max_distance = search_model.bi_encoder_confidence_threshold
else:
max_distance = math.inf

# Encode the query using the bi-encoder
if question_embedding is None:
with timer("Query Encode Time", logger, state.device):
search_model = await sync_to_async(get_user_search_model_or_default)(user)
question_embedding = state.embeddings_model[search_model.name].embed_query(query)

# Find relevant entries for the query
Expand Down

0 comments on commit af4e998

Please sign in to comment.