Skip to content

Commit

Permalink
[BugFix] follow the output type after eliminating agg (#52741)
Browse files Browse the repository at this point in the history
Signed-off-by: zihe.liu <[email protected]>
  • Loading branch information
ZiheLiu authored Nov 8, 2024
1 parent 2e50513 commit b094453
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import com.starrocks.analysis.Expr;
import com.starrocks.catalog.Function;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.PrimitiveType;
import com.starrocks.catalog.ScalarType;
import com.starrocks.catalog.Type;
import com.starrocks.sql.optimizer.OptExpression;
import com.starrocks.sql.optimizer.OptimizerContext;
Expand Down Expand Up @@ -145,37 +143,43 @@ private ScalarOperator handleAggregationFunction(String fnName, CallOperator cal
} else if (fnName.equals(FunctionSet.SUM) || fnName.equals(FunctionSet.AVG) ||
fnName.equals(FunctionSet.FIRST_VALUE) || fnName.equals(FunctionSet.MAX) ||
fnName.equals(FunctionSet.MIN) || fnName.equals(FunctionSet.GROUP_CONCAT)) {
return rewriteCastFunction(callOperator);
return rewriteNormalFunction(callOperator);
}
return callOperator;
}

private ScalarOperator rewriteCountFunction(CallOperator callOperator) {
Type outType = callOperator.getType();

if (callOperator.getArguments().isEmpty()) {
return ConstantOperator.createInt(1);
return rewriteCastFunction(outType, ConstantOperator.createInt(1));
}

IsNullPredicateOperator isNullPredicateOperator =
new IsNullPredicateOperator(callOperator.getArguments().get(0));
ArrayList<ScalarOperator> ifArgs = Lists.newArrayList();
ScalarOperator thenExpr = ConstantOperator.createInt(0);
ScalarOperator elseExpr = ConstantOperator.createInt(1);
ScalarOperator thenExpr = rewriteCastFunction(outType, ConstantOperator.createInt(0));
ScalarOperator elseExpr = rewriteCastFunction(outType, ConstantOperator.createInt(1));
ifArgs.add(isNullPredicateOperator);
ifArgs.add(thenExpr);
ifArgs.add(elseExpr);

Type[] argumentTypes = ifArgs.stream().map(ScalarOperator::getType).toArray(Type[]::new);
Function fn =
Expr.getBuiltinFunction(FunctionSet.IF, argumentTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
return new CallOperator(FunctionSet.IF, ScalarType.createType(PrimitiveType.TINYINT), ifArgs, fn);
return new CallOperator(FunctionSet.IF, outType, ifArgs, fn);
}

private ScalarOperator rewriteCastFunction(CallOperator callOperator) {
private ScalarOperator rewriteNormalFunction(CallOperator callOperator) {
ScalarOperator argument = callOperator.getArguments().get(0);
if (callOperator.getType().equals(argument.getType())) {
return argument;
return rewriteCastFunction(callOperator.getType(), argument);
}

private ScalarOperator rewriteCastFunction(Type outType, ScalarOperator func) {
if (outType.equals(func.getType())) {
return func;
}
return new CastOperator(callOperator.getType(), argument);
return new CastOperator(outType, func);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,11 @@ public void testEliminateAgg1() throws Exception {
assertContains(plan, " 1:Project\n" +
" | output columns:\n" +
" | 1 <-> [1: id, INT, false]\n" +
" | 6 <-> if[(5: varchar_value IS NULL, 0, 1); " +
"args: BOOLEAN,INT,INT; result: TINYINT; args nullable: false; result nullable: true]\n" +
" | 6 <-> if[(5: varchar_value IS NULL, cast(0 as BIGINT), cast(1 as BIGINT)); " +
"args: BOOLEAN,BIGINT,BIGINT; result: BIGINT; args nullable: false; result nullable: true]\n" +
" | cardinality: 1\n" +
" | \n" +
" 0:OlapScanNode\n" +
" table: test_agg_group_single_unique_key, rollup: test_agg_group_single_unique_key\n" +
" preAggregation: off. Reason: None aggregate function\n");
" 0:OlapScanNode");

sql = "SELECT\n" +
" id,\n" +
Expand Down Expand Up @@ -135,13 +133,11 @@ public void testEliminateAgg1() throws Exception {
" | output columns:\n" +
" | 1 <-> [1: id, INT, false]\n" +
" | 2 <-> [2: big_value, BIGINT, true]\n" +
" | 6 <-> if[(5: varchar_value IS NULL, 0, 1); args: BOOLEAN,INT,INT;" +
" result: TINYINT; args nullable: false; result nullable: true]\n" +
" | 6 <-> if[(5: varchar_value IS NULL, cast(0 as BIGINT), cast(1 as BIGINT)); " +
"args: BOOLEAN,BIGINT,BIGINT; result: BIGINT; args nullable: false; result nullable: true]\n" +
" | cardinality: 1\n" +
" | \n" +
" 0:OlapScanNode\n" +
" table: test_agg_group_multi_unique_key, rollup: test_agg_group_multi_unique_key\n" +
" preAggregation: off. Reason: None aggregate function\n");
" 0:OlapScanNode");
sql = "SELECT\n" +
" id,\n" +
" big_value,\n" +
Expand Down Expand Up @@ -438,7 +434,7 @@ public void testEliminateAggAfterAgg() throws Exception {
" 2:Project\n" +
" | <slot 6> : 6: c16\n" +
" | <slot 7> : 7: sum\n" +
" | <slot 8> : if(7: sum IS NULL, 0, 1)\n" +
" | <slot 8> : if(7: sum IS NULL, CAST(0 AS BIGINT), CAST(1 AS BIGINT))\n" +
" | \n" +
" 1:AGGREGATE (update finalize)\n" +
" | output: sum(1: c11)\n" +
Expand Down Expand Up @@ -491,9 +487,59 @@ public void testEliminateAggAfterAgg() throws Exception {
" 1:Project\n" +
" | <slot 1> : 1: c11\n" +
" | <slot 6> : 6: c16\n" +
" | <slot 8> : if(CAST(1: c11 AS BIGINT) IS NULL, 0, 1)\n" +
" | <slot 8> : if(CAST(1: c11 AS BIGINT) IS NULL, CAST(0 AS BIGINT), CAST(1 AS BIGINT))\n" +
" | \n" +
" 0:OlapScanNode");
}

@Test
public void testEliminateAggForCountReturnType() throws Exception {
String sql;
String plan;

sql = "select c21, count(c22) from tt2 group by c21";
plan = getVerboseExplain(sql);
assertContains(plan, "if[(2: c22 IS NULL, cast(0 as BIGINT), cast(1 as BIGINT)); " +
"args: BOOLEAN,BIGINT,BIGINT; result: BIGINT; args nullable: false; result nullable: true]");

sql = "select c21, count(1) as cnt from tt2 group by c21 order by cnt";
plan = getVerboseExplain(sql);
assertContains(plan, " 2:SORT\n" +
" | order by: [7, BIGINT, false] ASC\n" +
" | offset: 0\n" +
" | cardinality: 1\n" +
" | \n" +
" 1:Project\n" +
" | output columns:\n" +
" | 1 <-> [1: c21, INT, true]\n" +
" | 7 <-> 1\n" +
" | cardinality: 1");

sql = "select c21, count(*) as cnt from tt2 group by c21 order by cnt";
plan = getVerboseExplain(sql);
assertContains(plan, " 2:SORT\n" +
" | order by: [7, BIGINT, false] ASC\n" +
" | offset: 0\n" +
" | cardinality: 1\n" +
" | \n" +
" 1:Project\n" +
" | output columns:\n" +
" | 1 <-> [1: c21, INT, true]\n" +
" | 7 <-> 1\n" +
" | cardinality: 1");

sql = "select c21, count() as cnt from tt2 group by c21 order by cnt";
plan = getVerboseExplain(sql);
assertContains(plan, " 2:SORT\n" +
" | order by: [7, BIGINT, false] ASC\n" +
" | offset: 0\n" +
" | cardinality: 1\n" +
" | \n" +
" 1:Project\n" +
" | output columns:\n" +
" | 1 <-> [1: c21, INT, true]\n" +
" | 7 <-> 1\n" +
" | cardinality: 1");
}

}

0 comments on commit b094453

Please sign in to comment.