Skip to content

Commit

Permalink
[Enhancement] Push agg on small broadcast join (#52150)
Browse files Browse the repository at this point in the history
Signed-off-by: zihe.liu <[email protected]>
  • Loading branch information
ZiheLiu authored and stephen-shelby committed Nov 6, 2024
1 parent 2616004 commit bb07a54
Show file tree
Hide file tree
Showing 20 changed files with 1,793 additions and 274 deletions.
13 changes: 13 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ public class SessionVariable implements Serializable, Writable, Cloneable {
public static final String CBO_MAX_REORDER_NODE = "cbo_max_reorder_node";
public static final String CBO_PRUNE_SHUFFLE_COLUMN_RATE = "cbo_prune_shuffle_column_rate";
public static final String CBO_PUSH_DOWN_AGGREGATE_MODE = "cbo_push_down_aggregate_mode";
public static final String CBO_PUSH_DOWN_AGGREGATE_ON_BROADCAST_JOIN = "cbo_push_down_aggregate_on_broadcast_join";


public static final String CBO_PUSH_DOWN_DISTINCT_BELOW_WINDOW = "cbo_push_down_distinct_below_window";
public static final String CBO_PUSH_DOWN_AGGREGATE = "cbo_push_down_aggregate";
Expand Down Expand Up @@ -1484,6 +1486,9 @@ public static MaterializedViewRewriteMode parse(String str) {
show = CBO_PUSH_DOWN_AGGREGATE_MODE, flag = VariableMgr.INVISIBLE)
private int cboPushDownAggregateMode = -1;

@VarAttr(name = CBO_PUSH_DOWN_AGGREGATE_ON_BROADCAST_JOIN, flag = VariableMgr.INVISIBLE)
private boolean cboPushDownAggregateOnBroadcastJoin = true;

// auto, global, local
@VarAttr(name = CBO_PUSH_DOWN_AGGREGATE, flag = VariableMgr.INVISIBLE)
private String cboPushDownAggregate = "global";
Expand Down Expand Up @@ -3463,6 +3468,14 @@ public void setCboPushDownAggregateMode(int cboPushDownAggregateMode) {
this.cboPushDownAggregateMode = cboPushDownAggregateMode;
}

public boolean isCboPushDownAggregateOnBroadcastJoin() {
return cboPushDownAggregateOnBroadcastJoin;
}

public void setCboPushDownAggregateOnBroadcastJoin(boolean cboPushDownAggregateOnBroadcastJoin) {
this.cboPushDownAggregateOnBroadcastJoin = cboPushDownAggregateOnBroadcastJoin;
}

public String getCboPushDownAggregate() {
return cboPushDownAggregate;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.


package com.starrocks.sql.optimizer;

import com.google.common.base.Preconditions;
Expand Down Expand Up @@ -97,13 +96,14 @@ private static void collectForceCteStatistics(Group root, OptimizerContext conte
// Collect statistics of CTEProduceOperator outside of memo, used by only by table pruning features.
public static void collectForceCteStatisticsOutsideMemo(OptExpression root, OptimizerContext context) {
root.getInputs().forEach(input -> collectForceCteStatisticsOutsideMemo(input, context));
if (OperatorType.LOGICAL_CTE_ANCHOR.equals(root.getOp().getOpType())) {
Preconditions.checkState(root.getInputs().get(0).getOp() instanceof LogicalCTEProduceOperator);
LogicalCTEProduceOperator produce = (LogicalCTEProduceOperator) root.getInputs().get(0).getOp();
calculateStatistics(root.inputAt(0), context);
context.getCteContext().addCTEStatistics(produce.getCteId(), root.inputAt(0).getStatistics());
if (OperatorType.LOGICAL_CTE_PRODUCE.equals(root.getOp().getOpType())) {
LogicalCTEProduceOperator produce = (LogicalCTEProduceOperator) root.getOp();

calculateStatistics(root, context);
context.getCteContext().addCTEStatistics(produce.getCteId(), root.getStatistics());
}
}

private static void calculateStatistics(OptExpression expr, OptimizerContext context) {
// don't ask cte consume children
if (expr.getOp().getOpType() != OperatorType.LOGICAL_CTE_CONSUME) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import com.starrocks.sql.optimizer.rule.Rule;
import com.starrocks.sql.optimizer.rule.RuleSetType;
import com.starrocks.sql.optimizer.rule.implementation.OlapScanImplementationRule;
import com.starrocks.sql.optimizer.rule.join.JoinReorderFactory;
import com.starrocks.sql.optimizer.rule.join.ReorderJoinRule;
import com.starrocks.sql.optimizer.rule.mv.MaterializedViewRule;
import com.starrocks.sql.optimizer.rule.transformation.ApplyExceptionRule;
Expand Down Expand Up @@ -770,6 +771,20 @@ private OptExpression pushDownAggregation(OptExpression tree, TaskContext rootTa
}

if (context.getSessionVariable().getCboPushDownAggregateMode() != -1) {
if (context.getSessionVariable().isCboPushDownAggregateOnBroadcastJoin()) {
// Reorder joins before applying PushDownAggregateRule to better decide where to push down aggregator.
// For example, do not push down a not very efficient aggregator below a very small broadcast join.
ruleRewriteOnlyOnce(tree, rootTaskContext, RuleSetType.PARTITION_PRUNE);
ruleRewriteIterative(tree, rootTaskContext, new MergeTwoProjectRule());
ruleRewriteIterative(tree, rootTaskContext, new MergeProjectWithChildRule());
CTEUtils.collectForceCteStatisticsOutsideMemo(tree, context);
deriveLogicalProperty(tree);
tree = new ReorderJoinRule().rewrite(tree, JoinReorderFactory.createJoinReorderAdaptive(), context);
tree = new SeparateProjectRule().rewrite(tree, rootTaskContext);
deriveLogicalProperty(tree);
Utils.calculateStatistics(tree, context);
}

PushDownAggregateRule rule = new PushDownAggregateRule(rootTaskContext);
rule.getRewriter().collectRewriteContext(tree);
if (rule.getRewriter().isNeedRewrite()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public boolean containsAny(Collection<ColumnRefOperator> rhs) {
return rhs.stream().anyMatch(this::contains);
}

public boolean containsAll(List<Integer> rhs) {
public boolean containsAll(Collection<Integer> rhs) {
return rhs.stream().allMatch(id -> bitSet.contains(id));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
// limitations under the License.
package com.starrocks.sql.optimizer.rule.join;

import com.google.api.client.util.Lists;
import com.starrocks.qe.SessionVariable;
import com.starrocks.sql.optimizer.OptimizerContext;

import java.util.List;

// JoinReorderFactory is used to choose join reorder algorithm in RBO phase,
// it is used by ReorderJoinRule.rewrite method, at present JoinReorderFactory
// provides two factory implementation:
Expand All @@ -23,15 +27,33 @@
// 2. factory for creating JoinReorderCardinalityPreserving, which used by table pruning
// feature.
public interface JoinReorderFactory {
JoinOrder create(OptimizerContext context);
List<JoinOrder> create(OptimizerContext context, MultiJoinNode multiJoinNode);

// used by AutoMV to eliminate cross join.
static JoinReorderFactory createJoinReorderDummyStatisticsFactory() {
return JoinReorderDummyStatistics::new;
return (context, multiJoinNode) -> List.of(new JoinReorderDummyStatistics(context));
}

// used by table pruning feature.
static JoinReorderFactory createJoinReorderCardinalityPreserving() {
return JoinReorderCardinalityPreserving::new;
return (context, multiJoinNode) -> List.of(new JoinReorderCardinalityPreserving(context));
}

static JoinReorderFactory createJoinReorderAdaptive() {
return (context, multiJoinNode) -> {
List<JoinOrder> algorithms = Lists.newArrayList();
algorithms.add(new JoinReorderLeftDeep(context));

SessionVariable sv = context.getSessionVariable();
if (multiJoinNode.getAtoms().size() <= sv.getCboMaxReorderNodeUseDP() && sv.isCboEnableDPJoinReorder()) {
algorithms.add(new JoinReorderDP(context));
}

if (sv.isCboEnableGreedyJoinReorder()) {
algorithms.add(new JoinReorderGreedy(context));
}

return algorithms;
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ Optional<OptExpression> enumerate(JoinOrder reorderAlgorithm, OptimizerContext c
if (copyIntoMemo) {
context.getMemo().copyIn(innerJoinRoot.getGroupExpression().getGroup(), joinExpr);
} else {
joinExpr.deriveLogicalPropertyItself();
ExpressionContext expressionContext = new ExpressionContext(joinExpr);
StatisticsCalculator statisticsCalculator =
new StatisticsCalculator(expressionContext, context.getColumnRefFactory(), context);
statisticsCalculator.estimatorStats();
joinExpr.setStatistics(expressionContext.getStatistics());
return Optional.of(joinExpr);
}
}
Expand Down Expand Up @@ -184,8 +190,16 @@ public OptExpression rewrite(OptExpression input, JoinReorderFactory joinReorder
if (!multiJoinNode.checkDependsPredicate()) {
continue;
}
Optional<OptExpression> newChild =
enumerate(joinReorderFactory.create(context), context, child, multiJoinNode, false);

List<JoinOrder> orderAlgorithms = joinReorderFactory.create(context, multiJoinNode);
Optional<OptExpression> newChild = Optional.empty();
for (JoinOrder orderAlgorithm : orderAlgorithms) {
newChild = enumerate(orderAlgorithm, context, child, multiJoinNode, false);
if (newChild.isEmpty()) {
break;
}
}

if (newChild.isPresent()) {
int prevNumCrossJoins =
Utils.countJoinNodeSize(child, Sets.newHashSet(JoinOperator.CROSS_JOIN));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ public OptExpression rewriteImpl(OptExpression root) {
if (!(operator instanceof LogicalProjectOperator) && operator.getProjection() != null) {
Projection projection = operator.getProjection();
operator.setProjection(null);
// Clear statistics to recompute statistics when calculating statistics later,
// because some operators only recompute statistics when statistics is null.
root.setStatistics(null);
return OptExpression.create(new LogicalProjectOperator(projection.getColumnRefMap()), root);
} else {
return root;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.


package com.starrocks.sql.optimizer.rule.tree.pdagg;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.starrocks.catalog.Function;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.common.Pair;
import com.starrocks.sql.optimizer.OptExpression;
import com.starrocks.sql.optimizer.base.ColumnRefFactory;
import com.starrocks.sql.optimizer.operator.logical.LogicalAggregationOperator;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
Expand Down Expand Up @@ -53,6 +53,12 @@ public class AggregatePushDownContext {
// count's column ref operator.
public final Map<CallOperator, Pair<ColumnRefOperator, ColumnRefOperator>> avgToSumCountMapping = Maps.newHashMap();

// Aggregator will be pushed down to the position above targetPosition.
public OptExpression targetPosition = null;
// Whether targetPosition is an immediate left child of a small broadcast join.
public boolean immediateChildOfSmallBroadcastJoin = false;
public int rootToLeafPathIndex = 0;

public boolean hasWindow = false;

// record push down path
Expand All @@ -66,6 +72,11 @@ public AggregatePushDownContext() {
pushPaths = Lists.newArrayList();
}

public AggregatePushDownContext(int rootToLeafPathIndex) {
this();
this.rootToLeafPathIndex = rootToLeafPathIndex;
}

public void setAggregator(LogicalAggregationOperator aggregator) {
this.origAggregator = aggregator;
this.aggregations.putAll(aggregator.getAggregations());
Expand Down Expand Up @@ -116,7 +127,7 @@ public void registerAggRewriteInfo(CallOperator aggFunc,
}

public void registerOrigAggRewriteInfo(CallOperator aggFunc,
CallOperator origAgg) {
CallOperator origAgg) {
aggToOrigAggMap.put(aggFunc, origAgg);
}

Expand All @@ -133,4 +144,8 @@ public void combine(AggregatePushDownContext ctx) {
public boolean isRewrittenByEquivalent(CallOperator aggCall) {
return aggToFinalAggMap.containsKey(aggCall);
}

public OptExpression getTargetPosition() {
return targetPosition;
}
}
Loading

0 comments on commit bb07a54

Please sign in to comment.