Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More robust quoted names handling #167

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyrseas/augment/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def link_current(self, tables):
"""
for (sch, tbl) in self:
if not (sch, tbl) in tables:
raise KeyError("Table %s.%s not in current database" % (
sch, tbl))
raise KeyError("Table %s not in current database" %
quote_id(sch, tbl))
if not hasattr(self[(sch, tbl)], 'current'):
self[(sch, tbl)].current = tables[(sch, tbl)]
4 changes: 2 additions & 2 deletions pyrseas/augment/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
DbAugment and CfgTriggerDict derived from DbAugmentDict.
"""
from pyrseas.augment import DbAugmentDict, DbAugment
from pyrseas.dbobject import split_schema_obj
from pyrseas.dbobject import split_schema_obj, quote_id
from pyrseas.dbobject.trigger import Trigger


Expand Down Expand Up @@ -35,7 +35,7 @@ def apply(self, table):
newtrg.procedure[:14], table.name)
(sch, fnc) = split_schema_obj(newtrg.procedure)
if sch != table.schema:
newtrg.procedure = "%s.%s" % (table.schema, fnc)
newtrg.procedure = quote_id(table.schema, fnc)
table.triggers.update({newtrg.name: newtrg})


Expand Down
110 changes: 86 additions & 24 deletions pyrseas/dbobject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,28 @@ def fetch_reserved_words(db):
WHERE catcode = 'R'""")]


def quote_id(name):
def quote_id(*names):
"""Quotes an identifier if necessary.

:param name: string to be quoted
:param names: strings to be quoted. If more than one they will be merged
by dot.

:return: possibly quoted string
"""
regular_id = True
if not name[0] in VALID_FIRST_CHARS or name in RESERVED_WORDS:
regular_id = False
else:
for ltr in name[1:]:
if ltr not in VALID_CHARS:
regular_id = False
break
rv = []
for name in names:
regular_id = True
if not name[0] in VALID_FIRST_CHARS or name in RESERVED_WORDS:
regular_id = False
else:
for ltr in name[1:]:
if ltr not in VALID_CHARS:
regular_id = False
break

return regular_id and name or '"%s"' % name
rv.append(regular_id and name or '"%s"' % name.replace('"', '""'))

return '.'.join(rv)


def split_schema_obj(obj, sch=None):
Expand All @@ -68,22 +73,79 @@ def split_schema_obj(obj, sch=None):
qualsch = sch
if sch is None:
qualsch = 'public'
if obj[0] == '"' and obj[-1] == '"':
if '"."' in obj:
(qualsch, obj) = obj.split('"."')
qualsch = qualsch[1:]
obj = obj[:-1]
else:
obj = obj[1:-1]

tokens = tokenize_identifiers(obj)
if len(tokens) == 1:
obj = tokens[0]
elif len(tokens) == 2:
qualsch, obj = tokens
else:
# TODO: properly handle functions
if '.' in obj and '(' not in obj:
(qualsch, obj) = obj.split('.')
raise ValueError("invalid object name: %s")

if sch != qualsch:
sch = qualsch
return (sch, obj)


def tokenize_identifiers(s):
"""
Parse a string representing a dotted sequence of Postgres identifiers

Return a list of the tokens found, with double-quotes removed

Stop at a ( in case the name passed is actually a function with args
"""
START, QUOTE, NAME = range(3)
state = START

rv = []
t = ''
si = iter(s)
try:
while 1:
c = next(si)
if state == START:
if c == '"':
state = QUOTE
elif c == '.':
raise ValueError("invalid object name: %s" % s)
else:
state = NAME
t += c
elif state == NAME:
if c == '"':
raise ValueError("invalid object name: %s" % s)
elif c == '.':
# end of token
rv.append(t)
t = ''
state = START
elif c == '(':
break
else:
t += c
elif state == QUOTE:
if c == '"':
# end quote or escaped quote?
c2 = next(si)
if c2 == '"':
t += '"'
elif c2 == '.':
rv.append(t)
t = ''
state = START
elif c2 == '(':
break
else:
t += c
except StopIteration:
pass

rv.append(t)

return rv


def split_func_args(obj):
"""Split function name and argument from a signature, e.g. fun(int, text)

Expand Down Expand Up @@ -546,7 +608,7 @@ def qualname(self, objname=None):
if objname is None:
objname = self.name
return self.schema == 'public' and quote_id(objname) \
or "%s.%s" % (quote_id(self.schema), quote_id(objname))
or quote_id(self.schema, objname)

def unqualify(self):
"""Adjust the schema and table name if the latter is qualified"""
Expand All @@ -562,8 +624,8 @@ def extern_filename(self, ext='yaml'):
return super(DbSchemaObject, self).extern_filename(ext, True)

def rename(self, oldname):
return "ALTER %s %s.%s RENAME TO %s" % (
self.objtype, quote_id(self.schema), quote_id(oldname),
return "ALTER %s %s RENAME TO %s" % (
self.objtype, quote_id(self.schema, oldname),
quote_id(self.name))

def get_implied_deps(self, db):
Expand Down
6 changes: 3 additions & 3 deletions pyrseas/dbobject/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def get_implied_deps(self, db):
elif isinstance(self._table, Domain):
deps.add(db.types[self.schema, self.table])
else:
raise KeyError("Constraint '%s.%s' on unknown type/class" % (
self.schema, self.name))
raise KeyError("Constraint '%s' on unknown type/class" % (
quote_id(self.schema, self.name)))

return deps

Expand Down Expand Up @@ -471,7 +471,7 @@ def _from_catalog(self):
constr.unqualify()
oid = constr.oid
sch, tbl, cns = constr.key()
sch, tbl = split_schema_obj('%s.%s' % (sch, tbl))
sch, tbl = split_schema_obj(quote_id(sch, tbl)) # TODO why?
constr_type = constr.type
del constr.type
if constr_type != 'f':
Expand Down
2 changes: 1 addition & 1 deletion pyrseas/dbobject/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def _from_catalog(self):
index.unqualify()
oid = index.oid
sch, tbl, idx = index.key()
sch, tbl = split_schema_obj('%s.%s' % (sch, tbl))
sch, tbl = split_schema_obj(quote_id(sch, tbl)) # TODO: why?
keydefs, _, _ = index.defn.partition(' WHERE ')
_, _, keydefs = keydefs.partition(' USING ')
keydefs = keydefs[keydefs.find(' (') + 2:-1]
Expand Down
4 changes: 2 additions & 2 deletions pyrseas/dbobject/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def qualname(self):

No qualification is used if the schema is 'public'.
"""
return self.schema == 'public' and self.name \
or "%s.%s" % (quote_id(self.schema), self.name)
return self.schema == 'public' and quote_id(self.name) \
or quote_id(self.schema, self.name)

def identifier(self):
"""Return a full identifier for an operator object
Expand Down
2 changes: 1 addition & 1 deletion pyrseas/dbobject/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_attrs(self, dbconn):
data = dbconn.fetchone(
"""SELECT start_value, increment_by, max_value, min_value,
cache_value
FROM %s.%s""" % (quote_id(self.schema), quote_id(self.name)))
FROM %s""" % quote_id(self.schema, self.name))
for key, val in list(data.items()):
setattr(self, key, val)

Expand Down
1 change: 0 additions & 1 deletion pyrseas/dbobject/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def get_implied_deps(self, db):
# TODO: this breaks if a function name contains a '('
# (another case for a robust lookup function in db)
fschema, fname = split_schema_obj(self.procedure, self.schema)
fname, _ = fname.split('(', 1) # implicitly assert there is a (
if not fname.startswith('tsvector_update_trigger'):
deps.add(db.functions[fschema, fname, ''])

Expand Down
15 changes: 8 additions & 7 deletions pyrseas/relation/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
pyrseas.relation.join
"""
from pyrseas.dbobject import quote_id
from pyrseas.relation.attribute import Attribute
from pyrseas.relation.tuple import Tuple

Expand Down Expand Up @@ -98,9 +99,9 @@ def where_clause(self, qry_args=None):
attrs = {}
for name, attr in self.attributes:
if attr.name != attr.basename:
expr = "%s.%s" % (attr.projection.rangevar, attr.basename)
expr = quote_id(attr.projection.rangevar, attr.basename)
else:
expr = "%s.%s" % (attr.projection.rangevar, attr.name)
expr = quote_id(attr.projection.rangevar, attr.name)
attrs.update({attr.name: (expr, attr.type)})
subclauses = []
params = {}
Expand Down Expand Up @@ -156,12 +157,12 @@ def getsubset_qry():
exprs = []
for name, attr in self.attributes:
if attr.name != attr.basename:
exprs.append("%s.%s AS %s" % (
attr.projection.rangevar, attr.basename,
attr.name))
exprs.append("%s AS %s" % (
quote_id(attr.projection.rangevar, attr.basename),
quote_id(attr.name)))
else:
exprs.append("%s.%s" % (attr.projection.rangevar,
attr.name))
exprs.append(
quote_id(attr.projection.rangevar, attr.name))
self.getsubset_qry = "SELECT %s FROM %s" % (
", ".join(exprs), self.from_clause)
return self.getsubset_qry
Expand Down