diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index cda4b9f539..8bbe41de90 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -37,6 +37,7 @@ def annotate_types( expression_metadata: t.Optional[ExpressionMetadataType] = None, coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, dialect: DialectType = None, + overwrite_types: bool = True, ) -> E: """ Infers the types of an expression, annotating its AST accordingly. @@ -52,8 +53,9 @@ def annotate_types( Args: expression: Expression to annotate. schema: Database schema. - annotators: Maps expression type to corresponding annotation function. + expression_metadata: Maps expression type to corresponding annotation function. coerces_to: Maps expression type to set of types that it can be coerced into. + overwrite_types: Re-annotate the existing AST types. Returns: The expression annotated with types. @@ -61,7 +63,12 @@ def annotate_types( schema = ensure_schema(schema, dialect=dialect) - return TypeAnnotator(schema, expression_metadata, coerces_to).annotate(expression) + return TypeAnnotator( + schema=schema, + expression_metadata=expression_metadata, + coerces_to=coerces_to, + overwrite_types=overwrite_types, + ).annotate(expression) def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type: @@ -178,6 +185,7 @@ def __init__( expression_metadata: t.Optional[ExpressionMetadataType] = None, coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, binary_coercions: t.Optional[BinaryCoercions] = None, + overwrite_types: bool = True, ) -> None: self.schema = schema self.expression_metadata = ( @@ -202,6 +210,14 @@ def __init__( # would reprocess the entire subtree to coerce the types of its operands' projections self._setop_column_types: t.Dict[int, t.Dict[str, exp.DataType | exp.DataType.Type]] = {} + # When set to False, this enables partial annotation by skipping already-annotated nodes + self._overwrite_types = overwrite_types + + def clear(self) -> None: + self._visited.clear() + self._null_expressions.clear() + self._setop_column_types.clear() + def _set_type( self, expression: exp.Expression, target_type: t.Optional[exp.DataType | exp.DataType.Type] ) -> None: @@ -219,9 +235,12 @@ def _set_type( elif prev_type and t.cast(exp.DataType, prev_type).this == exp.DataType.Type.NULL: self._null_expressions.pop(expression_id, None) - def annotate(self, expression: E) -> E: - for scope in traverse_scope(expression): - self.annotate_scope(scope) + def annotate(self, expression: E, annotate_scope: bool = True) -> E: + # This flag is used to avoid costly scope traversals when we only care about annotating + # non-column expressions (partial type inference), e.g., when simplifying in the optimizer + if annotate_scope: + for scope in traverse_scope(expression): + self.annotate_scope(scope) # This takes care of non-traversable expressions expression = self._maybe_annotate(expression) @@ -370,7 +389,11 @@ def annotate_scope(self, scope: Scope) -> None: scope.expression.meta["query_type"] = struct_type def _maybe_annotate(self, expression: E) -> E: - if id(expression) in self._visited: + if id(expression) in self._visited or ( + not self._overwrite_types + and expression.type + and not expression.is_type(exp.DataType.Type.UNKNOWN) + ): return expression # We've already inferred the expression's type spec = self.expression_metadata.get(expression.__class__) diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index 610833d4af..f9865725c7 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -6,7 +6,7 @@ from sqlglot.errors import OptimizeError from sqlglot.helper import while_changing from sqlglot.optimizer.scope import find_all_in_scope -from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort +from sqlglot.optimizer.simplify import Simplifier, flatten logger = logging.getLogger("sqlglot") @@ -28,6 +28,8 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = Returns: sqlglot.Expression: normalized expression """ + simplifier = Simplifier(annotate_new_expressions=False) + for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))): if isinstance(node, exp.Connector): if normalized(node, dnf=dnf): @@ -35,7 +37,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = root = node is expression original = node.copy() - node.transform(rewrite_between, copy=False) + node.transform(simplifier.rewrite_between, copy=False) distance = normalization_distance(node, dnf=dnf, max_=max_distance) if distance > max_distance: @@ -46,7 +48,10 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = try: node = node.replace( - while_changing(node, lambda e: distributive_law(e, dnf, max_distance)) + while_changing( + node, + lambda e: distributive_law(e, dnf, max_distance, simplifier=simplifier), + ) ) except OptimizeError as e: logger.info(e) @@ -146,7 +151,7 @@ def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0): yield from _predicate_lengths(right, dnf, max_, depth) -def distributive_law(expression, dnf, max_distance): +def distributive_law(expression, dnf, max_distance, simplifier=None): """ x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) @@ -168,32 +173,34 @@ def distributive_law(expression, dnf, max_distance): from_func = exp.and_ if from_exp == exp.And else exp.or_ to_func = exp.and_ if to_exp == exp.And else exp.or_ + simplifier = simplifier or Simplifier(annotate_new_expressions=False) + if isinstance(a, to_exp) and isinstance(b, to_exp): if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): - return _distribute(a, b, from_func, to_func) - return _distribute(b, a, from_func, to_func) + return _distribute(a, b, from_func, to_func, simplifier) + return _distribute(b, a, from_func, to_func, simplifier) if isinstance(a, to_exp): - return _distribute(b, a, from_func, to_func) + return _distribute(b, a, from_func, to_func, simplifier) if isinstance(b, to_exp): - return _distribute(a, b, from_func, to_func) + return _distribute(a, b, from_func, to_func, simplifier) return expression -def _distribute(a, b, from_func, to_func): +def _distribute(a, b, from_func, to_func, simplifier): if isinstance(a, exp.Connector): exp.replace_children( a, lambda c: to_func( - uniq_sort(flatten(from_func(c, b.left))), - uniq_sort(flatten(from_func(c, b.right))), + simplifier.uniq_sort(flatten(from_func(c, b.left))), + simplifier.uniq_sort(flatten(from_func(c, b.right))), copy=False, ), ) else: a = to_func( - uniq_sort(flatten(from_func(a, b.left))), - uniq_sort(flatten(from_func(a, b.right))), + simplifier.uniq_sort(flatten(from_func(a, b.left))), + simplifier.uniq_sort(flatten(from_func(a, b.right))), copy=False, ) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 6520866151..280185b461 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -6,34 +6,25 @@ import itertools import typing as t from collections import deque, defaultdict -from functools import reduce +from functools import reduce, wraps import sqlglot from sqlglot import Dialect, exp from sqlglot.helper import first, merge_ranges, while_changing +from sqlglot.optimizer.annotate_types import TypeAnnotator from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope +from sqlglot.schema import ensure_schema if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType + DateRange = t.Tuple[datetime.date, datetime.date] DateTruncBinaryTransform = t.Callable[ [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] ] -logger = logging.getLogger("sqlglot") - -# Final means that an expression should not be simplified -FINAL = "final" - -# Value ranges for byte-sized signed/unsigned integers -TINYINT_MIN = -128 -TINYINT_MAX = 127 -UTINYINT_MIN = 0 -UTINYINT_MAX = 255 - -class UnsupportedUnit(Exception): - pass +logger = logging.getLogger("sqlglot") def simplify( @@ -60,96 +51,15 @@ def simplify( Returns: sqlglot.Expression: simplified expression """ + return Simplifier(dialect=dialect).simplify( + expression, + constant_propagation=constant_propagation, + coalesce_simplification=coalesce_simplification, + ) - dialect = Dialect.get_or_raise(dialect) - - def _simplify(expression): - pre_transformation_stack = [expression] - post_transformation_stack = [] - - while pre_transformation_stack: - node = pre_transformation_stack.pop() - - if node.meta.get(FINAL): - continue - - # group by expressions cannot be simplified, for example - # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 - # the projection must exactly match the group by key - group = node.args.get("group") - - if group and hasattr(node, "selects"): - groups = set(group.expressions) - group.meta[FINAL] = True - - for s in node.selects: - for n in s.walk(): - if n in groups: - s.meta[FINAL] = True - break - - having = node.args.get("having") - if having: - for n in having.walk(): - if n in groups: - having.meta[FINAL] = True - break - - parent = node.parent - root = node is expression - - new_node = rewrite_between(node) - new_node = uniq_sort(new_node, root) - new_node = absorb_and_eliminate(new_node, root) - new_node = simplify_concat(new_node) - new_node = simplify_conditionals(new_node) - - if constant_propagation: - new_node = propagate_constants(new_node, root) - - if new_node is not node: - node.replace(new_node) - - pre_transformation_stack.extend( - n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL) - ) - post_transformation_stack.append((new_node, parent)) - - while post_transformation_stack: - node, parent = post_transformation_stack.pop() - root = node is expression - - # Resets parent, arg_key, index pointers– this is needed because some of the - # previous transformations mutate the AST, leading to an inconsistent state - for k, v in tuple(node.args.items()): - node.set(k, v) - - # Post-order transformations - new_node = simplify_not(node, dialect) - new_node = flatten(new_node) - new_node = simplify_connectors(new_node, root) - new_node = remove_complements(new_node, root) - - if coalesce_simplification: - new_node = simplify_coalesce(new_node, dialect) - - new_node.parent = parent - - new_node = simplify_literals(new_node, root) - new_node = simplify_equality(new_node) - new_node = simplify_parens(new_node, dialect) - new_node = simplify_datetrunc(new_node, dialect) - new_node = sort_comparison(new_node) - new_node = simplify_startswith(new_node) - - if new_node is not node: - node.replace(new_node) - - return new_node - expression = while_changing(expression, _simplify) - remove_where_true(expression) - return expression +class UnsupportedUnit(Exception): + pass def catch(*exceptions): @@ -167,91 +77,30 @@ def wrapped(expression, *args, **kwargs): return decorator -def rewrite_between(expression: exp.Expression) -> exp.Expression: - """Rewrite x between y and z to x >= y AND x <= z. - - This is done because comparison simplification is only done on lt/lte/gt/gte. - """ - if isinstance(expression, exp.Between): - negate = isinstance(expression.parent, exp.Not) +def annotate_types_on_change(func): + @wraps(func) + def _func(self, expression: exp.Expression, *args, **kwargs) -> t.Optional[exp.Expression]: + new_expression = func(self, expression, *args, **kwargs) - expression = exp.and_( - exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), - exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), - copy=False, - ) + if new_expression is None: + return new_expression - if negate: - expression = exp.paren(expression, copy=False) + if self.annotate_new_expressions and expression != new_expression: + self._annotator.clear() - return expression + # We annotate this to ensure new children nodes are also annotated + new_expression = self._annotator.annotate( + expression=new_expression, + annotate_scope=False, + ) + # Whatever expression the original expression is transformed into needs to preserve + # the original type, otherwise the simplification could result in a different schema + new_expression.type = expression.type -COMPLEMENT_COMPARISONS = { - exp.LT: exp.GTE, - exp.GT: exp.LTE, - exp.LTE: exp.GT, - exp.GTE: exp.LT, - exp.EQ: exp.NEQ, - exp.NEQ: exp.EQ, -} + return new_expression -COMPLEMENT_SUBQUERY_PREDICATES = { - exp.All: exp.Any, - exp.Any: exp.All, -} - - -def simplify_not(expression: exp.Expression, dialect: Dialect) -> exp.Expression: - """ - Demorgan's Law - NOT (x OR y) -> NOT x AND NOT y - NOT (x AND y) -> NOT x OR NOT y - """ - if isinstance(expression, exp.Not): - this = expression.this - if is_null(this): - return exp.null() - if this.__class__ in COMPLEMENT_COMPARISONS: - right = this.expression - complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__) - if complement_subquery_predicate: - right = complement_subquery_predicate(this=right.this) - - return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) - if isinstance(this, exp.Paren): - condition = this.unnest() - if isinstance(condition, exp.And): - return exp.paren( - exp.or_( - exp.not_(condition.left, copy=False), - exp.not_(condition.right, copy=False), - copy=False, - ), - copy=False, - ) - if isinstance(condition, exp.Or): - return exp.paren( - exp.and_( - exp.not_(condition.left, copy=False), - exp.not_(condition.right, copy=False), - copy=False, - ), - copy=False, - ) - if is_null(condition): - return exp.null() - if always_true(this): - return exp.false() - if is_false(this): - return exp.true() - if isinstance(this, exp.Not) and dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION: - inner = this.this - if inner.is_type(exp.DataType.Type.BOOLEAN) or isinstance(inner, exp.Predicate): - # double negation - # NOT NOT x -> x, if x is BOOLEAN type - return inner - return expression + return _func def flatten(expression): @@ -267,247 +116,45 @@ def flatten(expression): return expression -def simplify_connectors(expression, root=True): - def _simplify_connectors(expression, left, right): - if isinstance(expression, exp.And): - if is_false(left) or is_false(right): - return exp.false() - if is_zero(left) or is_zero(right): - return exp.false() - if ( - (is_null(left) and is_null(right)) - or (is_null(left) and always_true(right)) - or (always_true(left) and is_null(right)) - ): - return exp.null() - if always_true(left) and always_true(right): - return exp.true() - if always_true(left): - return right - if always_true(right): - return left - return _simplify_comparison(expression, left, right) - elif isinstance(expression, exp.Or): - if always_true(left) or always_true(right): - return exp.true() - if ( - (is_null(left) and is_null(right)) - or (is_null(left) and always_false(right)) - or (always_false(left) and is_null(right)) - ): - return exp.null() - if is_false(left): - return right - if is_false(right): - return left - return _simplify_comparison(expression, left, right, or_=True) - - if isinstance(expression, exp.Connector): - return _flat_simplify(expression, _simplify_connectors, root) - return expression - - -LT_LTE = (exp.LT, exp.LTE) -GT_GTE = (exp.GT, exp.GTE) - -COMPARISONS = ( - *LT_LTE, - *GT_GTE, - exp.EQ, - exp.NEQ, - exp.Is, -) - -INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { - exp.LT: exp.GT, - exp.GT: exp.LT, - exp.LTE: exp.GTE, - exp.GTE: exp.LTE, -} - -NONDETERMINISTIC = (exp.Rand, exp.Randn) -AND_OR = (exp.And, exp.Or) - - -def _simplify_comparison(expression, left, right, or_=False): - if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): - ll, lr = left.args.values() - rl, rr = right.args.values() - - largs = {ll, lr} - rargs = {rl, rr} - - matching = largs & rargs - columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} - - if matching and columns: - try: - l = first(largs - columns) - r = first(rargs - columns) - except StopIteration: - return expression - - if l.is_number and r.is_number: - l = l.to_py() - r = r.to_py() - elif l.is_string and r.is_string: - l = l.name - r = r.name - else: - l = extract_date(l) - if not l: - return None - r = extract_date(r) - if not r: - return None - # python won't compare date and datetime, but many engines will upcast - l, r = cast_as_datetime(l), cast_as_datetime(r) - - for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): - if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): - return left if (av > bv if or_ else av <= bv) else right - if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): - return left if (av < bv if or_ else av >= bv) else right - - # we can't ever shortcut to true because the column could be null - if not or_: - if isinstance(a, exp.LT) and isinstance(b, GT_GTE): - if av <= bv: - return exp.false() - elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): - if av >= bv: - return exp.false() - elif isinstance(a, exp.EQ): - if isinstance(b, exp.LT): - return exp.false() if av >= bv else a - if isinstance(b, exp.LTE): - return exp.false() if av > bv else a - if isinstance(b, exp.GT): - return exp.false() if av <= bv else a - if isinstance(b, exp.GTE): - return exp.false() if av < bv else a - if isinstance(b, exp.NEQ): - return exp.false() if av == bv else a - return None - - -def remove_complements(expression, root=True): - """ - Removing complements. - - A AND NOT A -> FALSE - A OR NOT A -> TRUE - """ - if isinstance(expression, AND_OR) and (root or not expression.same_parent): - ops = set(expression.flatten()) - for op in ops: - if isinstance(op, exp.Not) and op.this in ops: - return exp.false() if isinstance(expression, exp.And) else exp.true() - - return expression - +def simplify_parens(expression: exp.Expression, dialect: DialectType) -> exp.Expression: + if not isinstance(expression, exp.Paren): + return expression -def uniq_sort(expression, root=True): - """ - Uniq and sort a connector. + this = expression.this + parent = expression.parent + parent_is_predicate = isinstance(parent, exp.Predicate) - C AND A AND B AND B -> A AND B AND C - """ - if isinstance(expression, exp.Connector) and (root or not expression.same_parent): - flattened = tuple(expression.flatten()) - - if isinstance(expression, exp.Xor): - result_func = exp.xor - # Do not deduplicate XOR as A XOR A != A if A == True - deduped = None - arr = tuple((gen(e), e) for e in flattened) - else: - result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ - deduped = {gen(e): e for e in flattened} - arr = tuple(deduped.items()) - - # check if the operands are already sorted, if not sort them - # A AND C AND B -> A AND B AND C - for i, (sql, e) in enumerate(arr[1:]): - if sql < arr[i][0]: - expression = result_func(*(e for _, e in sorted(arr)), copy=False) - break - else: - # we didn't have to sort but maybe we need to dedup - if deduped and len(deduped) < len(flattened): - expression = result_func(*deduped.values(), copy=False) + if isinstance(this, exp.Select): + return expression - return expression + if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)): + return expression + # Handle risingwave struct columns + # see https://docs.risingwave.com/sql/data-types/struct#retrieve-data-in-a-struct + if ( + dialect == "risingwave" + and isinstance(parent, exp.Dot) + and (isinstance(parent.right, (exp.Identifier, exp.Star))) + ): + return expression -def absorb_and_eliminate(expression, root=True): - """ - absorption: - A AND (A OR B) -> A - A OR (A AND B) -> A - A AND (NOT A OR B) -> A AND B - A OR (NOT A AND B) -> A OR B - elimination: - (A AND B) OR (A AND NOT B) -> A - (A OR B) AND (A OR NOT B) -> A - """ - if isinstance(expression, AND_OR) and (root or not expression.same_parent): - kind = exp.Or if isinstance(expression, exp.And) else exp.And - - ops = tuple(expression.flatten()) - - # Initialize lookup tables: - # Set of all operands, used to find complements for absorption. - op_set = set() - # Sub-operands, used to find subsets for absorption. - subops = defaultdict(list) - # Pairs of complements, used for elimination. - pairs = defaultdict(list) - - # Populate the lookup tables - for op in ops: - op_set.add(op) - - if not isinstance(op, kind): - # In cases like: A OR (A AND B) - # Subop will be: ^ - subops[op].append({op}) - continue - - # In cases like: (A AND B) OR (A AND B AND C) - # Subops will be: ^ ^ - subset = set(op.flatten()) - for i in subset: - subops[i].append(subset) - - a, b = op.unnest_operands() - if isinstance(a, exp.Not): - pairs[frozenset((a.this, b))].append((op, b)) - if isinstance(b, exp.Not): - pairs[frozenset((a, b.this))].append((op, a)) - - for op in ops: - if not isinstance(op, kind): - continue - - a, b = op.unnest_operands() - - # Absorb - if isinstance(a, exp.Not) and a.this in op_set: - a.replace(exp.true() if kind == exp.And else exp.false()) - continue - if isinstance(b, exp.Not) and b.this in op_set: - b.replace(exp.true() if kind == exp.And else exp.false()) - continue - superset = set(op.flatten()) - if any(any(subset < superset for subset in subops[i]) for i in superset): - op.replace(exp.false() if kind == exp.And else exp.true()) - continue - - # Eliminate - for other, complement in pairs[frozenset((a, b))]: - op.replace(complement) - other.replace(complement) + if ( + not isinstance(parent, (exp.Condition, exp.Binary)) + or isinstance(parent, exp.Paren) + or ( + not isinstance(this, exp.Binary) + and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) + ) + or ( + isinstance(this, exp.Predicate) + and not (parent_is_predicate or isinstance(parent, exp.Neg)) + ) + or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) + or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) + or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) + ): + return this return expression @@ -551,20 +198,6 @@ def propagate_constants(expression, root=True): return expression -INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { - exp.DateAdd: exp.Sub, - exp.DateSub: exp.Add, - exp.DatetimeAdd: exp.Sub, - exp.DatetimeSub: exp.Add, -} - -INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { - **INVERSE_DATE_OPS, - exp.Add: exp.Sub, - exp.Sub: exp.Add, -} - - def _is_number(expression: exp.Expression) -> bool: return expression.is_number @@ -573,210 +206,6 @@ def _is_interval(expression: exp.Expression) -> bool: return isinstance(expression, exp.Interval) and extract_interval(expression) is not None -@catch(ModuleNotFoundError, UnsupportedUnit) -def simplify_equality(expression: exp.Expression) -> exp.Expression: - """ - Use the subtraction and addition properties of equality to simplify expressions: - - x + 1 = 3 becomes x = 2 - - There are two binary operations in the above expression: + and = - Here's how we reference all the operands in the code below: - - l r - x + 1 = 3 - a b - """ - if isinstance(expression, COMPARISONS): - l, r = expression.left, expression.right - - if l.__class__ not in INVERSE_OPS: - return expression - - if r.is_number: - a_predicate = _is_number - b_predicate = _is_number - elif _is_date_literal(r): - a_predicate = _is_date_literal - b_predicate = _is_interval - else: - return expression - - if l.__class__ in INVERSE_DATE_OPS: - l = t.cast(exp.IntervalOp, l) - a = l.this - b = l.interval() - else: - l = t.cast(exp.Binary, l) - a, b = l.left, l.right - - if not a_predicate(a) and b_predicate(b): - pass - elif not a_predicate(b) and b_predicate(a): - a, b = b, a - else: - return expression - - return expression.__class__( - this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) - ) - return expression - - -def simplify_literals(expression, root=True): - if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): - return _flat_simplify(expression, _simplify_binary, root) - - if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): - return expression.this.this - - if type(expression) in INVERSE_DATE_OPS: - return _simplify_binary(expression, expression.this, expression.interval()) or expression - - return expression - - -NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) - - -def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: - if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): - this = _simplify_integer_cast(expr.this) - else: - this = expr.this - - if isinstance(expr, exp.Cast) and this.is_int: - num = this.to_py() - - # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any - # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is - # engine-dependent - if ( - TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES - ) or ( - UTINYINT_MIN <= num <= UTINYINT_MAX - and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES - ): - return this - - return expr - - -def _simplify_binary(expression, a, b): - if isinstance(expression, COMPARISONS): - a = _simplify_integer_cast(a) - b = _simplify_integer_cast(b) - - if isinstance(expression, exp.Is): - if isinstance(b, exp.Not): - c = b.this - not_ = True - else: - c = b - not_ = False - - if is_null(c): - if isinstance(a, exp.Literal): - return exp.true() if not_ else exp.false() - if is_null(a): - return exp.false() if not_ else exp.true() - elif isinstance(expression, NULL_OK): - return None - elif is_null(a) or is_null(b): - return exp.null() - - if a.is_number and b.is_number: - num_a = a.to_py() - num_b = b.to_py() - - if isinstance(expression, exp.Add): - return exp.Literal.number(num_a + num_b) - if isinstance(expression, exp.Mul): - return exp.Literal.number(num_a * num_b) - - # We only simplify Sub, Div if a and b have the same parent because they're not associative - if isinstance(expression, exp.Sub): - return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None - if isinstance(expression, exp.Div): - # engines have differing int div behavior so intdiv is not safe - if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: - return None - return exp.Literal.number(num_a / num_b) - - boolean = eval_boolean(expression, num_a, num_b) - - if boolean: - return boolean - elif a.is_string and b.is_string: - boolean = eval_boolean(expression, a.this, b.this) - - if boolean: - return boolean - elif _is_date_literal(a) and isinstance(b, exp.Interval): - date, b = extract_date(a), extract_interval(b) - if date and b: - if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): - return date_literal(date + b, extract_type(a)) - if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): - return date_literal(date - b, extract_type(a)) - elif isinstance(a, exp.Interval) and _is_date_literal(b): - a, date = extract_interval(a), extract_date(b) - # you cannot subtract a date from an interval - if a and b and isinstance(expression, exp.Add): - return date_literal(a + date, extract_type(b)) - elif _is_date_literal(a) and _is_date_literal(b): - if isinstance(expression, exp.Predicate): - a, b = extract_date(a), extract_date(b) - boolean = eval_boolean(expression, a, b) - if boolean: - return boolean - - return None - - -def simplify_parens(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression: - if not isinstance(expression, exp.Paren): - return expression - - this = expression.this - parent = expression.parent - parent_is_predicate = isinstance(parent, exp.Predicate) - - if isinstance(this, exp.Select): - return expression - - if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)): - return expression - - # Handle risingwave struct columns - # see https://docs.risingwave.com/sql/data-types/struct#retrieve-data-in-a-struct - if ( - dialect == "risingwave" - and isinstance(parent, exp.Dot) - and (isinstance(parent.right, (exp.Identifier, exp.Star))) - ): - return expression - - if ( - not isinstance(parent, (exp.Condition, exp.Binary)) - or isinstance(parent, exp.Paren) - or ( - not isinstance(this, exp.Binary) - and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) - ) - or ( - isinstance(this, exp.Predicate) - and not (parent_is_predicate or isinstance(parent, exp.Neg)) - ) - or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) - or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) - or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) - ): - return this - - return expression - - def _is_nonnull_constant(expression: exp.Expression) -> bool: return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) @@ -785,164 +214,6 @@ def _is_constant(expression: exp.Expression) -> bool: return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) -def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression: - # COALESCE(x) -> x - if ( - isinstance(expression, exp.Coalesce) - and (not expression.expressions or _is_nonnull_constant(expression.this)) - # COALESCE is also used as a Spark partitioning hint - and not isinstance(expression.parent, exp.Hint) - ): - return expression.this - - # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift, - # because they are not always equivalent. For example, if `x` is `NULL` and it comes - # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE` - if dialect == "redshift": - return expression - - if not isinstance(expression, COMPARISONS): - return expression - - if isinstance(expression.left, exp.Coalesce): - coalesce = expression.left - other = expression.right - elif isinstance(expression.right, exp.Coalesce): - coalesce = expression.right - other = expression.left - else: - return expression - - # This transformation is valid for non-constants, - # but it really only does anything if they are both constants. - if not _is_constant(other): - return expression - - # Find the first constant arg - for arg_index, arg in enumerate(coalesce.expressions): - if _is_constant(arg): - break - else: - return expression - - coalesce.set("expressions", coalesce.expressions[:arg_index]) - - # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, - # since we already remove COALESCE at the top of this function. - coalesce = coalesce if coalesce.expressions else coalesce.this - - # This expression is more complex than when we started, but it will get simplified further - return exp.paren( - exp.or_( - exp.and_( - coalesce.is_(exp.null()).not_(copy=False), - expression.copy(), - copy=False, - ), - exp.and_( - coalesce.is_(exp.null()), - type(expression)(this=arg.copy(), expression=other.copy()), - copy=False, - ), - copy=False, - ) - ) - - -CONCATS = (exp.Concat, exp.DPipe) - - -def simplify_concat(expression): - """Reduces all groups that contain string literals by concatenating them.""" - if not isinstance(expression, CONCATS) or ( - # We can't reduce a CONCAT_WS call if we don't statically know the separator - isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string - ): - return expression - - if isinstance(expression, exp.ConcatWs): - sep_expr, *expressions = expression.expressions - sep = sep_expr.name - concat_type = exp.ConcatWs - args = {} - else: - expressions = expression.expressions - sep = "" - concat_type = exp.Concat - args = { - "safe": expression.args.get("safe"), - "coalesce": expression.args.get("coalesce"), - } - - new_args = [] - for is_string_group, group in itertools.groupby( - expressions or expression.flatten(), lambda e: e.is_string - ): - if is_string_group: - new_args.append(exp.Literal.string(sep.join(string.name for string in group))) - else: - new_args.extend(group) - - if len(new_args) == 1 and new_args[0].is_string: - return new_args[0] - - if concat_type is exp.ConcatWs: - new_args = [sep_expr] + new_args - elif isinstance(expression, exp.DPipe): - return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) - - return concat_type(expressions=new_args, **args) - - -def simplify_conditionals(expression): - """Simplifies expressions like IF, CASE if their condition is statically known.""" - if isinstance(expression, exp.Case): - this = expression.this - for case in expression.args["ifs"]: - cond = case.this - if this: - # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... - cond = cond.replace(this.pop().eq(cond)) - - if always_true(cond): - return case.args["true"] - - if always_false(cond): - case.pop() - if not expression.args["ifs"]: - return expression.args.get("default") or exp.null() - elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): - if always_true(expression.this): - return expression.args["true"] - if always_false(expression.this): - return expression.args.get("false") or exp.null() - - return expression - - -def simplify_startswith(expression: exp.Expression) -> exp.Expression: - """ - Reduces a prefix check to either TRUE or FALSE if both the string and the - prefix are statically known. - - Example: - >>> from sqlglot import parse_one - >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() - 'TRUE' - """ - if ( - isinstance(expression, exp.StartsWith) - and expression.this.is_string - and expression.expression.is_string - ): - return exp.convert(expression.name.startswith(expression.expression.name)) - - return expression - - -DateRange = t.Tuple[datetime.date, datetime.date] - - def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: """ Get the date range for a DATE_TRUNC equality comparison: @@ -1004,136 +275,6 @@ def _datetrunc_neq( ) -DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { - exp.LT: lambda l, dt, u, d, t: l - < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), - exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), - exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), - exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), - exp.EQ: _datetrunc_eq, - exp.NEQ: _datetrunc_neq, -} -DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} -DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) - - -def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: - return isinstance(left, DATETRUNCS) and _is_date_literal(right) - - -@catch(ModuleNotFoundError, UnsupportedUnit) -def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: - """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" - comparison = expression.__class__ - - if isinstance(expression, DATETRUNCS): - this = expression.this - trunc_type = extract_type(this) - date = extract_date(this) - if date and expression.unit: - return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type) - elif comparison not in DATETRUNC_COMPARISONS: - return expression - - if isinstance(expression, exp.Binary): - l, r = expression.left, expression.right - - if not _is_datetrunc_predicate(l, r): - return expression - - l = t.cast(exp.DateTrunc, l) - trunc_arg = l.this - unit = l.unit.name.lower() - date = extract_date(r) - - if not date: - return expression - - return ( - DATETRUNC_BINARY_COMPARISONS[comparison]( - trunc_arg, date, unit, dialect, extract_type(r) - ) - or expression - ) - - if isinstance(expression, exp.In): - l = expression.this - rs = expression.expressions - - if rs and all(_is_datetrunc_predicate(l, r) for r in rs): - l = t.cast(exp.DateTrunc, l) - unit = l.unit.name.lower() - - ranges = [] - for r in rs: - date = extract_date(r) - if not date: - return expression - drange = _datetrunc_range(date, unit, dialect) - if drange: - ranges.append(drange) - - if not ranges: - return expression - - ranges = merge_ranges(ranges) - target_type = extract_type(*rs) - - return exp.or_( - *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False - ) - - return expression - - -def sort_comparison(expression: exp.Expression) -> exp.Expression: - if expression.__class__ in COMPLEMENT_COMPARISONS: - l, r = expression.this, expression.expression - l_column = isinstance(l, exp.Column) - r_column = isinstance(r, exp.Column) - l_const = _is_constant(l) - r_const = _is_constant(r) - - if ( - (l_column and not r_column) - or (r_const and not l_const) - or isinstance(r, exp.SubqueryPredicate) - ): - return expression - if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): - return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( - this=r, expression=l - ) - return expression - - -# CROSS joins result in an empty table if the right table is empty. -# So we can only simplify certain types of joins to CROSS. -# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x -JOINS = { - ("", ""), - ("", "INNER"), - ("RIGHT", ""), - ("RIGHT", "OUTER"), -} - - -def remove_where_true(expression): - for where in expression.find_all(exp.Where): - if always_true(where.this): - where.pop() - for join in expression.find_all(exp.Join): - if ( - always_true(join.args.get("on")) - and not join.args.get("using") - and not join.args.get("method") - and (join.side, join.kind) in JOINS - ): - join.args["on"].pop() - join.set("side", None) - join.set("kind", "CROSS") - - def always_true(expression): return (isinstance(expression, exp.Boolean) and expression.this) or ( isinstance(expression, exp.Literal) and expression.is_number and not is_zero(expression) @@ -1318,30 +459,934 @@ def boolean_literal(condition): return exp.true() if condition else exp.false() -def _flat_simplify(expression, simplifier, root=True): - if root or not expression.same_parent: - operands = [] - queue = deque(expression.flatten(unnest=False)) - size = len(queue) +class Simplifier: + def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True): + self.dialect = Dialect.get_or_raise(dialect) + self.annotate_new_expressions = annotate_new_expressions + + self._annotator: TypeAnnotator = TypeAnnotator( + schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False + ) + + # Final means that an expression should not be simplified + FINAL = "final" + + # Value ranges for byte-sized signed/unsigned integers + TINYINT_MIN = -128 + TINYINT_MAX = 127 + UTINYINT_MIN = 0 + UTINYINT_MAX = 255 + + COMPLEMENT_COMPARISONS = { + exp.LT: exp.GTE, + exp.GT: exp.LTE, + exp.LTE: exp.GT, + exp.GTE: exp.LT, + exp.EQ: exp.NEQ, + exp.NEQ: exp.EQ, + } + + COMPLEMENT_SUBQUERY_PREDICATES = { + exp.All: exp.Any, + exp.Any: exp.All, + } + + LT_LTE = (exp.LT, exp.LTE) + GT_GTE = (exp.GT, exp.GTE) + + COMPARISONS = ( + *LT_LTE, + *GT_GTE, + exp.EQ, + exp.NEQ, + exp.Is, + ) + + INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { + exp.LT: exp.GT, + exp.GT: exp.LT, + exp.LTE: exp.GTE, + exp.GTE: exp.LTE, + } + + NONDETERMINISTIC = (exp.Rand, exp.Randn) + AND_OR = (exp.And, exp.Or) + + INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { + exp.DateAdd: exp.Sub, + exp.DateSub: exp.Add, + exp.DatetimeAdd: exp.Sub, + exp.DatetimeSub: exp.Add, + } + + INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { + **INVERSE_DATE_OPS, + exp.Add: exp.Sub, + exp.Sub: exp.Add, + } + + NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) + + CONCATS = (exp.Concat, exp.DPipe) + + DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { + exp.LT: lambda l, dt, u, d, t: l + < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), + exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), + exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), + exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), + exp.EQ: _datetrunc_eq, + exp.NEQ: _datetrunc_neq, + } + DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} + DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) + + # CROSS joins result in an empty table if the right table is empty. + # So we can only simplify certain types of joins to CROSS. + # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x + JOINS = { + ("", ""), + ("", "INNER"), + ("RIGHT", ""), + ("RIGHT", "OUTER"), + } + + def simplify( + self, + expression: exp.Expression, + constant_propagation: bool = False, + coalesce_simplification: bool = False, + ): + def _simplify(expression): + pre_transformation_stack = [expression] + post_transformation_stack = [] + + while pre_transformation_stack: + node = pre_transformation_stack.pop() + + if node.meta.get(self.FINAL): + continue + + # group by expressions cannot be simplified, for example + # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 + # the projection must exactly match the group by key + group = node.args.get("group") + + if group and hasattr(node, "selects"): + groups = set(group.expressions) + group.meta[self.FINAL] = True + + for s in node.selects: + for n in s.walk(): + if n in groups: + s.meta[self.FINAL] = True + break + + having = node.args.get("having") + if having: + for n in having.walk(): + if n in groups: + having.meta[self.FINAL] = True + break + + parent = node.parent + root = node is expression + + new_node = self.rewrite_between(node) + new_node = self.uniq_sort(new_node, root) + new_node = self.absorb_and_eliminate(new_node, root) + new_node = self.simplify_concat(new_node) + new_node = self.simplify_conditionals(new_node) + + if constant_propagation: + new_node = propagate_constants(new_node, root) + + if new_node is not node: + node.replace(new_node) + + pre_transformation_stack.extend( + n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(self.FINAL) + ) + post_transformation_stack.append((new_node, parent)) - while queue: - a = queue.popleft() + while post_transformation_stack: + node, parent = post_transformation_stack.pop() + root = node is expression - for b in queue: - result = simplifier(expression, a, b) + # Resets parent, arg_key, index pointers– this is needed because some of the + # previous transformations mutate the AST, leading to an inconsistent state + for k, v in tuple(node.args.items()): + node.set(k, v) + + # Post-order transformations + new_node = self.simplify_not(node) + new_node = flatten(new_node) + new_node = self.simplify_connectors(new_node, root) + new_node = self.remove_complements(new_node, root) + + if coalesce_simplification: + new_node = self.simplify_coalesce(new_node) + new_node.parent = parent + + new_node = self.simplify_literals(new_node, root) + new_node = self.simplify_equality(new_node) + new_node = simplify_parens(new_node, dialect=self.dialect) + new_node = self.simplify_datetrunc(new_node) + new_node = self.sort_comparison(new_node) + new_node = self.simplify_startswith(new_node) + + if new_node is not node: + node.replace(new_node) + + return new_node + + expression = while_changing(expression, _simplify) + self.remove_where_true(expression) + return expression + + @annotate_types_on_change + def rewrite_between(self, expression: exp.Expression) -> exp.Expression: + """Rewrite x between y and z to x >= y AND x <= z. + + This is done because comparison simplification is only done on lt/lte/gt/gte. + """ + if isinstance(expression, exp.Between): + negate = isinstance(expression.parent, exp.Not) + + expression = exp.and_( + exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), + exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), + copy=False, + ) + + if negate: + expression = exp.paren(expression, copy=False) + + return expression - if result and result is not expression: - queue.remove(b) - queue.appendleft(result) + @annotate_types_on_change + def simplify_not(self, expression: exp.Expression) -> exp.Expression: + """ + Demorgan's Law + NOT (x OR y) -> NOT x AND NOT y + NOT (x AND y) -> NOT x OR NOT y + """ + if isinstance(expression, exp.Not): + this = expression.this + if is_null(this): + return exp.null() + if this.__class__ in self.COMPLEMENT_COMPARISONS: + right = this.expression + complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get( + right.__class__ + ) + if complement_subquery_predicate: + right = complement_subquery_predicate(this=right.this) + + return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) + if isinstance(this, exp.Paren): + condition = this.unnest() + if isinstance(condition, exp.And): + return exp.paren( + exp.or_( + exp.not_(condition.left, copy=False), + exp.not_(condition.right, copy=False), + copy=False, + ), + copy=False, + ) + if isinstance(condition, exp.Or): + return exp.paren( + exp.and_( + exp.not_(condition.left, copy=False), + exp.not_(condition.right, copy=False), + copy=False, + ), + copy=False, + ) + if is_null(condition): + return exp.null() + if always_true(this): + return exp.false() + if is_false(this): + return exp.true() + if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION: + inner = this.this + if inner.is_type(exp.DataType.Type.BOOLEAN): + # double negation + # NOT NOT x -> x, if x is BOOLEAN type + return inner + return expression + + @annotate_types_on_change + def simplify_connectors(self, expression, root=True): + def _simplify_connectors(expression, left, right): + if isinstance(expression, exp.And): + if is_false(left) or is_false(right): + return exp.false() + if is_zero(left) or is_zero(right): + return exp.false() + if ( + (is_null(left) and is_null(right)) + or (is_null(left) and always_true(right)) + or (always_true(left) and is_null(right)) + ): + return exp.null() + if always_true(left) and always_true(right): + return exp.true() + if always_true(left) and right.is_type(exp.DataType.Type.BOOLEAN): + return right + if always_true(right) and left.is_type(exp.DataType.Type.BOOLEAN): + return left + return self._simplify_comparison(expression, left, right) + elif isinstance(expression, exp.Or): + if always_true(left) or always_true(right): + return exp.true() + if ( + (is_null(left) and is_null(right)) + or (is_null(left) and always_false(right)) + or (always_false(left) and is_null(right)) + ): + return exp.null() + if is_false(left) and right.is_type(exp.DataType.Type.BOOLEAN): + return right + if is_false(right) and left.is_type(exp.DataType.Type.BOOLEAN): + return left + return self._simplify_comparison(expression, left, right, or_=True) + + if isinstance(expression, exp.Connector): + return self._flat_simplify(expression, _simplify_connectors, root) + return expression + + @annotate_types_on_change + def _simplify_comparison(self, expression, left, right, or_=False): + if isinstance(left, self.COMPARISONS) and isinstance(right, self.COMPARISONS): + ll, lr = left.args.values() + rl, rr = right.args.values() + + largs = {ll, lr} + rargs = {rl, rr} + + matching = largs & rargs + columns = { + m for m in matching if not _is_constant(m) and not m.find(*self.NONDETERMINISTIC) + } + + if matching and columns: + try: + l = first(largs - columns) + r = first(rargs - columns) + except StopIteration: + return expression + + if l.is_number and r.is_number: + l = l.to_py() + r = r.to_py() + elif l.is_string and r.is_string: + l = l.name + r = r.name + else: + l = extract_date(l) + if not l: + return None + r = extract_date(r) + if not r: + return None + # python won't compare date and datetime, but many engines will upcast + l, r = cast_as_datetime(l), cast_as_datetime(r) + + for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): + if isinstance(a, self.LT_LTE) and isinstance(b, self.LT_LTE): + return left if (av > bv if or_ else av <= bv) else right + if isinstance(a, self.GT_GTE) and isinstance(b, self.GT_GTE): + return left if (av < bv if or_ else av >= bv) else right + + # we can't ever shortcut to true because the column could be null + if not or_: + if isinstance(a, exp.LT) and isinstance(b, self.GT_GTE): + if av <= bv: + return exp.false() + elif isinstance(a, exp.GT) and isinstance(b, self.LT_LTE): + if av >= bv: + return exp.false() + elif isinstance(a, exp.EQ): + if isinstance(b, exp.LT): + return exp.false() if av >= bv else a + if isinstance(b, exp.LTE): + return exp.false() if av > bv else a + if isinstance(b, exp.GT): + return exp.false() if av <= bv else a + if isinstance(b, exp.GTE): + return exp.false() if av < bv else a + if isinstance(b, exp.NEQ): + return exp.false() if av == bv else a + return None + + @annotate_types_on_change + def remove_complements(self, expression, root=True): + """ + Removing complements. + + A AND NOT A -> FALSE + A OR NOT A -> TRUE + """ + if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): + ops = set(expression.flatten()) + for op in ops: + if isinstance(op, exp.Not) and op.this in ops: + return exp.false() if isinstance(expression, exp.And) else exp.true() + + return expression + + @annotate_types_on_change + def uniq_sort(self, expression, root=True): + """ + Uniq and sort a connector. + + C AND A AND B AND B -> A AND B AND C + """ + if isinstance(expression, exp.Connector) and (root or not expression.same_parent): + flattened = tuple(expression.flatten()) + + if isinstance(expression, exp.Xor): + result_func = exp.xor + # Do not deduplicate XOR as A XOR A != A if A == True + deduped = None + arr = tuple((gen(e), e) for e in flattened) + else: + result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ + deduped = {gen(e): e for e in flattened} + arr = tuple(deduped.items()) + + # check if the operands are already sorted, if not sort them + # A AND C AND B -> A AND B AND C + for i, (sql, e) in enumerate(arr[1:]): + if sql < arr[i][0]: + expression = result_func(*(e for _, e in sorted(arr)), copy=False) break else: - operands.append(a) + # we didn't have to sort but maybe we need to dedup + if deduped and len(deduped) < len(flattened): + expression = result_func(*deduped.values(), copy=False) + + return expression + + @annotate_types_on_change + def absorb_and_eliminate(self, expression, root=True): + """ + absorption: + A AND (A OR B) -> A + A OR (A AND B) -> A + A AND (NOT A OR B) -> A AND B + A OR (NOT A AND B) -> A OR B + elimination: + (A AND B) OR (A AND NOT B) -> A + (A OR B) AND (A OR NOT B) -> A + """ + if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): + kind = exp.Or if isinstance(expression, exp.And) else exp.And + + ops = tuple(expression.flatten()) + + # Initialize lookup tables: + # Set of all operands, used to find complements for absorption. + op_set = set() + # Sub-operands, used to find subsets for absorption. + subops = defaultdict(list) + # Pairs of complements, used for elimination. + pairs = defaultdict(list) + + # Populate the lookup tables + for op in ops: + op_set.add(op) + + if not isinstance(op, kind): + # In cases like: A OR (A AND B) + # Subop will be: ^ + subops[op].append({op}) + continue + + # In cases like: (A AND B) OR (A AND B AND C) + # Subops will be: ^ ^ + subset = set(op.flatten()) + for i in subset: + subops[i].append(subset) + + a, b = op.unnest_operands() + if isinstance(a, exp.Not): + pairs[frozenset((a.this, b))].append((op, b)) + if isinstance(b, exp.Not): + pairs[frozenset((a, b.this))].append((op, a)) + + for op in ops: + if not isinstance(op, kind): + continue + + a, b = op.unnest_operands() + + # Absorb + if isinstance(a, exp.Not) and a.this in op_set: + a.replace(exp.true() if kind == exp.And else exp.false()) + continue + if isinstance(b, exp.Not) and b.this in op_set: + b.replace(exp.true() if kind == exp.And else exp.false()) + continue + superset = set(op.flatten()) + if any(any(subset < superset for subset in subops[i]) for i in superset): + op.replace(exp.false() if kind == exp.And else exp.true()) + continue + + # Eliminate + for other, complement in pairs[frozenset((a, b))]: + op.replace(complement) + other.replace(complement) + + return expression + + @annotate_types_on_change + @catch(ModuleNotFoundError, UnsupportedUnit) + def simplify_equality(self, expression: exp.Expression) -> exp.Expression: + """ + Use the subtraction and addition properties of equality to simplify expressions: - if len(operands) < size: - return functools.reduce( - lambda a, b: expression.__class__(this=a, expression=b), operands + x + 1 = 3 becomes x = 2 + + There are two binary operations in the above expression: + and = + Here's how we reference all the operands in the code below: + + l r + x + 1 = 3 + a b + """ + if isinstance(expression, self.COMPARISONS): + l, r = expression.left, expression.right + + if l.__class__ not in self.INVERSE_OPS: + return expression + + if r.is_number: + a_predicate = _is_number + b_predicate = _is_number + elif _is_date_literal(r): + a_predicate = _is_date_literal + b_predicate = _is_interval + else: + return expression + + if l.__class__ in self.INVERSE_DATE_OPS: + l = t.cast(exp.IntervalOp, l) + a = l.this + b = l.interval() + else: + l = t.cast(exp.Binary, l) + a, b = l.left, l.right + + if not a_predicate(a) and b_predicate(b): + pass + elif not a_predicate(b) and b_predicate(a): + a, b = b, a + else: + return expression + + return expression.__class__( + this=a, expression=self.INVERSE_OPS[l.__class__](this=r, expression=b) ) - return expression + return expression + + @annotate_types_on_change + def simplify_literals(self, expression, root=True): + if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): + return self._flat_simplify(expression, self._simplify_binary, root) + + if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): + return expression.this.this + + if type(expression) in self.INVERSE_DATE_OPS: + return ( + self._simplify_binary(expression, expression.this, expression.interval()) + or expression + ) + + return expression + + def _simplify_integer_cast(self, expr: exp.Expression) -> exp.Expression: + if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): + this = self._simplify_integer_cast(expr.this) + else: + this = expr.this + + if isinstance(expr, exp.Cast) and this.is_int: + num = this.to_py() + + # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any + # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is + # engine-dependent + if ( + self.TINYINT_MIN <= num <= self.TINYINT_MAX + and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES + ) or ( + self.UTINYINT_MIN <= num <= self.UTINYINT_MAX + and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES + ): + return this + + return expr + + def _simplify_binary(self, expression, a, b): + if isinstance(expression, self.COMPARISONS): + a = self._simplify_integer_cast(a) + b = self._simplify_integer_cast(b) + + if isinstance(expression, exp.Is): + if isinstance(b, exp.Not): + c = b.this + not_ = True + else: + c = b + not_ = False + + if is_null(c): + if isinstance(a, exp.Literal): + return exp.true() if not_ else exp.false() + if is_null(a): + return exp.false() if not_ else exp.true() + elif isinstance(expression, self.NULL_OK): + return None + elif is_null(a) or is_null(b): + return exp.null() + + if a.is_number and b.is_number: + num_a = a.to_py() + num_b = b.to_py() + + if isinstance(expression, exp.Add): + return exp.Literal.number(num_a + num_b) + if isinstance(expression, exp.Mul): + return exp.Literal.number(num_a * num_b) + + # We only simplify Sub, Div if a and b have the same parent because they're not associative + if isinstance(expression, exp.Sub): + return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None + if isinstance(expression, exp.Div): + # engines have differing int div behavior so intdiv is not safe + if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: + return None + return exp.Literal.number(num_a / num_b) + + boolean = eval_boolean(expression, num_a, num_b) + + if boolean: + return boolean + elif a.is_string and b.is_string: + boolean = eval_boolean(expression, a.this, b.this) + + if boolean: + return boolean + elif _is_date_literal(a) and isinstance(b, exp.Interval): + date, b = extract_date(a), extract_interval(b) + if date and b: + if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): + return date_literal(date + b, extract_type(a)) + if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): + return date_literal(date - b, extract_type(a)) + elif isinstance(a, exp.Interval) and _is_date_literal(b): + a, date = extract_interval(a), extract_date(b) + # you cannot subtract a date from an interval + if a and b and isinstance(expression, exp.Add): + return date_literal(a + date, extract_type(b)) + elif _is_date_literal(a) and _is_date_literal(b): + if isinstance(expression, exp.Predicate): + a, b = extract_date(a), extract_date(b) + boolean = eval_boolean(expression, a, b) + if boolean: + return boolean + + return None + + @annotate_types_on_change + def simplify_coalesce(self, expression: exp.Expression) -> exp.Expression: + # COALESCE(x) -> x + if ( + isinstance(expression, exp.Coalesce) + and (not expression.expressions or _is_nonnull_constant(expression.this)) + # COALESCE is also used as a Spark partitioning hint + and not isinstance(expression.parent, exp.Hint) + ): + return expression.this + + # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift, + # because they are not always equivalent. For example, if `x` is `NULL` and it comes + # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE` + if self.dialect == "redshift": + return expression + + if not isinstance(expression, self.COMPARISONS): + return expression + + if isinstance(expression.left, exp.Coalesce): + coalesce = expression.left + other = expression.right + elif isinstance(expression.right, exp.Coalesce): + coalesce = expression.right + other = expression.left + else: + return expression + + # This transformation is valid for non-constants, + # but it really only does anything if they are both constants. + if not _is_constant(other): + return expression + + # Find the first constant arg + for arg_index, arg in enumerate(coalesce.expressions): + if _is_constant(arg): + break + else: + return expression + + coalesce.set("expressions", coalesce.expressions[:arg_index]) + + # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, + # since we already remove COALESCE at the top of this function. + coalesce = coalesce if coalesce.expressions else coalesce.this + + # This expression is more complex than when we started, but it will get simplified further + return exp.paren( + exp.or_( + exp.and_( + coalesce.is_(exp.null()).not_(copy=False), + expression.copy(), + copy=False, + ), + exp.and_( + coalesce.is_(exp.null()), + type(expression)(this=arg.copy(), expression=other.copy()), + copy=False, + ), + copy=False, + ), + copy=False, + ) + + @annotate_types_on_change + def simplify_concat(self, expression): + """Reduces all groups that contain string literals by concatenating them.""" + if not isinstance(expression, self.CONCATS) or ( + # We can't reduce a CONCAT_WS call if we don't statically know the separator + isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string + ): + return expression + + if isinstance(expression, exp.ConcatWs): + sep_expr, *expressions = expression.expressions + sep = sep_expr.name + concat_type = exp.ConcatWs + args = {} + else: + expressions = expression.expressions + sep = "" + concat_type = exp.Concat + args = { + "safe": expression.args.get("safe"), + "coalesce": expression.args.get("coalesce"), + } + + new_args = [] + for is_string_group, group in itertools.groupby( + expressions or expression.flatten(), lambda e: e.is_string + ): + if is_string_group: + new_args.append(exp.Literal.string(sep.join(string.name for string in group))) + else: + new_args.extend(group) + + if len(new_args) == 1 and new_args[0].is_string: + return new_args[0] + + if concat_type is exp.ConcatWs: + new_args = [sep_expr] + new_args + elif isinstance(expression, exp.DPipe): + return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) + + return concat_type(expressions=new_args, **args) + + @annotate_types_on_change + def simplify_conditionals(self, expression): + """Simplifies expressions like IF, CASE if their condition is statically known.""" + if isinstance(expression, exp.Case): + this = expression.this + for case in expression.args["ifs"]: + cond = case.this + if this: + # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... + cond = cond.replace(this.pop().eq(cond)) + + if always_true(cond): + return case.args["true"] + + if always_false(cond): + case.pop() + if not expression.args["ifs"]: + return expression.args.get("default") or exp.null() + elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): + if always_true(expression.this): + return expression.args["true"] + if always_false(expression.this): + return expression.args.get("false") or exp.null() + + return expression + + @annotate_types_on_change + def simplify_startswith(self, expression: exp.Expression) -> exp.Expression: + """ + Reduces a prefix check to either TRUE or FALSE if both the string and the + prefix are statically known. + + Example: + >>> from sqlglot import parse_one + >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() + 'TRUE' + """ + if ( + isinstance(expression, exp.StartsWith) + and expression.this.is_string + and expression.expression.is_string + ): + return exp.convert(expression.name.startswith(expression.expression.name)) + + return expression + + def _is_datetrunc_predicate(self, left: exp.Expression, right: exp.Expression) -> bool: + return isinstance(left, self.DATETRUNCS) and _is_date_literal(right) + + @annotate_types_on_change + @catch(ModuleNotFoundError, UnsupportedUnit) + def simplify_datetrunc(self, expression: exp.Expression) -> exp.Expression: + """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" + comparison = expression.__class__ + + if isinstance(expression, self.DATETRUNCS): + this = expression.this + trunc_type = extract_type(this) + date = extract_date(this) + if date and expression.unit: + return date_literal( + date_floor(date, expression.unit.name.lower(), self.dialect), trunc_type + ) + elif comparison not in self.DATETRUNC_COMPARISONS: + return expression + + if isinstance(expression, exp.Binary): + l, r = expression.left, expression.right + + if not self._is_datetrunc_predicate(l, r): + return expression + + l = t.cast(exp.DateTrunc, l) + trunc_arg = l.this + unit = l.unit.name.lower() + date = extract_date(r) + + if not date: + return expression + + return ( + self.DATETRUNC_BINARY_COMPARISONS[comparison]( + trunc_arg, date, unit, self.dialect, extract_type(r) + ) + or expression + ) + + if isinstance(expression, exp.In): + l = expression.this + rs = expression.expressions + + if rs and all(self._is_datetrunc_predicate(l, r) for r in rs): + l = t.cast(exp.DateTrunc, l) + unit = l.unit.name.lower() + + ranges = [] + for r in rs: + date = extract_date(r) + if not date: + return expression + drange = _datetrunc_range(date, unit, self.dialect) + if drange: + ranges.append(drange) + + if not ranges: + return expression + + ranges = merge_ranges(ranges) + target_type = extract_type(*rs) + + return exp.or_( + *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], + copy=False, + ) + + return expression + + @annotate_types_on_change + def sort_comparison(self, expression: exp.Expression) -> exp.Expression: + if expression.__class__ in self.COMPLEMENT_COMPARISONS: + l, r = expression.this, expression.expression + l_column = isinstance(l, exp.Column) + r_column = isinstance(r, exp.Column) + l_const = _is_constant(l) + r_const = _is_constant(r) + + if ( + (l_column and not r_column) + or (r_const and not l_const) + or isinstance(r, exp.SubqueryPredicate) + ): + return expression + if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): + return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( + this=r, expression=l + ) + return expression + + def remove_where_true(self, expression): + for where in expression.find_all(exp.Where): + if always_true(where.this): + where.pop() + for join in expression.find_all(exp.Join): + if ( + always_true(join.args.get("on")) + and not join.args.get("using") + and not join.args.get("method") + and (join.side, join.kind) in self.JOINS + ): + join.args["on"].pop() + join.set("side", None) + join.set("kind", "CROSS") + + def _flat_simplify(self, expression, simplifier, root=True): + if root or not expression.same_parent: + operands = [] + queue = deque(expression.flatten(unnest=False)) + size = len(queue) + + while queue: + a = queue.popleft() + + for b in queue: + result = simplifier(expression, a, b) + + if result and result is not expression: + queue.remove(b) + queue.appendleft(result) + break + else: + operands.append(a) + + if len(operands) < size: + return functools.reduce( + lambda a, b: expression.__class__(this=a, expression=b), operands + ) + return expression def gen(expression: t.Any, comments: bool = False) -> str: diff --git a/tests/fixtures/optimizer/normalize.sql b/tests/fixtures/optimizer/normalize.sql index 4b115422e9..8e94c192ca 100644 --- a/tests/fixtures/optimizer/normalize.sql +++ b/tests/fixtures/optimizer/normalize.sql @@ -14,7 +14,7 @@ (A AND B AND C AND D AND E AND F AND G) OR (H AND I AND J AND K AND L AND M AND N) OR (O AND P AND Q); NOT NOT NOT (A OR B); -NOT NOT NOT A AND NOT NOT NOT B; +NOT A AND NOT B; A OR B; A OR B; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 8857d89490..98f7d9970b 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -153,6 +153,18 @@ COALESCE(x, y) <> ALL (SELECT z FROM w); SELECT NOT (2 <> ALL (SELECT 2 UNION ALL SELECT 3)); SELECT 2 = ANY(SELECT 2 UNION ALL SELECT 3); +SELECT t_bool.a AND TRUE FROM t_bool; +SELECT t_bool.a FROM t_bool; + +SELECT TRUE AND t_bool.a FROM t_bool; +SELECT t_bool.a FROM t_bool; + +SELECT t_bool.a OR FALSE FROM t_bool; +SELECT t_bool.a FROM t_bool; + +SELECT FALSE OR t_bool.a FROM t_bool; +SELECT t_bool.a FROM t_bool; + -------------------------------------- -- Absorption -------------------------------------- @@ -160,34 +172,34 @@ SELECT 2 = ANY(SELECT 2 UNION ALL SELECT 3); (A OR B) AND (C OR NOT A); A AND (A OR B); -A; +A AND TRUE; A AND D AND E AND (B OR A); -A AND D AND E; +A AND D AND E AND TRUE; D AND A AND E AND (B OR A); -A AND D AND E; +A AND D AND E AND TRUE; (A OR B) AND A; -A; +A AND TRUE; C AND D AND (A OR B) AND E AND F AND A; -A AND C AND D AND E AND F; +A AND C AND D AND E AND F AND TRUE; A OR (A AND B); -A; +A OR FALSE; (A AND B) OR A; -A; +A OR FALSE; A AND (NOT A OR B); -A AND B; +A AND (B OR FALSE); (NOT A OR B) AND A; -A AND B; +A AND (B OR FALSE); A OR (NOT A AND B); -A OR B; +A OR (B AND TRUE); (A OR C) AND ((A OR C) OR B); A OR C; @@ -199,7 +211,7 @@ A AND (B AND C) AND (D AND E); A AND B AND C AND D AND E; A AND (A OR B) AND (A OR B OR C); -A; +A AND TRUE; (A OR B) AND (A OR C) AND (A OR B OR C); (A OR B) AND (A OR C); diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index d5cf860329..420325b432 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -46,7 +46,10 @@ def pushdown_projections(expression, **kwargs): def normalize(expression, **kwargs): + schema = kwargs.get("schema") + expression = optimizer.normalize.normalize(expression, dnf=False) + expression = annotate_types(expression, schema=schema) return optimizer.simplify.simplify(expression) @@ -300,7 +303,7 @@ def test_normalize(self): "x AND (y OR z)", ) - self.check_file("normalize", normalize) + self.check_file("normalize", normalize, schema=self.schema) @patch("sqlglot.generator.logger") def test_qualify_columns(self, logger):