1
1
import cocoindex
2
+ from dotenv import load_dotenv
3
+ from psycopg_pool import ConnectionPool
4
+ import os
5
+
6
+ @cocoindex .transform_flow ()
7
+ def text_to_embedding (
8
+ text : cocoindex .DataSlice [str ],
9
+ ) -> cocoindex .DataSlice [list [float ]]:
10
+ """
11
+ Embed the text using a SentenceTransformer model.
12
+ This is a shared logic between indexing and querying, so extract it as a function.
13
+ """
14
+ return text .transform (
15
+ cocoindex .functions .SentenceTransformerEmbed (
16
+ model = "sentence-transformers/all-MiniLM-L6-v2"
17
+ )
18
+ )
19
+
2
20
@cocoindex .flow_def (name = "TextEmbedding" )
3
21
def text_embedding_flow (flow_builder : cocoindex .FlowBuilder , data_scope : cocoindex .DataScope ):
4
22
# Add a data source to read files from a directory
@@ -18,9 +36,7 @@ def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind
18
36
# Transform data of each chunk
19
37
with doc ["chunks" ].row () as chunk :
20
38
# Embed the chunk, put into `embedding` field
21
- chunk ["embedding" ] = chunk ["text" ].transform (
22
- cocoindex .functions .SentenceTransformerEmbed (
23
- model = "sentence-transformers/all-MiniLM-L6-v2" ))
39
+ chunk ["embedding" ] = text_to_embedding (chunk ["text" ])
24
40
25
41
# Collect the chunk into the collector.
26
42
doc_embeddings .collect (filename = doc ["filename" ], location = chunk ["location" ],
@@ -31,35 +47,54 @@ def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind
31
47
"doc_embeddings" ,
32
48
cocoindex .storages .Postgres (),
33
49
primary_key_fields = ["filename" , "location" ],
34
- vector_index = [("embedding" , cocoindex .VectorSimilarityMetric .COSINE_SIMILARITY )])
50
+ vector_indexes = [
51
+ cocoindex .VectorIndexDef (
52
+ field_name = "embedding" ,
53
+ metric = cocoindex .VectorSimilarityMetric .COSINE_SIMILARITY ,
54
+ )
55
+ ],
56
+ )
35
57
36
- query_handler = cocoindex .query .SimpleSemanticsQueryHandler (
37
- name = "SemanticsSearch" ,
38
- flow = text_embedding_flow ,
39
- target_name = "doc_embeddings" ,
40
- query_transform_flow = lambda text : text .transform (
41
- cocoindex .functions .SentenceTransformerEmbed (
42
- model = "sentence-transformers/all-MiniLM-L6-v2" )),
43
- default_similarity_metric = cocoindex .VectorSimilarityMetric .COSINE_SIMILARITY )
58
+ def search (pool : ConnectionPool , query : str , top_k : int = 5 ):
59
+ # Get the table name, for the export target in the text_embedding_flow above.
60
+ table_name = cocoindex .utils .get_target_storage_default_name (
61
+ text_embedding_flow , "doc_embeddings"
62
+ )
63
+ # Evaluate the transform flow defined above with the input query, to get the embedding.
64
+ query_vector = text_to_embedding .eval (query )
65
+ # Run the query and get the results.
66
+ with pool .connection () as conn :
67
+ with conn .cursor () as cur :
68
+ cur .execute (
69
+ f"""
70
+ SELECT filename, text, embedding <=> %s::vector AS distance
71
+ FROM { table_name } ORDER BY distance LIMIT %s
72
+ """ ,
73
+ (query_vector , top_k ),
74
+ )
75
+ return [
76
+ {"filename" : row [0 ], "text" : row [1 ], "score" : 1.0 - row [2 ]}
77
+ for row in cur .fetchall ()
78
+ ]
44
79
45
- @cocoindex .main_fn ()
46
80
def _main ():
47
- # Run queries to demonstrate the query capabilities.
81
+ # Initialize the database connection pool.
82
+ pool = ConnectionPool (os .getenv ("COCOINDEX_DATABASE_URL" ))
83
+ # Run queries in a loop to demonstrate the query capabilities.
48
84
while True :
49
- try :
50
- query = input ("Enter search query (or Enter to quit): " )
51
- if query == '' :
52
- break
53
- results , _ = query_handler .search (query , 10 )
54
- print ("\n Search results:" )
55
- for result in results :
56
- print (f"[{ result .score :.3f} ] { result .data ['filename' ]} " )
57
- print (f" { result .data ['text' ]} " )
58
- print ("---" )
59
- print ()
60
- except KeyboardInterrupt :
85
+ query = input ("Enter search query (or Enter to quit): " )
86
+ if query == "" :
61
87
break
62
-
88
+ # Run the query function with the database connection pool and the query.
89
+ results = search (pool , query )
90
+ print ("\n Search results:" )
91
+ for result in results :
92
+ print (f"[{ result ['score' ]:.3f} ] { result ['filename' ]} " )
93
+ print (f" { result ['text' ]} " )
94
+ print ("---" )
95
+ print ()
63
96
64
97
if __name__ == "__main__" :
98
+ load_dotenv ()
99
+ cocoindex .init ()
65
100
_main ()
0 commit comments