Skip to content

Add other input types for SQLAlchemy BIT model #137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions pgvector/bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ def from_binary(cls, value):

@classmethod
def _to_db(cls, value):
if value is None:
return value

if not isinstance(value, cls):
raise ValueError('expected bit')

value = cls(value)
return value.to_text()

@classmethod
Expand All @@ -73,3 +76,9 @@ def _to_db_binary(cls, value):
raise ValueError('expected bit')

return value.to_binary()

@classmethod
def _from_db(cls, value):
if value is None or isinstance(value, cls):
return value
return cls.from_text(value)
21 changes: 20 additions & 1 deletion pgvector/sqlalchemy/bit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncpg
from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
from sqlalchemy.dialects.postgresql.base import ischema_names
from sqlalchemy.types import UserDefinedType, Float

from .. import Bit

class BIT(UserDefinedType):
cache_ok = True
Expand All @@ -14,6 +16,23 @@ def get_col_spec(self, **kw):
return 'BIT'
return 'BIT(%d)' % self.length

def bind_processor(self, dialect):
def process(value):
value = Bit._to_db(value)
if value and isinstance(dialect, PGDialect_asyncpg):
return asyncpg.BitString(value)
return value
return process

def result_processor(self, dialect, coltype):
def process(value):
if value is None: return None
else:
if isinstance(dialect, PGDialect_asyncpg):
return value.as_string()
return Bit._from_db(value).to_text()
return process

class comparator_factory(UserDefinedType.Comparator):
def hamming_distance(self, other):
return self.op('<~>', return_type=Float)(other)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ async def test_bit(self, engine):

async with async_session() as session:
async with session.begin():
embedding = asyncpg.BitString('101') if engine == asyncpg_engine else '101'
embedding = '101'
session.add(Item(id=1, binary_embedding=embedding))
item = await session.get(Item, 1)
assert item.binary_embedding == embedding
Expand Down