From dbbbefdd2ece5f6e045620f9f1e00feaec2eeb4a Mon Sep 17 00:00:00 2001 From: packy92 <110370499+packy92@users.noreply.github.com> Date: Mon, 11 Sep 2023 19:26:17 +0800 Subject: [PATCH] [BugFix] ensure a valid cast clause for MySQL/JDBC table (backport #23379) (#30739) Signed-off-by: packy --- .../java/com/starrocks/analysis/CastExpr.java | 7 +- .../java/com/starrocks/analysis/SlotRef.java | 18 ++-- .../main/java/com/starrocks/load/Load.java | 2 +- .../ExternalTablePredicateExtractor.java | 47 ++++++++-- ...hDownPredicateToExternalTableScanRule.java | 4 +- .../sql/plan/PlanFragmentBuilder.java | 2 - .../sql/plan/ScalarOperatorToExpr.java | 6 +- .../sql/plan/MySQLTableCastTest.java | 87 +++++++++++++++++++ .../java/com/starrocks/sql/plan/ScanTest.java | 12 +++ 9 files changed, 152 insertions(+), 33 deletions(-) create mode 100644 fe/fe-core/src/test/java/com/starrocks/sql/plan/MySQLTableCastTest.java diff --git a/fe/fe-core/src/main/java/com/starrocks/analysis/CastExpr.java b/fe/fe-core/src/main/java/com/starrocks/analysis/CastExpr.java index 3e28d3f69e17d..cb6fb92f50569 100644 --- a/fe/fe-core/src/main/java/com/starrocks/analysis/CastExpr.java +++ b/fe/fe-core/src/main/java/com/starrocks/analysis/CastExpr.java @@ -100,13 +100,10 @@ public Expr clone() { @Override public String toSqlImpl() { - if (isImplicit) { - return getChild(0).toSql(); - } - if (isAnalyzed) { + if (targetTypeDef == null) { return "CAST(" + getChild(0).toSql() + " AS " + type.toString() + ")"; } else { - return "CAST(" + getChild(0).toSql() + " AS " + targetTypeDef.toString() + ")"; + return "CAST(" + getChild(0).toSql() + " AS " + targetTypeDef + ")"; } } diff --git a/fe/fe-core/src/main/java/com/starrocks/analysis/SlotRef.java b/fe/fe-core/src/main/java/com/starrocks/analysis/SlotRef.java index be3ef4597d0f1..44e4dcb771a53 100644 --- a/fe/fe-core/src/main/java/com/starrocks/analysis/SlotRef.java +++ b/fe/fe-core/src/main/java/com/starrocks/analysis/SlotRef.java @@ -260,7 +260,7 @@ public String toSqlImpl() { if (tblName != null) { return tblName.toSql() + "." + "`" + col + "`"; } else if (label != null) { - return label + sb.toString(); + return label; } else if (desc.getSourceExprs() != null) { sb.append(""); for (Expr expr : desc.getSourceExprs()) { @@ -269,7 +269,7 @@ public String toSqlImpl() { } return sb.toString(); } else { - return "" + sb.toString(); + return ""; } } @@ -288,20 +288,18 @@ public String explainImpl() { @Override public String toMySql() { - if (col != null) { - return col; - } else { - return ""; + if (label == null) { + throw new IllegalArgumentException("should set label for cols in MySQLScanNode. SlotRef: " + debugString()); } + return label; } @Override public String toJDBCSQL(boolean isMySQL) { - if (col != null) { - return isMySQL ? "`" + col + "`" : col; - } else { - return ""; + if (label == null) { + throw new IllegalArgumentException("should set label for cols in JDBCScanNode. SlotRef: " + debugString()); } + return isMySQL ? "`" + label + "`" : label; } public TableName getTableName() { diff --git a/fe/fe-core/src/main/java/com/starrocks/load/Load.java b/fe/fe-core/src/main/java/com/starrocks/load/Load.java index 3975b3f57c5fe..cb7a73f09a2ae 100644 --- a/fe/fe-core/src/main/java/com/starrocks/load/Load.java +++ b/fe/fe-core/src/main/java/com/starrocks/load/Load.java @@ -602,7 +602,7 @@ private static void replaceSrcSlotDescType(Table tbl, Map exprsByN } SlotRef slotRef = (SlotRef) child; - String columnName = slotRef.getColumn().getName(); + String columnName = slotRef.getColumnName(); if (excludedColumns.contains(columnName)) { continue; } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ExternalTablePredicateExtractor.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ExternalTablePredicateExtractor.java index 6825ed1ee43b4..be7ab150d4a1e 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ExternalTablePredicateExtractor.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ExternalTablePredicateExtractor.java @@ -2,6 +2,8 @@ package com.starrocks.sql.optimizer.rewrite; +import com.google.common.collect.ImmutableSet; +import com.starrocks.catalog.PrimitiveType; import com.starrocks.sql.optimizer.Utils; import com.starrocks.sql.optimizer.operator.OperatorType; import com.starrocks.sql.optimizer.operator.scalar.BetweenPredicateOperator; @@ -17,15 +19,25 @@ import java.util.LinkedList; import java.util.List; +import java.util.Set; // Extract predicates that can be pushed down to external table // and predicates that must be reserved // from the entire predicate // To be safe, we only allow push down simple predicates public class ExternalTablePredicateExtractor { + + private static Set MYSQL_CAST_TYPE = ImmutableSet.of(PrimitiveType.DATE, PrimitiveType.CHAR, + PrimitiveType.DATETIME, PrimitiveType.DECIMALV2, PrimitiveType.DOUBLE, PrimitiveType.FLOAT, PrimitiveType.JSON); + + private final boolean isMySQL; private List pushedPredicates = new LinkedList<>(); private List reservedPredicates = new LinkedList<>(); + public ExternalTablePredicateExtractor(boolean isMySQL) { + this.isMySQL = isMySQL; + } + public ScalarOperator getPushPredicate() { return Utils.compoundAnd(pushedPredicates); } @@ -35,9 +47,6 @@ public ScalarOperator getReservePredicate() { } public void extract(ScalarOperator op) { - pushedPredicates.clear(); - reservedPredicates.clear(); - if (op.getOpType().equals(OperatorType.COMPOUND)) { CompoundPredicateOperator operator = (CompoundPredicateOperator) op; switch (operator.getCompoundType()) { @@ -46,7 +55,7 @@ public void extract(ScalarOperator op) { // for CNF, we can push down each predicate independently for (ScalarOperator conjunct : conjuncts) { if (conjunct.accept(new CanFullyPushDownVisitor(), null)) { - pushedPredicates.add(conjunct); + pushedPredicates.add(removeImplicitCast(conjunct)); } else { reservedPredicates.add(conjunct); } @@ -61,12 +70,12 @@ public void extract(ScalarOperator op) { return; } } - pushedPredicates.add(operator); + pushedPredicates.add(removeImplicitCast(operator)); return; } case NOT: { if (op.getChild(0).accept(new CanFullyPushDownVisitor(), null)) { - pushedPredicates.add(op); + pushedPredicates.add(removeImplicitCast(op)); } else { reservedPredicates.add(op); } @@ -77,14 +86,33 @@ public void extract(ScalarOperator op) { return; } if (op.accept(new CanFullyPushDownVisitor(), null)) { - pushedPredicates.add(op); + + pushedPredicates.add(removeImplicitCast(op)); } else { reservedPredicates.add(op); } } + private ScalarOperator removeImplicitCast(ScalarOperator operator) { + BaseScalarOperatorShuttle removeImplicitCastShuttle = new BaseScalarOperatorShuttle() { + @Override + public ScalarOperator visitCastOperator(CastOperator operator, Void context) { + boolean[] update = {false}; + List clonedOperators = visitList(operator.getChildren(), update); + if (operator.isImplicit()) { + return update[0] ? clonedOperators.get(0) : operator.getChild(0); + } else { + return update[0] ? new CastOperator(operator.getType(), clonedOperators.get(0), operator.isImplicit()) + : operator; + } + } + }; + + return operator.accept(removeImplicitCastShuttle, null); + } + // check whether a predicate can be pushed down as a whole - private static class CanFullyPushDownVisitor extends ScalarOperatorVisitor { + private class CanFullyPushDownVisitor extends ScalarOperatorVisitor { public CanFullyPushDownVisitor() { } @@ -139,6 +167,9 @@ public Boolean visitIsNullPredicate(IsNullPredicateOperator op, Void context) { @Override public Boolean visitCastOperator(CastOperator op, Void context) { + if (!op.isImplicit() && isMySQL && !MYSQL_CAST_TYPE.contains(op.getType().getPrimitiveType())) { + return false; + } return visitAllChildren(op, context); } } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/PushDownPredicateToExternalTableScanRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/PushDownPredicateToExternalTableScanRule.java index 2f61cf49495ba..9d8b81ea1d6c9 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/PushDownPredicateToExternalTableScanRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/PushDownPredicateToExternalTableScanRule.java @@ -49,8 +49,8 @@ public List transform(OptExpression input, OptimizerContext conte ScalarOperator predicate = Utils.compoundAnd(lfo.getPredicate(), operator.getPredicate()); ScalarOperator scanPredicate = operator.getPredicate(); ScalarOperator filterPredicate = lfo.getPredicate(); - - ExternalTablePredicateExtractor extractor = new ExternalTablePredicateExtractor(); + ExternalTablePredicateExtractor extractor = new ExternalTablePredicateExtractor( + operator.getOpType() == OperatorType.LOGICAL_MYSQL_SCAN); extractor.extract(predicate); ScalarOperator pushedPredicate = extractor.getPushPredicate(); ScalarOperator reservedPredicate = extractor.getReservePredicate(); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java index f2d7df102ed72..939b89c2e7ee8 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java @@ -1157,7 +1157,6 @@ public PlanFragment visitPhysicalMysqlScan(OptExpression optExpression, ExecPlan List predicates = Utils.extractConjuncts(node.getPredicate()); ScalarOperatorToExpr.FormatterContext formatterContext = new ScalarOperatorToExpr.FormatterContext(context.getColRefToExpr()); - formatterContext.setImplicitCast(true); for (ScalarOperator predicate : predicates) { scanNode.getConjuncts().add(ScalarOperatorToExpr.buildExecExpression(predicate, formatterContext)); } @@ -1241,7 +1240,6 @@ public PlanFragment visitPhysicalJDBCScan(OptExpression optExpression, ExecPlan List predicates = Utils.extractConjuncts(node.getPredicate()); ScalarOperatorToExpr.FormatterContext formatterContext = new ScalarOperatorToExpr.FormatterContext(context.getColRefToExpr()); - formatterContext.setImplicitCast(true); for (ScalarOperator predicate : predicates) { scanNode.getConjuncts().add(ScalarOperatorToExpr.buildExecExpression(predicate, formatterContext)); } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/plan/ScalarOperatorToExpr.java b/fe/fe-core/src/main/java/com/starrocks/sql/plan/ScalarOperatorToExpr.java index 0528e2febb5b6..6c735bce5af4c 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/plan/ScalarOperatorToExpr.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/plan/ScalarOperatorToExpr.java @@ -92,7 +92,6 @@ interface BuildExpr { public static class FormatterContext { private final Map colRefToExpr; private final Map projectOperatorMap; - private boolean implicitCast = false; public FormatterContext(Map variableToSlotRef) { this.colRefToExpr = variableToSlotRef; @@ -105,9 +104,6 @@ public FormatterContext(Map variableToSlotRef, this.projectOperatorMap = projectOperatorMap; } - public void setImplicitCast(boolean isImplicit) { - this.implicitCast = isImplicit; - } } public static class Formatter extends ScalarOperatorVisitor { @@ -513,7 +509,7 @@ public Expr visitCall(CallOperator call, FormatterContext context) { public Expr visitCastOperator(CastOperator operator, FormatterContext context) { CastExpr expr = new CastExpr(operator.getType(), buildExpr.build(operator.getChild(0), context)); - expr.setImplicit(context.implicitCast); + expr.setImplicit(operator.isImplicit()); hackTypeNull(expr); return expr; } diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/MySQLTableCastTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/MySQLTableCastTest.java new file mode 100644 index 0000000000000..575bb3a94f50f --- /dev/null +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/MySQLTableCastTest.java @@ -0,0 +1,87 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.plan; + +import com.google.common.collect.Lists; +import com.starrocks.common.FeConstants; +import com.starrocks.common.Pair; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.List; +import java.util.stream.Stream; + +public class MySQLTableCastTest extends PlanTestBase { + + @BeforeAll + public static void beforeClass() throws Exception { + PlanTestBase.beforeClass(); + FeConstants.runningUnitTest = true; + } + + + @ParameterizedTest(name = "sql_{index}: {0}.") + @MethodSource("explicitCastSqls") + void testExplicitCast(Pair pair) throws Exception { + String sql = pair.first; + String plan = getFragmentPlan(sql); + assertContains(plan, pair.second); + } + + @ParameterizedTest(name = "sql_{index}: {0}.") + @MethodSource("implicitCastSqls") + void testImplicitCast(Pair pair) throws Exception { + String sql = pair.first; + String plan = getFragmentPlan(sql); + assertContains(plan, pair.second); + } + + private static Stream explicitCastSqls() { + List> sqls = Lists.newArrayList(); + sqls.add(Pair.create("select * from ods_order where cast(cast(org_order_no as float) as date)", + "WHERE (CAST(CAST(org_order_no AS FLOAT) AS DATE))")); + + sqls.add(Pair.create("select * from ods_order where cast(org_order_no as varchar)", "WHERE (org_order_no)")); + + sqls.add(Pair.create("select * from ods_order where cast(org_order_no as date)", "WHERE (CAST(org_order_no AS DATE))")); + + sqls.add(Pair.create("select * from ods_order join mysql_table where cast(org_order_no as date) " + + "and k1 = cast(k2 as date)", "WHERE (k1 = CAST(k2 AS DATE))")); + + sqls.add(Pair.create("select * from ods_order where cast(org_order_no as int)", + "predicates: CAST(CAST(org_order_no AS INT) AS BOOLEAN)")); + + return sqls.stream().map(e -> Arguments.of(e)); + } + + private static Stream implicitCastSqls() { + List> sqls = Lists.newArrayList(); + sqls.add(Pair.create("select * from ods_order where org_order_no", "WHERE (org_order_no)")); + sqls.add(Pair.create("select * from ods_order where org_order_no or up_trade_no", + "WHERE ((org_order_no) OR (up_trade_no))")); + sqls.add(Pair.create("select * from ods_order where org_order_no and cast(up_trade_no as date)", + "WHERE (org_order_no) AND (CAST(up_trade_no AS DATE))")); + sqls.add(Pair.create("select * from ods_order where org_order_no or (up_trade_no = order_dt) or (mchnt_no = pay_st)", + "WHERE (((org_order_no) OR (up_trade_no = order_dt)) OR (mchnt_no = pay_st))")); + sqls.add(Pair.create("select * from ods_order join mysql_table where k1 = 'a' and order_dt = 'c'", + "WHERE (order_dt = 'c')")); + sqls.add(Pair.create("select * from (select * from ods_order join mysql_table where k1 = 'a' and order_dt = 'c')" + + " t1 where t1.k2 = 'c'", "WHERE (k2 = 'c') AND (k1 = 'a')")); + + return sqls.stream().map(e -> Arguments.of(e)); + } +} diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/ScanTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/ScanTest.java index 3c7b62105a806..d6b4bf4c9d26a 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/ScanTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/ScanTest.java @@ -396,4 +396,16 @@ public void testPushDownExternalTableMissNot() throws Exception { plan = getFragmentPlan(sql); assertContains(plan, "FROM `ods_order` WHERE (order_no NOT IN ('1', '2', '3'))"); } + + @Test + public void testImplicitCast() throws Exception { + String sql = "select count(distinct v1||v2) from t0"; + String plan = getFragmentPlan(sql); + assertContains(plan, "2:AGGREGATE (update finalize)\n" + + " | output: multi_distinct_count(4: expr)\n" + + " | group by: \n" + + " | \n" + + " 1:Project\n" + + " | : (CAST(1: v1 AS BOOLEAN)) OR (CAST(2: v2 AS BOOLEAN))"); + } }