diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
new file mode 100644
index 0000000000000..4e1900f80b68f
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
@@ -0,0 +1,1050 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.rel.rules;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptRuleOperand;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.logical.LogicalAggregate;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.tools.RelBuilderFactory;
+import org.apache.calcite.util.CompositeList;
+import org.apache.calcite.util.ImmutableIntList;
+import org.apache.calcite.util.Util;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.immutables.value.Value;
+
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.function.IntPredicate;
+import java.util.function.Predicate;
+
+/**
+ * Planner rule that reduces aggregate functions in {@link org.apache.calcite.rel.core.Aggregate}s
+ * to simpler forms. This rule is copied to fix the correctness issue in Flink before upgrading to
+ * the corresponding Calcite version. Flink modifications:
+ *
+ *
Lines 555 ~ 565 to fix CALCITE-7192.
+ *
+ *
Rewrites:
+ *
+ *
+ * - AVG(x) → SUM(x) / COUNT(x)
+ *
- STDDEV_POP(x) → SQRT( (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) / COUNT(x))
+ *
- STDDEV_SAMP(x) → SQRT( (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) / CASE COUNT(x) WHEN
+ * 1 THEN NULL ELSE COUNT(x) - 1 END)
+ *
- VAR_POP(x) → (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) / COUNT(x)
+ *
- VAR_SAMP(x) → (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) / CASE COUNT(x) WHEN 1 THEN
+ * NULL ELSE COUNT(x) - 1 END
+ *
- COVAR_POP(x, y) → (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) /
+ * REGR_COUNT(x, y)
+ *
- COVAR_SAMP(x, y) → (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / CASE
+ * REGR_COUNT(x, y) WHEN 1 THEN NULL ELSE REGR_COUNT(x, y) - 1 END
+ *
- REGR_SXX(x, y) → REGR_COUNT(x, y) * VAR_POP(y)
+ *
- REGR_SYY(x, y) → REGR_COUNT(x, y) * VAR_POP(x)
+ *
+ *
+ * Since many of these rewrites introduce multiple occurrences of simpler forms like {@code
+ * COUNT(x)}, the rule gathers common sub-expressions as it goes.
+ *
+ * @see CoreRules#AGGREGATE_REDUCE_FUNCTIONS
+ */
+@Value.Enclosing
+public class AggregateReduceFunctionsRule extends RelRule
+ implements TransformationRule {
+ // ~ Static fields/initializers ---------------------------------------------
+
+ private static void validateFunction(SqlKind function) {
+ if (!isValid(function)) {
+ throw new IllegalArgumentException(
+ "AggregateReduceFunctionsRule doesn't " + "support function: " + function.sql);
+ }
+ }
+
+ private static boolean isValid(SqlKind function) {
+ return SqlKind.AVG_AGG_FUNCTIONS.contains(function)
+ || SqlKind.COVAR_AVG_AGG_FUNCTIONS.contains(function)
+ || function == SqlKind.SUM;
+ }
+
+ private final Set functionsToReduce;
+
+ // ~ Constructors -----------------------------------------------------------
+
+ /** Creates an AggregateReduceFunctionsRule. */
+ protected AggregateReduceFunctionsRule(Config config) {
+ super(config);
+ this.functionsToReduce = ImmutableSet.copyOf(config.actualFunctionsToReduce());
+ }
+
+ @Deprecated // to be removed before 2.0
+ public AggregateReduceFunctionsRule(
+ RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory) {
+ this(
+ Config.DEFAULT
+ .withRelBuilderFactory(relBuilderFactory)
+ .withOperandSupplier(b -> b.exactly(operand))
+ .as(Config.class)
+ // reduce all functions handled by this rule
+ .withFunctionsToReduce(null));
+ }
+
+ @Deprecated // to be removed before 2.0
+ public AggregateReduceFunctionsRule(
+ Class extends Aggregate> aggregateClass,
+ RelBuilderFactory relBuilderFactory,
+ EnumSet functionsToReduce) {
+ this(
+ Config.DEFAULT
+ .withRelBuilderFactory(relBuilderFactory)
+ .as(Config.class)
+ .withOperandFor(aggregateClass)
+ // reduce specific functions provided by the client
+ .withFunctionsToReduce(
+ Objects.requireNonNull(functionsToReduce, "functionsToReduce")));
+ }
+
+ // ~ Methods ----------------------------------------------------------------
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ if (!super.matches(call)) {
+ return false;
+ }
+ Aggregate oldAggRel = (Aggregate) call.rels[0];
+ return containsAvgStddevVarCall(oldAggRel.getAggCallList());
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall ruleCall) {
+ Aggregate oldAggRel = (Aggregate) ruleCall.rels[0];
+ reduceAggs(ruleCall, oldAggRel);
+ }
+
+ /**
+ * Returns whether any of the aggregates are calls to AVG, STDDEV_*, VAR_*.
+ *
+ * @param aggCallList List of aggregate calls
+ */
+ private boolean containsAvgStddevVarCall(List aggCallList) {
+ return aggCallList.stream().anyMatch(this::canReduce);
+ }
+
+ /** Returns whether this rule can reduce a given aggregate function call. */
+ public boolean canReduce(AggregateCall call) {
+ return functionsToReduce.contains(call.getAggregation().getKind())
+ && config.extraCondition().test(call);
+ }
+
+ /**
+ * Returns whether this rule can reduce some agg-call, which its arg exists in the aggregate's
+ * group.
+ */
+ public boolean canReduceAggCallByGrouping(Aggregate oldAggRel, AggregateCall call) {
+ if (!Aggregate.isSimple(oldAggRel)) {
+ return false;
+ }
+ if (call.hasFilter()
+ || call.distinctKeys != null
+ || call.collation != RelCollations.EMPTY) {
+ return false;
+ }
+ final List argList = call.getArgList();
+ if (argList.size() != 1) {
+ return false;
+ }
+ if (!oldAggRel.getGroupSet().asSet().contains(argList.get(0))) {
+ // arg doesn't exist in aggregate's group.
+ return false;
+ }
+ final SqlKind kind = call.getAggregation().getKind();
+ switch (kind) {
+ case AVG:
+ case MAX:
+ case MIN:
+ case ANY_VALUE:
+ case FIRST_VALUE:
+ case LAST_VALUE:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ /**
+ * Reduces calls to functions AVG, SUM, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP, COVAR_POP,
+ * COVAR_SAMP, REGR_SXX, REGR_SYY if the function is present in {@link
+ * AggregateReduceFunctionsRule#functionsToReduce}
+ *
+ * It handles newly generated common subexpressions since this was done at the sql2rel stage.
+ */
+ private void reduceAggs(RelOptRuleCall ruleCall, Aggregate oldAggRel) {
+ RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
+
+ List oldCalls = oldAggRel.getAggCallList();
+ final int groupCount = oldAggRel.getGroupCount();
+
+ final List newCalls = new ArrayList<>();
+ final Map aggCallMapping = new HashMap<>();
+
+ final List projList = new ArrayList<>();
+
+ // pass through group key
+ for (int i = 0; i < groupCount; ++i) {
+ projList.add(rexBuilder.makeInputRef(oldAggRel, i));
+ }
+
+ // List of input expressions. If a particular aggregate needs more, it
+ // will add an expression to the end, and we will create an extra
+ // project.
+ final RelBuilder relBuilder = ruleCall.builder();
+ relBuilder.push(oldAggRel.getInput());
+ final List inputExprs = new ArrayList<>(relBuilder.fields());
+
+ // create new aggregate function calls and rest of project list together
+ for (AggregateCall oldCall : oldCalls) {
+ projList.add(reduceAgg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
+ }
+
+ final int extraArgCount =
+ inputExprs.size() - relBuilder.peek().getRowType().getFieldCount();
+ if (extraArgCount > 0) {
+ relBuilder.project(
+ inputExprs,
+ CompositeList.of(
+ relBuilder.peek().getRowType().getFieldNames(),
+ Collections.nCopies(extraArgCount, null)));
+ }
+ newAggregateRel(relBuilder, oldAggRel, newCalls);
+ newCalcRel(relBuilder, oldAggRel.getRowType(), projList);
+ final RelNode build = relBuilder.build();
+ ruleCall.transformTo(build);
+ }
+
+ private RexNode reduceAgg(
+ Aggregate oldAggRel,
+ AggregateCall oldCall,
+ List newCalls,
+ Map aggCallMapping,
+ List inputExprs) {
+ if (canReduceAggCallByGrouping(oldAggRel, oldCall)) {
+ // replace original MAX/MIN/AVG/ANY_VALUE/FIRST_VALUE/LAST_VALUE(x) with
+ // target field of x, when x exists in group
+ final RexNode reducedNode = reduceAggCallByGrouping(oldAggRel, oldCall);
+ return reducedNode;
+ } else if (canReduce(oldCall)) {
+ final Integer y;
+ final Integer x;
+ final SqlKind kind = oldCall.getAggregation().getKind();
+ switch (kind) {
+ case SUM:
+ // replace original SUM(x) with
+ // case COUNT(x) when 0 then null else SUM0(x) end
+ return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
+ case AVG:
+ // replace original AVG(x) with SUM(x) / COUNT(x)
+ return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs);
+ case COVAR_POP:
+ // replace original COVAR_POP(x, y) with
+ // (SUM(x * y) - SUM(y) * SUM(y) / COUNT(x))
+ // / COUNT(x))
+ return reduceCovariance(
+ oldAggRel, oldCall, true, newCalls, aggCallMapping, inputExprs);
+ case COVAR_SAMP:
+ // replace original COVAR_SAMP(x, y) with
+ // SQRT(
+ // (SUM(x * y) - SUM(x) * SUM(y) / COUNT(x))
+ // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END)
+ return reduceCovariance(
+ oldAggRel, oldCall, false, newCalls, aggCallMapping, inputExprs);
+ case REGR_SXX:
+ // replace original REGR_SXX(x, y) with
+ // REGR_COUNT(x, y) * VAR_POP(y)
+ assert oldCall.getArgList().size() == 2 : oldCall.getArgList();
+ x = oldCall.getArgList().get(0);
+ y = oldCall.getArgList().get(1);
+ //noinspection SuspiciousNameCombination
+ return reduceRegrSzz(
+ oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs, y, y, x);
+ case REGR_SYY:
+ // replace original REGR_SYY(x, y) with
+ // REGR_COUNT(x, y) * VAR_POP(x)
+ assert oldCall.getArgList().size() == 2 : oldCall.getArgList();
+ x = oldCall.getArgList().get(0);
+ y = oldCall.getArgList().get(1);
+ //noinspection SuspiciousNameCombination
+ return reduceRegrSzz(
+ oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs, x, x, y);
+ case STDDEV_POP:
+ // replace original STDDEV_POP(x) with
+ // SQRT(
+ // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
+ // / COUNT(x))
+ return reduceStddev(
+ oldAggRel, oldCall, true, true, newCalls, aggCallMapping, inputExprs);
+ case STDDEV_SAMP:
+ // replace original STDDEV_POP(x) with
+ // SQRT(
+ // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
+ // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END)
+ return reduceStddev(
+ oldAggRel, oldCall, false, true, newCalls, aggCallMapping, inputExprs);
+ case VAR_POP:
+ // replace original VAR_POP(x) with
+ // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
+ // / COUNT(x)
+ return reduceStddev(
+ oldAggRel, oldCall, true, false, newCalls, aggCallMapping, inputExprs);
+ case VAR_SAMP:
+ // replace original VAR_POP(x) with
+ // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
+ // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END
+ return reduceStddev(
+ oldAggRel, oldCall, false, false, newCalls, aggCallMapping, inputExprs);
+ default:
+ throw Util.unexpected(kind);
+ }
+ } else {
+ // anything else: preserve original call
+ RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
+ final int nGroups = oldAggRel.getGroupCount();
+ return rexBuilder.addAggCall(
+ oldCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ oldAggRel.getInput()::fieldIsNullable);
+ }
+ }
+
+ private static AggregateCall createAggregateCallWithBinding(
+ RelDataTypeFactory typeFactory,
+ SqlAggFunction aggFunction,
+ RelDataType operandType,
+ Aggregate oldAggRel,
+ AggregateCall oldCall,
+ int argOrdinal,
+ int filter) {
+ final Aggregate.AggCallBinding binding =
+ new Aggregate.AggCallBinding(
+ typeFactory,
+ aggFunction,
+ ImmutableList.of(operandType),
+ oldAggRel.getGroupCount(),
+ filter >= 0);
+ return AggregateCall.create(
+ aggFunction,
+ oldCall.isDistinct(),
+ oldCall.isApproximate(),
+ oldCall.ignoreNulls(),
+ ImmutableIntList.of(argOrdinal),
+ filter,
+ oldCall.distinctKeys,
+ oldCall.collation,
+ aggFunction.inferReturnType(binding),
+ null);
+ }
+
+ private static RexNode reduceAvg(
+ Aggregate oldAggRel,
+ AggregateCall oldCall,
+ List newCalls,
+ Map aggCallMapping,
+ @SuppressWarnings("unused") List inputExprs) {
+ final int nGroups = oldAggRel.getGroupCount();
+ final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
+ final AggregateCall sumCall =
+ AggregateCall.create(
+ SqlStdOperatorTable.SUM,
+ oldCall.isDistinct(),
+ oldCall.isApproximate(),
+ oldCall.ignoreNulls(),
+ oldCall.getArgList(),
+ oldCall.filterArg,
+ oldCall.distinctKeys,
+ oldCall.collation,
+ oldAggRel.getGroupCount(),
+ oldAggRel.getInput(),
+ null,
+ null);
+ final AggregateCall countCall =
+ AggregateCall.create(
+ SqlStdOperatorTable.COUNT,
+ oldCall.isDistinct(),
+ oldCall.isApproximate(),
+ oldCall.ignoreNulls(),
+ oldCall.getArgList(),
+ oldCall.filterArg,
+ oldCall.distinctKeys,
+ oldCall.collation,
+ oldAggRel.getGroupCount(),
+ oldAggRel.getInput(),
+ null,
+ null);
+
+ // NOTE: these references are with respect to the output
+ // of newAggRel
+ RexNode numeratorRef =
+ rexBuilder.addAggCall(
+ sumCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ oldAggRel.getInput()::fieldIsNullable);
+ final RexNode denominatorRef =
+ rexBuilder.addAggCall(
+ countCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ oldAggRel.getInput()::fieldIsNullable);
+
+ final RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
+ final RelDataType avgType =
+ typeFactory.createTypeWithNullability(
+ oldCall.getType(), numeratorRef.getType().isNullable());
+ numeratorRef = rexBuilder.ensureType(avgType, numeratorRef, true);
+ final RexNode divideRef =
+ rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
+ return rexBuilder.makeCast(oldCall.getType(), divideRef);
+ }
+
+ private static RexNode reduceSum(
+ Aggregate oldAggRel,
+ AggregateCall oldCall,
+ List newCalls,
+ Map aggCallMapping) {
+ final int nGroups = oldAggRel.getGroupCount();
+ RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
+
+ final AggregateCall sumZeroCall =
+ AggregateCall.create(
+ SqlStdOperatorTable.SUM0,
+ oldCall.isDistinct(),
+ oldCall.isApproximate(),
+ oldCall.ignoreNulls(),
+ oldCall.getArgList(),
+ oldCall.filterArg,
+ oldCall.distinctKeys,
+ oldCall.collation,
+ oldAggRel.getGroupCount(),
+ oldAggRel.getInput(),
+ null,
+ oldCall.name);
+ final AggregateCall countCall =
+ AggregateCall.create(
+ SqlStdOperatorTable.COUNT,
+ oldCall.isDistinct(),
+ oldCall.isApproximate(),
+ oldCall.ignoreNulls(),
+ oldCall.getArgList(),
+ oldCall.filterArg,
+ oldCall.distinctKeys,
+ oldCall.collation,
+ oldAggRel.getGroupCount(),
+ oldAggRel,
+ null,
+ null);
+
+ // NOTE: these references are with respect to the output
+ // of newAggRel
+ RexNode sumZeroRef =
+ rexBuilder.addAggCall(
+ sumZeroCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ oldAggRel.getInput()::fieldIsNullable);
+ if (!oldCall.getType().isNullable()) {
+ // If SUM(x) is not nullable, the validator must have determined that
+ // nulls are impossible (because the group is never empty and x is never
+ // null). Therefore we translate to SUM0(x).
+ return sumZeroRef;
+ }
+ RexNode countRef =
+ rexBuilder.addAggCall(
+ countCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ oldAggRel.getInput()::fieldIsNullable);
+ return rexBuilder.makeCall(
+ SqlStdOperatorTable.CASE,
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.EQUALS,
+ countRef,
+ rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
+ rexBuilder.makeNullLiteral(sumZeroRef.getType()),
+ sumZeroRef);
+ }
+
+ private static RexNode reduceStddev(
+ Aggregate oldAggRel,
+ AggregateCall oldCall,
+ boolean biased,
+ boolean sqrt,
+ List newCalls,
+ Map aggCallMapping,
+ List inputExprs) {
+ // stddev_pop(x) ==>
+ // power(
+ // (sum(x * x) - sum(x) * sum(x) / count(x))
+ // / count(x),
+ // .5)
+ //
+ // stddev_samp(x) ==>
+ // power(
+ // (sum(x * x) - sum(x) * sum(x) / count(x))
+ // / nullif(count(x) - 1, 0),
+ // .5)
+ final int nGroups = oldAggRel.getGroupCount();
+ final RelOptCluster cluster = oldAggRel.getCluster();
+ final RexBuilder rexBuilder = cluster.getRexBuilder();
+ final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
+
+ assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
+ final int argOrdinal = oldCall.getArgList().get(0);
+ final IntPredicate fieldIsNullable = oldAggRel.getInput()::fieldIsNullable;
+ final RelDataType oldCallType =
+ typeFactory.createTypeWithNullability(
+ oldCall.getType(), fieldIsNullable.test(argOrdinal));
+
+ final RexNode argRef = rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), true);
+
+ final RexNode argSquared =
+ rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
+ final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
+
+ // FLINK MODIFICATION BEGIN
+ final AggregateCall sumArgSquaredAggCall =
+ createAggregateCallWithBinding(
+ typeFactory,
+ SqlStdOperatorTable.SUM,
+ argSquared.getType(),
+ oldAggRel,
+ oldCall,
+ argSquaredOrdinal,
+ oldCall.filterArg);
+ // FLINK MODIFICATION END
+
+ final RexNode sumArgSquared =
+ rexBuilder.addAggCall(
+ sumArgSquaredAggCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ oldAggRel.getInput()::fieldIsNullable);
+
+ final AggregateCall sumArgAggCall =
+ AggregateCall.create(
+ SqlStdOperatorTable.SUM,
+ oldCall.isDistinct(),
+ oldCall.isApproximate(),
+ oldCall.ignoreNulls(),
+ ImmutableIntList.of(argOrdinal),
+ oldCall.filterArg,
+ oldCall.distinctKeys,
+ oldCall.collation,
+ oldAggRel.getGroupCount(),
+ oldAggRel.getInput(),
+ null,
+ null);
+
+ final RexNode sumArg =
+ rexBuilder.addAggCall(
+ sumArgAggCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ oldAggRel.getInput()::fieldIsNullable);
+ final RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true);
+ final RexNode sumSquaredArg =
+ rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArgCast, sumArgCast);
+
+ final AggregateCall countArgAggCall =
+ AggregateCall.create(
+ SqlStdOperatorTable.COUNT,
+ oldCall.isDistinct(),
+ oldCall.isApproximate(),
+ oldCall.ignoreNulls(),
+ oldCall.getArgList(),
+ oldCall.filterArg,
+ oldCall.distinctKeys,
+ oldCall.collation,
+ oldAggRel.getGroupCount(),
+ oldAggRel,
+ null,
+ null);
+
+ final RexNode countArg =
+ rexBuilder.addAggCall(
+ countArgAggCall,
+ nGroups,
+ newCalls,
+ aggCallMapping,
+ oldAggRel.getInput()::fieldIsNullable);
+
+ final RexNode div = divide(biased, rexBuilder, sumArgSquared, sumSquaredArg, countArg);
+
+ final RexNode result;
+ if (sqrt) {
+ final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
+ result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
+ } else {
+ result = div;
+ }
+
+ return rexBuilder.makeCast(oldCall.getType(), result);
+ }
+
+ private static RexNode reduceAggCallByGrouping(Aggregate oldAggRel, AggregateCall oldCall) {
+
+ final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
+ final List oldGroups = oldAggRel.getGroupSet().asList();
+ final Integer firstArg = oldCall.getArgList().get(0);
+ final int index = oldGroups.lastIndexOf(firstArg);
+ assert index >= 0;
+
+ final RexInputRef refByGroup = RexInputRef.of(index, oldAggRel.getRowType().getFieldList());
+ if (refByGroup.getType().equals(oldCall.getType())) {
+ return refByGroup;
+ } else {
+ return rexBuilder.makeCast(oldCall.getType(), refByGroup);
+ }
+ }
+
+ private static RexNode getSumAggregatedRexNode(
+ Aggregate oldAggRel,
+ AggregateCall oldCall,
+ List newCalls,
+ Map aggCallMapping,
+ RexBuilder rexBuilder,
+ int argOrdinal,
+ int filterArg) {
+ final AggregateCall aggregateCall =
+ AggregateCall.create(
+ SqlStdOperatorTable.SUM,
+ oldCall.isDistinct(),
+ oldCall.isApproximate(),
+ oldCall.ignoreNulls(),
+ ImmutableIntList.of(argOrdinal),
+ filterArg,
+ oldCall.distinctKeys,
+ oldCall.collation,
+ oldAggRel.getGroupCount(),
+ oldAggRel.getInput(),
+ null,
+ null);
+ return rexBuilder.addAggCall(
+ aggregateCall,
+ oldAggRel.getGroupCount(),
+ newCalls,
+ aggCallMapping,
+ oldAggRel.getInput()::fieldIsNullable);
+ }
+
+ private static RexNode getSumAggregatedRexNodeWithBinding(
+ Aggregate oldAggRel,
+ AggregateCall oldCall,
+ List newCalls,
+ Map aggCallMapping,
+ RelDataType operandType,
+ int argOrdinal,
+ int filter) {
+ RelOptCluster cluster = oldAggRel.getCluster();
+ final AggregateCall sumArgSquaredAggCall =
+ createAggregateCallWithBinding(
+ cluster.getTypeFactory(),
+ SqlStdOperatorTable.SUM,
+ operandType,
+ oldAggRel,
+ oldCall,
+ argOrdinal,
+ filter);
+
+ return cluster.getRexBuilder()
+ .addAggCall(
+ sumArgSquaredAggCall,
+ oldAggRel.getGroupCount(),
+ newCalls,
+ aggCallMapping,
+ oldAggRel.getInput()::fieldIsNullable);
+ }
+
+ private static RexNode getRegrCountRexNode(
+ Aggregate oldAggRel,
+ AggregateCall oldCall,
+ List newCalls,
+ Map aggCallMapping,
+ ImmutableIntList argOrdinals,
+ int filterArg) {
+ final AggregateCall countArgAggCall =
+ AggregateCall.create(
+ SqlStdOperatorTable.REGR_COUNT,
+ oldCall.isDistinct(),
+ oldCall.isApproximate(),
+ oldCall.ignoreNulls(),
+ argOrdinals,
+ filterArg,
+ oldCall.distinctKeys,
+ oldCall.collation,
+ oldAggRel.getGroupCount(),
+ oldAggRel,
+ null,
+ null);
+
+ return oldAggRel
+ .getCluster()
+ .getRexBuilder()
+ .addAggCall(
+ countArgAggCall,
+ oldAggRel.getGroupCount(),
+ newCalls,
+ aggCallMapping,
+ oldAggRel.getInput()::fieldIsNullable);
+ }
+
+ private static RexNode reduceRegrSzz(
+ Aggregate oldAggRel,
+ AggregateCall oldCall,
+ List newCalls,
+ Map aggCallMapping,
+ List inputExprs,
+ int xIndex,
+ int yIndex,
+ int nullFilterIndex) {
+ // regr_sxx(x, y) ==>
+ // sum(y * y, x) - sum(y, x) * sum(y, x) / regr_count(x, y)
+ //
+
+ final RelOptCluster cluster = oldAggRel.getCluster();
+ final RexBuilder rexBuilder = cluster.getRexBuilder();
+ final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
+ final IntPredicate fieldIsNullable = oldAggRel.getInput()::fieldIsNullable;
+
+ final RelDataType oldCallType =
+ typeFactory.createTypeWithNullability(
+ oldCall.getType(),
+ fieldIsNullable.test(xIndex)
+ || fieldIsNullable.test(yIndex)
+ || fieldIsNullable.test(nullFilterIndex));
+
+ final RexNode argX = rexBuilder.ensureType(oldCallType, inputExprs.get(xIndex), true);
+ final RexNode argY = rexBuilder.ensureType(oldCallType, inputExprs.get(yIndex), true);
+ final RexNode argNullFilter =
+ rexBuilder.ensureType(oldCallType, inputExprs.get(nullFilterIndex), true);
+
+ final RexNode argXArgY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argX, argY);
+ final int argSquaredOrdinal = lookupOrAdd(inputExprs, argXArgY);
+
+ final RexNode argXAndYNotNullFilter =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.AND,
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.AND,
+ rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argX),
+ rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argY)),
+ rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argNullFilter));
+ final int argXAndYNotNullFilterOrdinal = lookupOrAdd(inputExprs, argXAndYNotNullFilter);
+ final RexNode sumXY =
+ getSumAggregatedRexNodeWithBinding(
+ oldAggRel,
+ oldCall,
+ newCalls,
+ aggCallMapping,
+ argXArgY.getType(),
+ argSquaredOrdinal,
+ argXAndYNotNullFilterOrdinal);
+ final RexNode sumXYCast = rexBuilder.ensureType(oldCallType, sumXY, true);
+
+ final RexNode sumX =
+ getSumAggregatedRexNode(
+ oldAggRel,
+ oldCall,
+ newCalls,
+ aggCallMapping,
+ rexBuilder,
+ xIndex,
+ argXAndYNotNullFilterOrdinal);
+ final RexNode sumY =
+ xIndex == yIndex
+ ? sumX
+ : getSumAggregatedRexNode(
+ oldAggRel,
+ oldCall,
+ newCalls,
+ aggCallMapping,
+ rexBuilder,
+ yIndex,
+ argXAndYNotNullFilterOrdinal);
+
+ final RexNode sumXSumY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumX, sumY);
+
+ final RexNode countArg =
+ getRegrCountRexNode(
+ oldAggRel,
+ oldCall,
+ newCalls,
+ aggCallMapping,
+ ImmutableIntList.of(xIndex),
+ argXAndYNotNullFilterOrdinal);
+
+ RexLiteral zero = rexBuilder.makeExactLiteral(BigDecimal.ZERO);
+ RexNode nul = rexBuilder.makeNullLiteral(zero.getType());
+ final RexNode avgSumXSumY =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.CASE,
+ rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, zero),
+ nul,
+ rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumXSumY, countArg));
+ final RexNode avgSumXSumYCast = rexBuilder.ensureType(oldCallType, avgSumXSumY, true);
+ final RexNode result =
+ rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumXYCast, avgSumXSumYCast);
+ return rexBuilder.makeCast(oldCall.getType(), result);
+ }
+
+ private static RexNode reduceCovariance(
+ Aggregate oldAggRel,
+ AggregateCall oldCall,
+ boolean biased,
+ List newCalls,
+ Map aggCallMapping,
+ List inputExprs) {
+ // covar_pop(x, y) ==>
+ // (sum(x * y) - sum(x) * sum(y) / regr_count(x, y))
+ // / regr_count(x, y)
+ //
+ // covar_samp(x, y) ==>
+ // (sum(x * y) - sum(x) * sum(y) / regr_count(x, y))
+ // / regr_count(count(x, y) - 1, 0)
+ final RelOptCluster cluster = oldAggRel.getCluster();
+ final RexBuilder rexBuilder = cluster.getRexBuilder();
+ final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
+ assert oldCall.getArgList().size() == 2 : oldCall.getArgList();
+ final int argXOrdinal = oldCall.getArgList().get(0);
+ final int argYOrdinal = oldCall.getArgList().get(1);
+ final IntPredicate fieldIsNullable = oldAggRel.getInput()::fieldIsNullable;
+ final RelDataType oldCallType =
+ typeFactory.createTypeWithNullability(
+ oldCall.getType(),
+ fieldIsNullable.test(argXOrdinal) || fieldIsNullable.test(argYOrdinal));
+ final RexNode argX = rexBuilder.ensureType(oldCallType, inputExprs.get(argXOrdinal), true);
+ final RexNode argY = rexBuilder.ensureType(oldCallType, inputExprs.get(argYOrdinal), true);
+ final RexNode argXAndYNotNullFilter =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.AND,
+ rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argX),
+ rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argY));
+ final int argXAndYNotNullFilterOrdinal = lookupOrAdd(inputExprs, argXAndYNotNullFilter);
+ final RexNode argXY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argX, argY);
+ final int argXYOrdinal = lookupOrAdd(inputExprs, argXY);
+ final RexNode sumXY =
+ getSumAggregatedRexNodeWithBinding(
+ oldAggRel,
+ oldCall,
+ newCalls,
+ aggCallMapping,
+ argXY.getType(),
+ argXYOrdinal,
+ argXAndYNotNullFilterOrdinal);
+ final RexNode sumX =
+ getSumAggregatedRexNode(
+ oldAggRel,
+ oldCall,
+ newCalls,
+ aggCallMapping,
+ rexBuilder,
+ argXOrdinal,
+ argXAndYNotNullFilterOrdinal);
+ final RexNode sumY =
+ getSumAggregatedRexNode(
+ oldAggRel,
+ oldCall,
+ newCalls,
+ aggCallMapping,
+ rexBuilder,
+ argYOrdinal,
+ argXAndYNotNullFilterOrdinal);
+ final RexNode sumXSumY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumX, sumY);
+ final RexNode countArg =
+ getRegrCountRexNode(
+ oldAggRel,
+ oldCall,
+ newCalls,
+ aggCallMapping,
+ ImmutableIntList.of(argXOrdinal, argYOrdinal),
+ argXAndYNotNullFilterOrdinal);
+ final RexNode result = divide(biased, rexBuilder, sumXY, sumXSumY, countArg);
+ return rexBuilder.makeCast(oldCall.getType(), result);
+ }
+
+ private static RexNode divide(
+ boolean biased,
+ RexBuilder rexBuilder,
+ RexNode sumXY,
+ RexNode sumXSumY,
+ RexNode countArg) {
+ final RexNode avgSumSquaredArg =
+ rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumXSumY, countArg);
+ final RexNode diff =
+ rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumXY, avgSumSquaredArg);
+ final RexNode denominator;
+ if (biased) {
+ denominator = countArg;
+ } else {
+ final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
+ final RexNode nul = rexBuilder.makeNullLiteral(countArg.getType());
+ final RexNode countMinusOne =
+ rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
+ final RexNode countEqOne =
+ rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
+ denominator =
+ rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
+ }
+ return rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, diff, denominator);
+ }
+
+ /**
+ * Finds the ordinal of an element in a list, or adds it.
+ *
+ * @param list List
+ * @param element Element to lookup or add
+ * @param Element type
+ * @return Ordinal of element in list
+ */
+ private static int lookupOrAdd(List list, T element) {
+ int ordinal = list.indexOf(element);
+ if (ordinal == -1) {
+ ordinal = list.size();
+ list.add(element);
+ }
+ return ordinal;
+ }
+
+ /**
+ * Does a shallow clone of oldAggRel and updates aggCalls. Could be refactored into Aggregate
+ * and subclasses - but it's only needed for some subclasses.
+ *
+ * @param relBuilder Builder of relational expressions; at the top of its stack is its input
+ * @param oldAggregate LogicalAggregate to clone.
+ * @param newCalls New list of AggregateCalls
+ */
+ protected void newAggregateRel(
+ RelBuilder relBuilder, Aggregate oldAggregate, List newCalls) {
+ relBuilder.aggregate(
+ relBuilder.groupKey(oldAggregate.getGroupSet(), oldAggregate.getGroupSets()),
+ newCalls);
+ }
+
+ /**
+ * Adds a calculation with the expressions to compute the original aggregate calls from the
+ * decomposed ones.
+ *
+ * @param relBuilder Builder of relational expressions; at the top of its stack is its input
+ * @param rowType The output row type of the original aggregate.
+ * @param exprs The expressions to compute the original aggregate calls
+ */
+ protected void newCalcRel(RelBuilder relBuilder, RelDataType rowType, List exprs) {
+ relBuilder.project(exprs, rowType.getFieldNames());
+ }
+
+ /** Rule configuration. */
+ @Value.Immutable
+ public interface Config extends RelRule.Config {
+ Config DEFAULT =
+ ImmutableAggregateReduceFunctionsRule.Config.of()
+ .withOperandFor(LogicalAggregate.class);
+
+ Set DEFAULT_FUNCTIONS_TO_REDUCE =
+ ImmutableSet.builder()
+ .addAll(SqlKind.AVG_AGG_FUNCTIONS)
+ .addAll(SqlKind.COVAR_AVG_AGG_FUNCTIONS)
+ .add(SqlKind.SUM)
+ .build();
+
+ @Override
+ default AggregateReduceFunctionsRule toRule() {
+ return new AggregateReduceFunctionsRule(this);
+ }
+
+ /**
+ * The set of aggregate function types to try to reduce.
+ *
+ * Any aggregate function whose type is omitted from this set, OR which does not pass the
+ * {@link #extraCondition}, will be ignored.
+ */
+ @Nullable Set functionsToReduce();
+
+ /**
+ * A test that must pass before attempting to reduce any aggregate function.
+ *
+ * Any aggegate function which does not pass, OR whose type is omitted from {@link
+ * #functionsToReduce}, will be ignored. The default predicate always passes.
+ */
+ @Value.Default
+ default Predicate extraCondition() {
+ return ignored -> true;
+ }
+
+ /** Sets {@link #functionsToReduce}. */
+ Config withFunctionsToReduce(@Nullable Iterable functionSet);
+
+ default Config withFunctionsToReduce(@Nullable Set functionSet) {
+ return withFunctionsToReduce((Iterable) functionSet);
+ }
+
+ /** Sets {@link #extraCondition}. */
+ Config withExtraCondition(Predicate test);
+
+ /**
+ * Returns the validated set of functions to reduce, or the default set if not specified.
+ */
+ default Set actualFunctionsToReduce() {
+ final Set set = Util.first(functionsToReduce(), DEFAULT_FUNCTIONS_TO_REDUCE);
+ set.forEach(AggregateReduceFunctionsRule::validateFunction);
+ return set;
+ }
+
+ /** Defines an operand tree for the given classes. */
+ default Config withOperandFor(Class extends Aggregate> aggregateClass) {
+ return withOperandSupplier(b -> b.operand(aggregateClass).anyInputs()).as(Config.class);
+ }
+ }
+}
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceFunctionsRuleTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceFunctionsRuleTest.java
new file mode 100644
index 0000000000000..fc2325587592d
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceFunctionsRuleTest.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.planner.calcite.CalciteConfig;
+import org.apache.flink.table.planner.plan.optimize.program.BatchOptimizeContext;
+import org.apache.flink.table.planner.plan.optimize.program.FlinkBatchProgram;
+import org.apache.flink.table.planner.plan.optimize.program.FlinkHepRuleSetProgramBuilder;
+import org.apache.flink.table.planner.plan.optimize.program.HEP_RULES_EXECUTION_TYPE;
+import org.apache.flink.table.planner.utils.BatchTableTestUtil;
+import org.apache.flink.table.planner.utils.TableConfigUtils;
+import org.apache.flink.table.planner.utils.TableTestBase;
+
+import org.apache.calcite.plan.hep.HepMatchOrder;
+import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule;
+import org.apache.calcite.rel.rules.CoreRules;
+import org.apache.calcite.tools.RuleSets;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+/** Test for {@link AggregateReduceFunctionsRule}. */
+public class AggregateReduceFunctionsRuleTest extends TableTestBase {
+
+ private BatchTableTestUtil util;
+
+ @BeforeEach
+ void setup() {
+ util = batchTestUtil(TableConfig.getDefault());
+ util.buildBatchProgram(FlinkBatchProgram.DEFAULT_REWRITE());
+ CalciteConfig calciteConfig =
+ TableConfigUtils.getCalciteConfig(util.tableEnv().getConfig());
+ calciteConfig
+ .getBatchProgram()
+ .get()
+ .addLast(
+ "rules",
+ FlinkHepRuleSetProgramBuilder.newBuilder()
+ .setHepRulesExecutionType(
+ HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION())
+ .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
+ .add(RuleSets.ofList(CoreRules.AGGREGATE_REDUCE_FUNCTIONS))
+ .build());
+ util.tableEnv()
+ .executeSql(
+ "CREATE TABLE src (\n"
+ + " a VARCHAR,\n"
+ + " b BIGINT\n"
+ + ") WITH (\n"
+ + " 'connector' = 'values',\n"
+ + " 'bounded' = 'true'\n"
+ + ")");
+ }
+
+ @Test
+ void testVarianceStddevWithFilter() {
+ util.verifyRelPlan(
+ "SELECT a, \n"
+ + "STDDEV_POP(b) FILTER (WHERE b > 10), \n"
+ + "STDDEV_SAMP(b) FILTER (WHERE b > 20), \n"
+ + "VAR_POP(b) FILTER (WHERE b > 30), \n"
+ + "VAR_SAMP(b) FILTER (WHERE b > 40), \n"
+ + "AVG(b) FILTER (WHERE b > 50)\n"
+ + "FROM src GROUP BY a");
+ }
+}
diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceFunctionsRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceFunctionsRuleTest.xml
new file mode 100644
index 0000000000000..c2fceebec56ec
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceFunctionsRuleTest.xml
@@ -0,0 +1,47 @@
+
+
+
+
+
+ 10),
+STDDEV_SAMP(b) FILTER (WHERE b > 20),
+VAR_POP(b) FILTER (WHERE b > 30),
+VAR_SAMP(b) FILTER (WHERE b > 40),
+AVG(b) FILTER (WHERE b > 50)
+FROM src GROUP BY a]]>
+
+
+ ($1, 10))], $f3=[IS TRUE(>($1, 20))], $f4=[IS TRUE(>($1, 30))], $f5=[IS TRUE(>($1, 40))], $f6=[IS TRUE(>($1, 50))])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src]])
+]]>
+
+
+ ($1, 10))], $f3=[IS TRUE(>($1, 20))], $f4=[IS TRUE(>($1, 30))], $f5=[IS TRUE(>($1, 40))], $f6=[IS TRUE(>($1, 50))])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src]])
+]]>
+
+
+