Skip to content

Commit

Permalink
[BugFix] ensure a valid cast clause for MySQL/JDBC table (backport #2…
Browse files Browse the repository at this point in the history
…3379) (#30739)

Signed-off-by: packy <[email protected]>
  • Loading branch information
packy92 committed Sep 11, 2023
1 parent f3404f8 commit dbbbefd
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 33 deletions.
7 changes: 2 additions & 5 deletions fe/fe-core/src/main/java/com/starrocks/analysis/CastExpr.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,10 @@ public Expr clone() {

@Override
public String toSqlImpl() {
if (isImplicit) {
return getChild(0).toSql();
}
if (isAnalyzed) {
if (targetTypeDef == null) {
return "CAST(" + getChild(0).toSql() + " AS " + type.toString() + ")";
} else {
return "CAST(" + getChild(0).toSql() + " AS " + targetTypeDef.toString() + ")";
return "CAST(" + getChild(0).toSql() + " AS " + targetTypeDef + ")";
}
}

Expand Down
18 changes: 8 additions & 10 deletions fe/fe-core/src/main/java/com/starrocks/analysis/SlotRef.java
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ public String toSqlImpl() {
if (tblName != null) {
return tblName.toSql() + "." + "`" + col + "`";
} else if (label != null) {
return label + sb.toString();
return label;
} else if (desc.getSourceExprs() != null) {
sb.append("<slot ").append(desc.getId().asInt()).append(">");
for (Expr expr : desc.getSourceExprs()) {
Expand All @@ -269,7 +269,7 @@ public String toSqlImpl() {
}
return sb.toString();
} else {
return "<slot " + desc.getId().asInt() + ">" + sb.toString();
return "<slot " + desc.getId().asInt() + ">";
}
}

Expand All @@ -288,20 +288,18 @@ public String explainImpl() {

@Override
public String toMySql() {
if (col != null) {
return col;
} else {
return "<slot " + Integer.toString(desc.getId().asInt()) + ">";
if (label == null) {
throw new IllegalArgumentException("should set label for cols in MySQLScanNode. SlotRef: " + debugString());
}
return label;
}

@Override
public String toJDBCSQL(boolean isMySQL) {
if (col != null) {
return isMySQL ? "`" + col + "`" : col;
} else {
return "<slot " + Integer.toString(desc.getId().asInt()) + ">";
if (label == null) {
throw new IllegalArgumentException("should set label for cols in JDBCScanNode. SlotRef: " + debugString());
}
return isMySQL ? "`" + label + "`" : label;
}

public TableName getTableName() {
Expand Down
2 changes: 1 addition & 1 deletion fe/fe-core/src/main/java/com/starrocks/load/Load.java
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ private static void replaceSrcSlotDescType(Table tbl, Map<String, Expr> exprsByN
}

SlotRef slotRef = (SlotRef) child;
String columnName = slotRef.getColumn().getName();
String columnName = slotRef.getColumnName();
if (excludedColumns.contains(columnName)) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

package com.starrocks.sql.optimizer.rewrite;

import com.google.common.collect.ImmutableSet;
import com.starrocks.catalog.PrimitiveType;
import com.starrocks.sql.optimizer.Utils;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.scalar.BetweenPredicateOperator;
Expand All @@ -17,15 +19,25 @@

import java.util.LinkedList;
import java.util.List;
import java.util.Set;

// Extract predicates that can be pushed down to external table
// and predicates that must be reserved
// from the entire predicate
// To be safe, we only allow push down simple predicates
public class ExternalTablePredicateExtractor {

private static Set<PrimitiveType> MYSQL_CAST_TYPE = ImmutableSet.of(PrimitiveType.DATE, PrimitiveType.CHAR,
PrimitiveType.DATETIME, PrimitiveType.DECIMALV2, PrimitiveType.DOUBLE, PrimitiveType.FLOAT, PrimitiveType.JSON);

private final boolean isMySQL;
private List<ScalarOperator> pushedPredicates = new LinkedList<>();
private List<ScalarOperator> reservedPredicates = new LinkedList<>();

public ExternalTablePredicateExtractor(boolean isMySQL) {
this.isMySQL = isMySQL;
}

public ScalarOperator getPushPredicate() {
return Utils.compoundAnd(pushedPredicates);
}
Expand All @@ -35,9 +47,6 @@ public ScalarOperator getReservePredicate() {
}

public void extract(ScalarOperator op) {
pushedPredicates.clear();
reservedPredicates.clear();

if (op.getOpType().equals(OperatorType.COMPOUND)) {
CompoundPredicateOperator operator = (CompoundPredicateOperator) op;
switch (operator.getCompoundType()) {
Expand All @@ -46,7 +55,7 @@ public void extract(ScalarOperator op) {
// for CNF, we can push down each predicate independently
for (ScalarOperator conjunct : conjuncts) {
if (conjunct.accept(new CanFullyPushDownVisitor(), null)) {
pushedPredicates.add(conjunct);
pushedPredicates.add(removeImplicitCast(conjunct));
} else {
reservedPredicates.add(conjunct);
}
Expand All @@ -61,12 +70,12 @@ public void extract(ScalarOperator op) {
return;
}
}
pushedPredicates.add(operator);
pushedPredicates.add(removeImplicitCast(operator));
return;
}
case NOT: {
if (op.getChild(0).accept(new CanFullyPushDownVisitor(), null)) {
pushedPredicates.add(op);
pushedPredicates.add(removeImplicitCast(op));
} else {
reservedPredicates.add(op);
}
Expand All @@ -77,14 +86,33 @@ public void extract(ScalarOperator op) {
return;
}
if (op.accept(new CanFullyPushDownVisitor(), null)) {
pushedPredicates.add(op);

pushedPredicates.add(removeImplicitCast(op));
} else {
reservedPredicates.add(op);
}
}

private ScalarOperator removeImplicitCast(ScalarOperator operator) {
BaseScalarOperatorShuttle removeImplicitCastShuttle = new BaseScalarOperatorShuttle() {
@Override
public ScalarOperator visitCastOperator(CastOperator operator, Void context) {
boolean[] update = {false};
List<ScalarOperator> clonedOperators = visitList(operator.getChildren(), update);
if (operator.isImplicit()) {
return update[0] ? clonedOperators.get(0) : operator.getChild(0);
} else {
return update[0] ? new CastOperator(operator.getType(), clonedOperators.get(0), operator.isImplicit())
: operator;
}
}
};

return operator.accept(removeImplicitCastShuttle, null);
}

// check whether a predicate can be pushed down as a whole
private static class CanFullyPushDownVisitor extends ScalarOperatorVisitor<Boolean, Void> {
private class CanFullyPushDownVisitor extends ScalarOperatorVisitor<Boolean, Void> {
public CanFullyPushDownVisitor() {
}

Expand Down Expand Up @@ -139,6 +167,9 @@ public Boolean visitIsNullPredicate(IsNullPredicateOperator op, Void context) {

@Override
public Boolean visitCastOperator(CastOperator op, Void context) {
if (!op.isImplicit() && isMySQL && !MYSQL_CAST_TYPE.contains(op.getType().getPrimitiveType())) {
return false;
}
return visitAllChildren(op, context);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ public List<OptExpression> transform(OptExpression input, OptimizerContext conte
ScalarOperator predicate = Utils.compoundAnd(lfo.getPredicate(), operator.getPredicate());
ScalarOperator scanPredicate = operator.getPredicate();
ScalarOperator filterPredicate = lfo.getPredicate();

ExternalTablePredicateExtractor extractor = new ExternalTablePredicateExtractor();
ExternalTablePredicateExtractor extractor = new ExternalTablePredicateExtractor(
operator.getOpType() == OperatorType.LOGICAL_MYSQL_SCAN);
extractor.extract(predicate);
ScalarOperator pushedPredicate = extractor.getPushPredicate();
ScalarOperator reservedPredicate = extractor.getReservePredicate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,6 @@ public PlanFragment visitPhysicalMysqlScan(OptExpression optExpression, ExecPlan
List<ScalarOperator> predicates = Utils.extractConjuncts(node.getPredicate());
ScalarOperatorToExpr.FormatterContext formatterContext =
new ScalarOperatorToExpr.FormatterContext(context.getColRefToExpr());
formatterContext.setImplicitCast(true);
for (ScalarOperator predicate : predicates) {
scanNode.getConjuncts().add(ScalarOperatorToExpr.buildExecExpression(predicate, formatterContext));
}
Expand Down Expand Up @@ -1241,7 +1240,6 @@ public PlanFragment visitPhysicalJDBCScan(OptExpression optExpression, ExecPlan
List<ScalarOperator> predicates = Utils.extractConjuncts(node.getPredicate());
ScalarOperatorToExpr.FormatterContext formatterContext =
new ScalarOperatorToExpr.FormatterContext(context.getColRefToExpr());
formatterContext.setImplicitCast(true);
for (ScalarOperator predicate : predicates) {
scanNode.getConjuncts().add(ScalarOperatorToExpr.buildExecExpression(predicate, formatterContext));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ interface BuildExpr {
public static class FormatterContext {
private final Map<ColumnRefOperator, Expr> colRefToExpr;
private final Map<ColumnRefOperator, ScalarOperator> projectOperatorMap;
private boolean implicitCast = false;

public FormatterContext(Map<ColumnRefOperator, Expr> variableToSlotRef) {
this.colRefToExpr = variableToSlotRef;
Expand All @@ -105,9 +104,6 @@ public FormatterContext(Map<ColumnRefOperator, Expr> variableToSlotRef,
this.projectOperatorMap = projectOperatorMap;
}

public void setImplicitCast(boolean isImplicit) {
this.implicitCast = isImplicit;
}
}

public static class Formatter extends ScalarOperatorVisitor<Expr, FormatterContext> {
Expand Down Expand Up @@ -513,7 +509,7 @@ public Expr visitCall(CallOperator call, FormatterContext context) {
public Expr visitCastOperator(CastOperator operator, FormatterContext context) {
CastExpr expr =
new CastExpr(operator.getType(), buildExpr.build(operator.getChild(0), context));
expr.setImplicit(context.implicitCast);
expr.setImplicit(operator.isImplicit());
hackTypeNull(expr);
return expr;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.starrocks.sql.plan;

import com.google.common.collect.Lists;
import com.starrocks.common.FeConstants;
import com.starrocks.common.Pair;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import java.util.List;
import java.util.stream.Stream;

public class MySQLTableCastTest extends PlanTestBase {

@BeforeAll
public static void beforeClass() throws Exception {
PlanTestBase.beforeClass();
FeConstants.runningUnitTest = true;
}


@ParameterizedTest(name = "sql_{index}: {0}.")
@MethodSource("explicitCastSqls")
void testExplicitCast(Pair<String, String> pair) throws Exception {
String sql = pair.first;
String plan = getFragmentPlan(sql);
assertContains(plan, pair.second);
}

@ParameterizedTest(name = "sql_{index}: {0}.")
@MethodSource("implicitCastSqls")
void testImplicitCast(Pair<String, String> pair) throws Exception {
String sql = pair.first;
String plan = getFragmentPlan(sql);
assertContains(plan, pair.second);
}

private static Stream<Arguments> explicitCastSqls() {
List<Pair<String, String>> sqls = Lists.newArrayList();
sqls.add(Pair.create("select * from ods_order where cast(cast(org_order_no as float) as date)",
"WHERE (CAST(CAST(org_order_no AS FLOAT) AS DATE))"));

sqls.add(Pair.create("select * from ods_order where cast(org_order_no as varchar)", "WHERE (org_order_no)"));

sqls.add(Pair.create("select * from ods_order where cast(org_order_no as date)", "WHERE (CAST(org_order_no AS DATE))"));

sqls.add(Pair.create("select * from ods_order join mysql_table where cast(org_order_no as date) " +
"and k1 = cast(k2 as date)", "WHERE (k1 = CAST(k2 AS DATE))"));

sqls.add(Pair.create("select * from ods_order where cast(org_order_no as int)",
"predicates: CAST(CAST(org_order_no AS INT) AS BOOLEAN)"));

return sqls.stream().map(e -> Arguments.of(e));
}

private static Stream<Arguments> implicitCastSqls() {
List<Pair<String, String>> sqls = Lists.newArrayList();
sqls.add(Pair.create("select * from ods_order where org_order_no", "WHERE (org_order_no)"));
sqls.add(Pair.create("select * from ods_order where org_order_no or up_trade_no",
"WHERE ((org_order_no) OR (up_trade_no))"));
sqls.add(Pair.create("select * from ods_order where org_order_no and cast(up_trade_no as date)",
"WHERE (org_order_no) AND (CAST(up_trade_no AS DATE))"));
sqls.add(Pair.create("select * from ods_order where org_order_no or (up_trade_no = order_dt) or (mchnt_no = pay_st)",
"WHERE (((org_order_no) OR (up_trade_no = order_dt)) OR (mchnt_no = pay_st))"));
sqls.add(Pair.create("select * from ods_order join mysql_table where k1 = 'a' and order_dt = 'c'",
"WHERE (order_dt = 'c')"));
sqls.add(Pair.create("select * from (select * from ods_order join mysql_table where k1 = 'a' and order_dt = 'c')" +
" t1 where t1.k2 = 'c'", "WHERE (k2 = 'c') AND (k1 = 'a')"));

return sqls.stream().map(e -> Arguments.of(e));
}
}
12 changes: 12 additions & 0 deletions fe/fe-core/src/test/java/com/starrocks/sql/plan/ScanTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -396,4 +396,16 @@ public void testPushDownExternalTableMissNot() throws Exception {
plan = getFragmentPlan(sql);
assertContains(plan, "FROM `ods_order` WHERE (order_no NOT IN ('1', '2', '3'))");
}

@Test
public void testImplicitCast() throws Exception {
String sql = "select count(distinct v1||v2) from t0";
String plan = getFragmentPlan(sql);
assertContains(plan, "2:AGGREGATE (update finalize)\n" +
" | output: multi_distinct_count(4: expr)\n" +
" | group by: \n" +
" | \n" +
" 1:Project\n" +
" | <slot 4> : (CAST(1: v1 AS BOOLEAN)) OR (CAST(2: v2 AS BOOLEAN))");
}
}

0 comments on commit dbbbefd

Please sign in to comment.