From e5ef77e0ecf1037c5386b0aeededf0a9b7a4452e Mon Sep 17 00:00:00 2001 From: liutang123 Date: Tue, 6 Aug 2024 10:40:44 +0800 Subject: [PATCH] [opt](optimizer) Remove unused code to unify code Now, Agg's child predicates will not spread to agg. For example: select a, sum(b) from ( select a,b from t where a = 1 and b = 2 ) t group by a `a = 1` in scan can be propagated to `a` of agg. But `b = 2` in scan can not be propagated to `sum(b)` of agg. --- .../rules/rewrite/PullUpPredicates.java | 18 +------- .../rules/rewrite/InferPredicatesTest.java | 44 +++++++++++++++++++ 2 files changed, 45 insertions(+), 17 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java index dec288a6f52d4c..7f7314d483cd2f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java @@ -35,11 +35,9 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet.Builder; import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import com.google.common.collect.Sets; import java.util.IdentityHashMap; -import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; @@ -107,21 +105,7 @@ public ImmutableSet visitLogicalAggregate(LogicalAggregate { ImmutableSet childPredicates = aggregate.child().accept(this, context); // TODO - List outputExpressions = aggregate.getOutputExpressions(); - - Map expressionSlotMap - = Maps.newLinkedHashMapWithExpectedSize(outputExpressions.size()); - for (NamedExpression output : outputExpressions) { - if (hasAgg(output)) { - expressionSlotMap.putIfAbsent( - output instanceof Alias ? output.child(0) : output, output.toSlot() - ); - } - } - Expression expression = ExpressionUtils.replace( - ExpressionUtils.and(Lists.newArrayList(childPredicates)), - expressionSlotMap - ); + Expression expression = ExpressionUtils.and(Lists.newArrayList(childPredicates)); Set predicates = Sets.newLinkedHashSet(ExpressionUtils.extractConjunction(expression)); return getAvailableExpressions(predicates, aggregate); }); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java index f79bdff9ec5405..c2eb7543b1901b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java @@ -73,6 +73,12 @@ protected void runBeforeAll() throws Exception { + "distributed by hash(k2) buckets 1\n" + "properties('replication_num' = '1');"); + createTables("CREATE TABLE `test`.`test_tt` (\n" + + "`key` varchar(*) NOT NULL,\n" + + " `value` varchar(*) NULL\n" + + ") ENGINE=OLAP\n" + + "DISTRIBUTED BY HASH(`key`) BUCKETS 1\n" + + "PROPERTIES ('replication_allocation' = 'tag.location.default: 1');"); connectContext.setDatabase("test"); connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); } @@ -645,4 +651,42 @@ void inferPredicateByConstValue() { )) ); } + + @Test + void testAggMultiAliasWithSameChild() { + String sql = "SELECT t.*\n" + + "FROM (\n" + + " SELECT `key`, a , b \n" + + " FROM (\n" + + " SELECT `key`,\n" + + " any_value(value) AS a,\n" + + " any_value(CAST(value AS double)) AS b\n" + + " FROM (\n" + + " SELECT `key`, CAST(value AS double) AS value\n" + + " FROM test_tt\n" + + " WHERE `key` = '1'\n" + + " ) agg\n" + + " GROUP BY `key`\n" + + " ) proj\n" + + ") t\n" + + "LEFT JOIN\n" + + "( SELECT id, name FROM student) t2\n" + + "ON t.`key`=t2.`name`"; + PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalJoin( + any(), + logicalProject( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getConjuncts().size() == 1 + && filter.getPredicate().toSql().contains("name = '1'")) + ) + ).when(join -> join.getJoinType() == JoinType.LEFT_OUTER_JOIN) + ); + + } }