Skip to content

Commit

Permalink
feat: support parsing of SQL queries with APPLY (substrait-io#106)
Browse files Browse the repository at this point in the history
* feat: support parsing of SQL queries with APPLY

This change adds support for parsing of SQL queries with APPLY (join with
correlated subquery), and to build OuterReferences map of correlated variables
present in the query's join predicates. The OuterRefs will be used while
constructing Substrait plans to bind correlated variables. The change also
adds few example queries which depend on APPLY / LATERAL operators.

This change still does not map calcite-correlated-join to Substrait, as the
spec for APPLY is still not approved. As such, while the parsing of calcite
query plans will succeed after this change, the unit tests and run time
conversion will continue to fail in the final step of building the
Substrait plan. Additional changes are needed to support APPLY.

Refs #substrait-io/substrait/issues/357

* fix: unit test cases to validate correlated vars

This change addresses review comments, the unit tests validate the outer
reference map built from calcite plans of APPLY queries.

* fix: add test for nested APPLY

This change addresses review comments. A new test case to validate
nested APPLY join parsing is added. Also added validation of depth
information in existing tests.
  • Loading branch information
ashvina authored Dec 2, 2022
1 parent ce9ac66 commit 179764e
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.util.Map;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rex.*;
Expand Down Expand Up @@ -48,6 +49,32 @@ public RelNode visit(LogicalFilter filter) throws RuntimeException {
return super.visit(filter);
}

@Override
public RelNode visit(LogicalCorrelate correlate) throws RuntimeException {
for (CorrelationId id : correlate.getVariablesSet()) {
if (!nestedDepth.containsKey(id)) {
nestedDepth.put(id, 0);
}
}

apply(correlate.getLeft());

// Correlated join is a special case. The right-rel is a correlated sub-query but not a REX. So,
// the RexVisitor cannot be applied to it to correctly compute the depth map. Hence, we need to
// manually compute the depth map for the right-rel.
for (Map.Entry<CorrelationId, Integer> entry : nestedDepth.entrySet()) {
nestedDepth.put(entry.getKey(), entry.getValue() + 1);
}

apply(correlate.getRight()); // look inside sub-queries

for (Map.Entry<CorrelationId, Integer> entry : nestedDepth.entrySet()) {
nestedDepth.put(entry.getKey(), entry.getValue() - 1);
}

return correlate;
}

@Override
public RelNode visitOther(RelNode other) throws RuntimeException {
for (RelNode child : other.getInputs()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ WHERE p_size <
AND PS.ps_suppkey = l.l_suppkey))
Filter --- $coor0
Filter --- $corr0
/ \ condition
/ p_size < RexSubquery
Scan(P) |
Expand All @@ -23,7 +23,7 @@ WHERE p_size <
|
Project
|
Filter --- $coor2
Filter --- $corr2
/ \
/ \
Scan (L) \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.substrait.isthmus;

import io.substrait.function.ImmutableSimpleExtension;
import io.substrait.function.SimpleExtension;
import io.substrait.type.NamedStruct;
import java.io.IOException;
Expand Down Expand Up @@ -67,15 +66,17 @@ protected SqlConverterBase(FeatureBoard features) {
new ProxyingMetadataHandlerProvider(DefaultRelMetadataProvider.INSTANCE);
return new RelMetadataQuery(handler);
});
parserConfig = SqlParser.Config.DEFAULT.withParserFactory(SqlDdlParserImpl.FACTORY);
featureBoard = features == null ? FEATURES_DEFAULT : features;
parserConfig =
SqlParser.Config.DEFAULT
.withParserFactory(SqlDdlParserImpl.FACTORY)
.withConformance(featureBoard.sqlConformanceMode());
}

protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION;

static {
SimpleExtension.ExtensionCollection defaults =
ImmutableSimpleExtension.ExtensionCollection.builder().build();
SimpleExtension.ExtensionCollection defaults;
try {
defaults = SimpleExtension.loadDefaults();
} catch (IOException e) {
Expand Down
41 changes: 26 additions & 15 deletions isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.isthmus;

import com.google.common.annotations.VisibleForTesting;
import io.substrait.expression.proto.FunctionCollector;
import io.substrait.proto.Plan;
import io.substrait.proto.PlanRel;
Expand All @@ -13,6 +14,7 @@
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.validate.SqlValidator;
Expand Down Expand Up @@ -95,6 +97,17 @@ private List<RelRoot> sqlToRelNode(
if (!featureBoard.allowsSqlBatch() && parsedList.size() > 1) {
throw new UnsupportedOperationException("SQL must contain only a single statement: " + sql);
}
SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader);
List<RelRoot> roots =
parsedList.stream()
.map(parsed -> getBestExpRelRoot(converter, parsed))
.collect(java.util.stream.Collectors.toList());
return roots;
}

@VisibleForTesting
SqlToRelConverter createSqlToRelConverter(
SqlValidator validator, CalciteCatalogReader catalogReader) {
SqlToRelConverter converter =
new SqlToRelConverter(
null,
Expand All @@ -103,20 +116,18 @@ private List<RelRoot> sqlToRelNode(
relOptCluster,
StandardConvertletTable.INSTANCE,
converterConfig);
List<RelRoot> roots =
parsedList.stream()
.map(
parsed -> {
RelRoot root = converter.convertQuery(parsed, true, true);
{
var program = HepProgram.builder().build();
HepPlanner hepPlanner = new HepPlanner(program);
hepPlanner.setRoot(root.rel);
root = root.withRel(hepPlanner.findBestExp());
}
return root;
})
.collect(java.util.stream.Collectors.toList());
return roots;
return converter;
}

@VisibleForTesting
static RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) {
RelRoot root = converter.convertQuery(parsed, true, true);
{
var program = HepProgram.builder().build();
HepPlanner hepPlanner = new HepPlanner(program);
hepPlanner.setRoot(root.rel);
root = root.withRel(hepPlanner.findBestExp());
}
return root;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
@SuppressWarnings("UnstableApiUsage")
@Value.Enclosing
public class SubstraitRelVisitor extends RelNodeVisitor<Rel, RuntimeException> {

static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(SubstraitRelVisitor.class);
private static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build();
Expand Down Expand Up @@ -196,6 +197,19 @@ public Rel visit(LogicalJoin join) {

@Override
public Rel visit(LogicalCorrelate correlate) {
// left input of correlated-join is similar to the left input of a logical join
apply(correlate.getLeft());

// right input of correlated-join is similar to a correlated sub-query
apply(correlate.getRight());

var joinType =
switch (correlate.getJoinType()) {
case INNER -> Join.JoinType.INNER; // corresponds to CROSS APPLY join
case LEFT -> Join.JoinType.LEFT; // corresponds to OUTER APPLY join
default -> throw new IllegalArgumentException(
"Invalid correlated join type: " + correlate.getJoinType());
};
return super.visit(correlate);
}

Expand Down
173 changes: 173 additions & 0 deletions isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package io.substrait.isthmus;

import java.util.Map;
import org.apache.calcite.adapter.tpcds.TpcdsSchema;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.validate.SqlConformanceEnum;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class ApplyJoinPlanTest {

private static RelRoot getCalcitePlan(SqlToSubstrait s, TpcdsSchema schema, String sql)
throws SqlParseException {
var pair = s.registerSchema("tpcds", schema);
var converter = s.createSqlToRelConverter(pair.left, pair.right);
SqlParser parser = SqlParser.create(sql, s.parserConfig);
var root = s.getBestExpRelRoot(converter, parser.parseQuery());
return root;
}

private static void validateOuterRef(
Map<RexFieldAccess, Integer> fieldAccessDepthMap, String refName, String colName, int depth) {
var entry =
fieldAccessDepthMap.entrySet().stream()
.filter(f -> f.getKey().getReferenceExpr().toString().equals(refName))
.filter(f -> f.getKey().getField().getName().equals(colName))
.filter(f -> f.getValue() == depth)
.findFirst();
Assertions.assertTrue(entry.isPresent());
}

private static Map<RexFieldAccess, Integer> buildOuterFieldRefMap(RelRoot root) {
final OuterReferenceResolver resolver = new OuterReferenceResolver();
var fieldAccessDepthMap = resolver.getFieldAccessDepthMap();
Assertions.assertEquals(0, fieldAccessDepthMap.size());
resolver.apply(root.rel);
return fieldAccessDepthMap;
}

@Test
public void lateralJoinQuery() throws SqlParseException {
TpcdsSchema schema = new TpcdsSchema(1.0);
String sql;
sql =
"""
SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk
FROM store_sales CROSS JOIN LATERAL
(select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)""";

/* the calcite plan for the above query is:
LogicalProject(SS_SOLD_DATE_SK=[$0], SS_ITEM_SK=[$2], SS_CUSTOMER_SK=[$3])
LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{2}])
LogicalTableScan(table=[[tpcds, STORE_SALES]])
LogicalProject(I_ITEM_SK=[$0])
LogicalFilter(condition=[=($0, $cor0.SS_ITEM_SK)])
LogicalTableScan(table=[[tpcds, ITEM]])
*/

// validate outer reference map
RelRoot root = getCalcitePlan(new SqlToSubstrait(), schema, sql);
Map<RexFieldAccess, Integer> fieldAccessDepthMap = buildOuterFieldRefMap(root);
Assertions.assertEquals(1, fieldAccessDepthMap.size());
validateOuterRef(fieldAccessDepthMap, "$cor0", "SS_ITEM_SK", 1);

// TODO validate end to end conversion
var sE2E = new SqlToSubstrait();
Assertions.assertThrows(
UnsupportedOperationException.class,
() -> sE2E.execute(sql, "tpcds", schema),
"Lateral join is not supported");
}

@Test
public void outerApplyQuery() throws SqlParseException {
TpcdsSchema schema = new TpcdsSchema(1.0);
String sql;
sql =
"""
SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk
FROM store_sales OUTER APPLY
(select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)""";

FeatureBoard featureBoard =
ImmutableFeatureBoard.builder()
.sqlConformanceMode(SqlConformanceEnum.SQL_SERVER_2008)
.build();
SqlToSubstrait s = new SqlToSubstrait(featureBoard);
RelRoot root = getCalcitePlan(s, schema, sql);

Map<RexFieldAccess, Integer> fieldAccessDepthMap = buildOuterFieldRefMap(root);
Assertions.assertEquals(1, fieldAccessDepthMap.size());
validateOuterRef(fieldAccessDepthMap, "$cor0", "SS_ITEM_SK", 1);

// TODO validate end to end conversion
Assertions.assertThrows(
UnsupportedOperationException.class,
() -> s.execute(sql, "tpcds", schema),
"APPLY is not supported");
}

@Test
public void nestedApplyJoinQuery() throws SqlParseException {
TpcdsSchema schema = new TpcdsSchema(1.0);
String sql;
sql =
"""
SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk
FROM store_sales CROSS APPLY
( SELECT i_item_sk
FROM item CROSS APPLY
( SELECT p_promo_sk
FROM promotion
WHERE p_item_sk = i_item_sk AND p_item_sk = ss_item_sk )
WHERE item.i_item_sk = store_sales.ss_item_sk )""";

/* the calcite plan for the above query is:
LogicalProject(SS_SOLD_DATE_SK=[$0], SS_ITEM_SK=[$2], SS_CUSTOMER_SK=[$3])
LogicalCorrelate(correlation=[$cor2], joinType=[inner], requiredColumns=[{2}])
LogicalTableScan(table=[[tpcds, STORE_SALES]])
LogicalProject(I_ITEM_SK=[$0])
LogicalFilter(condition=[=($0, $cor2.SS_ITEM_SK)])
LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])
LogicalTableScan(table=[[tpcds, ITEM]])
LogicalProject(P_PROMO_SK=[$0])
LogicalFilter(condition=[AND(=($4, $cor0.I_ITEM_SK), =($4, $cor2.SS_ITEM_SK))])
LogicalTableScan(table=[[tpcds, PROMOTION]])
*/
FeatureBoard featureBoard =
ImmutableFeatureBoard.builder()
.sqlConformanceMode(SqlConformanceEnum.SQL_SERVER_2008)
.build();
SqlToSubstrait s = new SqlToSubstrait(featureBoard);
RelRoot root = getCalcitePlan(s, schema, sql);

Map<RexFieldAccess, Integer> fieldAccessDepthMap = buildOuterFieldRefMap(root);
Assertions.assertEquals(3, fieldAccessDepthMap.size());
validateOuterRef(fieldAccessDepthMap, "$cor2", "SS_ITEM_SK", 1);
validateOuterRef(fieldAccessDepthMap, "$cor2", "SS_ITEM_SK", 2);
validateOuterRef(fieldAccessDepthMap, "$cor0", "I_ITEM_SK", 1);

// TODO validate end to end conversion
Assertions.assertThrows(
UnsupportedOperationException.class,
() -> s.execute(sql, "tpcds", schema),
"APPLY is not supported");
}

@Test
public void crossApplyQuery() throws SqlParseException {
TpcdsSchema schema = new TpcdsSchema(1.0);
String sql;
sql =
"""
SELECT ss_sold_date_sk, ss_item_sk, ss_customer_sk
FROM store_sales CROSS APPLY
(select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)""";

FeatureBoard featureBoard =
ImmutableFeatureBoard.builder()
.sqlConformanceMode(SqlConformanceEnum.SQL_SERVER_2008)
.build();
SqlToSubstrait s = new SqlToSubstrait(featureBoard);

// TODO validate end to end conversion
Assertions.assertThrows(
UnsupportedOperationException.class,
() -> s.execute(sql, "tpcds", schema),
"APPLY is not supported");
}
}

0 comments on commit 179764e

Please sign in to comment.