176
176
import static com .google .common .collect .ImmutableList .toImmutableList ;
177
177
import static java .util .Objects .requireNonNull ;
178
178
import static org .apache .iotdb .commons .schema .table .column .TsTableColumnCategory .MEASUREMENT ;
179
+ import static org .apache .iotdb .commons .schema .table .column .TsTableColumnCategory .TIME ;
179
180
import static org .apache .iotdb .db .queryengine .common .DataNodeEndPoints .isSameNode ;
180
181
import static org .apache .iotdb .db .queryengine .execution .operator .process .join .merge .MergeSortComparator .getComparatorForTable ;
181
182
import static org .apache .iotdb .db .queryengine .execution .operator .source .relational .TableScanOperator .constructAlignedPath ;
182
183
import static org .apache .iotdb .db .queryengine .execution .operator .source .relational .aggregation .AccumulatorFactory .createAccumulator ;
183
184
import static org .apache .iotdb .db .queryengine .execution .operator .source .relational .aggregation .AccumulatorFactory .createGroupedAccumulator ;
184
- import static org .apache .iotdb .db .queryengine .execution .operator .source .relational .aggregation .AccumulatorFactory .isTimeColumn ;
185
185
import static org .apache .iotdb .db .queryengine .plan .analyze .PredicateUtils .convertPredicateToFilter ;
186
186
import static org .apache .iotdb .db .queryengine .plan .planner .OperatorTreeGenerator .ASC_TIME_COMPARATOR ;
187
187
import static org .apache .iotdb .db .queryengine .plan .planner .OperatorTreeGenerator .IDENTITY_FILL ;
190
190
import static org .apache .iotdb .db .queryengine .plan .planner .OperatorTreeGenerator .getPreviousFill ;
191
191
import static org .apache .iotdb .db .queryengine .plan .relational .metadata .TableBuiltinAggregationFunction .getAggregationTypeByFuncName ;
192
192
import static org .apache .iotdb .db .queryengine .plan .relational .planner .SortOrder .ASC_NULLS_LAST ;
193
+ import static org .apache .iotdb .db .queryengine .plan .relational .planner .ir .GlobalTimePredicateExtractVisitor .isTimeColumn ;
193
194
import static org .apache .iotdb .db .queryengine .plan .relational .type .InternalTypeManager .getTSDataType ;
194
195
import static org .apache .iotdb .db .utils .constant .SqlConstant .AVG ;
195
196
import static org .apache .iotdb .db .utils .constant .SqlConstant .COUNT ;
@@ -309,6 +310,8 @@ public Operator visitTableScan(TableScanNode node, LocalExecutionPlanContext con
309
310
Map <Symbol , ColumnSchema > columnSchemaMap = node .getAssignments ();
310
311
Map <Symbol , Integer > idAndAttributeColumnsIndexMap = node .getIdAndAttributeIndexMap ();
311
312
List <String > measurementColumnNames = new ArrayList <>();
313
+ Map <String , Integer > measurementColumnsIndexMap = new HashMap <>();
314
+ String timeColumnName = null ;
312
315
List <IMeasurementSchema > measurementSchemas = new ArrayList <>();
313
316
int measurementColumnCount = 0 ;
314
317
int idx = 0 ;
@@ -327,14 +330,16 @@ public Operator visitTableScan(TableScanNode node, LocalExecutionPlanContext con
327
330
case MEASUREMENT :
328
331
columnsIndexArray [idx ++] = measurementColumnCount ;
329
332
measurementColumnCount ++;
330
- measurementColumnNames .add (columnName .getName ());
333
+ measurementColumnNames .add (schema .getName ());
331
334
measurementSchemas .add (
332
335
new MeasurementSchema (schema .getName (), getTSDataType (schema .getType ())));
333
336
columnSchemas .add (schema );
337
+ measurementColumnsIndexMap .put (columnName .getName (), measurementColumnCount - 1 );
334
338
break ;
335
339
case TIME :
336
340
columnsIndexArray [idx ++] = -1 ;
337
341
columnSchemas .add (schema );
342
+ timeColumnName = columnName .getName ();
338
343
break ;
339
344
default :
340
345
throw new IllegalArgumentException (
@@ -347,17 +352,20 @@ public Operator visitTableScan(TableScanNode node, LocalExecutionPlanContext con
347
352
if (!outputSet .contains (entry .getKey ())
348
353
&& entry .getValue ().getColumnCategory () == MEASUREMENT ) {
349
354
measurementColumnCount ++;
350
- measurementColumnNames .add (entry .getKey ().getName ());
355
+ measurementColumnNames .add (entry .getValue ().getName ());
351
356
measurementSchemas .add (
352
357
new MeasurementSchema (
353
358
entry .getValue ().getName (), getTSDataType (entry .getValue ().getType ())));
359
+ measurementColumnsIndexMap .put (entry .getKey ().getName (), measurementColumnCount - 1 );
360
+ } else if (entry .getValue ().getColumnCategory () == TIME ) {
361
+ timeColumnName = entry .getKey ().getName ();
354
362
}
355
363
}
356
364
357
365
SeriesScanOptions .Builder scanOptionsBuilder =
358
- node .getTimePredicate ()
359
- . map ( timePredicate -> getSeriesScanOptionsBuilder (context , timePredicate ))
360
- . orElse ( new SeriesScanOptions .Builder () );
366
+ node .getTimePredicate (). isPresent ()
367
+ ? getSeriesScanOptionsBuilder (context , node . getTimePredicate (). get ( ))
368
+ : new SeriesScanOptions .Builder ();
361
369
scanOptionsBuilder .withPushDownLimit (node .getPushDownLimit ());
362
370
scanOptionsBuilder .withPushDownOffset (node .getPushDownOffset ());
363
371
scanOptionsBuilder .withPushLimitToEachDevice (node .isPushLimitToEachDevice ());
@@ -366,7 +374,8 @@ public Operator visitTableScan(TableScanNode node, LocalExecutionPlanContext con
366
374
Expression pushDownPredicate = node .getPushDownPredicate ();
367
375
if (pushDownPredicate != null ) {
368
376
scanOptionsBuilder .withPushDownFilter (
369
- convertPredicateToFilter (pushDownPredicate , measurementColumnNames , columnSchemaMap ));
377
+ convertPredicateToFilter (
378
+ pushDownPredicate , measurementColumnsIndexMap , columnSchemaMap , timeColumnName ));
370
379
}
371
380
372
381
OperatorContext operatorContext =
@@ -1178,19 +1187,38 @@ public Operator visitJoin(JoinNode node, LocalExecutionPlanContext context) {
1178
1187
Operator leftChild = node .getLeftChild ().accept (this , context );
1179
1188
Operator rightChild = node .getRightChild ().accept (this , context );
1180
1189
1181
- int leftTimeColumnPosition =
1182
- node .getLeftChild ().getOutputSymbols ().indexOf (node .getCriteria ().get (0 ).getLeft ());
1190
+ ImmutableMap <Symbol , Integer > leftColumnNamesMap =
1191
+ makeLayoutFromOutputSymbols (node .getLeftChild ().getOutputSymbols ());
1192
+ Integer leftTimeColumnPosition = leftColumnNamesMap .get (node .getCriteria ().get (0 ).getLeft ());
1193
+ if (leftTimeColumnPosition == null ) {
1194
+ throw new IllegalStateException ("Left child of JoinNode doesn't contain time column" );
1195
+ }
1183
1196
int [] leftOutputSymbolIdx = new int [node .getLeftOutputSymbols ().size ()];
1184
1197
for (int i = 0 ; i < leftOutputSymbolIdx .length ; i ++) {
1185
- leftOutputSymbolIdx [i ] =
1186
- node .getLeftChild ().getOutputSymbols ().indexOf (node .getLeftOutputSymbols ().get (i ));
1198
+ Integer index = leftColumnNamesMap .get (node .getLeftOutputSymbols ().get (i ));
1199
+ if (index == null ) {
1200
+ throw new IllegalStateException (
1201
+ "Left child of JoinNode doesn't contain LeftOutputSymbol "
1202
+ + node .getLeftOutputSymbols ().get (i ));
1203
+ }
1204
+ leftOutputSymbolIdx [i ] = index ;
1205
+ }
1206
+
1207
+ ImmutableMap <Symbol , Integer > rightColumnNamesMap =
1208
+ makeLayoutFromOutputSymbols (node .getRightChild ().getOutputSymbols ());
1209
+ Integer rightTimeColumnPosition = rightColumnNamesMap .get (node .getCriteria ().get (0 ).getRight ());
1210
+ if (rightTimeColumnPosition == null ) {
1211
+ throw new IllegalStateException ("Right child of JoinNode doesn't contain time column" );
1187
1212
}
1188
- int rightTimeColumnPosition =
1189
- node .getRightChild ().getOutputSymbols ().indexOf (node .getCriteria ().get (0 ).getRight ());
1190
1213
int [] rightOutputSymbolIdx = new int [node .getRightOutputSymbols ().size ()];
1191
1214
for (int i = 0 ; i < rightOutputSymbolIdx .length ; i ++) {
1192
- rightOutputSymbolIdx [i ] =
1193
- node .getRightChild ().getOutputSymbols ().indexOf (node .getRightOutputSymbols ().get (i ));
1215
+ Integer index = rightColumnNamesMap .get (node .getRightOutputSymbols ().get (i ));
1216
+ if (index == null ) {
1217
+ throw new IllegalStateException (
1218
+ "Right child of JoinNode doesn't contain RightOutputSymbol "
1219
+ + node .getLeftOutputSymbols ().get (i ));
1220
+ }
1221
+ rightOutputSymbolIdx [i ] = index ;
1194
1222
}
1195
1223
1196
1224
if (requireNonNull (node .getJoinType ()) == JoinNode .JoinType .INNER ) {
@@ -1364,17 +1392,20 @@ private Operator planGlobalAggregation(
1364
1392
aggregationMap .get (symbol ),
1365
1393
node .getStep (),
1366
1394
typeProvider ,
1367
- true )));
1395
+ true ,
1396
+ null )));
1368
1397
return new AggregationOperator (context , child , aggregatorBuilder .build ());
1369
1398
}
1370
1399
1400
+ // timeColumnName will only be set for AggTableScan.
1371
1401
private TableAggregator buildAggregator (
1372
1402
Map <Symbol , Integer > childLayout ,
1373
1403
Symbol symbol ,
1374
1404
AggregationNode .Aggregation aggregation ,
1375
1405
AggregationNode .Step step ,
1376
1406
TypeProvider typeProvider ,
1377
- boolean scanAscending ) {
1407
+ boolean scanAscending ,
1408
+ String timeColumnName ) {
1378
1409
List <Integer > argumentChannels = new ArrayList <>();
1379
1410
for (Expression argument : aggregation .getArguments ()) {
1380
1411
Symbol argumentSymbol = Symbol .from (argument );
@@ -1393,7 +1424,8 @@ private TableAggregator buildAggregator(
1393
1424
originalArgumentTypes ,
1394
1425
aggregation .getArguments (),
1395
1426
Collections .emptyMap (),
1396
- scanAscending );
1427
+ scanAscending ,
1428
+ timeColumnName );
1397
1429
1398
1430
return new TableAggregator (
1399
1431
accumulator ,
@@ -1424,7 +1456,8 @@ private Operator planGroupByAggregation(
1424
1456
.forEach (
1425
1457
(k , v ) ->
1426
1458
aggregatorBuilder .add (
1427
- buildAggregator (childLayout , k , v , node .getStep (), typeProvider , true )));
1459
+ buildAggregator (
1460
+ childLayout , k , v , node .getStep (), typeProvider , true , null )));
1428
1461
1429
1462
return new StreamingAggregationOperator (
1430
1463
operatorContext ,
@@ -1563,9 +1596,6 @@ public Operator visitAggregationTableScan(
1563
1596
List <TableAggregator > aggregators = new ArrayList <>(node .getAggregations ().size ());
1564
1597
Map <Symbol , Integer > columnLayout = new HashMap <>(node .getAggregations ().size ());
1565
1598
1566
- boolean [] ret = checkStatisticAndScanOrder (node );
1567
- boolean canUseStatistic = ret [0 ];
1568
- boolean scanAscending = ret [1 ];
1569
1599
int distinctArgumentCount = node .getAssignments ().size ();
1570
1600
int aggregationsCount = node .getAggregations ().size ();
1571
1601
List <Integer > aggColumnIndexes = new ArrayList <>();
@@ -1577,6 +1607,8 @@ public Operator visitAggregationTableScan(
1577
1607
List <ColumnSchema > columnSchemas = new ArrayList <>(aggregationsCount );
1578
1608
int [] columnsIndexArray = new int [distinctArgumentCount ];
1579
1609
List <String > measurementColumnNames = new ArrayList <>();
1610
+ Map <String , Integer > measurementColumnsIndexMap = new HashMap <>();
1611
+ String timeColumnName = null ;
1580
1612
List <IMeasurementSchema > measurementSchemas = new ArrayList <>();
1581
1613
1582
1614
for (Map .Entry <Symbol , AggregationNode .Aggregation > entry : node .getAggregations ().entrySet ()) {
@@ -1599,16 +1631,18 @@ public Operator visitAggregationTableScan(
1599
1631
if (!columnLayout .containsKey (symbol )) {
1600
1632
columnsIndexArray [channel ] = measurementColumnCount ;
1601
1633
measurementColumnCount ++;
1602
- measurementColumnNames .add (symbol .getName ());
1634
+ measurementColumnNames .add (schema .getName ());
1603
1635
measurementSchemas .add (
1604
1636
new MeasurementSchema (schema .getName (), getTSDataType (schema .getType ())));
1605
1637
columnSchemas .add (schema );
1638
+ measurementColumnsIndexMap .put (symbol .getName (), measurementColumnCount - 1 );
1606
1639
}
1607
1640
break ;
1608
1641
case TIME :
1609
1642
if (!columnLayout .containsKey (symbol )) {
1610
1643
columnsIndexArray [channel ] = -1 ;
1611
1644
columnSchemas .add (schema );
1645
+ timeColumnName = symbol .getName ();
1612
1646
}
1613
1647
break ;
1614
1648
default :
@@ -1623,29 +1657,38 @@ public Operator visitAggregationTableScan(
1623
1657
aggColumnIndexes .add (columnLayout .get (symbol ));
1624
1658
}
1625
1659
}
1626
-
1627
- aggregators .add (
1628
- buildAggregator (
1629
- columnLayout ,
1630
- entry .getKey (),
1631
- entry .getValue (),
1632
- node .getStep (),
1633
- context .getTypeProvider (),
1634
- scanAscending ));
1635
1660
}
1636
1661
1637
- // TODO if this needed?
1638
1662
for (Map .Entry <Symbol , ColumnSchema > entry : node .getAssignments ().entrySet ()) {
1639
1663
if (!columnLayout .containsKey (entry .getKey ())
1640
1664
&& entry .getValue ().getColumnCategory () == MEASUREMENT ) {
1641
1665
measurementColumnCount ++;
1642
- measurementColumnNames .add (entry .getKey ().getName ());
1666
+ measurementColumnNames .add (entry .getValue ().getName ());
1643
1667
measurementSchemas .add (
1644
1668
new MeasurementSchema (
1645
1669
entry .getValue ().getName (), getTSDataType (entry .getValue ().getType ())));
1670
+ measurementColumnsIndexMap .put (entry .getKey ().getName (), measurementColumnCount - 1 );
1671
+ } else if (entry .getValue ().getColumnCategory () == TIME ) {
1672
+ timeColumnName = entry .getKey ().getName ();
1646
1673
}
1647
1674
}
1648
1675
1676
+ boolean [] ret = checkStatisticAndScanOrder (node , timeColumnName );
1677
+ boolean canUseStatistic = ret [0 ];
1678
+ boolean scanAscending = ret [1 ];
1679
+
1680
+ for (Map .Entry <Symbol , AggregationNode .Aggregation > entry : node .getAggregations ().entrySet ()) {
1681
+ aggregators .add (
1682
+ buildAggregator (
1683
+ columnLayout ,
1684
+ entry .getKey (),
1685
+ entry .getValue (),
1686
+ node .getStep (),
1687
+ context .getTypeProvider (),
1688
+ scanAscending ,
1689
+ timeColumnName ));
1690
+ }
1691
+
1649
1692
ITableTimeRangeIterator timeRangeIterator = null ;
1650
1693
List <ColumnSchema > groupingKeySchemas = null ;
1651
1694
int [] groupingKeyIndex = null ;
@@ -1699,17 +1742,18 @@ public Operator visitAggregationTableScan(
1699
1742
node .getPlanNodeId (),
1700
1743
AggregationTableScanNode .class .getSimpleName ());
1701
1744
SeriesScanOptions .Builder scanOptionsBuilder =
1702
- node .getTimePredicate ()
1703
- . map ( timePredicate -> getSeriesScanOptionsBuilder (context , timePredicate ))
1704
- . orElse ( new SeriesScanOptions .Builder () );
1745
+ node .getTimePredicate (). isPresent ()
1746
+ ? getSeriesScanOptionsBuilder (context , node . getTimePredicate (). get ( ))
1747
+ : new SeriesScanOptions .Builder ();
1705
1748
scanOptionsBuilder .withPushDownLimit (node .getPushDownLimit ());
1706
1749
scanOptionsBuilder .withPushDownOffset (node .getPushDownOffset ());
1707
1750
scanOptionsBuilder .withPushLimitToEachDevice (node .isPushLimitToEachDevice ());
1708
1751
scanOptionsBuilder .withAllSensors (new HashSet <>(measurementColumnNames ));
1709
1752
Expression pushDownPredicate = node .getPushDownPredicate ();
1710
1753
if (pushDownPredicate != null ) {
1711
1754
scanOptionsBuilder .withPushDownFilter (
1712
- convertPredicateToFilter (pushDownPredicate , measurementColumnNames , columnSchemaMap ));
1755
+ convertPredicateToFilter (
1756
+ pushDownPredicate , measurementColumnsIndexMap , columnSchemaMap , timeColumnName ));
1713
1757
}
1714
1758
1715
1759
Set <String > allSensors = new HashSet <>(measurementColumnNames );
@@ -1755,7 +1799,8 @@ public Operator visitAggregationTableScan(
1755
1799
return aggTableScanOperator ;
1756
1800
}
1757
1801
1758
- private boolean [] checkStatisticAndScanOrder (AggregationTableScanNode node ) {
1802
+ private boolean [] checkStatisticAndScanOrder (
1803
+ AggregationTableScanNode node , String timeColumnName ) {
1759
1804
boolean canUseStatistic = true ;
1760
1805
int ascendingCount = 0 , descendingCount = 0 ;
1761
1806
@@ -1797,8 +1842,8 @@ private boolean[] checkStatisticAndScanOrder(AggregationTableScanNode node) {
1797
1842
1798
1843
// only last_by(time, x) or last_by(x,time) can use statistic
1799
1844
if ((LAST_BY_AGGREGATION .equals (funcName ) || FIRST_BY_AGGREGATION .equals (funcName ))
1800
- && !isTimeColumn (aggregation .getArguments ().get (0 ))
1801
- && !isTimeColumn (aggregation .getArguments ().get (1 ))) {
1845
+ && !isTimeColumn (aggregation .getArguments ().get (0 ), timeColumnName )
1846
+ && !isTimeColumn (aggregation .getArguments ().get (1 ), timeColumnName )) {
1802
1847
canUseStatistic = false ;
1803
1848
}
1804
1849
break ;
0 commit comments