11"""Tests for operation utilities.""" 
22
33import  os 
4- from  unittest .mock  import  Mock 
4+ from  unittest .mock  import  Mock ,  patch 
55
66import  pytest 
77from  bson  import  ObjectId 
88from  pymongo  import  MongoClient 
99from  pymongo .collection  import  Collection 
1010
11- from  pymongo_vectorsearch_utils .operation  import  bulk_embed_and_insert_texts 
11+ from  pymongo_vectorsearch_utils  import  drop_vector_search_index 
12+ from  pymongo_vectorsearch_utils .index  import  create_vector_search_index , wait_for_docs_in_index 
13+ from  pymongo_vectorsearch_utils .operation  import  bulk_embed_and_insert_texts , execute_search_query 
1214
1315DB_NAME  =  "vectorsearch_utils_test" 
1416COLLECTION_NAME  =  "test_operation" 
17+ VECTOR_INDEX_NAME  =  "operation_vector_index" 
1518
1619
1720@pytest .fixture (scope = "module" ) 
@@ -22,6 +25,17 @@ def client():
2225    client .close ()
2326
2427
28+ @pytest .fixture (scope = "module" ) 
29+ def  preserved_collection (client ):
30+     if  COLLECTION_NAME  not  in client [DB_NAME ].list_collection_names ():
31+         clxn  =  client [DB_NAME ].create_collection (COLLECTION_NAME )
32+     else :
33+         clxn  =  client [DB_NAME ][COLLECTION_NAME ]
34+     clxn .delete_many ({})
35+     yield  clxn 
36+     clxn .delete_many ({})
37+ 
38+ 
2539@pytest .fixture  
2640def  collection (client ):
2741    if  COLLECTION_NAME  not  in client [DB_NAME ].list_collection_names ():
@@ -266,3 +280,182 @@ def test_custom_field_names(self, collection: Collection, mock_embedding_func):
266280        assert  "vector"  in  doc 
267281        assert  doc ["content" ] ==  texts [0 ]
268282        assert  doc ["vector" ] ==  [0.0 , 0.0 , 0.0 ]
283+ 
284+ 
285+ class  TestExecuteSearchQuery :
286+     @pytest .fixture (scope = "class" , autouse = True ) 
287+     def  vector_search_index (self , client ):
288+         coll  =  client [DB_NAME ][COLLECTION_NAME ]
289+         if  len (coll .list_search_indexes (VECTOR_INDEX_NAME ).to_list ()) ==  0 :
290+             create_vector_search_index (
291+                 collection = coll ,
292+                 index_name = VECTOR_INDEX_NAME ,
293+                 dimensions = 3 ,
294+                 path = "embedding" ,
295+                 similarity = "cosine" ,
296+                 filters = ["category" , "color" , "wheels" ],
297+                 wait_until_complete = 120 ,
298+             )
299+         yield 
300+         drop_vector_search_index (collection = coll , index_name = VECTOR_INDEX_NAME )
301+ 
302+     @pytest .fixture (scope = "class" , autouse = True ) 
303+     def  sample_docs (self , preserved_collection : Collection , vector_search_index ):
304+         texts  =  ["apple fruit" , "banana fruit" , "car vehicle" , "bike vehicle" ]
305+         metadatas  =  [
306+             {"category" : "fruit" , "color" : "red" },
307+             {"category" : "fruit" , "color" : "yellow" },
308+             {"category" : "vehicle" , "wheels" : 4 },
309+             {"category" : "vehicle" , "wheels" : 2 },
310+         ]
311+ 
312+         def  embeddings (texts ):
313+             mapping  =  {
314+                 "apple fruit" : [1.0 , 0.5 , 0.0 ],
315+                 "banana fruit" : [0.5 , 0.5 , 0.0 ],
316+                 "car vehicle" : [0.0 , 0.5 , 1.0 ],
317+                 "bike vehicle" : [0.0 , 1.0 , 0.5 ],
318+             }
319+             return  [mapping [text ] for  text  in  texts ]
320+ 
321+         bulk_embed_and_insert_texts (
322+             texts = texts ,
323+             metadatas = metadatas ,
324+             embedding_func = embeddings ,
325+             collection = preserved_collection ,
326+             text_key = "text" ,
327+             embedding_key = "embedding" ,
328+         )
329+         # Add a document that should not be returned in searches 
330+         preserved_collection .insert_one (
331+             {
332+                 "category" : "fruit" ,
333+                 "color" : "red" ,
334+                 "embedding" : [1.0 , 1.0 , 1.0 ],
335+             }
336+         )
337+         wait_for_docs_in_index (preserved_collection , VECTOR_INDEX_NAME , n_docs = 5 )
338+         return  preserved_collection 
339+ 
340+     def  test_basic_search_query (self , sample_docs : Collection ):
341+         query_vector  =  [1.0 , 0.5 , 0.0 ]
342+ 
343+         result  =  execute_search_query (
344+             query_vector = query_vector ,
345+             collection = sample_docs ,
346+             embedding_key = "embedding" ,
347+             text_key = "text" ,
348+             index_name = VECTOR_INDEX_NAME ,
349+             k = 2 ,
350+         )
351+ 
352+         assert  len (result ) ==  2 
353+         assert  result [0 ]["text" ] ==  "apple fruit" 
354+         assert  result [1 ]["text" ] ==  "banana fruit" 
355+         assert  "score"  in  result [0 ]
356+         assert  "score"  in  result [1 ]
357+ 
358+     def  test_search_with_pre_filter (self , sample_docs : Collection ):
359+         query_vector  =  [1.0 , 0.5 , 1.0 ]
360+         pre_filter  =  {"category" : "fruit" }
361+ 
362+         result  =  execute_search_query (
363+             query_vector = query_vector ,
364+             collection = sample_docs ,
365+             embedding_key = "embedding" ,
366+             text_key = "text" ,
367+             index_name = VECTOR_INDEX_NAME ,
368+             k = 4 ,
369+             pre_filter = pre_filter ,
370+         )
371+ 
372+         assert  len (result ) ==  2 
373+         assert  result [0 ]["category" ] ==  "fruit" 
374+         assert  result [1 ]["category" ] ==  "fruit" 
375+ 
376+     def  test_search_with_post_filter_pipeline (self , sample_docs : Collection ):
377+         query_vector  =  [1.0 , 0.5 , 0.0 ]
378+         post_filter_pipeline  =  [
379+             {"$match" : {"score" : {"$gte" : 0.99 }}},
380+             {"$sort" : {"score" : - 1 }},
381+         ]
382+ 
383+         result  =  execute_search_query (
384+             query_vector = query_vector ,
385+             collection = sample_docs ,
386+             embedding_key = "embedding" ,
387+             text_key = "text" ,
388+             index_name = VECTOR_INDEX_NAME ,
389+             k = 2 ,
390+             post_filter_pipeline = post_filter_pipeline ,
391+         )
392+ 
393+         assert  len (result ) ==  1 
394+ 
395+     def  test_search_with_embeddings_included (self , sample_docs : Collection ):
396+         query_vector  =  [1.0 , 0.5 , 0.0 ]
397+ 
398+         result  =  execute_search_query (
399+             query_vector = query_vector ,
400+             collection = sample_docs ,
401+             embedding_key = "embedding" ,
402+             text_key = "text" ,
403+             index_name = VECTOR_INDEX_NAME ,
404+             k = 1 ,
405+             include_embeddings = True ,
406+         )
407+ 
408+         assert  len (result ) ==  1 
409+         assert  "embedding"  in  result [0 ]
410+         assert  result [0 ]["embedding" ] ==  [1.0 , 0.5 , 0.0 ]
411+ 
412+     def  test_search_with_custom_field_names (self , sample_docs : Collection ):
413+         query_vector  =  [1.0 , 0.5 , 0.25 ]
414+ 
415+         mock_cursor  =  [
416+             {
417+                 "_id" : ObjectId (),
418+                 "content" : "apple fruit" ,
419+                 "vector" : [1.0 , 0.5 , 0.25 ],
420+                 "score" : 0.9 ,
421+             }
422+         ]
423+ 
424+         with  patch .object (sample_docs , "aggregate" ) as  mock_aggregate :
425+             mock_aggregate .return_value  =  mock_cursor 
426+ 
427+             result  =  execute_search_query (
428+                 query_vector = query_vector ,
429+                 collection = sample_docs ,
430+                 embedding_key = "vector" ,
431+                 text_key = "content" ,
432+                 index_name = VECTOR_INDEX_NAME ,
433+                 k = 1 ,
434+             )
435+ 
436+             assert  len (result ) ==  1 
437+             assert  "content"  in  result [0 ]
438+             assert  result [0 ]["content" ] ==  "apple fruit" 
439+ 
440+             pipeline_arg  =  mock_aggregate .call_args [0 ][0 ]
441+             vector_search_stage  =  pipeline_arg [0 ]["$vectorSearch" ]
442+             assert  vector_search_stage ["path" ] ==  "vector" 
443+             assert  {"$project" : {"vector" : 0 }} in  pipeline_arg 
444+ 
445+     def  test_search_filters_documents_without_text_key (self , sample_docs : Collection ):
446+         query_vector  =  [1.0 , 0.5 , 0.0 ]
447+ 
448+         result  =  execute_search_query (
449+             query_vector = query_vector ,
450+             collection = sample_docs ,
451+             embedding_key = "embedding" ,
452+             text_key = "text" ,
453+             index_name = VECTOR_INDEX_NAME ,
454+             k = 3 ,
455+         )
456+ 
457+         # Should only return documents with text field 
458+         assert  len (result ) ==  2 
459+         assert  all ("text"  in  doc  for  doc  in  result )
460+         assert  result [0 ]["text" ] ==  "apple fruit" 
461+         assert  result [1 ]["text" ] ==  "banana fruit" 
0 commit comments