Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def annotate_types(
expression_metadata: t.Optional[ExpressionMetadataType] = None,
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
dialect: DialectType = None,
overwrite_types: bool = True,
) -> E:
"""
Infers the types of an expression, annotating its AST accordingly.
Expand All @@ -52,16 +53,22 @@ def annotate_types(
Args:
expression: Expression to annotate.
schema: Database schema.
annotators: Maps expression type to corresponding annotation function.
expression_metadata: Maps expression type to corresponding annotation function.
coerces_to: Maps expression type to set of types that it can be coerced into.
overwrite_types: Re-annotate the existing AST types.

Returns:
The expression annotated with types.
"""

schema = ensure_schema(schema, dialect=dialect)

return TypeAnnotator(schema, expression_metadata, coerces_to).annotate(expression)
return TypeAnnotator(
schema=schema,
expression_metadata=expression_metadata,
coerces_to=coerces_to,
overwrite_types=overwrite_types,
).annotate(expression)


def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
Expand Down Expand Up @@ -178,6 +185,7 @@ def __init__(
expression_metadata: t.Optional[ExpressionMetadataType] = None,
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
binary_coercions: t.Optional[BinaryCoercions] = None,
overwrite_types: bool = True,
) -> None:
self.schema = schema
self.expression_metadata = (
Expand All @@ -202,6 +210,14 @@ def __init__(
# would reprocess the entire subtree to coerce the types of its operands' projections
self._setop_column_types: t.Dict[int, t.Dict[str, exp.DataType | exp.DataType.Type]] = {}

# When set to False, this enables partial annotation by skipping already-annotated nodes
self._overwrite_types = overwrite_types

def clear(self) -> None:
self._visited.clear()
self._null_expressions.clear()
self._setop_column_types.clear()

def _set_type(
self, expression: exp.Expression, target_type: t.Optional[exp.DataType | exp.DataType.Type]
) -> None:
Expand All @@ -219,9 +235,12 @@ def _set_type(
elif prev_type and t.cast(exp.DataType, prev_type).this == exp.DataType.Type.NULL:
self._null_expressions.pop(expression_id, None)

def annotate(self, expression: E) -> E:
for scope in traverse_scope(expression):
self.annotate_scope(scope)
def annotate(self, expression: E, annotate_scope: bool = True) -> E:
# This flag is used to avoid costly scope traversals when we only care about annotating
# non-column expressions (partial type inference), e.g., when simplifying in the optimizer
if annotate_scope:
for scope in traverse_scope(expression):
self.annotate_scope(scope)

# This takes care of non-traversable expressions
expression = self._maybe_annotate(expression)
Expand Down Expand Up @@ -370,7 +389,11 @@ def annotate_scope(self, scope: Scope) -> None:
scope.expression.meta["query_type"] = struct_type

def _maybe_annotate(self, expression: E) -> E:
if id(expression) in self._visited:
if id(expression) in self._visited or (
not self._overwrite_types
and expression.type
and not expression.is_type(exp.DataType.Type.UNKNOWN)
):
return expression # We've already inferred the expression's type

spec = self.expression_metadata.get(expression.__class__)
Expand Down
33 changes: 20 additions & 13 deletions sqlglot/optimizer/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlglot.errors import OptimizeError
from sqlglot.helper import while_changing
from sqlglot.optimizer.scope import find_all_in_scope
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
from sqlglot.optimizer.simplify import Simplifier, flatten

logger = logging.getLogger("sqlglot")

Expand All @@ -28,14 +28,16 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
Returns:
sqlglot.Expression: normalized expression
"""
simplifier = Simplifier(annotate_new_expressions=False)

for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
continue
root = node is expression
original = node.copy()

node.transform(rewrite_between, copy=False)
node.transform(simplifier.rewrite_between, copy=False)
distance = normalization_distance(node, dnf=dnf, max_=max_distance)

if distance > max_distance:
Expand All @@ -46,7 +48,10 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =

try:
node = node.replace(
while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
while_changing(
node,
lambda e: distributive_law(e, dnf, max_distance, simplifier=simplifier),
)
)
except OptimizeError as e:
logger.info(e)
Expand Down Expand Up @@ -146,7 +151,7 @@ def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0):
yield from _predicate_lengths(right, dnf, max_, depth)


def distributive_law(expression, dnf, max_distance):
def distributive_law(expression, dnf, max_distance, simplifier=None):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
Expand All @@ -168,32 +173,34 @@ def distributive_law(expression, dnf, max_distance):
from_func = exp.and_ if from_exp == exp.And else exp.or_
to_func = exp.and_ if to_exp == exp.And else exp.or_

simplifier = simplifier or Simplifier(annotate_new_expressions=False)

if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
return _distribute(a, b, from_func, to_func)
return _distribute(b, a, from_func, to_func)
return _distribute(a, b, from_func, to_func, simplifier)
return _distribute(b, a, from_func, to_func, simplifier)
if isinstance(a, to_exp):
return _distribute(b, a, from_func, to_func)
return _distribute(b, a, from_func, to_func, simplifier)
if isinstance(b, to_exp):
return _distribute(a, b, from_func, to_func)
return _distribute(a, b, from_func, to_func, simplifier)

return expression


def _distribute(a, b, from_func, to_func):
def _distribute(a, b, from_func, to_func, simplifier):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
uniq_sort(flatten(from_func(c, b.left))),
uniq_sort(flatten(from_func(c, b.right))),
simplifier.uniq_sort(flatten(from_func(c, b.left))),
simplifier.uniq_sort(flatten(from_func(c, b.right))),
copy=False,
),
)
else:
a = to_func(
uniq_sort(flatten(from_func(a, b.left))),
uniq_sort(flatten(from_func(a, b.right))),
simplifier.uniq_sort(flatten(from_func(a, b.left))),
simplifier.uniq_sort(flatten(from_func(a, b.right))),
copy=False,
)

Expand Down
Loading