Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Join #37

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open

Join #37

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions dictorm/dictorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

try: # pragma: no cover
from dictorm.pg import Select, Insert, Update, Delete, And
from dictorm.pg import Column, Comparison, Operator
from dictorm.pg import Column, Comparison, Operator, Join
from dictorm.sqlite import Insert as SqliteInsert
from dictorm.sqlite import Column as SqliteColumn
from dictorm.sqlite import Update as SqliteUpdate
except ImportError: # pragma: no cover
from .pg import Select, Insert, Update, Delete, And
from .pg import Column, Comparison, Operator
from .pg import Column, Comparison, Operator, Join
from .sqlite import Insert as SqliteInsert
from .sqlite import Column as SqliteColumn
from .sqlite import Update as SqliteUpdate
Expand Down Expand Up @@ -390,7 +390,7 @@ def __init__(self, table_name, db):
self.db = db
self.curs = db.curs
self.pks = []
self.refs = {}
self.joins = {}
self._refresh_pks()
self.order_by = None
self.fks = {}
Expand Down Expand Up @@ -430,7 +430,7 @@ def __call__(self, *a, **kw):
Used to insert a row into this table.
"""
d = Dict(self, *a, **kw)
for ref_name in self.refs:
for ref_name in self.joins:
d[ref_name] = None
return d

Expand Down Expand Up @@ -472,6 +472,8 @@ def get_where(self, *a, **kw):

"""
# All args/kwargs are combined in an SQL And comparison
if len(a) == 1 and isinstance(a[0], Join):
return ResultsGenerator(self, a[0], self.db)
operator_group = args_to_comp(And(), self, *a, **kw)

order_by = None
Expand Down Expand Up @@ -545,7 +547,7 @@ def columns_info(self):
return self.cached_columns_info


def __setitem__(self, ref_name, ref):
def __setitem__(self, join_name, join):
"""
Create reference that will be gotten by all Dicts created from this
table.
Expand All @@ -555,22 +557,21 @@ def __setitem__(self, ref_name, ref):

For more examples see Table's doc.
"""
if ref.column1.table != self:
if join.column1.table != self:
# Dict.__getitem__ expects the columns to be in a particular order,
# fix any order issues.
ref.column1, ref.column2 = ref.column2, ref.column1
self.fks[ref.column1.column] = ref_name
self.refs[ref_name] = ref
join.column1, join.column2 = join.column2, join.column1
self.fks[join.column1.column] = join_name
self.joins[join_name] = join


def __getitem__(self, ref_name):
def __getitem__(self, join_name):
"""
Get a reference if it has already been created. Otherwise, return a
Column object which is used to create a reference.
Get a Join if it has already been created. Otherwise, return a new Join
"""
if ref_name in self.refs:
return self.refs[ref_name]
return self.db.column(self, ref_name)
if join_name in self.joins:
return self.joins[join_name]
return self.db.column(self, join_name)



Expand Down Expand Up @@ -622,7 +623,7 @@ def flush(self):
All original column/values will bet inserted/updated by this method.
All references will be flushed as well.
"""
if self._table.refs:
if self._table.joins:
for i in self.values():
if isinstance(i, Dict):
i.flush()
Expand Down Expand Up @@ -701,15 +702,15 @@ def no_refs(self):
Return a dictionary without the key/value(s) added by a reference. They
should never be sent in the query to the Database.
"""
return dict([(k,v) for k,v in self.items() if k not in self._table.refs]
return dict([(k,v) for k,v in self.items() if k not in self._table.joins]
)


def references(self):
"""
Return a dictionary of only the referenced rows.
"""
return dict([(k,v) for k,v in self.items() if k in self._table.refs])
return dict([(k,v) for k,v in self.items() if k in self._table.joins])


def __getitem__(self, key):
Expand All @@ -718,7 +719,7 @@ def __getitem__(self, key):
referenced row, get that row first. Will only get a referenced row
once, until the referenced row's foreign key is changed.
"""
ref = self._table.refs.get(key)
ref = self._table.joins.get(key)
if not ref and key not in self:
raise KeyError(str(key))
# Only get the referenced row once, if it has a value, the reference's
Expand Down
126 changes: 122 additions & 4 deletions dictorm/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,114 @@ def __add__(self, item):
return self


class PreJoin(object):
"""
Used only to store information for join that will be created by a
comparison.__getitem__
"""

def __init__(self, comp, key):
self.comp = comp
self.key = key


def __eq__(self, key):
join = Join(self.comp)
return join.where(join.comps[0].column2.table[self.key] == key)



class Join(object):

query = 'SELECT "{table}".* FROM "{table}"{joins}{where}'
join_query = '"{c1.table.name}" ON {c1}={c2}'

def __init__(self, *comps, kind=''):
self.comps = list(comps)
self.kind = ' '+kind if kind else ''
self.ooc = None


def where(self, *ooc):
join = Join(*self.comps, kind=self.kind)
join.ooc = ooc
return join


def __str__(self):
joined_to = self.comps[0].column1.table
table = joined_to.name
formats = {'table':table,}
formats['joins'] = self.join_str(joined_to)
formats['where'] = ''
if self.ooc:
formats['where'] = ' WHERE ' + ''.join([str(i) for i in self.ooc])
return self.query.format(**formats)


def join_str(self, joined_to):
other_joins = []
my_joins = []
for comp in self.comps:
if isinstance(comp, Join):
other_joins.append(comp.join_str(joined_to))
else:
c1, c2 = comp.column1, comp.column2
if c1.table == joined_to:
c1, c2 = c2, c1
my_joins.append(self.join_query.format(c1=c1, c2=c2))
query = '{kind} JOIN '.format(kind=self.kind)
query = query + query.join(my_joins) + ' '.join(other_joins)
return query


def values(self):
if self.ooc:
return [i.value() for i in self.ooc]
return []


def build(self):
return (str(self), self.values())


def Join(self, *comps, kind=None):
kind = kind or Join
join = kind(*comps)
self.comps.append(join)
return self



class LeftJoin(Join):
def __init__(self, *comps):
super().__init__(*comps, kind='LEFT')



class RightJoin(Join):
def __init__(self, *comps):
super().__init__(*comps, kind='RIGHT')



class InnerJoin(Join):
def __init__(self, *comps):
super().__init__(*comps, kind='INNER')



class FullOuterJoin(Join):
def __init__(self, *comps):
super().__init__(*comps, kind='FULL OUTER')



class FullJoin(Join):
def __init__(self, *comps):
super().__init__(*comps, kind='FULL')



class Insert(object):

Expand Down Expand Up @@ -192,6 +300,7 @@ class Delete(Update):
query = 'DELETE FROM "{table}"'



class Comparison(object):

interpolation_str = '%s'
Expand All @@ -204,6 +313,7 @@ def __init__(self, column1, column2, kind):
self._substratum = None
self._aggregate = False


def __repr__(self): # pragma: no cover
if isinstance(self.column2, Null):
ret = 'Comparison({0}{1})'.format(self.column1, self.kind)
Expand All @@ -215,10 +325,10 @@ def __repr__(self): # pragma: no cover


def __str__(self):
c1 = self.column1.column
c1 = self.column1
if self._null_kind():
return '"{0}"{1}'.format(c1, self.kind)
return '"{0}"{1}{2}'.format(c1, self.kind, self.interpolation_str)
return '"{0}"."{1}"{2}'.format(c1.table.name, c1.column, self.kind)
return '"{0}"."{1}"{2}{3}'.format(c1.table.name, c1.column, self.kind, self.interpolation_str)


def value(self):
Expand Down Expand Up @@ -251,9 +361,14 @@ def Or(self, comp2): return Or(self, comp2)
def Xor(self, comp2): return Xor(self, comp2)
def And(self, comp2): return And(self, comp2)

def __getitem__(self, key):
return PreJoin(self, key)



class Null():

class Null(): pass
def __repr__(self): return ''



Expand All @@ -268,6 +383,9 @@ def __init__(self, table, column):
def __repr__(self): # pragma: no cover
return '{0}.{1}'.format(self.table.name, self.column)

def __str__(self):
return '"{0}"."{1}"'.format(self.table.name, self.column)

def many(self, column):
c = self.comparison(self, column, '=')
c.many = True
Expand Down
1 change: 1 addition & 0 deletions dictorm/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Provide Sqlite3 support by making simple changes to dictorm.pg classes.
'''
from .pg import And
from .pg import Join
from .pg import Column as PostgresqlColumn
from .pg import Comparison as PostgresqlComparison
from .pg import Insert as PostgresqlInsert
Expand Down
33 changes: 32 additions & 1 deletion dictorm/test/test_dictorm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#! /usr/bin/env python
import dictorm
from dictorm.pg import Join
from psycopg2.extras import DictCursor
import os
import psycopg2
Expand All @@ -24,7 +25,9 @@
}

def _no_refs(o):
if isinstance(o, dictorm.Dict):
if o == None:
return None
elif isinstance(o, dictorm.Dict):
return o.no_refs()
l = []
for i in o:
Expand Down Expand Up @@ -1216,6 +1219,34 @@ def test_aggregate(self):
self.assertEqualNoRefs(bob['subordinates_departments'], [it, sales, hr])


def test_join(self):
Person, Car = self.db['person'], self.db['car']
Person['car'] = Person['id'] == Car['person_id']
Car['owner'] = Car['person_id'] == Person['id']

bob = Person(name='Bob').flush()
aly = Person(name='Aly').flush()

bob_car = Car(name='Ford', person_id=bob['id']).flush()
aly_car = Car(name='Dodge', person_id=aly['id']).flush()

self.assertEqualNoRefs(bob['car'], bob_car)
self.assertEqualNoRefs(aly['car'], aly_car)

self.assertEqualNoRefs(
Person.get_where(Person['car']['name'] == 'Ford'),
[bob,])

self.assertEqualNoRefs(
Car.get_where(Car['owner']['name'] == 'Bob'),
[bob_car,])

self.assertEqualNoRefs(
Car.get_where(Join(Car['person_id'] == Person['id']
).where(Person['name'] == 'Bob')),
[bob_car,])



class SqliteTestBase(object):

Expand Down
Loading