Skip to content

Commit

Permalink
[Refactor][BugFix][CherryPick] refactor mv expression rewrite for 3.1 (
Browse files Browse the repository at this point in the history
…#30111) (#30732)

Signed-off-by: ABingHuang <[email protected]>
  • Loading branch information
ABingHuang authored Sep 11, 2023
1 parent e321e85 commit 4b150a9
Show file tree
Hide file tree
Showing 3 changed files with 361 additions and 246 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package com.starrocks.sql.optimizer.rewrite;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.starrocks.sql.optimizer.operator.scalar.ArrayOperator;
import com.starrocks.sql.optimizer.operator.scalar.ArraySliceOperator;
Expand All @@ -40,243 +42,216 @@
import com.starrocks.sql.optimizer.operator.scalar.SubfieldOperator;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;

/**
* When you want to replace some types of nodes in a scalarOperator tree, you can extend this class and
* override specific visit methods. It will return a new scalarOperator tree with specific nodes replaced.
* shuttle means a scalarOperator bus, it takes you traverse the scalarOperator tree.
*/
public class BaseScalarOperatorShuttle extends ScalarOperatorVisitor<ScalarOperator, Void> {
private static final Map<Class<? extends ScalarOperator>, BiFunction<ScalarOperator, List<ScalarOperator>, ScalarOperator>>
CLONE_FUNCTIONS;

static {
CLONE_FUNCTIONS = ImmutableMap.<Class<? extends ScalarOperator>,
BiFunction<ScalarOperator, List<ScalarOperator>, ScalarOperator>>builder()
.put(ConstantOperator.class, (op, childOps) -> op)
.put(ColumnRefOperator.class, (op, childOps) -> op)
.put(ArrayOperator.class, (op, childOps) -> new ArrayOperator(op.getType(), op.isNullable(), childOps))
.put(CollectionElementOperator.class, (op, childOps) -> new CollectionElementOperator(op.getType(),
childOps.get(0), childOps.get(1)))
.put(ArraySliceOperator.class, (op, childOps) -> new ArraySliceOperator(op.getType(), childOps))
.put(CallOperator.class, (op, childOps) -> {
CallOperator call = (CallOperator) op;
return new CallOperator(call.getFnName(), call.getType(), childOps, call.getFunction(), call.isDistinct()); })
.put(PredicateOperator.class, (op, childOps) -> op)
.put(BetweenPredicateOperator.class, (op, childOps) -> {
BetweenPredicateOperator between = (BetweenPredicateOperator) op;
return new BetweenPredicateOperator(between.isNotBetween(), childOps); })
.put(BinaryPredicateOperator.class, (op, childOps) -> {
BinaryPredicateOperator binary = (BinaryPredicateOperator) op;
return new BinaryPredicateOperator(binary.getBinaryType(), childOps); })
.put(CompoundPredicateOperator.class, (op, childOps) -> {
CompoundPredicateOperator compound = (CompoundPredicateOperator) op;
return new CompoundPredicateOperator(compound.getCompoundType(), childOps); })
.put(ExistsPredicateOperator.class, (op, childOps) -> {
ExistsPredicateOperator exist = (ExistsPredicateOperator) op;
return new ExistsPredicateOperator(exist.isNotExists(), childOps); })
.put(InPredicateOperator.class, (op, childOps) -> {
InPredicateOperator inPredicate = (InPredicateOperator) op;
return new InPredicateOperator(inPredicate.isNotIn(), childOps); })
.put(IsNullPredicateOperator.class, (op, childOps) -> {
IsNullPredicateOperator isNullPredicate = (IsNullPredicateOperator) op;
return new IsNullPredicateOperator(isNullPredicate.isNotNull(), childOps.get(0)); })
.put(LikePredicateOperator.class, (op, childOps) -> {
LikePredicateOperator like = (LikePredicateOperator) op;
return new LikePredicateOperator(like.getLikeType(), childOps); })
.put(CastOperator.class, (op, childOps) -> {
CastOperator cast = (CastOperator) op;
return new CastOperator(cast.getType(), childOps.get(0), cast.isImplicit()); })
.put(CaseWhenOperator.class, (op, childOps) -> {
CaseWhenOperator caseWhen = (CaseWhenOperator) op;
ScalarOperator clonedCaseClause = null;
ScalarOperator clonedElseClause = null;
List<ScalarOperator> clonedWhenThenClauses;
if (caseWhen.hasCase()) {
clonedCaseClause = childOps.get(0);
}
if (caseWhen.hasElse()) {
clonedElseClause = childOps.get(childOps.size() - 1);
}

int whenThenEndIdx = caseWhen.hasElse() ? childOps.size() - 1 : childOps.size();
clonedWhenThenClauses = childOps.subList(caseWhen.getWhenStart(), whenThenEndIdx);

return new CaseWhenOperator(caseWhen.getType(), clonedCaseClause, clonedElseClause, clonedWhenThenClauses);
})
.put(SubfieldOperator.class, (op, childOps) -> {
SubfieldOperator subfield = (SubfieldOperator) op;
return new SubfieldOperator(childOps.get(0), subfield.getType(), subfield.getFieldNames()); })
.put(MapOperator.class, (op, childOps) -> new MapOperator(op.getType(), childOps))
.put(MultiInPredicateOperator.class, (op, childOps) -> {
MultiInPredicateOperator multiIn = (MultiInPredicateOperator) op;
return new MultiInPredicateOperator(multiIn.isNotIn(), childOps, multiIn.getTupleSize()); })
.put(LambdaFunctionOperator.class, (op, childOps) -> {
LambdaFunctionOperator lambda = (LambdaFunctionOperator) op;
return new LambdaFunctionOperator(lambda.getRefColumns(), childOps.get(0), lambda.getType()); })
.put(CloneOperator.class, (op, childOps) -> new CloneOperator(childOps.get(0)))
.build();
}

public ScalarOperator visit(ScalarOperator scalarOperator, Void context) {
return scalarOperator;
}

public Optional<ScalarOperator> preprocess(ScalarOperator scalarOperator) {
return Optional.empty();
}

public ScalarOperator shuttleIfUpdate(ScalarOperator operator) {
Optional<ScalarOperator> preprocessed = preprocess(operator);
if (preprocessed.isPresent()) {
return preprocessed.get();
}
boolean[] update = {false};
List<ScalarOperator> clonedChildOperators = visitList(operator.getChildren(), update);
if (update[0]) {
BiFunction<ScalarOperator, List<ScalarOperator>, ScalarOperator> cloningFunction =
CLONE_FUNCTIONS.get(operator.getClass());
Preconditions.checkNotNull(cloningFunction);
return cloningFunction.apply(operator, clonedChildOperators);
} else {
return operator;
}
}


@Override
public ScalarOperator visitConstant(ConstantOperator literal, Void context) {
return literal;
return shuttleIfUpdate(literal);
}

@Override
public ScalarOperator visitVariableReference(ColumnRefOperator variable, Void context) {
return variable;
return shuttleIfUpdate(variable);
}

@Override
public ScalarOperator visitArray(ArrayOperator array, Void context) {
boolean[] update = {false};
List<ScalarOperator> clonedOperators = visitList(array.getChildren(), update);
if (update[0]) {
return new ArrayOperator(array.getType(), array.isNullable(), clonedOperators);
} else {
return array;
}
return shuttleIfUpdate(array);
}

@Override
public ScalarOperator visitCollectionElement(CollectionElementOperator collectionElementOp, Void context) {
boolean[] update = {false};
List<ScalarOperator> clonedOperators = visitList(collectionElementOp.getChildren(), update);
if (update[0]) {
return new CollectionElementOperator(collectionElementOp.getType(), clonedOperators.get(0),
clonedOperators.get(1));
}
return collectionElementOp;
return shuttleIfUpdate(collectionElementOp);
}

@Override
public ScalarOperator visitArraySlice(ArraySliceOperator array, Void context) {
boolean[] update = {false};
List<ScalarOperator> clonedOperators = visitList(array.getChildren(), update);
if (update[0]) {
return new ArraySliceOperator(array.getType(), clonedOperators);
} else {
return array;
}
return shuttleIfUpdate(array);
}

@Override
public ScalarOperator visitCall(CallOperator call, Void context) {
boolean[] update = {false};
List<ScalarOperator> clonedOperators = visitList(call.getChildren(), update);
if (update[0]) {
return new CallOperator(call.getFnName(), call.getType(), clonedOperators,
call.getFunction(), call.isDistinct());
} else {
return call;
}
return shuttleIfUpdate(call);
}

@Override
public ScalarOperator visitPredicate(PredicateOperator predicate, Void context) {
return predicate;
return shuttleIfUpdate(predicate);
}

@Override
public ScalarOperator visitBetweenPredicate(BetweenPredicateOperator predicate, Void context) {
boolean[] update = {false};
List<ScalarOperator> clonedOperators = visitList(predicate.getChildren(), update);
if (update[0]) {
return new BetweenPredicateOperator(predicate.isNotBetween(), clonedOperators);
} else {
return predicate;
}
return shuttleIfUpdate(predicate);
}

@Override
public ScalarOperator visitBinaryPredicate(BinaryPredicateOperator predicate, Void context) {
boolean[] update = {false};
List<ScalarOperator> clonedOperators = visitList(predicate.getChildren(), update);
if (update[0]) {
return new BinaryPredicateOperator(predicate.getBinaryType(), clonedOperators);
} else {
return predicate;
}
return shuttleIfUpdate(predicate);
}

@Override
public ScalarOperator visitCompoundPredicate(CompoundPredicateOperator predicate, Void context) {
boolean[] update = {false};
List<ScalarOperator> clonedOperators = visitList(predicate.getChildren(), update);
if (update[0]) {
return new CompoundPredicateOperator(predicate.getCompoundType(), clonedOperators);
} else {
return predicate;
}
return shuttleIfUpdate(predicate);
}

@Override
public ScalarOperator visitExistsPredicate(ExistsPredicateOperator predicate, Void context) {
boolean[] update = {false};
List<ScalarOperator> clonedOperators = visitList(predicate.getChildren(), update);
if (update[0]) {
return new ExistsPredicateOperator(predicate.isNotExists(), clonedOperators);
} else {
return predicate;
}
return shuttleIfUpdate(predicate);
}

@Override
public ScalarOperator visitInPredicate(InPredicateOperator predicate, Void context) {
boolean[] update = {false};
List<ScalarOperator> clonedOperators = visitList(predicate.getChildren(), update);
if (update[0]) {
return new InPredicateOperator(predicate.isNotIn(), clonedOperators);
} else {
return predicate;
}
return shuttleIfUpdate(predicate);
}

@Override
public ScalarOperator visitIsNullPredicate(IsNullPredicateOperator predicate, Void context) {
boolean[] update = {false};
List<ScalarOperator> clonedOperators = visitList(predicate.getChildren(), update);
if (update[0]) {
return new IsNullPredicateOperator(predicate.isNotNull(), clonedOperators.get(0));
} else {
return predicate;
}
return shuttleIfUpdate(predicate);
}

@Override
public ScalarOperator visitLikePredicateOperator(LikePredicateOperator predicate, Void context) {
boolean[] update = {false};
List<ScalarOperator> clonedOperators = visitList(predicate.getChildren(), update);
if (update[0]) {
return new LikePredicateOperator(predicate.getLikeType(), clonedOperators);
} else {
return predicate;
}
return shuttleIfUpdate(predicate);
}

@Override
public ScalarOperator visitCastOperator(CastOperator operator, Void context) {
boolean[] update = {false};
List<ScalarOperator> clonedOperators = visitList(operator.getChildren(), update);
if (update[0]) {
return new CastOperator(operator.getType(), clonedOperators.get(0), operator.isImplicit());
} else {
return operator;
}
return shuttleIfUpdate(operator);
}

@Override
public ScalarOperator visitCaseWhenOperator(CaseWhenOperator operator, Void context) {
boolean[] update = {false};
List<ScalarOperator> clonedOperators = visitList(operator.getChildren(), update);
if (update[0]) {
ScalarOperator clonedCaseClause = null;
ScalarOperator clonedElseClause = null;
List<ScalarOperator> clonedWhenThenClauses;
if (operator.hasCase()) {
clonedCaseClause = clonedOperators.get(0);
}
if (operator.hasElse()) {
clonedElseClause = clonedOperators.get(clonedOperators.size() - 1);
}

int whenThenEndIdx = operator.hasElse() ? clonedOperators.size() - 1 : clonedOperators.size();
clonedWhenThenClauses = clonedOperators.subList(operator.getWhenStart(), whenThenEndIdx);

return new CaseWhenOperator(operator.getType(), clonedCaseClause, clonedElseClause, clonedWhenThenClauses);
} else {
return operator;
}
return shuttleIfUpdate(operator);
}

@Override
public ScalarOperator visitSubfield(SubfieldOperator operator, Void context) {
boolean[] update = {false};
List<ScalarOperator> child = visitList(operator.getChildren(), update);
if (update[0]) {
return new SubfieldOperator(child.get(0), operator.getType(), operator.getFieldNames());
} else {
return operator;
}
return shuttleIfUpdate(operator);
}

@Override
public ScalarOperator visitMap(MapOperator operator, Void context) {
boolean[] update = {false};
List<ScalarOperator> children = visitList(operator.getChildren(), update);
if (update[0]) {
return new MapOperator(operator.getType(), children);
} else {
return operator;
}
return shuttleIfUpdate(operator);
}

@Override
public ScalarOperator visitMultiInPredicate(MultiInPredicateOperator operator, Void context) {
boolean[] update = {false};
List<ScalarOperator> children = visitList(operator.getChildren(), update);
if (update[0]) {
return new MultiInPredicateOperator(operator.isNotIn(), children, operator.getTupleSize());
} else {
return operator;
}
return shuttleIfUpdate(operator);
}

@Override
public ScalarOperator visitLambdaFunctionOperator(LambdaFunctionOperator operator, Void context) {
boolean[] update = {false};
List<ScalarOperator> children = visitList(operator.getChildren(), update);
if (update[0]) {
return new LambdaFunctionOperator(operator.getRefColumns(), children.get(0), operator.getType());
} else {
return operator;
}
return shuttleIfUpdate(operator);
}

@Override
public ScalarOperator visitCloneOperator(CloneOperator operator, Void context) {
boolean[] update = {false};
List<ScalarOperator> children = visitList(operator.getChildren(), update);
if (update[0]) {
return new CloneOperator(children.get(0));
} else {
return operator;
}
return shuttleIfUpdate(operator);
}

protected List<ScalarOperator> visitList(List<ScalarOperator> operators, boolean[] update) {
Expand Down
Loading

0 comments on commit 4b150a9

Please sign in to comment.