Skip to content

Commit a7930af

Browse files
vishwarajananddishaprakash
authored andcommitted
refactor: parse embedding dimensions as float[] for numpyV2 (#277)
1 parent 2fb91b1 commit a7930af

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

src/langchain_google_cloud_sql_pg/async_vectorstore.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,11 @@ async def __aadd_embeddings(
249249
else ""
250250
)
251251
insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{metadata_col_names}'
252-
values = {"id": id, "content": content, "embedding": str(embedding)}
252+
values = {
253+
"id": id,
254+
"content": content,
255+
"embedding": str([float(dimension) for dimension in embedding]),
256+
}
253257
values_stmt = "VALUES (:id, :content, :embedding"
254258

255259
# Add metadata
@@ -554,9 +558,9 @@ async def __query_collection(
554558
columns.append(self.metadata_json_column)
555559

556560
column_names = ", ".join(f'"{col}"' for col in columns)
557-
558561
filter = f"WHERE {filter}" if filter else ""
559-
stmt = f"SELECT {column_names}, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.schema_name}\".\"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};"
562+
embedding_string = f"'{[float(dimension) for dimension in embedding]}'"
563+
stmt = f'SELECT {column_names}, {search_function}({self.embedding_column}, {embedding_string}) as distance FROM "{self.schema_name}"."{self.table_name}" {filter} ORDER BY {self.embedding_column} {operator} {embedding_string} LIMIT {k};'
560564
if self.index_query_options:
561565
async with self.pool.connect() as conn:
562566
await conn.execute(

tests/test_engine.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ async def test_init_table(self, engine):
130130
id = str(uuid.uuid4())
131131
content = "coffee"
132132
embedding = await embeddings_service.aembed_query(content)
133-
stmt = f"INSERT INTO {DEFAULT_TABLE} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding}');"
133+
# Note: DeterministicFakeEmbedding generates a numpy array, converting to list a list of float values
134+
embedding_string = [float(dimension) for dimension in embedding]
135+
stmt = f"INSERT INTO {DEFAULT_TABLE} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding_string}');"
134136
await aexecute(engine, stmt)
135137

136138
async def test_init_table_custom(self, engine):
@@ -350,7 +352,9 @@ async def test_init_table(self, engine):
350352
id = str(uuid.uuid4())
351353
content = "coffee"
352354
embedding = await embeddings_service.aembed_query(content)
353-
stmt = f"INSERT INTO {DEFAULT_TABLE_SYNC} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding}');"
355+
# Note: DeterministicFakeEmbedding generates a numpy array, converting to list a list of float values
356+
embedding_string = [float(dimension) for dimension in embedding]
357+
stmt = f"INSERT INTO {DEFAULT_TABLE_SYNC} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding_string}');"
354358
await aexecute(engine, stmt)
355359

356360
async def test_init_table_custom(self, engine):

0 commit comments

Comments
 (0)