Skip to content

Commit cb8b0f9

Browse files
authored
update query handler and documentation (#4)
1 parent 1b84923 commit cb8b0f9

File tree

3 files changed

+66
-29
lines changed

3 files changed

+66
-29
lines changed

.env

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Postgres database address for cocoindex
2+
COCOINDEX_DATABASE_URL=postgres://cocoindex:cocoindex@localhost/cocoindex

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ export COCOINDEX_DATABASE_URL="postgresql://cocoindex:cocoindex@localhost:5432/c
2828
Setup index:
2929

3030
```bash
31-
python quickstart.py cocoindex setup
31+
cocoindex setup quickstart.py
3232
```
3333

3434
Update index:
3535

3636
```bash
37-
python quickstart.py cocoindex update
37+
cocoindex update quickstart.py
3838
```
3939

4040
Run query:

quickstart.py

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
11
import cocoindex
2+
from dotenv import load_dotenv
3+
from psycopg_pool import ConnectionPool
4+
import os
5+
6+
@cocoindex.transform_flow()
7+
def text_to_embedding(
8+
text: cocoindex.DataSlice[str],
9+
) -> cocoindex.DataSlice[list[float]]:
10+
"""
11+
Embed the text using a SentenceTransformer model.
12+
This is a shared logic between indexing and querying, so extract it as a function.
13+
"""
14+
return text.transform(
15+
cocoindex.functions.SentenceTransformerEmbed(
16+
model="sentence-transformers/all-MiniLM-L6-v2"
17+
)
18+
)
19+
220
@cocoindex.flow_def(name="TextEmbedding")
321
def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope):
422
# Add a data source to read files from a directory
@@ -18,9 +36,7 @@ def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind
1836
# Transform data of each chunk
1937
with doc["chunks"].row() as chunk:
2038
# Embed the chunk, put into `embedding` field
21-
chunk["embedding"] = chunk["text"].transform(
22-
cocoindex.functions.SentenceTransformerEmbed(
23-
model="sentence-transformers/all-MiniLM-L6-v2"))
39+
chunk["embedding"] = text_to_embedding(chunk["text"])
2440

2541
# Collect the chunk into the collector.
2642
doc_embeddings.collect(filename=doc["filename"], location=chunk["location"],
@@ -31,35 +47,54 @@ def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind
3147
"doc_embeddings",
3248
cocoindex.storages.Postgres(),
3349
primary_key_fields=["filename", "location"],
34-
vector_index=[("embedding", cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY)])
50+
vector_indexes=[
51+
cocoindex.VectorIndexDef(
52+
field_name="embedding",
53+
metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
54+
)
55+
],
56+
)
3557

36-
query_handler = cocoindex.query.SimpleSemanticsQueryHandler(
37-
name="SemanticsSearch",
38-
flow=text_embedding_flow,
39-
target_name="doc_embeddings",
40-
query_transform_flow=lambda text: text.transform(
41-
cocoindex.functions.SentenceTransformerEmbed(
42-
model="sentence-transformers/all-MiniLM-L6-v2")),
43-
default_similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY)
58+
def search(pool: ConnectionPool, query: str, top_k: int = 5):
59+
# Get the table name, for the export target in the text_embedding_flow above.
60+
table_name = cocoindex.utils.get_target_storage_default_name(
61+
text_embedding_flow, "doc_embeddings"
62+
)
63+
# Evaluate the transform flow defined above with the input query, to get the embedding.
64+
query_vector = text_to_embedding.eval(query)
65+
# Run the query and get the results.
66+
with pool.connection() as conn:
67+
with conn.cursor() as cur:
68+
cur.execute(
69+
f"""
70+
SELECT filename, text, embedding <=> %s::vector AS distance
71+
FROM {table_name} ORDER BY distance LIMIT %s
72+
""",
73+
(query_vector, top_k),
74+
)
75+
return [
76+
{"filename": row[0], "text": row[1], "score": 1.0 - row[2]}
77+
for row in cur.fetchall()
78+
]
4479

45-
@cocoindex.main_fn()
4680
def _main():
47-
# Run queries to demonstrate the query capabilities.
81+
# Initialize the database connection pool.
82+
pool = ConnectionPool(os.getenv("COCOINDEX_DATABASE_URL"))
83+
# Run queries in a loop to demonstrate the query capabilities.
4884
while True:
49-
try:
50-
query = input("Enter search query (or Enter to quit): ")
51-
if query == '':
52-
break
53-
results, _ = query_handler.search(query, 10)
54-
print("\nSearch results:")
55-
for result in results:
56-
print(f"[{result.score:.3f}] {result.data['filename']}")
57-
print(f" {result.data['text']}")
58-
print("---")
59-
print()
60-
except KeyboardInterrupt:
85+
query = input("Enter search query (or Enter to quit): ")
86+
if query == "":
6187
break
62-
88+
# Run the query function with the database connection pool and the query.
89+
results = search(pool, query)
90+
print("\nSearch results:")
91+
for result in results:
92+
print(f"[{result['score']:.3f}] {result['filename']}")
93+
print(f" {result['text']}")
94+
print("---")
95+
print()
6396

6497
if __name__ == "__main__":
98+
load_dotenv()
99+
cocoindex.init()
65100
_main()

0 commit comments

Comments
 (0)