Skip to content

Commit 21af258

Browse files
authored
Try to eliminate redundant Project and Sort For right table of Join clause in some self-join cases
1 parent f8bcf2a commit 21af258

33 files changed

+713
-496
lines changed

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java

+5-10
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,16 @@
3737
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedSumAccumulator;
3838
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedVarianceAccumulator;
3939
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
40-
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
4140

4241
import org.apache.tsfile.enums.TSDataType;
4342

4443
import java.util.List;
4544
import java.util.Map;
4645

4746
import static com.google.common.base.Preconditions.checkState;
48-
import static org.apache.iotdb.commons.schema.table.TsTable.TIME_COLUMN_NAME;
4947
import static org.apache.iotdb.db.queryengine.plan.relational.metadata.TableBuiltinAggregationFunction.FIRST_BY;
5048
import static org.apache.iotdb.db.queryengine.plan.relational.metadata.TableBuiltinAggregationFunction.LAST_BY;
49+
import static org.apache.iotdb.db.queryengine.plan.relational.planner.ir.GlobalTimePredicateExtractVisitor.isTimeColumn;
5150

5251
public class AccumulatorFactory {
5352

@@ -57,7 +56,8 @@ public static TableAccumulator createAccumulator(
5756
List<TSDataType> inputDataTypes,
5857
List<Expression> inputExpressions,
5958
Map<String, String> inputAttributes,
60-
boolean ascending) {
59+
boolean ascending,
60+
String timeColumnName) {
6161
if (aggregationType == TAggregationType.UDAF) {
6262
// If UDAF accumulator receives raw input, it needs to check input's attribute
6363
throw new UnsupportedOperationException();
@@ -66,9 +66,9 @@ public static TableAccumulator createAccumulator(
6666
&& inputExpressions.size() > 1) {
6767
boolean xIsTimeColumn = false;
6868
boolean yIsTimeColumn = false;
69-
if (isTimeColumn(inputExpressions.get(1))) {
69+
if (isTimeColumn(inputExpressions.get(1), timeColumnName)) {
7070
yIsTimeColumn = true;
71-
} else if (isTimeColumn(inputExpressions.get(0))) {
71+
} else if (isTimeColumn(inputExpressions.get(0), timeColumnName)) {
7272
xIsTimeColumn = true;
7373
}
7474
if (LAST_BY.getFunctionName().equals(functionName)) {
@@ -326,9 +326,4 @@ private static TableAccumulator createBuiltinSingleInputAccumulator(
326326
public interface KeepEvaluator {
327327
boolean apply(long keep);
328328
}
329-
330-
public static boolean isTimeColumn(Expression expression) {
331-
return expression instanceof SymbolReference
332-
&& TIME_COLUMN_NAME.equals(((SymbolReference) expression).getName());
333-
}
334329
}

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/PredicateUtils.java

+5-4
Original file line numberDiff line numberDiff line change
@@ -309,16 +309,17 @@ public static Filter convertPredicateToFilter(
309309

310310
public static Filter convertPredicateToFilter(
311311
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression predicate,
312-
List<String> allMeasurements,
313-
Map<Symbol, ColumnSchema> schemaMap) {
312+
Map<String, Integer> measurementColumnsIndexMap,
313+
Map<Symbol, ColumnSchema> schemaMap,
314+
String timeColumnName) {
314315
if (predicate == null) {
315316
return null;
316317
}
317318
return predicate.accept(
318319
new org.apache.iotdb.db.queryengine.plan.relational.analyzer.predicate
319-
.ConvertPredicateToFilterVisitor(),
320+
.ConvertPredicateToFilterVisitor(timeColumnName),
320321
new org.apache.iotdb.db.queryengine.plan.relational.analyzer.predicate
321-
.ConvertPredicateToFilterVisitor.Context(allMeasurements, schemaMap));
322+
.ConvertPredicateToFilterVisitor.Context(measurementColumnsIndexMap, schemaMap));
322323
}
323324

324325
/**

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/memory/TableModelStatementMemorySourceVisitor.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@ public StatementMemorySource visitExplain(
8383
final TableDistributedPlanGenerator.PlanContext planContext =
8484
new TableDistributedPlanGenerator.PlanContext();
8585
final PlanNode outputNodeWithExchange =
86-
new TableDistributedPlanner(context.getAnalysis(), symbolAllocator, logicalPlan)
86+
new TableDistributedPlanner(
87+
context.getAnalysis(),
88+
symbolAllocator,
89+
logicalPlan,
90+
LocalExecutionPlanner.getInstance().metadata)
8791
.generateDistributedPlanWithOptimize(planContext);
8892

8993
final List<String> lines =

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java

+86-41
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,12 @@
176176
import static com.google.common.collect.ImmutableList.toImmutableList;
177177
import static java.util.Objects.requireNonNull;
178178
import static org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory.MEASUREMENT;
179+
import static org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory.TIME;
179180
import static org.apache.iotdb.db.queryengine.common.DataNodeEndPoints.isSameNode;
180181
import static org.apache.iotdb.db.queryengine.execution.operator.process.join.merge.MergeSortComparator.getComparatorForTable;
181182
import static org.apache.iotdb.db.queryengine.execution.operator.source.relational.TableScanOperator.constructAlignedPath;
182183
import static org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AccumulatorFactory.createAccumulator;
183184
import static org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AccumulatorFactory.createGroupedAccumulator;
184-
import static org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AccumulatorFactory.isTimeColumn;
185185
import static org.apache.iotdb.db.queryengine.plan.analyze.PredicateUtils.convertPredicateToFilter;
186186
import static org.apache.iotdb.db.queryengine.plan.planner.OperatorTreeGenerator.ASC_TIME_COMPARATOR;
187187
import static org.apache.iotdb.db.queryengine.plan.planner.OperatorTreeGenerator.IDENTITY_FILL;
@@ -190,6 +190,7 @@
190190
import static org.apache.iotdb.db.queryengine.plan.planner.OperatorTreeGenerator.getPreviousFill;
191191
import static org.apache.iotdb.db.queryengine.plan.relational.metadata.TableBuiltinAggregationFunction.getAggregationTypeByFuncName;
192192
import static org.apache.iotdb.db.queryengine.plan.relational.planner.SortOrder.ASC_NULLS_LAST;
193+
import static org.apache.iotdb.db.queryengine.plan.relational.planner.ir.GlobalTimePredicateExtractVisitor.isTimeColumn;
193194
import static org.apache.iotdb.db.queryengine.plan.relational.type.InternalTypeManager.getTSDataType;
194195
import static org.apache.iotdb.db.utils.constant.SqlConstant.AVG;
195196
import static org.apache.iotdb.db.utils.constant.SqlConstant.COUNT;
@@ -309,6 +310,8 @@ public Operator visitTableScan(TableScanNode node, LocalExecutionPlanContext con
309310
Map<Symbol, ColumnSchema> columnSchemaMap = node.getAssignments();
310311
Map<Symbol, Integer> idAndAttributeColumnsIndexMap = node.getIdAndAttributeIndexMap();
311312
List<String> measurementColumnNames = new ArrayList<>();
313+
Map<String, Integer> measurementColumnsIndexMap = new HashMap<>();
314+
String timeColumnName = null;
312315
List<IMeasurementSchema> measurementSchemas = new ArrayList<>();
313316
int measurementColumnCount = 0;
314317
int idx = 0;
@@ -327,14 +330,16 @@ public Operator visitTableScan(TableScanNode node, LocalExecutionPlanContext con
327330
case MEASUREMENT:
328331
columnsIndexArray[idx++] = measurementColumnCount;
329332
measurementColumnCount++;
330-
measurementColumnNames.add(columnName.getName());
333+
measurementColumnNames.add(schema.getName());
331334
measurementSchemas.add(
332335
new MeasurementSchema(schema.getName(), getTSDataType(schema.getType())));
333336
columnSchemas.add(schema);
337+
measurementColumnsIndexMap.put(columnName.getName(), measurementColumnCount - 1);
334338
break;
335339
case TIME:
336340
columnsIndexArray[idx++] = -1;
337341
columnSchemas.add(schema);
342+
timeColumnName = columnName.getName();
338343
break;
339344
default:
340345
throw new IllegalArgumentException(
@@ -347,17 +352,20 @@ public Operator visitTableScan(TableScanNode node, LocalExecutionPlanContext con
347352
if (!outputSet.contains(entry.getKey())
348353
&& entry.getValue().getColumnCategory() == MEASUREMENT) {
349354
measurementColumnCount++;
350-
measurementColumnNames.add(entry.getKey().getName());
355+
measurementColumnNames.add(entry.getValue().getName());
351356
measurementSchemas.add(
352357
new MeasurementSchema(
353358
entry.getValue().getName(), getTSDataType(entry.getValue().getType())));
359+
measurementColumnsIndexMap.put(entry.getKey().getName(), measurementColumnCount - 1);
360+
} else if (entry.getValue().getColumnCategory() == TIME) {
361+
timeColumnName = entry.getKey().getName();
354362
}
355363
}
356364

357365
SeriesScanOptions.Builder scanOptionsBuilder =
358-
node.getTimePredicate()
359-
.map(timePredicate -> getSeriesScanOptionsBuilder(context, timePredicate))
360-
.orElse(new SeriesScanOptions.Builder());
366+
node.getTimePredicate().isPresent()
367+
? getSeriesScanOptionsBuilder(context, node.getTimePredicate().get())
368+
: new SeriesScanOptions.Builder();
361369
scanOptionsBuilder.withPushDownLimit(node.getPushDownLimit());
362370
scanOptionsBuilder.withPushDownOffset(node.getPushDownOffset());
363371
scanOptionsBuilder.withPushLimitToEachDevice(node.isPushLimitToEachDevice());
@@ -366,7 +374,8 @@ public Operator visitTableScan(TableScanNode node, LocalExecutionPlanContext con
366374
Expression pushDownPredicate = node.getPushDownPredicate();
367375
if (pushDownPredicate != null) {
368376
scanOptionsBuilder.withPushDownFilter(
369-
convertPredicateToFilter(pushDownPredicate, measurementColumnNames, columnSchemaMap));
377+
convertPredicateToFilter(
378+
pushDownPredicate, measurementColumnsIndexMap, columnSchemaMap, timeColumnName));
370379
}
371380

372381
OperatorContext operatorContext =
@@ -1178,19 +1187,38 @@ public Operator visitJoin(JoinNode node, LocalExecutionPlanContext context) {
11781187
Operator leftChild = node.getLeftChild().accept(this, context);
11791188
Operator rightChild = node.getRightChild().accept(this, context);
11801189

1181-
int leftTimeColumnPosition =
1182-
node.getLeftChild().getOutputSymbols().indexOf(node.getCriteria().get(0).getLeft());
1190+
ImmutableMap<Symbol, Integer> leftColumnNamesMap =
1191+
makeLayoutFromOutputSymbols(node.getLeftChild().getOutputSymbols());
1192+
Integer leftTimeColumnPosition = leftColumnNamesMap.get(node.getCriteria().get(0).getLeft());
1193+
if (leftTimeColumnPosition == null) {
1194+
throw new IllegalStateException("Left child of JoinNode doesn't contain time column");
1195+
}
11831196
int[] leftOutputSymbolIdx = new int[node.getLeftOutputSymbols().size()];
11841197
for (int i = 0; i < leftOutputSymbolIdx.length; i++) {
1185-
leftOutputSymbolIdx[i] =
1186-
node.getLeftChild().getOutputSymbols().indexOf(node.getLeftOutputSymbols().get(i));
1198+
Integer index = leftColumnNamesMap.get(node.getLeftOutputSymbols().get(i));
1199+
if (index == null) {
1200+
throw new IllegalStateException(
1201+
"Left child of JoinNode doesn't contain LeftOutputSymbol "
1202+
+ node.getLeftOutputSymbols().get(i));
1203+
}
1204+
leftOutputSymbolIdx[i] = index;
1205+
}
1206+
1207+
ImmutableMap<Symbol, Integer> rightColumnNamesMap =
1208+
makeLayoutFromOutputSymbols(node.getRightChild().getOutputSymbols());
1209+
Integer rightTimeColumnPosition = rightColumnNamesMap.get(node.getCriteria().get(0).getRight());
1210+
if (rightTimeColumnPosition == null) {
1211+
throw new IllegalStateException("Right child of JoinNode doesn't contain time column");
11871212
}
1188-
int rightTimeColumnPosition =
1189-
node.getRightChild().getOutputSymbols().indexOf(node.getCriteria().get(0).getRight());
11901213
int[] rightOutputSymbolIdx = new int[node.getRightOutputSymbols().size()];
11911214
for (int i = 0; i < rightOutputSymbolIdx.length; i++) {
1192-
rightOutputSymbolIdx[i] =
1193-
node.getRightChild().getOutputSymbols().indexOf(node.getRightOutputSymbols().get(i));
1215+
Integer index = rightColumnNamesMap.get(node.getRightOutputSymbols().get(i));
1216+
if (index == null) {
1217+
throw new IllegalStateException(
1218+
"Right child of JoinNode doesn't contain RightOutputSymbol "
1219+
+ node.getLeftOutputSymbols().get(i));
1220+
}
1221+
rightOutputSymbolIdx[i] = index;
11941222
}
11951223

11961224
if (requireNonNull(node.getJoinType()) == JoinNode.JoinType.INNER) {
@@ -1364,17 +1392,20 @@ private Operator planGlobalAggregation(
13641392
aggregationMap.get(symbol),
13651393
node.getStep(),
13661394
typeProvider,
1367-
true)));
1395+
true,
1396+
null)));
13681397
return new AggregationOperator(context, child, aggregatorBuilder.build());
13691398
}
13701399

1400+
// timeColumnName will only be set for AggTableScan.
13711401
private TableAggregator buildAggregator(
13721402
Map<Symbol, Integer> childLayout,
13731403
Symbol symbol,
13741404
AggregationNode.Aggregation aggregation,
13751405
AggregationNode.Step step,
13761406
TypeProvider typeProvider,
1377-
boolean scanAscending) {
1407+
boolean scanAscending,
1408+
String timeColumnName) {
13781409
List<Integer> argumentChannels = new ArrayList<>();
13791410
for (Expression argument : aggregation.getArguments()) {
13801411
Symbol argumentSymbol = Symbol.from(argument);
@@ -1393,7 +1424,8 @@ private TableAggregator buildAggregator(
13931424
originalArgumentTypes,
13941425
aggregation.getArguments(),
13951426
Collections.emptyMap(),
1396-
scanAscending);
1427+
scanAscending,
1428+
timeColumnName);
13971429

13981430
return new TableAggregator(
13991431
accumulator,
@@ -1424,7 +1456,8 @@ private Operator planGroupByAggregation(
14241456
.forEach(
14251457
(k, v) ->
14261458
aggregatorBuilder.add(
1427-
buildAggregator(childLayout, k, v, node.getStep(), typeProvider, true)));
1459+
buildAggregator(
1460+
childLayout, k, v, node.getStep(), typeProvider, true, null)));
14281461

14291462
return new StreamingAggregationOperator(
14301463
operatorContext,
@@ -1563,9 +1596,6 @@ public Operator visitAggregationTableScan(
15631596
List<TableAggregator> aggregators = new ArrayList<>(node.getAggregations().size());
15641597
Map<Symbol, Integer> columnLayout = new HashMap<>(node.getAggregations().size());
15651598

1566-
boolean[] ret = checkStatisticAndScanOrder(node);
1567-
boolean canUseStatistic = ret[0];
1568-
boolean scanAscending = ret[1];
15691599
int distinctArgumentCount = node.getAssignments().size();
15701600
int aggregationsCount = node.getAggregations().size();
15711601
List<Integer> aggColumnIndexes = new ArrayList<>();
@@ -1577,6 +1607,8 @@ public Operator visitAggregationTableScan(
15771607
List<ColumnSchema> columnSchemas = new ArrayList<>(aggregationsCount);
15781608
int[] columnsIndexArray = new int[distinctArgumentCount];
15791609
List<String> measurementColumnNames = new ArrayList<>();
1610+
Map<String, Integer> measurementColumnsIndexMap = new HashMap<>();
1611+
String timeColumnName = null;
15801612
List<IMeasurementSchema> measurementSchemas = new ArrayList<>();
15811613

15821614
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
@@ -1599,16 +1631,18 @@ public Operator visitAggregationTableScan(
15991631
if (!columnLayout.containsKey(symbol)) {
16001632
columnsIndexArray[channel] = measurementColumnCount;
16011633
measurementColumnCount++;
1602-
measurementColumnNames.add(symbol.getName());
1634+
measurementColumnNames.add(schema.getName());
16031635
measurementSchemas.add(
16041636
new MeasurementSchema(schema.getName(), getTSDataType(schema.getType())));
16051637
columnSchemas.add(schema);
1638+
measurementColumnsIndexMap.put(symbol.getName(), measurementColumnCount - 1);
16061639
}
16071640
break;
16081641
case TIME:
16091642
if (!columnLayout.containsKey(symbol)) {
16101643
columnsIndexArray[channel] = -1;
16111644
columnSchemas.add(schema);
1645+
timeColumnName = symbol.getName();
16121646
}
16131647
break;
16141648
default:
@@ -1623,29 +1657,38 @@ public Operator visitAggregationTableScan(
16231657
aggColumnIndexes.add(columnLayout.get(symbol));
16241658
}
16251659
}
1626-
1627-
aggregators.add(
1628-
buildAggregator(
1629-
columnLayout,
1630-
entry.getKey(),
1631-
entry.getValue(),
1632-
node.getStep(),
1633-
context.getTypeProvider(),
1634-
scanAscending));
16351660
}
16361661

1637-
// TODO if this needed?
16381662
for (Map.Entry<Symbol, ColumnSchema> entry : node.getAssignments().entrySet()) {
16391663
if (!columnLayout.containsKey(entry.getKey())
16401664
&& entry.getValue().getColumnCategory() == MEASUREMENT) {
16411665
measurementColumnCount++;
1642-
measurementColumnNames.add(entry.getKey().getName());
1666+
measurementColumnNames.add(entry.getValue().getName());
16431667
measurementSchemas.add(
16441668
new MeasurementSchema(
16451669
entry.getValue().getName(), getTSDataType(entry.getValue().getType())));
1670+
measurementColumnsIndexMap.put(entry.getKey().getName(), measurementColumnCount - 1);
1671+
} else if (entry.getValue().getColumnCategory() == TIME) {
1672+
timeColumnName = entry.getKey().getName();
16461673
}
16471674
}
16481675

1676+
boolean[] ret = checkStatisticAndScanOrder(node, timeColumnName);
1677+
boolean canUseStatistic = ret[0];
1678+
boolean scanAscending = ret[1];
1679+
1680+
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
1681+
aggregators.add(
1682+
buildAggregator(
1683+
columnLayout,
1684+
entry.getKey(),
1685+
entry.getValue(),
1686+
node.getStep(),
1687+
context.getTypeProvider(),
1688+
scanAscending,
1689+
timeColumnName));
1690+
}
1691+
16491692
ITableTimeRangeIterator timeRangeIterator = null;
16501693
List<ColumnSchema> groupingKeySchemas = null;
16511694
int[] groupingKeyIndex = null;
@@ -1699,17 +1742,18 @@ public Operator visitAggregationTableScan(
16991742
node.getPlanNodeId(),
17001743
AggregationTableScanNode.class.getSimpleName());
17011744
SeriesScanOptions.Builder scanOptionsBuilder =
1702-
node.getTimePredicate()
1703-
.map(timePredicate -> getSeriesScanOptionsBuilder(context, timePredicate))
1704-
.orElse(new SeriesScanOptions.Builder());
1745+
node.getTimePredicate().isPresent()
1746+
? getSeriesScanOptionsBuilder(context, node.getTimePredicate().get())
1747+
: new SeriesScanOptions.Builder();
17051748
scanOptionsBuilder.withPushDownLimit(node.getPushDownLimit());
17061749
scanOptionsBuilder.withPushDownOffset(node.getPushDownOffset());
17071750
scanOptionsBuilder.withPushLimitToEachDevice(node.isPushLimitToEachDevice());
17081751
scanOptionsBuilder.withAllSensors(new HashSet<>(measurementColumnNames));
17091752
Expression pushDownPredicate = node.getPushDownPredicate();
17101753
if (pushDownPredicate != null) {
17111754
scanOptionsBuilder.withPushDownFilter(
1712-
convertPredicateToFilter(pushDownPredicate, measurementColumnNames, columnSchemaMap));
1755+
convertPredicateToFilter(
1756+
pushDownPredicate, measurementColumnsIndexMap, columnSchemaMap, timeColumnName));
17131757
}
17141758

17151759
Set<String> allSensors = new HashSet<>(measurementColumnNames);
@@ -1755,7 +1799,8 @@ public Operator visitAggregationTableScan(
17551799
return aggTableScanOperator;
17561800
}
17571801

1758-
private boolean[] checkStatisticAndScanOrder(AggregationTableScanNode node) {
1802+
private boolean[] checkStatisticAndScanOrder(
1803+
AggregationTableScanNode node, String timeColumnName) {
17591804
boolean canUseStatistic = true;
17601805
int ascendingCount = 0, descendingCount = 0;
17611806

@@ -1797,8 +1842,8 @@ private boolean[] checkStatisticAndScanOrder(AggregationTableScanNode node) {
17971842

17981843
// only last_by(time, x) or last_by(x,time) can use statistic
17991844
if ((LAST_BY_AGGREGATION.equals(funcName) || FIRST_BY_AGGREGATION.equals(funcName))
1800-
&& !isTimeColumn(aggregation.getArguments().get(0))
1801-
&& !isTimeColumn(aggregation.getArguments().get(1))) {
1845+
&& !isTimeColumn(aggregation.getArguments().get(0), timeColumnName)
1846+
&& !isTimeColumn(aggregation.getArguments().get(1), timeColumnName)) {
18021847
canUseStatistic = false;
18031848
}
18041849
break;

0 commit comments

Comments
 (0)