diff --git a/learning_resources_search/api.py b/learning_resources_search/api.py index b2689d41d4..eec3673942 100644 --- a/learning_resources_search/api.py +++ b/learning_resources_search/api.py @@ -13,6 +13,7 @@ from learning_resources.models import LearningResource from learning_resources_search.connection import ( get_default_alias_name, + get_vector_model_id, ) from learning_resources_search.constants import ( CONTENT_FILE_TYPE, @@ -654,12 +655,18 @@ def add_text_query_to_search( text_query = {"bool": {"must": [text_query], "filter": query_type_query}} if use_hybrid_search: - encoder = dense_encoder() - query_vector = encoder.embed_query(text) + model_id = get_vector_model_id() + + if not model_id: + log.error("Vector model not found. Cannot perform hybrid search.") + error_message = "Vector model not found." + raise ValueError(error_message) + vector_query = { - "knn": { + "neural": { "vector_embedding": { - "vector": query_vector, + "query_text": text, + "model_id": model_id, "k": HYBRID_SEARCH_KNN_K_VALUE, } } diff --git a/learning_resources_search/api_test.py b/learning_resources_search/api_test.py index 1dec281539..bdb9c485a7 100644 --- a/learning_resources_search/api_test.py +++ b/learning_resources_search/api_test.py @@ -2415,8 +2415,10 @@ def test_execute_learn_search_with_hybrid_search(mocker, settings, opensearch): settings.DEFAULT_SEARCH_MODE = "best_fields" - mock_encoder = mocker.patch("learning_resources_search.api.dense_encoder")() - mock_encoder.embed_query.return_value = [0.1, 0.2, 0.3] + mocker.patch( + "learning_resources_search.api.get_vector_model_id", + return_value="vector_model_id", + ) search_params = { "aggregations": ["offered_by"], @@ -2723,7 +2725,15 @@ def test_execute_learn_search_with_hybrid_search(mocker, settings, opensearch): "filter": {"exists": {"field": "resource_type"}}, } }, - {"knn": {"vector_embedding": {"vector": [0.1, 0.2, 0.3], "k": 5}}}, + { + "neural": { + "vector_embedding": { + "query_text": "math", + "model_id": "vector_model_id", + "k": 5, + } + } + }, ], } }, diff --git a/learning_resources_search/connection.py b/learning_resources_search/connection.py index 8fb5fd893d..7cf41bc751 100644 --- a/learning_resources_search/connection.py +++ b/learning_resources_search/connection.py @@ -3,10 +3,12 @@ """ import uuid +from contextlib import suppress from functools import partial from django.conf import settings from opensearch_dsl.connections import connections +from opensearchpy.exceptions import ConflictError from learning_resources_search.constants import ( ALL_INDEX_TYPES, @@ -135,3 +137,153 @@ def refresh_index(index): """ conn = get_conn() conn.indices.refresh(index) + + +def create_openai_embedding_connector_and_model( + model_name=settings.OPENSEARCH_VECTOR_MODEL_BASE_NAME, + openai_model=settings.QDRANT_DENSE_MODEL, +): + """ + Create OpenAI embedding connector and model for opensearch vector search. + The model will be used to generate embeddings for user queries + + Args: + model_name: Name param for the model in opensearch + openai_model: Name of the OpenAI model that will be loaded + """ + + conn = get_conn() + + body = { + "name": f"{model_name}_connector", + "description": "openAI Embedding Connector ", + "version": "0.1", + "protocol": "http", + "parameters": { + "model": openai_model, + }, + "credential": {"openAI_key": settings.OPENAI_API_KEY}, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://api.openai.com/v1/embeddings", + "headers": { + "Authorization": "Bearer ${credential.openAI_key}", + }, + "request_body": '{"input": ${parameters.input}, "model": "${parameters.model}" }', # noqa: E501 + "pre_process_function": "connector.pre_process.openai.embedding", + "post_process_function": "connector.post_process.openai.embedding", + } + ], + } + + connector_response = conn.transport.perform_request( + "POST", "/_plugins/_ml/connectors/_create", body=body + ) + + connector_id = connector_response["connector_id"] + + model_group_response = conn.transport.perform_request( + "POST", + "/_plugins/_ml/model_groups/_register", + body={ + "name": f"{model_name}_group", + "description": "OpenAI Embedding Model Group", + }, + ) + + model_group_id = model_group_response["model_group_id"] + + conn.transport.perform_request( + "POST", + "/_plugins/_ml/models/_register", + body={ + "name": model_name, + "function_name": "remote", + "model_group_id": model_group_id, + "description": "OpenAI embedding model", + "connector_id": connector_id, + }, + ) + + +def get_vector_model_id(model_name=settings.OPENSEARCH_VECTOR_MODEL_BASE_NAME): + """ + Get the model ID for the currently loaded opensearch vector model + Args: + model_name: Name of the model to get the id for + Returns: + str or None: The model ID if found, else None + """ + conn = get_conn() + body = {"query": {"term": {"name.keyword": model_name}}} + models = conn.transport.perform_request( + "GET", "/_plugins/_ml/models/_search", body=body + ) + + if len(models.get("hits", {}).get("hits", [])) > 0: + return models["hits"]["hits"][0]["_id"] + + return None + + +def deploy_vector_model(model_name=settings.OPENSEARCH_VECTOR_MODEL_BASE_NAME): + """ + Deploy an opensearch vector model + + Args: + model_name: Name of the model to deploy + """ + conn = get_conn() + model_id = get_vector_model_id(model_name=model_name) + conn.transport.perform_request("POST", f"/_plugins/_ml/models/{model_id}/_deploy") + + +def cleanup_vector_models( + exclude_model_names=[settings.OPENSEARCH_VECTOR_MODEL_BASE_NAME], # noqa: B006 +): + """ + Delete an opensearch vector models. If exclude_model_name is provided, + do not delete that model. + + Args: + exclude_model_names: List of names of the models to keep or None + """ + conn = get_conn() + body = {"query": {"match_all": {}}} + models_response = conn.transport.perform_request( + "GET", "/_plugins/_ml/models/_search", body=body + ) + + deleted_models = [] + for model in models_response.get("hits", {}).get("hits", []): + model_id = model.get("_id") + model_name = model.get("_source", {}).get("name") + model_group_id = model.get("_source", {}).get("model_group_id") + connector_id = model.get("_source", {}).get("connector_id") + + if model_name not in (exclude_model_names or []): + if model.get("_source", {}).get("model_state") == "DEPLOYED": + conn.transport.perform_request( + "POST", f"/_plugins/_ml/models/{model_id}/_undeploy" + ) + + conn.transport.perform_request("DELETE", f"/_plugins/_ml/models/{model_id}") + deleted_models.append(model_name) + + if model_group_id: + # ConflictError is raised if other models still use the group + with suppress(ConflictError): + conn.transport.perform_request( + "DELETE", f"/_plugins/_ml/model_groups/{model_group_id}" + ) + + if connector_id: + # ConflictError is raised if other models still use the connector + with suppress(ConflictError): + conn.transport.perform_request( + "DELETE", f"/_plugins/_ml/connectors/{connector_id}" + ) + + return deleted_models diff --git a/main/settings.py b/main/settings.py index e827a47fac..94cc64d015 100644 --- a/main/settings.py +++ b/main/settings.py @@ -740,6 +740,10 @@ def get_all_config_keys(): MICROMASTERS_CMS_API_URL = get_string("MICROMASTERS_CMS_API_URL", None) +OPENSEARCH_VECTOR_MODEL_BASE_NAME = get_string( + name="OPENSEARCH_VECTOR_MODEL_BASE_NAME", + default="hybrid_search_model", +) POSTHOG_PROJECT_API_KEY = get_string( name="POSTHOG_PROJECT_API_KEY", default="", @@ -795,7 +799,9 @@ def get_all_config_keys(): QDRANT_BASE_COLLECTION_NAME = get_string( name="QDRANT_COLLECTION_NAME", default="resource_embeddings" ) -QDRANT_DENSE_MODEL = get_string(name="QDRANT_DENSE_MODEL", default=None) +QDRANT_DENSE_MODEL = get_string( + name="QDRANT_DENSE_MODEL", default="text-embedding-3-small" +) QDRANT_SPARSE_MODEL = get_string( name="QDRANT_SPARSE_MODEL", default="prithivida/Splade_PP_en_v1" ) diff --git a/vector_search/encoders/litellm.py b/vector_search/encoders/litellm.py index 6d7613e429..d27019c937 100644 --- a/vector_search/encoders/litellm.py +++ b/vector_search/encoders/litellm.py @@ -16,7 +16,7 @@ class LiteLLMEncoder(BaseEncoder): token_encoding_name = settings.LITELLM_TOKEN_ENCODING_NAME - def __init__(self, model_name="text-embedding-3-small"): + def __init__(self, model_name): self.model_name = model_name try: self.token_encoding_name = tiktoken.encoding_name_for_model(model_name)