From 333989ffa70efc18c6707c3cf670bac826925fe1 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Tue, 24 Oct 2023 12:10:01 -0500 Subject: [PATCH 01/35] feat: convert sql expression into proto extended expressions --- .../extension/ExtensionCollector.java | 49 +++++ .../io/substrait/isthmus/SqlToSubstrait.java | 196 ++++++++++++++++++ .../isthmus/ExtendedExpressionTestBase.java | 51 +++++ .../SimpleExtendedExpressionsTest.java | 13 ++ 4 files changed, 309 insertions(+) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java create mode 100644 isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index bcdd969d4..714eaec93 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -1,5 +1,6 @@ package io.substrait.extension; +import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import io.substrait.proto.SimpleExtensionURI; @@ -98,6 +99,54 @@ public void addExtensionsToPlan(Plan.Builder builder) { builder.addAllExtensions(extensionList); } + public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) { + var uriPos = new AtomicInteger(1); + var uris = new HashMap(); + + var extensionList = new ArrayList(); + for (var e : funcMap.forwardMap.entrySet()) { + SimpleExtensionURI uri = + uris.computeIfAbsent( + e.getValue().namespace(), + k -> + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(uriPos.getAndIncrement()) + .setUri(k) + .build()); + var decl = + SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(e.getKey()) + .setName(e.getValue().key()) + .setExtensionUriReference(uri.getExtensionUriAnchor())) + .build(); + extensionList.add(decl); + } + for (var e : typeMap.forwardMap.entrySet()) { + SimpleExtensionURI uri = + uris.computeIfAbsent( + e.getValue().namespace(), + k -> + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(uriPos.getAndIncrement()) + .setUri(k) + .build()); + var decl = + SimpleExtensionDeclaration.newBuilder() + .setExtensionType( + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(e.getKey()) + .setName(e.getValue().key()) + .setExtensionUriReference(uri.getExtensionUriAnchor())) + .build(); + extensionList.add(decl); + } + + builder.addAllExtensionUris(uris.values()); + builder.addAllExtensions(extensionList); + } + /** We don't depend on guava... */ private static class BidiMap { private final Map forwardMap; diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 7a850499a..e0f4c0d9b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -1,22 +1,44 @@ package io.substrait.isthmus; +import com.github.bsideup.jabel.Desugar; import com.google.common.annotations.VisibleForTesting; import io.substrait.extension.ExtensionCollector; +import io.substrait.proto.Expression; +import io.substrait.proto.Expression.ScalarFunction; +import io.substrait.proto.ExpressionReference; +import io.substrait.proto.ExtendedExpression; +import io.substrait.proto.FunctionArgument; import io.substrait.proto.Plan; import io.substrait.proto.PlanRel; +import io.substrait.proto.SimpleExtensionDeclaration; +import io.substrait.proto.SimpleExtensionURI; import io.substrait.relation.RelProtoConverter; import io.substrait.type.NamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import io.substrait.type.proto.TypeProtoConverter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.function.Function; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.Schema; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql2rel.SqlToRelConverter; import org.apache.calcite.sql2rel.StandardConvertletTable; @@ -48,6 +70,12 @@ public Plan execute(String sql, String name, Schema schema) throws SqlParseExcep return executeInner(sql, factory, pair.left, pair.right); } + public ExtendedExpression executeExpression(String expr, List tables) + throws SqlParseException { + var pair = registerCreateTables(tables); + return executeInnerExpression(expr, pair.left, pair.right); + } + // Package protected for testing List sqlToRelNode(String sql, List tables) throws SqlParseException { var pair = registerCreateTables(tables); @@ -91,6 +119,138 @@ private Plan executeInner( return plan.build(); } + private ExtendedExpression executeInnerExpression( + String sql, SqlValidator validator, CalciteCatalogReader catalogReader) + throws SqlParseException { + ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); + ExtensionCollector functionCollector = new ExtensionCollector(); + sqlToRexNode(sql, validator, catalogReader) + .forEach( + rexNode -> { + // FIXME! Implement it dynamically for more expression types + ResulTraverseRowExpression result = TraverseRexNode.getRowExpression(rexNode); + + // FIXME! Get output type dynamically: + // final static Map getTypeCreator = new HashMap<>(){{put("BOOLEAN", + // TypeCreator.of(true).BOOLEAN);}}; + // getTypeCreator.get(rexNode.getType()).accept(...) + io.substrait.proto.Type output = + TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); + + // FIXME! setFunctionReference, addArguments(index: 0, 1) + Expression.Builder expressionBuilder = + Expression.newBuilder() + .setScalarFunction( + ScalarFunction.newBuilder() + .setFunctionReference(1) + .setOutputType(output) + .addArguments( + 0, + FunctionArgument.newBuilder().setValue(result.referenceBuilder())) + .addArguments( + 1, + FunctionArgument.newBuilder() + .setValue(result.expressionBuilderLiteral()))); + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expressionBuilder) + .addOutputNames(result.ref().getName()); + + // FIXME! Get schema dynamically + // (as the same for Plan with: + // TypeConverter.DEFAULT.toNamedStruct(rexNode.getType());) + List columnNames = + Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); + List dataTypes = + Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING); + NamedStruct namedStruct = + NamedStruct.of( + columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); + + extendedExpressionBuilder + .addReferredExpr(0, expressionReferenceBuilder) + .setBaseSchema(namedStruct.toProto(new TypeProtoConverter(functionCollector))); + + // Extensions URI FIXME! Populate/create this dynamically + HashMap extensionUris = new HashMap<>(); + extensionUris.put( + "key-001", + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("/functions_comparison.yaml") + .build()); + + // Extensions FIXME! Populate/create this dynamically, maybe use rexNode.getKind() + ArrayList extensions = new ArrayList<>(); + SimpleExtensionDeclaration extensionFunctionLowerThan = + SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("gt:any_any") + .setExtensionUriReference(1)) + .build(); + extensions.add(extensionFunctionLowerThan); + + extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); + extendedExpressionBuilder.addAllExtensions(extensions); + }); + return extendedExpressionBuilder.build(); + } + + static class TraverseRexNode { + static RexInputRef ref = null; + static Expression.Builder referenceBuilder = null; + static Expression.Builder expressionBuilderLiteral = null; + + static ResulTraverseRowExpression getRowExpression(RexNode rexNode) { + + switch (rexNode.getClass().getSimpleName().toUpperCase()) { + case "REXCALL": + for (RexNode rexInternal : ((RexCall) rexNode).operands) { + getRowExpression(rexInternal); + } + ; + break; + case "REXINPUTREF": + ref = (RexInputRef) rexNode; + referenceBuilder = + Expression.newBuilder() + .setSelection( + Expression.FieldReference.newBuilder() + .setDirectReference( + Expression.ReferenceSegment.newBuilder() + .setStructField( + Expression.ReferenceSegment.StructField.newBuilder() + .setField(ref.getIndex())))); + break; + case "REXLITERAL": + RexLiteral literal = (RexLiteral) rexNode; + expressionBuilderLiteral = + Expression.newBuilder() + .setLiteral( + Expression.Literal.newBuilder().setI32(literal.getValueAs(Integer.class))); + break; + default: + throw new AssertionError( + "Unsupported type for: " + rexNode.getClass().getSimpleName().toUpperCase()); + } + ResulTraverseRowExpression result = + new ResulTraverseRowExpression(ref, referenceBuilder, expressionBuilderLiteral); + return result; + } + } + + @Desugar + private record ResulTraverseRowExpression( + RexInputRef ref, + Expression.Builder referenceBuilder, + Expression.Builder expressionBuilderLiteral) {} + private List sqlToRelNode( String sql, SqlValidator validator, CalciteCatalogReader catalogReader) throws SqlParseException { @@ -107,6 +267,42 @@ private List sqlToRelNode( return roots; } + private List sqlToRexNode( + String sql, SqlValidator validator, CalciteCatalogReader catalogReader) + throws SqlParseException { + SqlParser parser = SqlParser.create(sql, parserConfig); + SqlNode sqlNode = parser.parseExpression(); + Result result = getResult(validator); + SqlNode validSQLNode = + validator.validateParameterizedExpression( + sqlNode, + result.nameToTypeMap()); // FIXME! It may be optional to include this validation + SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); + RexNode rexNode = converter.convertExpression(validSQLNode, result.nameToNodeMap()); + + return Collections.singletonList(rexNode); + } + + private static Result getResult(SqlValidator validator) { + // FIXME! Needs to be created dinamycally, this is for PoC purpose + HashMap nameToNodeMap = new HashMap<>(); + nameToNodeMap.put( + "N_NATIONKEY", + new RexInputRef(0, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); + nameToNodeMap.put( + "N_REGIONKEY", + new RexInputRef(1, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); + final Map nameToTypeMap = new HashMap<>(); + for (Map.Entry entry : nameToNodeMap.entrySet()) { + nameToTypeMap.put(entry.getKey(), entry.getValue().getType()); + } + Result result = new Result(nameToNodeMap, nameToTypeMap); + return result; + } + + private @Desugar record Result( + HashMap nameToNodeMap, Map nameToTypeMap) {} + @VisibleForTesting SqlToRelConverter createSqlToRelConverter( SqlValidator validator, CalciteCatalogReader catalogReader) { diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java new file mode 100644 index 000000000..10f3f57e3 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -0,0 +1,51 @@ +package io.substrait.isthmus; + +import com.google.common.base.Charsets; +import com.google.common.io.Resources; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.util.JsonFormat; +import io.substrait.proto.ExtendedExpression; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import org.apache.calcite.sql.parser.SqlParseException; + +public class ExtendedExpressionTestBase { + public static String asString(String resource) throws IOException { + return Resources.toString(Resources.getResource(resource), Charsets.UTF_8); + } + + public static List tpchSchemaCreateStatements() throws IOException { + String[] values = asString("tpch/schema.sql").split(";"); + return Arrays.stream(values) + .filter(t -> !t.trim().isBlank()) + .collect(java.util.stream.Collectors.toList()); + } + + protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query) + throws IOException, SqlParseException { + return assertProtoExtendedExpressionRoundrip(query, new SqlToSubstrait()); + } + + protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query, SqlToSubstrait s) + throws IOException, SqlParseException { + return assertProtoExtendedExpressionRoundrip(query, s, tpchSchemaCreateStatements()); + } + + protected ExtendedExpression assertProtoExtendedExpressionRoundrip( + String query, SqlToSubstrait s, List creates) throws SqlParseException { + io.substrait.proto.ExtendedExpression protoExtendedExpression = + s.executeExpression(query, creates); + + try { + String ee = JsonFormat.printer().print(protoExtendedExpression); + System.out.println("Proto Extended Expression: \n" + ee); + + // FIXME! Implement test validation as the same as proto Plan implementation + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + + return protoExtendedExpression; + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java new file mode 100644 index 000000000..bfcea38c9 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -0,0 +1,13 @@ +package io.substrait.isthmus; + +import java.io.IOException; +import org.apache.calcite.sql.parser.SqlParseException; +import org.junit.jupiter.api.Test; + +public class SimpleExtendedExpressionsTest extends ExtendedExpressionTestBase { + + @Test + public void filter() throws IOException, SqlParseException { + assertProtoExtendedExpressionRoundrip("N_NATIONKEY > 18"); + } +} From f4b6581a6b177654eff632c9c3719fcc83c1b7b3 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 26 Oct 2023 11:58:18 -0500 Subject: [PATCH 02/35] fix: implement nameToNodeMap and nameToTypeMap dyamically instead of hard coded --- build.gradle.kts | 9 ++- isthmus/build.gradle.kts | 2 + .../substrait/isthmus/SqlConverterBase.java | 38 +++++++++- .../io/substrait/isthmus/SqlToSubstrait.java | 46 ++++-------- .../io/substrait/isthmus/SubstraitToSql.java | 4 +- .../ExtendedExpressionIntegrationTest.java | 68 ++++++++++++++++++ .../test/resources/tpch/data/nation.parquet | Bin 0 -> 2319 bytes 7 files changed, 129 insertions(+), 38 deletions(-) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java create mode 100644 isthmus/src/test/resources/tpch/data/nation.parquet diff --git a/build.gradle.kts b/build.gradle.kts index 47a9da29f..3d711fbd5 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -33,8 +33,13 @@ val submodulesUpdate by } allprojects { - repositories { mavenCentral() } - + repositories { + mavenCentral() + maven { + name = "github" + url = uri("https://nightlies.apache.org/arrow/java") + } + } tasks.configureEach { val javaToolchains = project.extensions.getByType() useJUnitPlatform() diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index 9941f51de..a3437d076 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -94,6 +94,8 @@ dependencies { implementation("org.immutables:value-annotations:2.8.8") annotationProcessor("org.immutables:value:2.8.8") testImplementation("org.apache.calcite:calcite-plus:${CALCITE_VERSION}") + testImplementation("org.apache.arrow:arrow-dataset:14.0.0-SNAPSHOT") + testImplementation("org.apache.arrow:arrow-memory-netty:14.0.0-SNAPSHOT") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index ec83bbc82..40a853539 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -1,11 +1,15 @@ package io.substrait.isthmus; +import com.github.bsideup.jabel.Desugar; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.calcite.SubstraitOperatorTable; import io.substrait.type.NamedStruct; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.function.Function; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionProperty; @@ -22,8 +26,12 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.Schema; +import org.apache.calcite.schema.Schemas; import org.apache.calcite.schema.Table; import org.apache.calcite.schema.impl.AbstractTable; import org.apache.calcite.sql.SqlNode; @@ -86,8 +94,24 @@ protected SqlConverterBase(FeatureBoard features) { EXTENSION_COLLECTION = defaults; } - Pair registerCreateTables(List tables) + /* + HashMap nameToNodeMap = new HashMap<>(); + nameToNodeMap.put( + "N_NATIONKEY", + new RexInputRef(0, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); + nameToNodeMap.put( + "N_REGIONKEY", + new RexInputRef(1, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); + final Map nameToTypeMap = new HashMap<>(); + for (Map.Entry entry : nameToNodeMap.entrySet()) { + nameToTypeMap.put(entry.getKey(), entry.getValue().getType()); + } + */ + + Result registerCreateTables(List tables) throws SqlParseException { + Map nameToTypeMap = new HashMap<>(); + Map nameToNodeMap = new HashMap<>(); CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); CalciteCatalogReader catalogReader = new CalciteCatalogReader(rootSchema, List.of(), factory, config); @@ -97,10 +121,20 @@ Pair registerCreateTables(List table List tList = parseCreateTable(factory, validator, tableDef); for (DefinedTable t : tList) { rootSchema.add(t.getName(), t); + for (RelDataTypeField field : t.type.getFieldList()) { + nameToTypeMap.put(field.getName(), field.getType()); + nameToNodeMap.put(field.getName(), new RexInputRef(field.getIndex(), field.getType())); + } } } } - return Pair.of(validator, catalogReader); + return new Result(validator, catalogReader, nameToTypeMap, nameToNodeMap); + } + + @Desugar + public record Result(SqlValidator validator, CalciteCatalogReader catalogReader, + Map nameToTypeMap, Map nameToNodeMap) { + } Pair registerCreateTables( diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index e0f4c0d9b..c275dcdaa 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -61,8 +61,8 @@ public Plan execute(String sql, Function, NamedStruct> tableLookup) } public Plan execute(String sql, List tables) throws SqlParseException { - var pair = registerCreateTables(tables); - return executeInner(sql, factory, pair.left, pair.right); + var result = registerCreateTables(tables); + return executeInner(sql, factory, result.validator(), result.catalogReader()); } public Plan execute(String sql, String name, Schema schema) throws SqlParseException { @@ -72,14 +72,15 @@ public Plan execute(String sql, String name, Schema schema) throws SqlParseExcep public ExtendedExpression executeExpression(String expr, List tables) throws SqlParseException { - var pair = registerCreateTables(tables); - return executeInnerExpression(expr, pair.left, pair.right); + var result = registerCreateTables(tables); + return executeInnerExpression(expr, result.validator(), result.catalogReader(), + result.nameToTypeMap(), result.nameToNodeMap()); } // Package protected for testing List sqlToRelNode(String sql, List tables) throws SqlParseException { - var pair = registerCreateTables(tables); - return sqlToRelNode(sql, pair.left, pair.right); + var result = registerCreateTables(tables); + return sqlToRelNode(sql, result.validator(), result.catalogReader()); } // Package protected for testing @@ -120,11 +121,12 @@ private Plan executeInner( } private ExtendedExpression executeInnerExpression( - String sql, SqlValidator validator, CalciteCatalogReader catalogReader) + String sql, SqlValidator validator, CalciteCatalogReader catalogReader, + Map nameToTypeMap, Map nameToNodeMap) throws SqlParseException { ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); - sqlToRexNode(sql, validator, catalogReader) + sqlToRexNode(sql, validator, catalogReader, nameToTypeMap, nameToNodeMap) .forEach( rexNode -> { // FIXME! Implement it dynamically for more expression types @@ -268,41 +270,21 @@ private List sqlToRelNode( } private List sqlToRexNode( - String sql, SqlValidator validator, CalciteCatalogReader catalogReader) + String sql, SqlValidator validator, CalciteCatalogReader catalogReader, + Map nameToTypeMap, Map nameToNodeMap) throws SqlParseException { SqlParser parser = SqlParser.create(sql, parserConfig); SqlNode sqlNode = parser.parseExpression(); - Result result = getResult(validator); SqlNode validSQLNode = validator.validateParameterizedExpression( sqlNode, - result.nameToTypeMap()); // FIXME! It may be optional to include this validation + nameToTypeMap); // FIXME! It may be optional to include this validation SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); - RexNode rexNode = converter.convertExpression(validSQLNode, result.nameToNodeMap()); + RexNode rexNode = converter.convertExpression(validSQLNode, nameToNodeMap); return Collections.singletonList(rexNode); } - private static Result getResult(SqlValidator validator) { - // FIXME! Needs to be created dinamycally, this is for PoC purpose - HashMap nameToNodeMap = new HashMap<>(); - nameToNodeMap.put( - "N_NATIONKEY", - new RexInputRef(0, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); - nameToNodeMap.put( - "N_REGIONKEY", - new RexInputRef(1, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); - final Map nameToTypeMap = new HashMap<>(); - for (Map.Entry entry : nameToNodeMap.entrySet()) { - nameToTypeMap.put(entry.getKey(), entry.getValue().getType()); - } - Result result = new Result(nameToNodeMap, nameToTypeMap); - return result; - } - - private @Desugar record Result( - HashMap nameToNodeMap, Map nameToTypeMap) {} - @VisibleForTesting SqlToRelConverter createSqlToRelConverter( SqlValidator validator, CalciteCatalogReader catalogReader) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java index f402aacd3..5a18cd27e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java @@ -22,8 +22,8 @@ public SubstraitToSql() { public RelNode substraitRelToCalciteRel(Rel relRoot, List tables) throws SqlParseException { - var pair = registerCreateTables(tables); - return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, pair.right, parserConfig); + var result = registerCreateTables(tables); + return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, result.catalogReader(), parserConfig); } public RelNode substraitRelToCalciteRel( diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java new file mode 100644 index 000000000..4e0b29ec6 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -0,0 +1,68 @@ +package io.substrait.isthmus.integration; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.ibm.icu.impl.ClassLoaderUtil; +import io.substrait.isthmus.ExtendedExpressionTestBase; +import io.substrait.isthmus.SqlToSubstrait; +import io.substrait.proto.ExtendedExpression; +import java.io.IOException; +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.util.Base64; +import java.util.Optional; +import org.apache.arrow.dataset.file.FileFormat; +import org.apache.arrow.dataset.file.FileSystemDatasetFactory; +import org.apache.arrow.dataset.jni.NativeMemoryPool; +import org.apache.arrow.dataset.scanner.ScanOptions; +import org.apache.arrow.dataset.scanner.Scanner; +import org.apache.arrow.dataset.source.Dataset; +import org.apache.arrow.dataset.source.DatasetFactory; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.calcite.sql.parser.SqlParseException; +import org.junit.jupiter.api.Test; + +public class ExtendedExpressionIntegrationTest { + + @Test + public void projectAndFilterDataset() throws SqlParseException, IOException, URISyntaxException { + URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); + ScanOptions options = + new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitFilter(getSubstraitExpressionFilter()) + .build(); + try (BufferAllocator allocator = new RootAllocator(); + DatasetFactory datasetFactory = + new FileSystemDatasetFactory( + allocator, NativeMemoryPool.getDefault(), FileFormat.PARQUET, resource.toURI().toString()); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches()) { + int count = 0; + while (reader.loadNextBatch()) { + count += reader.getVectorSchemaRoot().getRowCount(); + System.out.println(reader.getVectorSchemaRoot().contentToTSVString()); + } + assertEquals(4, count); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static ByteBuffer getSubstraitExpressionFilter() throws IOException, SqlParseException { + ExtendedExpression extendedExpression = + new SqlToSubstrait() + .executeExpression( + "N_NATIONKEY > 20", ExtendedExpressionTestBase.tpchSchemaCreateStatements()); + byte[] extendedExpressions = + Base64.getDecoder() + .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); + ByteBuffer substraitExpressionFilter = ByteBuffer.allocateDirect(extendedExpressions.length); + substraitExpressionFilter.put(extendedExpressions); + return substraitExpressionFilter; + } +} diff --git a/isthmus/src/test/resources/tpch/data/nation.parquet b/isthmus/src/test/resources/tpch/data/nation.parquet new file mode 100644 index 0000000000000000000000000000000000000000..0189118ce7344297b1ad6f1095a11162ddd960ea GIT binary patch literal 2319 zcmdT`U5p#m6~5OQZ#?W6cEDrC$}Zfk*6!}sJH*az2w_7yyLOWy>!0=7K~71-_4wMJ z-SLcfev)+p51>k*Nc(^qP^Cy!C{zmSLsb>=5Gqv(EfJNdih}yI;-L>nZFxu~S^-Kc z&g@Tq#9QAg&1aoC_xztT=en#`G7^$L!SJM|ERX}z07A=SA%svC!wB5^Wmr;cIgO|l zbxsTKk&kQnYEdm<{kckQ3ETz+sui_rK1Yse#Ur^=A)0tFwp3NC`K7IHoV}|a_clZ1 z5QoLU$F3-cQ<%dt=0rik2+smJEsmvwUIhr+2j>3ri1d5$E_=Ux;O4tq1>Oe&q(j?n zlPN@}4r}?Q*(WW-q9$pwp6wc*3xvmXkGN(Z&S;w&3!nx9Eqjf*r03XO!(}_ix^6p7 z!)g=HCSHdThqZl^I)uW3Z+Wgl8n)R4_NvcZFiuU|S^;}t8K|~vG#Rm5o2or#ZCct1 z1VWVF?6^Iq8{oZ1^%dHN03d8a8@BHe{Nw?{u`NS~&>T2Y2>+fqon~O3Q5O=m+R@43 zG;s}+S)NoDr;!u=UABX!TfSr1K4eJj)65C43ZCJ0@bYlW^jj^%?1w)a(t{xJ>)dG; zPCn6)eam-!YT|gzZrY~dg=dkGJTv6CArSuK*s5>nE9U5N!LeNzD#^LuF7O>%Ax&( z;`=Baeh!&ae_c>P&UFb}(|s3eq0?pCtnZQjkSQW%W@9gO$FMvo5!BvMU!VBR0k+n2 z7-&$p-|cyJ^1ITPh!rkrSNW`RJ%kJ`9P8Nn3hUz=xju7T=qqDA(ts2@%pBP(WWi$> zMW9>z&xiQ+sY`?fjbfH=jJ_kx!Ku0|t&cTq+Xcs--An!|bB&+P_~@`C^B)OgAdz20 z3O|9;(pN&;Xn8Y7W%@X<4L#9S|0VuW?G*l7@9?LF9O&}N(Hg!Ifr*;AGQY8b<=j4*&n&s zy~g>j@@Xi1sP-~FTQY6<}b`dhn!_FBpn`kEwf9xj`?$|R}igVMvOg^ z_)7UvXgy;*Ec_5el0QAU=ZWJ`!4HNl+p`_Za0@@yj%VKFHR=d3%rrWN6nr^myLOOa z(~l%?A^O@squsrv>zkg@GYu%SBCTaVM-hv*@qzW9AcKt|`o_QcLGcx#2zq!Oxo^vh zR|@-}t=Gkk)f>@`L;1`FL4iZ+juV(?Lx&=ned(#he@0y{O^nA-Nqt8;d-W``)Mc*2 z>WSogoW=t5XsC4Efv8`WMj`paFbFt6dl-4Tn*^L-wD+oU7C}hIe`*g%Z!sAUWFf^_6Dh zl&g1Gmr4;Ng_0q8&^l#N&w$UK^^@O0k^sDcfY)V7sQI=C?2UUxc z;lx4;4TC2MCGN2)h2q)x*BG8uWLD;aOGXa%JX|~DcfA24&=)rKxBYLbZgBCRbE%g? zv-0lky)EzFzJa76?B#7ZoAa~yG3FJpk)008`~HV} g0m}3L5GE8n4il7$pz(nmOlWQn{TlAWGW-?#1~YkXEC2ui literal 0 HcmV?d00001 From a79f57d742277d44b5952188d8446b3c1fd09ff6 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 26 Oct 2023 17:29:38 -0500 Subject: [PATCH 03/35] fix: cover support also for project extended expression --- .../substrait/isthmus/SqlConverterBase.java | 14 ++-- .../io/substrait/isthmus/SqlToSubstrait.java | 43 +++++++++--- .../io/substrait/isthmus/SubstraitToSql.java | 3 +- .../ExtendedExpressionIntegrationTest.java | 65 +++++++++++++++++-- 4 files changed, 101 insertions(+), 24 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index 40a853539..9448c34ce 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -6,7 +6,6 @@ import io.substrait.type.NamedStruct; import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -31,7 +30,6 @@ import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.Schema; -import org.apache.calcite.schema.Schemas; import org.apache.calcite.schema.Table; import org.apache.calcite.schema.impl.AbstractTable; import org.apache.calcite.sql.SqlNode; @@ -108,8 +106,7 @@ protected SqlConverterBase(FeatureBoard features) { } */ - Result registerCreateTables(List tables) - throws SqlParseException { + Result registerCreateTables(List tables) throws SqlParseException { Map nameToTypeMap = new HashMap<>(); Map nameToNodeMap = new HashMap<>(); CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); @@ -132,10 +129,11 @@ Result registerCreateTables(List tables) } @Desugar - public record Result(SqlValidator validator, CalciteCatalogReader catalogReader, - Map nameToTypeMap, Map nameToNodeMap) { - - } + public record Result( + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) {} Pair registerCreateTables( Function, NamedStruct> tableLookup) throws SqlParseException { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index c275dcdaa..baf67514c 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -3,6 +3,8 @@ import com.github.bsideup.jabel.Desugar; import com.google.common.annotations.VisibleForTesting; import io.substrait.extension.ExtensionCollector; +import io.substrait.isthmus.expression.RexExpressionConverter; +import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.proto.Expression; import io.substrait.proto.Expression.ScalarFunction; import io.substrait.proto.ExpressionReference; @@ -38,7 +40,6 @@ import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParser; -import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql2rel.SqlToRelConverter; import org.apache.calcite.sql2rel.StandardConvertletTable; @@ -46,6 +47,12 @@ /** Take a SQL statement and a set of table definitions and return a substrait plan. */ public class SqlToSubstrait extends SqlConverterBase { + private final ScalarFunctionConverter functionConverter = + new ScalarFunctionConverter(EXTENSION_COLLECTION.scalarFunctions(), factory); + + private final RexExpressionConverter rexExpressionConverter = + new RexExpressionConverter(functionConverter); + public SqlToSubstrait() { this(null); } @@ -73,8 +80,12 @@ public Plan execute(String sql, String name, Schema schema) throws SqlParseExcep public ExtendedExpression executeExpression(String expr, List tables) throws SqlParseException { var result = registerCreateTables(tables); - return executeInnerExpression(expr, result.validator(), result.catalogReader(), - result.nameToTypeMap(), result.nameToNodeMap()); + return executeInnerExpression( + expr, + result.validator(), + result.catalogReader(), + result.nameToTypeMap(), + result.nameToNodeMap()); } // Package protected for testing @@ -121,11 +132,15 @@ private Plan executeInner( } private ExtendedExpression executeInnerExpression( - String sql, SqlValidator validator, CalciteCatalogReader catalogReader, - Map nameToTypeMap, Map nameToNodeMap) + String sql, + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) throws SqlParseException { ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); + RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector); sqlToRexNode(sql, validator, catalogReader, nameToTypeMap, nameToNodeMap) .forEach( rexNode -> { @@ -186,6 +201,11 @@ private ExtendedExpression executeInnerExpression( .setUri("/functions_comparison.yaml") .build()); + io.substrait.expression.Expression.ScalarFunctionInvocation func = + (io.substrait.expression.Expression.ScalarFunctionInvocation) + rexNode.accept(rexExpressionConverter); + String declaration = func.declaration().key(); + // Extensions FIXME! Populate/create this dynamically, maybe use rexNode.getKind() ArrayList extensions = new ArrayList<>(); SimpleExtensionDeclaration extensionFunctionLowerThan = @@ -193,7 +213,7 @@ private ExtendedExpression executeInnerExpression( .setExtensionFunction( SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(1) - .setName("gt:any_any") + .setName(declaration) .setExtensionUriReference(1)) .build(); extensions.add(extensionFunctionLowerThan); @@ -229,6 +249,7 @@ static ResulTraverseRowExpression getRowExpression(RexNode rexNode) { .setStructField( Expression.ReferenceSegment.StructField.newBuilder() .setField(ref.getIndex())))); + break; case "REXLITERAL": RexLiteral literal = (RexLiteral) rexNode; @@ -270,15 +291,17 @@ private List sqlToRelNode( } private List sqlToRexNode( - String sql, SqlValidator validator, CalciteCatalogReader catalogReader, - Map nameToTypeMap, Map nameToNodeMap) + String sql, + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) throws SqlParseException { SqlParser parser = SqlParser.create(sql, parserConfig); SqlNode sqlNode = parser.parseExpression(); SqlNode validSQLNode = validator.validateParameterizedExpression( - sqlNode, - nameToTypeMap); // FIXME! It may be optional to include this validation + sqlNode, nameToTypeMap); // FIXME! It may be optional to include this validation SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); RexNode rexNode = converter.convertExpression(validSQLNode, nameToNodeMap); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java index 5a18cd27e..d43fda1c1 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java @@ -23,7 +23,8 @@ public SubstraitToSql() { public RelNode substraitRelToCalciteRel(Rel relRoot, List tables) throws SqlParseException { var result = registerCreateTables(tables); - return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, result.catalogReader(), parserConfig); + return SubstraitRelNodeConverter.convert( + relRoot, relOptCluster, result.catalogReader(), parserConfig); } public RelNode substraitRelToCalciteRel( diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 4e0b29ec6..b5cae82eb 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -21,6 +21,7 @@ import org.apache.arrow.dataset.source.DatasetFactory; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.calcite.sql.parser.SqlParseException; import org.junit.jupiter.api.Test; @@ -28,17 +29,21 @@ public class ExtendedExpressionIntegrationTest { @Test - public void projectAndFilterDataset() throws SqlParseException, IOException, URISyntaxException { + public void filterDataset() throws SqlParseException, IOException, URISyntaxException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); + String sqlExpression = "N_NATIONKEY > 20"; ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) - .substraitFilter(getSubstraitExpressionFilter()) + .substraitFilter(getFilterExtendedExpression(sqlExpression)) .build(); try (BufferAllocator allocator = new RootAllocator(); DatasetFactory datasetFactory = new FileSystemDatasetFactory( - allocator, NativeMemoryPool.getDefault(), FileFormat.PARQUET, resource.toURI().toString()); + allocator, + NativeMemoryPool.getDefault(), + FileFormat.PARQUET, + resource.toURI().toString()); Dataset dataset = datasetFactory.finish(); Scanner scanner = dataset.newScan(options); ArrowReader reader = scanner.scanBatches()) { @@ -53,11 +58,47 @@ public void projectAndFilterDataset() throws SqlParseException, IOException, URI } } - private static ByteBuffer getSubstraitExpressionFilter() throws IOException, SqlParseException { + @Test + public void projectDataset() throws SqlParseException, IOException, URISyntaxException { + URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); + String sqlExpression = "N_NATIONKEY + 20"; + ScanOptions options = + new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitProjection(getProjectExtendedExpression(sqlExpression)) + .build(); + try (BufferAllocator allocator = new RootAllocator(); + DatasetFactory datasetFactory = + new FileSystemDatasetFactory( + allocator, + NativeMemoryPool.getDefault(), + FileFormat.PARQUET, + resource.toURI().toString()); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches()) { + int count = 0; + int sum = 0; + while (reader.loadNextBatch()) { + count += reader.getVectorSchemaRoot().getRowCount(); + IntVector intVector = (IntVector) reader.getVectorSchemaRoot().getVector(0); + for (int i = 0; i < intVector.getValueCount(); i++) { + sum += intVector.get(i); + } + } + assertEquals(25, count); + assertEquals(24 * 25 / 2 + 20 * count, sum); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static ByteBuffer getFilterExtendedExpression(String sqlExpression) + throws IOException, SqlParseException { ExtendedExpression extendedExpression = new SqlToSubstrait() .executeExpression( - "N_NATIONKEY > 20", ExtendedExpressionTestBase.tpchSchemaCreateStatements()); + sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); byte[] extendedExpressions = Base64.getDecoder() .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); @@ -65,4 +106,18 @@ private static ByteBuffer getSubstraitExpressionFilter() throws IOException, Sql substraitExpressionFilter.put(extendedExpressions); return substraitExpressionFilter; } + + private static ByteBuffer getProjectExtendedExpression(String sqlExpression) + throws IOException, SqlParseException { + ExtendedExpression extendedExpression = + new SqlToSubstrait() + .executeExpression( + sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); + byte[] extendedExpressions = + Base64.getDecoder() + .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); + ByteBuffer substraitExpressionProject = ByteBuffer.allocateDirect(extendedExpressions.length); + substraitExpressionProject.put(extendedExpressions); + return substraitExpressionProject; + } } From a37be9224dc11ba9b08b1a428ffeccc9143b3d39 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 26 Oct 2023 23:15:54 -0500 Subject: [PATCH 04/35] fix: cover support also for project extended expression --- .editorconfig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.editorconfig b/.editorconfig index 3d674d593..984db0c67 100644 --- a/.editorconfig +++ b/.editorconfig @@ -10,7 +10,7 @@ trim_trailing_whitespace = true [*.{yaml,yml}] indent_size = 2 -[{**/*.sql,**/OuterReferenceResolver.md,gradlew.bat}] +[{**/*.sql,**/OuterReferenceResolver.md,gradlew.bat,**/*.parquet}] charset = unset end_of_line = unset insert_final_newline = unset From 9f6aaf3eb5dce49239c9b15d3d71b5ade7d7f413 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Wed, 15 Nov 2023 10:16:52 -0500 Subject: [PATCH 05/35] fix: create schema dynamically --- build.gradle.kts | 8 +----- isthmus/build.gradle.kts | 4 +-- .../substrait/isthmus/SqlConverterBase.java | 7 ++---- .../io/substrait/isthmus/SqlToSubstrait.java | 25 ++++--------------- .../io/substrait/isthmus/TypeConverter.java | 13 ++++++++++ .../ExtendedExpressionIntegrationTest.java | 8 ++++++ 6 files changed, 31 insertions(+), 34 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 3d711fbd5..293163045 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -33,13 +33,7 @@ val submodulesUpdate by } allprojects { - repositories { - mavenCentral() - maven { - name = "github" - url = uri("https://nightlies.apache.org/arrow/java") - } - } + repositories { mavenCentral() } tasks.configureEach { val javaToolchains = project.extensions.getByType() useJUnitPlatform() diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index a3437d076..a5ae3abd2 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -94,8 +94,8 @@ dependencies { implementation("org.immutables:value-annotations:2.8.8") annotationProcessor("org.immutables:value:2.8.8") testImplementation("org.apache.calcite:calcite-plus:${CALCITE_VERSION}") - testImplementation("org.apache.arrow:arrow-dataset:14.0.0-SNAPSHOT") - testImplementation("org.apache.arrow:arrow-memory-netty:14.0.0-SNAPSHOT") + testImplementation("org.apache.arrow:arrow-dataset:14.0.0") + testImplementation("org.apache.arrow:arrow-memory-netty:14.0.0") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index 9448c34ce..3466cf826 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -5,10 +5,7 @@ import io.substrait.isthmus.calcite.SubstraitOperatorTable; import io.substrait.type.NamedStruct; import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.function.Function; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionProperty; @@ -107,7 +104,7 @@ protected SqlConverterBase(FeatureBoard features) { */ Result registerCreateTables(List tables) throws SqlParseException { - Map nameToTypeMap = new HashMap<>(); + Map nameToTypeMap = new LinkedHashMap<>(); Map nameToNodeMap = new HashMap<>(); CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); CalciteCatalogReader catalogReader = diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index baf67514c..70ff4c087 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -16,11 +16,9 @@ import io.substrait.proto.SimpleExtensionURI; import io.substrait.relation.RelProtoConverter; import io.substrait.type.NamedStruct; -import io.substrait.type.Type; import io.substrait.type.TypeCreator; import io.substrait.type.proto.TypeProtoConverter; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -173,24 +171,7 @@ private ExtendedExpression executeInnerExpression( .setExpression(expressionBuilder) .addOutputNames(result.ref().getName()); - // FIXME! Get schema dynamically - // (as the same for Plan with: - // TypeConverter.DEFAULT.toNamedStruct(rexNode.getType());) - List columnNames = - Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); - List dataTypes = - Arrays.asList( - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING, - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING); - NamedStruct namedStruct = - NamedStruct.of( - columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); - - extendedExpressionBuilder - .addReferredExpr(0, expressionReferenceBuilder) - .setBaseSchema(namedStruct.toProto(new TypeProtoConverter(functionCollector))); + extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); // Extensions URI FIXME! Populate/create this dynamically HashMap extensionUris = new HashMap<>(); @@ -221,6 +202,10 @@ private ExtendedExpression executeInnerExpression( extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); extendedExpressionBuilder.addAllExtensions(extensions); }); + NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); + extendedExpressionBuilder.setBaseSchema( + namedStruct.toProto(new TypeProtoConverter(functionCollector))); + return extendedExpressionBuilder.build(); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java index ba68d5cfc..73c846d45 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java @@ -3,6 +3,7 @@ import static io.substrait.isthmus.SubstraitTypeSystem.DAY_SECOND_INTERVAL; import static io.substrait.isthmus.SubstraitTypeSystem.YEAR_MONTH_INTERVAL; +import com.google.common.collect.Lists; import io.substrait.function.NullableType; import io.substrait.function.TypeExpression; import io.substrait.type.NamedStruct; @@ -11,6 +12,7 @@ import io.substrait.type.TypeVisitor; import java.util.ArrayList; import java.util.List; +import java.util.Map; import javax.annotation.Nullable; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; @@ -56,6 +58,17 @@ public NamedStruct toNamedStruct(RelDataType type) { return NamedStruct.of(names, struct); } + public NamedStruct toNamedStruct(Map nameToTypeMap) { + var names = Lists.newArrayList(); + var types = Lists.newArrayList(); + nameToTypeMap.forEach( + (k, v) -> { + names.add(k); + types.add(toSubstrait(v, names)); + }); + return NamedStruct.of(names, Type.Struct.builder().fields(types).nullable(false).build()); + } + private Type toSubstrait(RelDataType type, List names) { // Check for user mapped types first as they may re-use SqlTypeNames var userType = userTypeMapper.toSubstrait(type); diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index b5cae82eb..267b2525b 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -2,6 +2,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; +import com.google.protobuf.util.JsonFormat; import com.ibm.icu.impl.ClassLoaderUtil; import io.substrait.isthmus.ExtendedExpressionTestBase; import io.substrait.isthmus.SqlToSubstrait; @@ -85,6 +86,7 @@ public void projectDataset() throws SqlParseException, IOException, URISyntaxExc for (int i = 0; i < intVector.getValueCount(); i++) { sum += intVector.get(i); } + System.out.println(reader.getVectorSchemaRoot().contentToTSVString()); } assertEquals(25, count); assertEquals(24 * 25 / 2 + 20 * count, sum); @@ -99,6 +101,9 @@ private static ByteBuffer getFilterExtendedExpression(String sqlExpression) new SqlToSubstrait() .executeExpression( sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); + System.out.println( + "JsonFormat.printer().print(getFilterExtendedExpression): " + + JsonFormat.printer().print(extendedExpression)); byte[] extendedExpressions = Base64.getDecoder() .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); @@ -113,6 +118,9 @@ private static ByteBuffer getProjectExtendedExpression(String sqlExpression) new SqlToSubstrait() .executeExpression( sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); + System.out.println( + "JsonFormat.printer().print(getProjectExtendedExpression): " + + JsonFormat.printer().print(extendedExpression)); byte[] extendedExpressions = Base64.getDecoder() .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); From 52b41e341940e7ade458560c4befbd5408de054d Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 16 Nov 2023 16:37:27 -0500 Subject: [PATCH 06/35] fix: set function reference and extensions dinamically --- .pre-commit-config.yaml | 4 +- .../io/substrait/isthmus/SqlToSubstrait.java | 120 +++++++++++------- .../ExtendedExpressionIntegrationTest.java | 8 +- 3 files changed, 77 insertions(+), 55 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 102b99017..a2505b4ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,11 @@ repos: - repo: https://github.com/adrienverge/yamllint.git - rev: v1.26.0 + rev: v1.33.0 hooks: - id: yamllint args: [-c=.yamllint.yaml] - repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook - rev: v8.0.0 + rev: v9.9.0 hooks: - id: commitlint stages: [commit-msg] diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 70ff4c087..37d088a3e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -3,6 +3,7 @@ import com.github.bsideup.jabel.Desugar; import com.google.common.annotations.VisibleForTesting; import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.RexExpressionConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.proto.Expression; @@ -18,11 +19,8 @@ import io.substrait.type.NamedStruct; import io.substrait.type.TypeCreator; import io.substrait.type.proto.TypeProtoConverter; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.io.IOException; +import java.util.*; import java.util.function.Function; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; @@ -138,34 +136,33 @@ private ExtendedExpression executeInnerExpression( throws SqlParseException { ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); - RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector); sqlToRexNode(sql, validator, catalogReader, nameToTypeMap, nameToNodeMap) .forEach( rexNode -> { // FIXME! Implement it dynamically for more expression types - ResulTraverseRowExpression result = TraverseRexNode.getRowExpression(rexNode); + // ResulTraverseRowExpression result = TraverseRexNode.getRowExpression(rexNode); + ResulTraverseRowExpression result = new TraverseRexNode().getRowExpression(rexNode); - // FIXME! Get output type dynamically: - // final static Map getTypeCreator = new HashMap<>(){{put("BOOLEAN", - // TypeCreator.of(true).BOOLEAN);}}; - // getTypeCreator.get(rexNode.getType()).accept(...) io.substrait.proto.Type output = TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); - // FIXME! setFunctionReference, addArguments(index: 0, 1) + List functionArgumentList = new ArrayList<>(); + result + .expressionBuilderMap() + .forEach( + (k, v) -> { + functionArgumentList.add(FunctionArgument.newBuilder().setValue(v).build()); + }); + + ScalarFunction.Builder scalarFunctionBuilder = + ScalarFunction.newBuilder() + .setFunctionReference(1) // rel_01 + .setOutputType(output) + .addAllArguments(functionArgumentList); + Expression.Builder expressionBuilder = - Expression.newBuilder() - .setScalarFunction( - ScalarFunction.newBuilder() - .setFunctionReference(1) - .setOutputType(output) - .addArguments( - 0, - FunctionArgument.newBuilder().setValue(result.referenceBuilder())) - .addArguments( - 1, - FunctionArgument.newBuilder() - .setValue(result.expressionBuilderLiteral()))); + Expression.newBuilder().setScalarFunction(scalarFunctionBuilder); + ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setExpression(expressionBuilder) @@ -173,32 +170,53 @@ private ExtendedExpression executeInnerExpression( extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); - // Extensions URI FIXME! Populate/create this dynamically - HashMap extensionUris = new HashMap<>(); - extensionUris.put( - "key-001", - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(1) - .setUri("/functions_comparison.yaml") - .build()); - io.substrait.expression.Expression.ScalarFunctionInvocation func = (io.substrait.expression.Expression.ScalarFunctionInvocation) rexNode.accept(rexExpressionConverter); - String declaration = func.declaration().key(); + String declaration = + func.declaration().key(); // values example: gt:any_any, add:i64_i64 + + // this is not mandatory to be defined; it is working without this definition. It is + // only created here to create a proto message that has the correct semantics + HashMap extensionUris = new HashMap<>(); + SimpleExtensionURI simpleExtensionURI; + try { + simpleExtensionURI = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) // rel_02 + .setUri( + SimpleExtension.loadDefaults().scalarFunctions().stream() + .filter(s -> s.toString().equalsIgnoreCase(declaration)) + .findFirst() + .orElseThrow( + () -> + new IllegalArgumentException( + String.format( + "Failed to get URI resource for %s.", declaration))) + .uri()) + .build(); + } catch (IOException e) { + throw new RuntimeException(e); + } + extensionUris.put("uri", simpleExtensionURI); - // Extensions FIXME! Populate/create this dynamically, maybe use rexNode.getKind() ArrayList extensions = new ArrayList<>(); SimpleExtensionDeclaration extensionFunctionLowerThan = SimpleExtensionDeclaration.newBuilder() .setExtensionFunction( SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(1) + .setFunctionAnchor( + scalarFunctionBuilder.getFunctionReference()) // rel_01 .setName(declaration) - .setExtensionUriReference(1)) + .setExtensionUriReference( + simpleExtensionURI.getExtensionUriAnchor())) // rel_02 .build(); extensions.add(extensionFunctionLowerThan); + System.out.println( + "extendedExpressionBuilder.getExtensionUrisList(): " + + extendedExpressionBuilder.getExtensionUrisList()); + // adding it for semantic purposes, it is not mandatory or needed extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); extendedExpressionBuilder.addAllExtensions(extensions); }); @@ -209,13 +227,14 @@ private ExtendedExpression executeInnerExpression( return extendedExpressionBuilder.build(); } - static class TraverseRexNode { - static RexInputRef ref = null; - static Expression.Builder referenceBuilder = null; - static Expression.Builder expressionBuilderLiteral = null; - - static ResulTraverseRowExpression getRowExpression(RexNode rexNode) { + class TraverseRexNode { + RexInputRef ref = null; + int control = 0; + Expression.Builder referenceBuilder = null; + Expression.Builder literalBuilder = null; + Map expressionBuilderMap = new LinkedHashMap<>(); + ResulTraverseRowExpression getRowExpression(RexNode rexNode) { switch (rexNode.getClass().getSimpleName().toUpperCase()) { case "REXCALL": for (RexNode rexInternal : ((RexCall) rexNode).operands) { @@ -234,22 +253,24 @@ static ResulTraverseRowExpression getRowExpression(RexNode rexNode) { .setStructField( Expression.ReferenceSegment.StructField.newBuilder() .setField(ref.getIndex())))); - + expressionBuilderMap.put(control, referenceBuilder); + control++; break; case "REXLITERAL": RexLiteral literal = (RexLiteral) rexNode; - expressionBuilderLiteral = + literalBuilder = Expression.newBuilder() .setLiteral( Expression.Literal.newBuilder().setI32(literal.getValueAs(Integer.class))); + expressionBuilderMap.put(control, literalBuilder); + control++; break; default: throw new AssertionError( "Unsupported type for: " + rexNode.getClass().getSimpleName().toUpperCase()); } - ResulTraverseRowExpression result = - new ResulTraverseRowExpression(ref, referenceBuilder, expressionBuilderLiteral); - return result; + return new ResulTraverseRowExpression( + ref, referenceBuilder, literalBuilder, expressionBuilderMap); } } @@ -257,7 +278,8 @@ static ResulTraverseRowExpression getRowExpression(RexNode rexNode) { private record ResulTraverseRowExpression( RexInputRef ref, Expression.Builder referenceBuilder, - Expression.Builder expressionBuilderLiteral) {} + Expression.Builder literalBuilder, + Map expressionBuilderMap) {} private List sqlToRelNode( String sql, SqlValidator validator, CalciteCatalogReader catalogReader) diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 267b2525b..77472d714 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -8,7 +8,6 @@ import io.substrait.isthmus.SqlToSubstrait; import io.substrait.proto.ExtendedExpression; import java.io.IOException; -import java.net.URISyntaxException; import java.net.URL; import java.nio.ByteBuffer; import java.util.Base64; @@ -30,7 +29,7 @@ public class ExtendedExpressionIntegrationTest { @Test - public void filterDataset() throws SqlParseException, IOException, URISyntaxException { + public void filterDataset() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); String sqlExpression = "N_NATIONKEY > 20"; ScanOptions options = @@ -55,14 +54,15 @@ public void filterDataset() throws SqlParseException, IOException, URISyntaxExce } assertEquals(4, count); } catch (Exception e) { + e.printStackTrace(); throw new RuntimeException(e); } } @Test - public void projectDataset() throws SqlParseException, IOException, URISyntaxException { + public void projectDataset() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); - String sqlExpression = "N_NATIONKEY + 20"; + String sqlExpression = "20 + N_NATIONKEY"; ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) From 3d80d1f8087081fa6fc1f56856a335a0b984b553 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 16 Nov 2023 17:40:03 -0500 Subject: [PATCH 07/35] fix: clean code --- build.gradle.kts | 1 + .../extension/ExtensionCollector.java | 49 ------ isthmus/build.gradle.kts | 5 +- .../substrait/isthmus/SqlConverterBase.java | 22 +-- .../io/substrait/isthmus/SqlToSubstrait.java | 162 ++++++++---------- 5 files changed, 84 insertions(+), 155 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 293163045..47a9da29f 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -34,6 +34,7 @@ val submodulesUpdate by allprojects { repositories { mavenCentral() } + tasks.configureEach { val javaToolchains = project.extensions.getByType() useJUnitPlatform() diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index 714eaec93..bcdd969d4 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -1,6 +1,5 @@ package io.substrait.extension; -import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import io.substrait.proto.SimpleExtensionURI; @@ -99,54 +98,6 @@ public void addExtensionsToPlan(Plan.Builder builder) { builder.addAllExtensions(extensionList); } - public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) { - var uriPos = new AtomicInteger(1); - var uris = new HashMap(); - - var extensionList = new ArrayList(); - for (var e : funcMap.forwardMap.entrySet()) { - SimpleExtensionURI uri = - uris.computeIfAbsent( - e.getValue().namespace(), - k -> - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(uriPos.getAndIncrement()) - .setUri(k) - .build()); - var decl = - SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(e.getKey()) - .setName(e.getValue().key()) - .setExtensionUriReference(uri.getExtensionUriAnchor())) - .build(); - extensionList.add(decl); - } - for (var e : typeMap.forwardMap.entrySet()) { - SimpleExtensionURI uri = - uris.computeIfAbsent( - e.getValue().namespace(), - k -> - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(uriPos.getAndIncrement()) - .setUri(k) - .build()); - var decl = - SimpleExtensionDeclaration.newBuilder() - .setExtensionType( - SimpleExtensionDeclaration.ExtensionType.newBuilder() - .setTypeAnchor(e.getKey()) - .setName(e.getValue().key()) - .setExtensionUriReference(uri.getExtensionUriAnchor())) - .build(); - extensionList.add(decl); - } - - builder.addAllExtensionUris(uris.values()); - builder.addAllExtensions(extensionList); - } - /** We don't depend on guava... */ private static class BidiMap { private final Map forwardMap; diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index a5ae3abd2..abf5e412c 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -72,6 +72,7 @@ java { } var CALCITE_VERSION = "1.34.0" +var ARROW_VERSION = "14.0.0" dependencies { implementation(project(":core")) @@ -94,8 +95,8 @@ dependencies { implementation("org.immutables:value-annotations:2.8.8") annotationProcessor("org.immutables:value:2.8.8") testImplementation("org.apache.calcite:calcite-plus:${CALCITE_VERSION}") - testImplementation("org.apache.arrow:arrow-dataset:14.0.0") - testImplementation("org.apache.arrow:arrow-memory-netty:14.0.0") + testImplementation("org.apache.arrow:arrow-dataset:${ARROW_VERSION}") + testImplementation("org.apache.arrow:arrow-memory-netty:${ARROW_VERSION}") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index 3466cf826..716fdb66e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -89,20 +89,6 @@ protected SqlConverterBase(FeatureBoard features) { EXTENSION_COLLECTION = defaults; } - /* - HashMap nameToNodeMap = new HashMap<>(); - nameToNodeMap.put( - "N_NATIONKEY", - new RexInputRef(0, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); - nameToNodeMap.put( - "N_REGIONKEY", - new RexInputRef(1, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); - final Map nameToTypeMap = new HashMap<>(); - for (Map.Entry entry : nameToNodeMap.entrySet()) { - nameToTypeMap.put(entry.getKey(), entry.getValue().getType()); - } - */ - Result registerCreateTables(List tables) throws SqlParseException { Map nameToTypeMap = new LinkedHashMap<>(); Map nameToNodeMap = new HashMap<>(); @@ -116,8 +102,12 @@ Result registerCreateTables(List tables) throws SqlParseException { for (DefinedTable t : tList) { rootSchema.add(t.getName(), t); for (RelDataTypeField field : t.type.getFieldList()) { - nameToTypeMap.put(field.getName(), field.getType()); - nameToNodeMap.put(field.getName(), new RexInputRef(field.getIndex(), field.getType())); + nameToTypeMap.put( + field.getName(), field.getType()); // to validate the sql expression tree + nameToNodeMap.put( + field.getName(), + new RexInputRef( + field.getIndex(), field.getType())); // to convert sql expression into RexNode } } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 37d088a3e..50d18f361 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -136,90 +136,80 @@ private ExtendedExpression executeInnerExpression( throws SqlParseException { ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); - sqlToRexNode(sql, validator, catalogReader, nameToTypeMap, nameToNodeMap) + RexNode rexNode = sqlToRexNode(sql, validator, catalogReader, nameToTypeMap, nameToNodeMap); + ResulTraverseRowExpression result = new TraverseRexNode().getRowExpression(rexNode); + io.substrait.proto.Type output = + TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); + List functionArgumentList = new ArrayList<>(); + result + .expressionBuilderMap() .forEach( - rexNode -> { - // FIXME! Implement it dynamically for more expression types - // ResulTraverseRowExpression result = TraverseRexNode.getRowExpression(rexNode); - ResulTraverseRowExpression result = new TraverseRexNode().getRowExpression(rexNode); - - io.substrait.proto.Type output = - TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); - - List functionArgumentList = new ArrayList<>(); - result - .expressionBuilderMap() - .forEach( - (k, v) -> { - functionArgumentList.add(FunctionArgument.newBuilder().setValue(v).build()); - }); - - ScalarFunction.Builder scalarFunctionBuilder = - ScalarFunction.newBuilder() - .setFunctionReference(1) // rel_01 - .setOutputType(output) - .addAllArguments(functionArgumentList); - - Expression.Builder expressionBuilder = - Expression.newBuilder().setScalarFunction(scalarFunctionBuilder); - - ExpressionReference.Builder expressionReferenceBuilder = - ExpressionReference.newBuilder() - .setExpression(expressionBuilder) - .addOutputNames(result.ref().getName()); - - extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); - - io.substrait.expression.Expression.ScalarFunctionInvocation func = - (io.substrait.expression.Expression.ScalarFunctionInvocation) - rexNode.accept(rexExpressionConverter); - String declaration = - func.declaration().key(); // values example: gt:any_any, add:i64_i64 - - // this is not mandatory to be defined; it is working without this definition. It is - // only created here to create a proto message that has the correct semantics - HashMap extensionUris = new HashMap<>(); - SimpleExtensionURI simpleExtensionURI; - try { - simpleExtensionURI = - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(1) // rel_02 - .setUri( - SimpleExtension.loadDefaults().scalarFunctions().stream() - .filter(s -> s.toString().equalsIgnoreCase(declaration)) - .findFirst() - .orElseThrow( - () -> - new IllegalArgumentException( - String.format( - "Failed to get URI resource for %s.", declaration))) - .uri()) - .build(); - } catch (IOException e) { - throw new RuntimeException(e); - } - extensionUris.put("uri", simpleExtensionURI); + (k, v) -> { + functionArgumentList.add(FunctionArgument.newBuilder().setValue(v).build()); + }); - ArrayList extensions = new ArrayList<>(); - SimpleExtensionDeclaration extensionFunctionLowerThan = - SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor( - scalarFunctionBuilder.getFunctionReference()) // rel_01 - .setName(declaration) - .setExtensionUriReference( - simpleExtensionURI.getExtensionUriAnchor())) // rel_02 - .build(); - extensions.add(extensionFunctionLowerThan); + ScalarFunction.Builder scalarFunctionBuilder = + ScalarFunction.newBuilder() + .setFunctionReference(1) // rel_01 + .setOutputType(output) + .addAllArguments(functionArgumentList); + + Expression.Builder expressionBuilder = + Expression.newBuilder().setScalarFunction(scalarFunctionBuilder); + + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expressionBuilder) + .addOutputNames(result.ref().getName()); + + extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); + + io.substrait.expression.Expression.ScalarFunctionInvocation func = + (io.substrait.expression.Expression.ScalarFunctionInvocation) + rexNode.accept(rexExpressionConverter); + String declaration = func.declaration().key(); // values example: gt:any_any, add:i64_i64 + + // this is not mandatory to be defined; it is working without this definition. It is + // only created here to create a proto message that has the correct semantics + HashMap extensionUris = new HashMap<>(); + SimpleExtensionURI simpleExtensionURI; + try { + simpleExtensionURI = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) // rel_02 + .setUri( + SimpleExtension.loadDefaults().scalarFunctions().stream() + .filter(s -> s.toString().equalsIgnoreCase(declaration)) + .findFirst() + .orElseThrow( + () -> + new IllegalArgumentException( + String.format("Failed to get URI resource for %s.", declaration))) + .uri()) + .build(); + } catch (IOException e) { + throw new RuntimeException(e); + } + extensionUris.put("uri", simpleExtensionURI); + + ArrayList extensions = new ArrayList<>(); + SimpleExtensionDeclaration extensionFunctionLowerThan = + SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(scalarFunctionBuilder.getFunctionReference()) // rel_01 + .setName(declaration) + .setExtensionUriReference(simpleExtensionURI.getExtensionUriAnchor())) // rel_02 + .build(); + extensions.add(extensionFunctionLowerThan); + + System.out.println( + "extendedExpressionBuilder.getExtensionUrisList(): " + + extendedExpressionBuilder.getExtensionUrisList()); + // adding it for semantic purposes, it is not mandatory or needed + extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); + extendedExpressionBuilder.addAllExtensions(extensions); - System.out.println( - "extendedExpressionBuilder.getExtensionUrisList(): " - + extendedExpressionBuilder.getExtensionUrisList()); - // adding it for semantic purposes, it is not mandatory or needed - extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); - extendedExpressionBuilder.addAllExtensions(extensions); - }); NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); extendedExpressionBuilder.setBaseSchema( namedStruct.toProto(new TypeProtoConverter(functionCollector))); @@ -297,7 +287,7 @@ private List sqlToRelNode( return roots; } - private List sqlToRexNode( + private RexNode sqlToRexNode( String sql, SqlValidator validator, CalciteCatalogReader catalogReader, @@ -306,13 +296,9 @@ private List sqlToRexNode( throws SqlParseException { SqlParser parser = SqlParser.create(sql, parserConfig); SqlNode sqlNode = parser.parseExpression(); - SqlNode validSQLNode = - validator.validateParameterizedExpression( - sqlNode, nameToTypeMap); // FIXME! It may be optional to include this validation + SqlNode validSQLNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap); SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); - RexNode rexNode = converter.convertExpression(validSQLNode, nameToNodeMap); - - return Collections.singletonList(rexNode); + return converter.convertExpression(validSQLNode, nameToNodeMap); } @VisibleForTesting From 5954a626af7aab6045158581d0baa6d17d6e645d Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 16 Nov 2023 18:24:23 -0500 Subject: [PATCH 08/35] fix: clean code --- .../integration/ExtendedExpressionIntegrationTest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 77472d714..da34e62f1 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -29,7 +29,7 @@ public class ExtendedExpressionIntegrationTest { @Test - public void filterDataset() throws SqlParseException, IOException { + public void filterDatasetUsingExtendedExpression() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); String sqlExpression = "N_NATIONKEY > 20"; ScanOptions options = @@ -60,7 +60,7 @@ public void filterDataset() throws SqlParseException, IOException { } @Test - public void projectDataset() throws SqlParseException, IOException { + public void projectDatasetUsingExtendedExpression() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); String sqlExpression = "20 + N_NATIONKEY"; ScanOptions options = From fc33a3233800f91300393e1f001dd7e91aa40f5b Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Fri, 17 Nov 2023 17:05:30 -0500 Subject: [PATCH 09/35] fix: rename variables to clean code --- .../io/substrait/isthmus/SqlToSubstrait.java | 22 ++++++++--- .../isthmus/ExtendedExpressionTestBase.java | 2 +- .../SimpleExtendedExpressionsTest.java | 22 ++++++++++- .../ExtendedExpressionIntegrationTest.java | 39 +++++++++++++++++-- 4 files changed, 73 insertions(+), 12 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 50d18f361..fa61c5e81 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -73,11 +73,20 @@ public Plan execute(String sql, String name, Schema schema) throws SqlParseExcep return executeInner(sql, factory, pair.left, pair.right); } - public ExtendedExpression executeExpression(String expr, List tables) + /** + * Process to execute an SQL Expression to convert into an Extended expression protobuf message + * + * @param sqlExpression expression defined by the user + * @param tables of names of table needed to consider to load into memory for catalog, schema, + * validate and parse sql + * @return extended expression protobuf message + * @throws SqlParseException + */ + public ExtendedExpression executeSQLExpression(String sqlExpression, List tables) throws SqlParseException { var result = registerCreateTables(tables); - return executeInnerExpression( - expr, + return executeInnerSQLExpression( + sqlExpression, result.validator(), result.catalogReader(), result.nameToTypeMap(), @@ -127,8 +136,8 @@ private Plan executeInner( return plan.build(); } - private ExtendedExpression executeInnerExpression( - String sql, + private ExtendedExpression executeInnerSQLExpression( + String sqlExpression, SqlValidator validator, CalciteCatalogReader catalogReader, Map nameToTypeMap, @@ -136,7 +145,8 @@ private ExtendedExpression executeInnerExpression( throws SqlParseException { ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); - RexNode rexNode = sqlToRexNode(sql, validator, catalogReader, nameToTypeMap, nameToNodeMap); + RexNode rexNode = + sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); ResulTraverseRowExpression result = new TraverseRexNode().getRowExpression(rexNode); io.substrait.proto.Type output = TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index 10f3f57e3..3bee0b61e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -35,7 +35,7 @@ protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query, protected ExtendedExpression assertProtoExtendedExpressionRoundrip( String query, SqlToSubstrait s, List creates) throws SqlParseException { io.substrait.proto.ExtendedExpression protoExtendedExpression = - s.executeExpression(query, creates); + s.executeSQLExpression(query, creates); try { String ee = JsonFormat.printer().print(protoExtendedExpression); diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index bfcea38c9..8b0248ca4 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -1,5 +1,6 @@ package io.substrait.isthmus; +import io.substrait.proto.ExtendedExpression; import java.io.IOException; import org.apache.calcite.sql.parser.SqlParseException; import org.junit.jupiter.api.Test; @@ -8,6 +9,25 @@ public class SimpleExtendedExpressionsTest extends ExtendedExpressionTestBase { @Test public void filter() throws IOException, SqlParseException { - assertProtoExtendedExpressionRoundrip("N_NATIONKEY > 18"); + ExtendedExpression extendedExpression = + assertProtoExtendedExpressionRoundrip("L_ORDERKEY > 10"); + } + + @Test + public void in() throws IOException, SqlParseException { + ExtendedExpression extendedExpression = + assertProtoExtendedExpressionRoundrip("L_ORDERKEY IN (10, 20)"); + } + + @Test + public void isNotNull() throws IOException, SqlParseException { + ExtendedExpression extendedExpression = + assertProtoExtendedExpressionRoundrip("L_ORDERKEY is not null"); + } + + @Test + public void isNull() throws IOException, SqlParseException { + ExtendedExpression extendedExpression = + assertProtoExtendedExpressionRoundrip("L_ORDERKEY is null"); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index da34e62f1..03d5f63ad 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -29,7 +29,7 @@ public class ExtendedExpressionIntegrationTest { @Test - public void filterDatasetUsingExtendedExpression() throws SqlParseException, IOException { + public void filterDataset() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); String sqlExpression = "N_NATIONKEY > 20"; ScanOptions options = @@ -60,7 +60,7 @@ public void filterDatasetUsingExtendedExpression() throws SqlParseException, IOE } @Test - public void projectDatasetUsingExtendedExpression() throws SqlParseException, IOException { + public void projectDataset() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); String sqlExpression = "20 + N_NATIONKEY"; ScanOptions options = @@ -95,11 +95,42 @@ public void projectDatasetUsingExtendedExpression() throws SqlParseException, IO } } + @Test + public void filterDatasetUsingExtendedExpression() throws SqlParseException, IOException { + URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); + String sqlExpression = "N_NATIONKEY > 20"; + ScanOptions options = + new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitFilter(getFilterExtendedExpression(sqlExpression)) + .build(); + try (BufferAllocator allocator = new RootAllocator(); + DatasetFactory datasetFactory = + new FileSystemDatasetFactory( + allocator, + NativeMemoryPool.getDefault(), + FileFormat.PARQUET, + resource.toURI().toString()); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches()) { + int count = 0; + while (reader.loadNextBatch()) { + count += reader.getVectorSchemaRoot().getRowCount(); + System.out.println(reader.getVectorSchemaRoot().contentToTSVString()); + } + assertEquals(4, count); + } catch (Exception e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + } + private static ByteBuffer getFilterExtendedExpression(String sqlExpression) throws IOException, SqlParseException { ExtendedExpression extendedExpression = new SqlToSubstrait() - .executeExpression( + .executeSQLExpression( sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); System.out.println( "JsonFormat.printer().print(getFilterExtendedExpression): " @@ -116,7 +147,7 @@ private static ByteBuffer getProjectExtendedExpression(String sqlExpression) throws IOException, SqlParseException { ExtendedExpression extendedExpression = new SqlToSubstrait() - .executeExpression( + .executeSQLExpression( sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); System.out.println( "JsonFormat.printer().print(getProjectExtendedExpression): " From 217f2a0a6160de5e229e7abe9d882786076dcb55 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 23 Nov 2023 16:46:51 -0500 Subject: [PATCH 10/35] fix: from/to pojo/protobuf --- core/build.gradle.kts | 5 + .../expression/ExpressionReference.java | 11 + .../expression/ExtendedExpression.java | 20 ++ .../ExtendedExpressionProtoConverter.java | 210 ++++++++++++++++++ .../extended/expression/MyCliente.java | 75 +++++++ .../ProtoExtendedExpressionConverter.java | 176 +++++++++++++++ .../extension/ExtensionCollector.java | 49 ++++ .../extension/ImmutableExtensionLookup.java | 44 ++++ .../io/substrait/plan/PlanProtoConverter.java | 4 + .../ExtendedExpressionProtoConverterTest.java | 76 +++++++ .../ProtoExtendedExpressionConverterTest.java | 94 ++++++++ .../io/substrait/isthmus/SqlToSubstrait.java | 105 +++++++++ .../isthmus/ExtendedExpressionTestBase.java | 11 +- .../SimpleExtendedExpressionsTest.java | 8 +- .../ExtendedExpressionIntegrationTest.java | 2 +- isthmus/src/test/resources/tpch/schema.sql | 71 ------ 16 files changed, 887 insertions(+), 74 deletions(-) create mode 100644 core/src/main/java/io/substrait/extended/expression/ExpressionReference.java create mode 100644 core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java create mode 100644 core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java create mode 100644 core/src/main/java/io/substrait/extended/expression/MyCliente.java create mode 100644 core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java create mode 100644 core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java create mode 100644 core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 6b8cfac66..c06a00e86 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -85,6 +85,11 @@ dependencies { compileOnly("org.immutables:value-annotations:2.8.8") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") + + implementation("com.google.protobuf:protobuf-java-util:3.17.3") { + exclude("com.google.guava", "guava") + .because("Brings in Guava for Android, which we don't want (and breaks multimaps).") + } } java { diff --git a/core/src/main/java/io/substrait/extended/expression/ExpressionReference.java b/core/src/main/java/io/substrait/extended/expression/ExpressionReference.java new file mode 100644 index 000000000..2214f0438 --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/ExpressionReference.java @@ -0,0 +1,11 @@ +package io.substrait.extended.expression; + +import io.substrait.expression.Expression; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class ExpressionReference { + public abstract Expression getExpression(); + + public abstract String getOutputNames(); +} diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java new file mode 100644 index 000000000..e0a5c03f9 --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java @@ -0,0 +1,20 @@ +package io.substrait.extended.expression; + +import io.substrait.expression.Expression; +import io.substrait.proto.AdvancedExtension; +import io.substrait.type.NamedStruct; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class ExtendedExpression { + public abstract Map getReferredExpr(); + + public abstract NamedStruct getBaseSchema(); + + public abstract List getExpectedTypeUrls(); + + public abstract Optional getAdvancedExtension(); +} diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java new file mode 100644 index 000000000..e4015f3cf --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java @@ -0,0 +1,210 @@ +package io.substrait.extended.expression; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.expression.*; +import io.substrait.expression.proto.ExpressionProtoConverter; +import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; +import io.substrait.proto.ExpressionReference; +import io.substrait.proto.ExtendedExpression; +import io.substrait.type.NamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import io.substrait.type.proto.TypeProtoConverter; +import java.io.IOException; +import java.util.*; + +public class ExtendedExpressionProtoConverter { + public ExtendedExpression toProto( + io.substrait.extended.expression.ExtendedExpression extendedExpressionPojo) { + + ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); + ExtensionCollector functionCollector = new ExtensionCollector(); + + final ExpressionProtoConverter expressionProtoConverter = + new ExpressionProtoConverter(functionCollector, null); + + // convert expression pojo into expression protobuf + io.substrait.proto.Expression expressionProto = + expressionProtoConverter.visit( + (Expression.ScalarFunctionInvocation) extendedExpressionPojo.getReferredExpr().get(0)); + + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder().setExpression(expressionProto).addOutputNames("column-01"); + + extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); + extendedExpressionBuilder.setBaseSchema( + extendedExpressionPojo.getBaseSchema().toProto(new TypeProtoConverter(functionCollector))); + + + functionCollector.addExtensionsToPlan(extendedExpressionBuilder); + if (extendedExpressionPojo.getAdvancedExtension().isPresent()) { + extendedExpressionBuilder.setAdvancedExtensions( + extendedExpressionPojo.getAdvancedExtension().get()); + } + return extendedExpressionBuilder.build(); + } + + public static void main(String[] args) throws IOException { + SimpleExtension.ExtensionCollection defaultExtensionCollection = SimpleExtension.loadDefaults(); + System.out.println( + "defaultExtensionCollection.scalarFunctions(): " + + defaultExtensionCollection.scalarFunctions()); + System.out.println( + "defaultExtensionCollection.windowFunctions(): " + + defaultExtensionCollection.windowFunctions()); + System.out.println( + "defaultExtensionCollection.aggregateFunctions(): " + + defaultExtensionCollection.aggregateFunctions()); + + Optional equal = + defaultExtensionCollection.scalarFunctions().stream() + .filter( + s -> { + return s.name().equalsIgnoreCase("add"); + }) + .findFirst() + .map( + declaration -> + ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.I32) + .build(), + ExpressionCreator.i32(false, 183))); + + Map indexToExpressionMap = new HashMap<>(); + indexToExpressionMap.put(0, equal.get()); + List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); + List dataTypes = + Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING); + NamedStruct namedStruct = + NamedStruct.of( + columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); + ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.builder() + .putAllReferredExpr(indexToExpressionMap) + .baseSchema(namedStruct); + + ExtendedExpression proto = new ExtendedExpressionProtoConverter().toProto(builder.build()); + + System.out.println( + "JsonFormat.printer().print(getFilterExtendedExpression): " + + JsonFormat.printer().print(proto)); + } + + public static ExtendedExpression createExtendedExpression( + io.substrait.expression.Expression.ScalarFunctionInvocation expr) { + ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); + + io.substrait.proto.Expression expression = new ExpressionProtoConverter(null, null).visit(expr); + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expression.toBuilder()) + .addOutputNames("col-01"); + + extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); + + return extendedExpressionBuilder.build(); + } + + public static void createExtendedExpressionManually() { + + Map nameToExpressionMap = new HashMap<>(); + ImmutableExpression.I32Literal build = Expression.I32Literal.builder().value(10).build(); + nameToExpressionMap.put("out_01", build); + + List expressionList = new ArrayList<>(); + expressionList.add(0, null); + + // nation table + List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); + List dataTypes = + Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING); + NamedStruct namedStruct = + NamedStruct.of( + columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); + + ExtensionCollector functionCollector = new ExtensionCollector(); + + new ExpressionProtoConverter(new ExtensionCollector(), null); + + FunctionArg functionArg = + new FunctionArg() { + @Override + public R accept( + SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor) + throws E { + return null; + } + }; + + // var argVisitor = FunctionArg.toProto(new TypeProtoConverter(new ExtensionCollector()), this); + + /* + + // FIXME! setFunctionReference, addArguments(index: 0, 1) + io.substrait.proto.Expression.Builder expressionBuilder = + io.substrait.proto.Expression.newBuilder() + .setScalarFunction( + io.substrait.proto.Expression.ScalarFunction.newBuilder() + .setFunctionReference(1) + .setOutputType(output) + .addArguments( + 0, + FunctionArgument.newBuilder().setValue(result.referenceBuilder())) + .addArguments( + 1, + FunctionArgument.newBuilder() + .setValue(result.expressionBuilderLiteral()))); + io.substrait.proto.ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expressionBuilder) + .addOutputNames(result.ref().getName()); + + */ + + /* + + io.substrait.extended.expression.ExtendedExpression extendedExpression = new io.substrait.extended.expression.ExtendedExpression() { + @Override + public List getReferredExpr() { + io.substrait.extended.expression.ExpressionReference + + @Override + public NamedStruct getBaseSchema() { + return null; + } + + @Override + public List getExpectedTypeUrls() { + return null; + } + + @Override + public Optional getAdvancedExtension() { + return Optional.empty(); + } + }; + + System.out.println("inicio"); + System.out.println(extendedExpression.getReferredExpr().get(0)); + System.out.println(extendedExpression.getReferredExpr().get(0).getType()); + System.out.println("fin"); + + ExpressionReferenceOrBuilder + + */ + + } +} diff --git a/core/src/main/java/io/substrait/extended/expression/MyCliente.java b/core/src/main/java/io/substrait/extended/expression/MyCliente.java new file mode 100644 index 000000000..0ec8a90ed --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/MyCliente.java @@ -0,0 +1,75 @@ +package io.substrait.extended.expression; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.extension.SimpleExtension; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.NamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.io.IOException; +import java.util.*; + +public class MyCliente { + public static void main(String[] args) throws IOException { + SimpleExtension.ExtensionCollection defaultExtensionCollection = SimpleExtension.loadDefaults(); + Optional equal = + defaultExtensionCollection.scalarFunctions().stream() + .filter( + s -> { + System.out.println(":>>>>"); + System.out.println(s); + System.out.println(s.uri()); + System.out.println(s.returnType()); + System.out.println(s.description()); + System.out.println("s.name(): " + s.name()); + System.out.println(s.key()); + return s.name().equalsIgnoreCase("add"); + }) + .findFirst() + .map( + declaration -> { + System.out.println("declaration: " + declaration); + System.out.println("declaration.name(): " + declaration.name()); + return ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.I32) + .build(), + ExpressionCreator.i32(false, 183)); + }); + + Map indexToExpressionMap = new HashMap<>(); + indexToExpressionMap.put(0, equal.get()); + List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); + List dataTypes = + Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING); + NamedStruct namedStruct = + NamedStruct.of( + columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); + + ImmutableNamedStruct.builder() + .addNames("id") + .struct(Type.Struct.builder().nullable(false).addFields(TypeCreator.REQUIRED.I32).build()) + .build(); + + ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.builder() + .putAllReferredExpr(indexToExpressionMap) + .baseSchema(namedStruct); + + System.out.println( + "JsonFormat.printer().print(getFilterExtendedExpression): " + + JsonFormat.printer() + .print(new ExtendedExpressionProtoConverter().toProto(builder.build()))); + } +} diff --git a/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java new file mode 100644 index 000000000..ed1fc377b --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java @@ -0,0 +1,176 @@ +package io.substrait.extended.expression; + +import io.substrait.expression.Expression; +import io.substrait.expression.proto.ProtoExpressionConverter; +import io.substrait.extension.*; +import io.substrait.proto.ExpressionReference; +import io.substrait.proto.NamedStruct; +import io.substrait.relation.ProtoRelConverter; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.Type; +import io.substrait.type.proto.ProtoTypeConverter; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +public class ProtoExtendedExpressionConverter { + private ExtensionCollector lookup = new ExtensionCollector(); + private ProtoTypeConverter protoTypeConverter = + new ProtoTypeConverter( + lookup, ImmutableSimpleExtension.ExtensionCollection.builder().build()); + + private ProtoExpressionConverter getPprotoExpressionConverter(ExtensionLookup functionLookup) { + return new ProtoExpressionConverter( + functionLookup, + this.extensionCollection, + null, + null); + } + + private ProtoExpressionConverter getPprotoExpressionConverter(ExtensionLookup functionLookup, io.substrait.type.NamedStruct namedStruct) { + return new ProtoExpressionConverter( + functionLookup, + this.extensionCollection, + namedStruct.struct(), + null); + } + + protected final SimpleExtension.ExtensionCollection extensionCollection; + + public ProtoExtendedExpressionConverter() throws IOException { + this(SimpleExtension.loadDefaults()); + } + + public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection extensionCollection) { + this.extensionCollection = extensionCollection; + } + + protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) { + return new ProtoRelConverter(functionLookup, this.extensionCollection); + } + + public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExpressionProto) { + ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder().from(extendedExpressionProto).build(); + + + // para struct + NamedStruct baseSchema = extendedExpressionProto.getBaseSchema(); + io.substrait.type.NamedStruct namedStruct = newNamedStruct(baseSchema); + + System.out.println("namedStruct"); + System.out.println(namedStruct); + + ProtoExpressionConverter protoExpressionConverter = + getPprotoExpressionConverter(functionLookup, namedStruct); + + Map indexToExpressionMap = new HashMap<>(); + for (ExpressionReference expressionReference : extendedExpressionProto.getReferredExprList()) { + System.out.println( + "expressionReference.getExpression(): " + expressionReference.getExpression()); + indexToExpressionMap.put( + 0, protoExpressionConverter.from(expressionReference.getExpression())); + } + + // para struct + /* + NamedStruct baseSchema = extendedExpressionProto.getBaseSchema(); + io.substrait.type.NamedStruct namedStruct = newNamedStruct(baseSchema); + + System.out.println("namedStruct"); + System.out.println(namedStruct); + + */ + + ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.builder() + .putAllReferredExpr(indexToExpressionMap) + .advancedExtension( + Optional.ofNullable( + extendedExpressionProto.hasAdvancedExtensions() + ? extendedExpressionProto.getAdvancedExtensions() + : null)) + .baseSchema(namedStruct); + /* + ProtocolStringList namesList = baseSchema.getNamesList(); + + Type.Struct struct = baseSchema.getStruct(); + Type types = struct.getTypes(0); + System.out.println("types.getDescriptorForType().getName(): " + types.getDescriptorForType().); + + + */ + + /* + System.out.println("namesList: " + namesList); + System.out.println("baseSchema.getStruct(): " + baseSchema.getStruct()); + System.out.println("}}{{{{{{{{{{''------>"); + System.out.println("baseSchema.getStruct(): " + baseSchema.getStruct().getTypes(0)); + + + */ + + /* + ImmutableNamedStruct.builder(). + + // para expression + + Optional equal = + defaultExtensionCollection.scalarFunctions().stream() + .filter( + s -> { + System.out.println(":>>>>"); + System.out.println(s); + System.out.println(s.uri()); + System.out.println(s.returnType()); + System.out.println(s.description()); + System.out.println("s.name(): " + s.name()); + System.out.println(s.key()); + return s.name().equalsIgnoreCase("add"); + }) + .findFirst() + .map( + declaration -> { + System.out.println("declaration: " + declaration); + System.out.println("declaration.name(): " + declaration.name()); + return ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.I32) + .build(), + ExpressionCreator.i32(false, 183) + ); + } + ); + + Map indexToExpressionMap = new HashMap<>(); + indexToExpressionMap.put(0, equal.get()); + + ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.builder() + .putAllReferredExpr(indexToExpressionMap) + .baseSchema(namedStruct); + + */ + + return builder.build(); + } + + private io.substrait.type.NamedStruct newNamedStruct(NamedStruct namedStruct) { + var struct = namedStruct.getStruct(); + return ImmutableNamedStruct.builder() + .names(namedStruct.getNamesList()) + .struct( + Type.Struct.builder() + .fields( + struct.getTypesList().stream() + .map(protoTypeConverter::from) + .collect(java.util.stream.Collectors.toList())) + .nullable(ProtoTypeConverter.isNullable(struct.getNullability())) + .build()) + .build(); + } +} diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index bcdd969d4..f2a4a6f18 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -1,5 +1,6 @@ package io.substrait.extension; +import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import io.substrait.proto.SimpleExtensionURI; @@ -98,6 +99,54 @@ public void addExtensionsToPlan(Plan.Builder builder) { builder.addAllExtensions(extensionList); } + public void addExtensionsToPlan(ExtendedExpression.Builder builder) { + var uriPos = new AtomicInteger(1); + var uris = new HashMap(); + + var extensionList = new ArrayList(); + for (var e : funcMap.forwardMap.entrySet()) { + SimpleExtensionURI uri = + uris.computeIfAbsent( + e.getValue().namespace(), + k -> + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(uriPos.getAndIncrement()) + .setUri(k) + .build()); + var decl = + SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(e.getKey()) + .setName(e.getValue().key()) + .setExtensionUriReference(uri.getExtensionUriAnchor())) + .build(); + extensionList.add(decl); + } + for (var e : typeMap.forwardMap.entrySet()) { + SimpleExtensionURI uri = + uris.computeIfAbsent( + e.getValue().namespace(), + k -> + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(uriPos.getAndIncrement()) + .setUri(k) + .build()); + var decl = + SimpleExtensionDeclaration.newBuilder() + .setExtensionType( + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(e.getKey()) + .setName(e.getValue().key()) + .setExtensionUriReference(uri.getExtensionUriAnchor())) + .build(); + extensionList.add(decl); + } + + builder.addAllExtensionUris(uris.values()); + builder.addAllExtensions(extensionList); + } + /** We don't depend on guava... */ private static class BidiMap { private final Map forwardMap; diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index 2cab2cce8..3d600002b 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -1,5 +1,6 @@ package io.substrait.extension; +import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import java.util.Collections; @@ -73,6 +74,49 @@ public Builder from(Plan p) { return this; } + public Builder from(ExtendedExpression p) { + Map namespaceMap = new HashMap<>(); + for (var extension : p.getExtensionUrisList()) { + namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri()); + } + + // Add all functions used in plan to the functionMap + for (var extension : p.getExtensionsList()) { + if (!extension.hasExtensionFunction()) { + continue; + } + SimpleExtensionDeclaration.ExtensionFunction func = extension.getExtensionFunction(); + int reference = func.getFunctionAnchor(); + String namespace = namespaceMap.get(func.getExtensionUriReference()); + if (namespace == null) { + throw new IllegalStateException( + "Could not find extension URI of " + func.getExtensionUriReference()); + } + String name = func.getName(); + SimpleExtension.FunctionAnchor anchor = SimpleExtension.FunctionAnchor.of(namespace, name); + functionMap.put(reference, anchor); + } + + // Add all types used in plan to the typeMap + for (var extension : p.getExtensionsList()) { + if (!extension.hasExtensionType()) { + continue; + } + SimpleExtensionDeclaration.ExtensionType type = extension.getExtensionType(); + int reference = type.getTypeAnchor(); + String namespace = namespaceMap.get(type.getExtensionUriReference()); + if (namespace == null) { + throw new IllegalStateException( + "Could not find extension URI of " + type.getExtensionUriReference()); + } + String name = type.getName(); + SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(namespace, name); + typeMap.put(reference, anchor); + } + + return this; + } + public ImmutableExtensionLookup build() { return new ImmutableExtensionLookup( Collections.unmodifiableMap(functionMap), Collections.unmodifiableMap(typeMap)); diff --git a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java index 0bdf7d68c..af0f6d69a 100644 --- a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java +++ b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java @@ -34,6 +34,10 @@ public Plan toProto(io.substrait.plan.Plan plan) { if (plan.getAdvancedExtension().isPresent()) { builder.setAdvancedExtensions(plan.getAdvancedExtension().get()); } + /* + extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); + extendedExpressionBuilder.addAllExtensions(extensions); + */ return builder.build(); } } diff --git a/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java b/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java new file mode 100644 index 000000000..4004e9de7 --- /dev/null +++ b/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java @@ -0,0 +1,76 @@ +package io.substrait.extended.expression; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.NamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.io.IOException; +import java.util.*; +import org.junit.jupiter.api.Test; + +public class ExtendedExpressionProtoConverterTest extends TestBase { + @Test + public void toProtoTest() throws IOException { + Optional equal = + defaultExtensionCollection.scalarFunctions().stream() + .filter( + s -> { + System.out.println(":>>>>"); + System.out.println(s); + System.out.println(s.uri()); + System.out.println(s.returnType()); + System.out.println(s.description()); + System.out.println("s.name(): " + s.name()); + System.out.println(s.key()); + return s.name().equalsIgnoreCase("add"); + }) + .findFirst() + .map( + declaration -> { + System.out.println("declaration: " + declaration); + System.out.println("declaration.name(): " + declaration.name()); + return ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.I32) + .build(), + ExpressionCreator.i32(false, 183)); + }); + + Map indexToExpressionMap = new HashMap<>(); + indexToExpressionMap.put(0, equal.get()); + List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); + List dataTypes = + Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING); + NamedStruct namedStruct = + NamedStruct.of( + columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); + + ImmutableNamedStruct.builder() + .addNames("id") + .struct(Type.Struct.builder().nullable(false).addFields(TypeCreator.REQUIRED.I32).build()) + .build(); + + ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.builder() + .putAllReferredExpr(indexToExpressionMap) + .baseSchema(namedStruct); + + System.out.println( + "JsonFormat.printer().print(getFilterExtendedExpression): " + + JsonFormat.printer() + .print(new ExtendedExpressionProtoConverter().toProto(builder.build()))); + } +} diff --git a/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java b/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java new file mode 100644 index 000000000..1c1fa40b5 --- /dev/null +++ b/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java @@ -0,0 +1,94 @@ +package io.substrait.extended.expression; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.proto.ExtendedExpression; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.NamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.io.IOException; +import java.util.*; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ProtoExtendedExpressionConverterTest extends TestBase { + @Test + public void fromTest() throws IOException { + Optional equal = + defaultExtensionCollection.scalarFunctions().stream() + .filter( + s -> { + System.out.println(":>>>>"); + System.out.println(s); + System.out.println(s.uri()); + System.out.println(s.returnType()); + System.out.println(s.description()); + System.out.println("s.name(): " + s.name()); + System.out.println(s.key()); + return s.name().equalsIgnoreCase("add"); + }) + .findFirst() + .map( + declaration -> { + System.out.println("declaration: " + declaration); + System.out.println("declaration.name(): " + declaration.name()); + return ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.I32) + .build(), + ExpressionCreator.i32(false, 183)); + }); + + Map indexToExpressionMap = new HashMap<>(); + indexToExpressionMap.put(0, equal.get()); + List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); + /* + List dataTypes = + Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING); + NamedStruct namedStruct = + NamedStruct.of( + columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); + */ + ImmutableNamedStruct id = ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct(Type.Struct.builder().nullable(false).addFields(TypeCreator.NULLABLE.I32, + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.I32, + TypeCreator.REQUIRED.STRING).build()) + .build(); + + ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.builder() + .putAllReferredExpr(indexToExpressionMap) + .baseSchema(id); + + ExtendedExpression proto = new ExtendedExpressionProtoConverter().toProto(builder.build()); + + System.out.println("=======POJO 01======="); + System.out.println("xxxx: " + builder); + System.out.println("=======PROTO 02======="); + System.out.println("yyyy: " + JsonFormat.printer().print(proto)); + + System.out.println("=======POJO 03======="); + io.substrait.extended.expression.ExtendedExpression from = + new ProtoExtendedExpressionConverter().from(proto); + System.out.println("zzzz: " + from); + System.out.println("11111111"); + + + Assertions.assertEquals(from, builder.build()); + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index fa61c5e81..2b35dffef 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -155,6 +155,8 @@ private ExtendedExpression executeInnerSQLExpression( .expressionBuilderMap() .forEach( (k, v) -> { + System.out.println("k->" + k); + System.out.println("v->" + v); functionArgumentList.add(FunctionArgument.newBuilder().setValue(v).build()); }); @@ -224,6 +226,109 @@ private ExtendedExpression executeInnerSQLExpression( extendedExpressionBuilder.setBaseSchema( namedStruct.toProto(new TypeProtoConverter(functionCollector))); + /* + builder.addAllExtensionUris(uris.values()); + builder.addAllExtensions(extensionList); + */ + + return extendedExpressionBuilder.build(); + } + + private ExtendedExpression executeInnerSQLExpressionPojo( + String sqlExpression, + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) + throws SqlParseException { + ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); + ExtensionCollector functionCollector = new ExtensionCollector(); + RexNode rexNode = + sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); + ResulTraverseRowExpression result = new TraverseRexNode().getRowExpression(rexNode); + io.substrait.proto.Type output = + TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); + List functionArgumentList = new ArrayList<>(); + result + .expressionBuilderMap() + .forEach( + (k, v) -> { + System.out.println("k->" + k); + System.out.println("v->" + v); + functionArgumentList.add(FunctionArgument.newBuilder().setValue(v).build()); + }); + + ScalarFunction.Builder scalarFunctionBuilder = + ScalarFunction.newBuilder() + .setFunctionReference(1) // rel_01 + .setOutputType(output) + .addAllArguments(functionArgumentList); + + Expression.Builder expressionBuilder = + Expression.newBuilder().setScalarFunction(scalarFunctionBuilder); + + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expressionBuilder) + .addOutputNames(result.ref().getName()); + + extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); + + io.substrait.expression.Expression.ScalarFunctionInvocation func = + (io.substrait.expression.Expression.ScalarFunctionInvocation) + rexNode.accept(rexExpressionConverter); + String declaration = func.declaration().key(); // values example: gt:any_any, add:i64_i64 + + // this is not mandatory to be defined; it is working without this definition. It is + // only created here to create a proto message that has the correct semantics + HashMap extensionUris = new HashMap<>(); + SimpleExtensionURI simpleExtensionURI; + try { + simpleExtensionURI = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) // rel_02 + .setUri( + SimpleExtension.loadDefaults().scalarFunctions().stream() + .filter(s -> s.toString().equalsIgnoreCase(declaration)) + .findFirst() + .orElseThrow( + () -> + new IllegalArgumentException( + String.format("Failed to get URI resource for %s.", declaration))) + .uri()) + .build(); + } catch (IOException e) { + throw new RuntimeException(e); + } + extensionUris.put("uri", simpleExtensionURI); + + ArrayList extensions = new ArrayList<>(); + SimpleExtensionDeclaration extensionFunctionLowerThan = + SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(scalarFunctionBuilder.getFunctionReference()) // rel_01 + .setName(declaration) + .setExtensionUriReference(simpleExtensionURI.getExtensionUriAnchor())) // rel_02 + .build(); + extensions.add(extensionFunctionLowerThan); + + System.out.println( + "extendedExpressionBuilder.getExtensionUrisList(): " + + extendedExpressionBuilder.getExtensionUrisList()); + // adding it for semantic purposes, it is not mandatory or needed + extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); + extendedExpressionBuilder.addAllExtensions(extensions); + + NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); + extendedExpressionBuilder.setBaseSchema( + namedStruct.toProto(new TypeProtoConverter(functionCollector))); + + /* + builder.addAllExtensionUris(uris.values()); + builder.addAllExtensions(extensionList); + */ + return extendedExpressionBuilder.build(); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index 3bee0b61e..1c3742b52 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -4,11 +4,14 @@ import com.google.common.io.Resources; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.util.JsonFormat; +import io.substrait.extended.expression.ExtendedExpressionProtoConverter; +import io.substrait.extended.expression.ProtoExtendedExpressionConverter; import io.substrait.proto.ExtendedExpression; import java.io.IOException; import java.util.Arrays; import java.util.List; import org.apache.calcite.sql.parser.SqlParseException; +import org.junit.jupiter.api.Assertions; public class ExtendedExpressionTestBase { public static String asString(String resource) throws IOException { @@ -33,7 +36,7 @@ protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query, } protected ExtendedExpression assertProtoExtendedExpressionRoundrip( - String query, SqlToSubstrait s, List creates) throws SqlParseException { + String query, SqlToSubstrait s, List creates) throws SqlParseException, IOException { io.substrait.proto.ExtendedExpression protoExtendedExpression = s.executeSQLExpression(query, creates); @@ -41,6 +44,12 @@ protected ExtendedExpression assertProtoExtendedExpressionRoundrip( String ee = JsonFormat.printer().print(protoExtendedExpression); System.out.println("Proto Extended Expression: \n" + ee); + io.substrait.extended.expression.ExtendedExpression from = + new ProtoExtendedExpressionConverter().from(protoExtendedExpression); + + ExtendedExpression proto = new ExtendedExpressionProtoConverter().toProto(from); + + Assertions.assertEquals(proto, protoExtendedExpression); // FIXME! Implement test validation as the same as proto Plan implementation } catch (InvalidProtocolBufferException e) { throw new RuntimeException(e); diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index 8b0248ca4..5a407a936 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -10,7 +10,13 @@ public class SimpleExtendedExpressionsTest extends ExtendedExpressionTestBase { @Test public void filter() throws IOException, SqlParseException { ExtendedExpression extendedExpression = - assertProtoExtendedExpressionRoundrip("L_ORDERKEY > 10"); + assertProtoExtendedExpressionRoundrip("N_NATIONKEY > 10"); + } + + @Test + public void projection() throws IOException, SqlParseException { + ExtendedExpression extendedExpression = + assertProtoExtendedExpressionRoundrip("L_ORDERKEY + 10"); } @Test diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 03d5f63ad..0121fae9c 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -31,7 +31,7 @@ public class ExtendedExpressionIntegrationTest { @Test public void filterDataset() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); - String sqlExpression = "N_NATIONKEY > 20"; + String sqlExpression = "N_REGIONKEY > 20"; ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) diff --git a/isthmus/src/test/resources/tpch/schema.sql b/isthmus/src/test/resources/tpch/schema.sql index 81f6f927b..b8fb4cfd0 100644 --- a/isthmus/src/test/resources/tpch/schema.sql +++ b/isthmus/src/test/resources/tpch/schema.sql @@ -1,77 +1,6 @@ -CREATE TABLE PART ( - P_PARTKEY BIGINT NOT NULL, - P_NAME VARCHAR(55), - P_MFGR CHAR(25), - P_BRAND CHAR(10), - P_TYPE VARCHAR(25), - P_SIZE INTEGER, - P_CONTAINER CHAR(10), - P_RETAILPRICE DECIMAL, - P_COMMENT VARCHAR(23) -); -CREATE TABLE SUPPLIER ( - S_SUPPKEY BIGINT NOT NULL, - S_NAME CHAR(25), - S_ADDRESS VARCHAR(40), - S_NATIONKEY BIGINT NOT NULL, - S_PHONE CHAR(15), - S_ACCTBAL DECIMAL, - S_COMMENT VARCHAR(101) -); -CREATE TABLE PARTSUPP ( - PS_PARTKEY BIGINT NOT NULL, - PS_SUPPKEY BIGINT NOT NULL, - PS_AVAILQTY INTEGER, - PS_SUPPLYCOST DECIMAL, - PS_COMMENT VARCHAR(199) -); -CREATE TABLE CUSTOMER ( - C_CUSTKEY BIGINT NOT NULL, - C_NAME VARCHAR(25), - C_ADDRESS VARCHAR(40), - C_NATIONKEY BIGINT NOT NULL, - C_PHONE CHAR(15), - C_ACCTBAL DECIMAL, - C_MKTSEGMENT CHAR(10), - C_COMMENT VARCHAR(117) -); -CREATE TABLE ORDERS ( - O_ORDERKEY BIGINT NOT NULL, - O_CUSTKEY BIGINT NOT NULL, - O_ORDERSTATUS CHAR(1), - O_TOTALPRICE DECIMAL, - O_ORDERDATE DATE, - O_ORDERPRIORITY CHAR(15), - O_CLERK CHAR(15), - O_SHIPPRIORITY INTEGER, - O_COMMENT VARCHAR(79) -); -CREATE TABLE LINEITEM ( - L_ORDERKEY BIGINT NOT NULL, - L_PARTKEY BIGINT NOT NULL, - L_SUPPKEY BIGINT NOT NULL, - L_LINENUMBER INTEGER, - L_QUANTITY DECIMAL, - L_EXTENDEDPRICE DECIMAL, - L_DISCOUNT DECIMAL, - L_TAX DECIMAL, - L_RETURNFLAG CHAR(1), - L_LINESTATUS CHAR(1), - L_SHIPDATE DATE, - L_COMMITDATE DATE, - L_RECEIPTDATE DATE, - L_SHIPINSTRUCT CHAR(25), - L_SHIPMODE CHAR(10), - L_COMMENT VARCHAR(44) -); CREATE TABLE NATION ( N_NATIONKEY BIGINT NOT NULL, N_NAME CHAR(25), N_REGIONKEY BIGINT NOT NULL, N_COMMENT VARCHAR(152) ); -CREATE TABLE REGION ( - R_REGIONKEY BIGINT NOT NULL, - R_NAME CHAR(25), - R_COMMENT VARCHAR(152) -); From 75e4f48621a5ed589c484531d5f8d4409aa3b943 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Fri, 24 Nov 2023 16:56:10 -0500 Subject: [PATCH 11/35] feat: enable support from/to pojo/protobuf for extended expressions --- .../expression/ExtendedExpression.java | 28 +++++++ .../ExtendedExpressionProtoConverter.java | 50 +++++++++++ .../ProtoExtendedExpressionConverter.java | 84 +++++++++++++++++++ .../extension/ExtensionCollector.java | 25 +++++- .../extension/ImmutableExtensionLookup.java | 13 +-- .../io/substrait/plan/ProtoPlanConverter.java | 5 +- .../ExtendedExpressionProtoConverterTest.java | 73 ++++++++++++++++ .../ProtoExtendedExpressionConverterTest.java | 80 ++++++++++++++++++ 8 files changed, 349 insertions(+), 9 deletions(-) create mode 100644 core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java create mode 100644 core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java create mode 100644 core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java create mode 100644 core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java create mode 100644 core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java new file mode 100644 index 000000000..4f705e82c --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java @@ -0,0 +1,28 @@ +package io.substrait.extended.expression; + +import io.substrait.expression.Expression; +import io.substrait.proto.AdvancedExtension; +import io.substrait.type.NamedStruct; +import java.util.List; +import java.util.Optional; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class ExtendedExpression { + public abstract List getReferredExpr(); + + public abstract NamedStruct getBaseSchema(); + + public abstract List getExpectedTypeUrls(); + + // creating simple extensions, such as extensionURIs and extensions, is performed on the fly + + public abstract Optional getAdvancedExtension(); + + @Value.Immutable + public abstract static class ExpressionReference { + public abstract Expression getReferredExpr(); + + public abstract List getOutputNames(); + } +} diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java new file mode 100644 index 000000000..6ace03df4 --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java @@ -0,0 +1,50 @@ +package io.substrait.extended.expression; + +import io.substrait.expression.Expression; +import io.substrait.expression.proto.ExpressionProtoConverter; +import io.substrait.extension.ExtensionCollector; +import io.substrait.proto.ExpressionReference; +import io.substrait.proto.ExtendedExpression; +import io.substrait.type.proto.TypeProtoConverter; + +/** + * Converts from {@link io.substrait.extended.expression.ExtendedExpression} to {@link + * ExtendedExpression} + */ +public class ExtendedExpressionProtoConverter { + public ExtendedExpression toProto( + io.substrait.extended.expression.ExtendedExpression extendedExpression) { + + ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); + ExtensionCollector functionCollector = new ExtensionCollector(); + + final ExpressionProtoConverter expressionProtoConverter = + new ExpressionProtoConverter(functionCollector, null); + + for (io.substrait.extended.expression.ExtendedExpression.ExpressionReference + expressionReference : extendedExpression.getReferredExpr()) { + + io.substrait.proto.Expression expressionProto = + expressionProtoConverter.visit( + (Expression.ScalarFunctionInvocation) expressionReference.getReferredExpr()); + + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expressionProto) + .addAllOutputNames(expressionReference.getOutputNames()); + + extendedExpressionBuilder.addReferredExpr(expressionReferenceBuilder); + } + extendedExpressionBuilder.setBaseSchema( + extendedExpression.getBaseSchema().toProto(new TypeProtoConverter(functionCollector))); + + // the process of adding simple extensions, such as extensionURIs and extensions, is handled on + // the fly + functionCollector.addExtensionsToExtendedExpression(extendedExpressionBuilder); + if (extendedExpression.getAdvancedExtension().isPresent()) { + extendedExpressionBuilder.setAdvancedExtensions( + extendedExpression.getAdvancedExtension().get()); + } + return extendedExpressionBuilder.build(); + } +} diff --git a/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java new file mode 100644 index 000000000..14c82b5ac --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java @@ -0,0 +1,84 @@ +package io.substrait.extended.expression; + +import io.substrait.expression.Expression; +import io.substrait.expression.proto.ProtoExpressionConverter; +import io.substrait.extension.*; +import io.substrait.proto.ExpressionReference; +import io.substrait.proto.NamedStruct; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.Type; +import io.substrait.type.proto.ProtoTypeConverter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** Converts from {@link io.substrait.proto.ExtendedExpression} to {@link ExtendedExpression} */ +public class ProtoExtendedExpressionConverter { + private final SimpleExtension.ExtensionCollection extensionCollection; + + public ProtoExtendedExpressionConverter() throws IOException { + this(SimpleExtension.loadDefaults()); + } + + public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection extensionCollection) { + this.extensionCollection = extensionCollection; + } + + private final ProtoTypeConverter protoTypeConverter = + new ProtoTypeConverter( + new ExtensionCollector(), ImmutableSimpleExtension.ExtensionCollection.builder().build()); + + public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExpression) { + // fill in simple extension information through a discovery in the current proto-extended + // expression + ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder() + .from(extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList()) + .build(); + + NamedStruct baseSchemaProto = extendedExpression.getBaseSchema(); + io.substrait.type.NamedStruct namedStruct = convertNamedStrutProtoToPojo(baseSchemaProto); + + ProtoExpressionConverter protoExpressionConverter = + new ProtoExpressionConverter( + functionLookup, this.extensionCollection, namedStruct.struct(), null); + + List expressionReferences = new ArrayList<>(); + for (ExpressionReference expressionReference : extendedExpression.getReferredExprList()) { + Expression expressionPojo = + protoExpressionConverter.from(expressionReference.getExpression()); + expressionReferences.add( + ImmutableExpressionReference.builder() + .referredExpr(expressionPojo) + .addAllOutputNames(expressionReference.getOutputNamesList()) + .build()); + } + + ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.builder() + .referredExpr(expressionReferences) + .advancedExtension( + Optional.ofNullable( + extendedExpression.hasAdvancedExtensions() + ? extendedExpression.getAdvancedExtensions() + : null)) + .baseSchema(namedStruct); + return builder.build(); + } + + private io.substrait.type.NamedStruct convertNamedStrutProtoToPojo(NamedStruct namedStruct) { + var struct = namedStruct.getStruct(); + return ImmutableNamedStruct.builder() + .names(namedStruct.getNamesList()) + .struct( + Type.Struct.builder() + .fields( + struct.getTypesList().stream() + .map(protoTypeConverter::from) + .collect(java.util.stream.Collectors.toList())) + .nullable(ProtoTypeConverter.isNullable(struct.getNullability())) + .build()) + .build(); + } +} diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index bcdd969d4..402a8c94c 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -1,5 +1,7 @@ package io.substrait.extension; +import com.github.bsideup.jabel.Desugar; +import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import io.substrait.proto.SimpleExtensionURI; @@ -51,6 +53,20 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) { } public void addExtensionsToPlan(Plan.Builder builder) { + SimpleExtensions simpleExtensions = getExtensions(); + + builder.addAllExtensionUris(simpleExtensions.uris().values()); + builder.addAllExtensions(simpleExtensions.extensionList()); + } + + public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) { + SimpleExtensions simpleExtensions = getExtensions(); + + builder.addAllExtensionUris(simpleExtensions.uris().values()); + builder.addAllExtensions(simpleExtensions.extensionList()); + } + + private SimpleExtensions getExtensions() { var uriPos = new AtomicInteger(1); var uris = new HashMap(); @@ -93,11 +109,14 @@ public void addExtensionsToPlan(Plan.Builder builder) { .build(); extensionList.add(decl); } - - builder.addAllExtensionUris(uris.values()); - builder.addAllExtensions(extensionList); + return new SimpleExtensions(uris, extensionList); } + @Desugar + private record SimpleExtensions( + HashMap uris, + ArrayList extensionList) {} + /** We don't depend on guava... */ private static class BidiMap { private final Map forwardMap; diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index 2cab2cce8..c88bafc1c 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -1,9 +1,10 @@ package io.substrait.extension; -import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; +import io.substrait.proto.SimpleExtensionURI; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; /** @@ -30,14 +31,16 @@ public static class Builder { private final Map functionMap = new HashMap<>(); private final Map typeMap = new HashMap<>(); - public Builder from(Plan p) { + public Builder from( + List simpleExtensionURIs, + List simpleExtensionDeclarations) { Map namespaceMap = new HashMap<>(); - for (var extension : p.getExtensionUrisList()) { + for (var extension : simpleExtensionURIs) { namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri()); } // Add all functions used in plan to the functionMap - for (var extension : p.getExtensionsList()) { + for (var extension : simpleExtensionDeclarations) { if (!extension.hasExtensionFunction()) { continue; } @@ -54,7 +57,7 @@ public Builder from(Plan p) { } // Add all types used in plan to the typeMap - for (var extension : p.getExtensionsList()) { + for (var extension : simpleExtensionDeclarations) { if (!extension.hasExtensionType()) { continue; } diff --git a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java index be4f4ad9f..7222eb7ed 100644 --- a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java +++ b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java @@ -32,7 +32,10 @@ protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) } public Plan from(io.substrait.proto.Plan plan) { - ExtensionLookup functionLookup = ImmutableExtensionLookup.builder().from(plan).build(); + ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder() + .from(plan.getExtensionUrisList(), plan.getExtensionsList()) + .build(); ProtoRelConverter relConverter = getProtoRelConverter(functionLookup); List roots = new ArrayList<>(); for (PlanRel planRel : plan.getRelationsList()) { diff --git a/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java b/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java new file mode 100644 index 000000000..20079e24f --- /dev/null +++ b/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java @@ -0,0 +1,73 @@ +package io.substrait.extended.expression; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.junit.jupiter.api.Test; + +public class ExtendedExpressionProtoConverterTest extends TestBase { + @Test + public void toProtoTest() { + // create predefined POJO extended expression + Optional scalarFunctionExpression = + defaultExtensionCollection.scalarFunctions().stream() + .filter(s -> s.name().equalsIgnoreCase("add")) + .findFirst() + .map( + declaration -> + ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183))); + + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .referredExpr(scalarFunctionExpression.get()) + .addOutputNames("new-column") + .build(); + + List expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); + + ImmutableNamedStruct namedStruct = + ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.NULLABLE.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); + + ImmutableExtendedExpression.Builder extendedExpression = + ImmutableExtendedExpression.builder() + .referredExpr(expressionReferences) + .baseSchema(namedStruct); + + // convert POJO extended expression into PROTOBUF extended expression + io.substrait.proto.ExtendedExpression proto = + new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); + + assertEquals( + "/functions_arithmetic_decimal.yaml", proto.getExtensionUrisList().get(0).getUri()); + assertEquals("add:dec_dec", proto.getExtensionsList().get(0).getExtensionFunction().getName()); + } +} diff --git a/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java b/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java new file mode 100644 index 000000000..9ab84f274 --- /dev/null +++ b/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java @@ -0,0 +1,80 @@ +package io.substrait.extended.expression; + +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.proto.ExtendedExpression; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ProtoExtendedExpressionConverterTest extends TestBase { + @Test + public void fromTest() throws IOException { + // create predefined POJO extended expression + Optional scalarFunctionExpression = + defaultExtensionCollection.scalarFunctions().stream() + .filter(s -> s.name().equalsIgnoreCase("add")) + .findFirst() + .map( + declaration -> + ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183))); + + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .referredExpr(scalarFunctionExpression.get()) + .addOutputNames("new-column") + .build(); + + List + expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); + + ImmutableNamedStruct namedStruct = + ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); + + // pojo initial extended expression + ImmutableExtendedExpression extendedExpressionPojoInitial = + ImmutableExtendedExpression.builder() + .referredExpr(expressionReferences) + .baseSchema(namedStruct) + .build(); + + // proto extended expression + ExtendedExpression extendedExpressionProto = + new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); + + // pojo final extended expression + io.substrait.extended.expression.ExtendedExpression extendedExpressionPojoFinal = + new ProtoExtendedExpressionConverter().from(extendedExpressionProto); + + // validate extended expression pojo initial equals to final roundtrip + Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); + } +} From 5adc79fe735d2cf0107a04bdacda32a1c4404279 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Fri, 24 Nov 2023 18:53:24 -0500 Subject: [PATCH 12/35] fix: consume core module for proto/pojo conversions --- core/build.gradle.kts | 5 - .../expression/ExpressionReference.java | 11 - .../expression/ExtendedExpression.java | 12 +- .../ExtendedExpressionProtoConverter.java | 206 ++------------ .../extended/expression/MyCliente.java | 75 ----- .../ProtoExtendedExpressionConverter.java | 157 +++------- .../extension/ExtensionCollector.java | 70 ++--- .../extension/ImmutableExtensionLookup.java | 57 +--- .../io/substrait/plan/PlanProtoConverter.java | 4 - .../io/substrait/plan/ProtoPlanConverter.java | 5 +- .../ExtendedExpressionProtoConverterTest.java | 99 ++++--- .../ProtoExtendedExpressionConverterTest.java | 114 ++++---- .../substrait/isthmus/SqlConverterBase.java | 19 +- .../io/substrait/isthmus/SqlToSubstrait.java | 269 ++---------------- .../io/substrait/isthmus/SubstraitToSql.java | 5 +- .../isthmus/ExtendedExpressionTestBase.java | 28 +- .../SimpleExtendedExpressionsTest.java | 2 +- .../ExtendedExpressionIntegrationTest.java | 61 +--- isthmus/src/test/resources/tpch/schema.sql | 71 +++++ 19 files changed, 329 insertions(+), 941 deletions(-) delete mode 100644 core/src/main/java/io/substrait/extended/expression/ExpressionReference.java delete mode 100644 core/src/main/java/io/substrait/extended/expression/MyCliente.java diff --git a/core/build.gradle.kts b/core/build.gradle.kts index c06a00e86..6b8cfac66 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -85,11 +85,6 @@ dependencies { compileOnly("org.immutables:value-annotations:2.8.8") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") - - implementation("com.google.protobuf:protobuf-java-util:3.17.3") { - exclude("com.google.guava", "guava") - .because("Brings in Guava for Android, which we don't want (and breaks multimaps).") - } } java { diff --git a/core/src/main/java/io/substrait/extended/expression/ExpressionReference.java b/core/src/main/java/io/substrait/extended/expression/ExpressionReference.java deleted file mode 100644 index 2214f0438..000000000 --- a/core/src/main/java/io/substrait/extended/expression/ExpressionReference.java +++ /dev/null @@ -1,11 +0,0 @@ -package io.substrait.extended.expression; - -import io.substrait.expression.Expression; -import org.immutables.value.Value; - -@Value.Immutable -public abstract class ExpressionReference { - public abstract Expression getExpression(); - - public abstract String getOutputNames(); -} diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java index e0a5c03f9..4f705e82c 100644 --- a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java @@ -4,17 +4,25 @@ import io.substrait.proto.AdvancedExtension; import io.substrait.type.NamedStruct; import java.util.List; -import java.util.Map; import java.util.Optional; import org.immutables.value.Value; @Value.Immutable public abstract class ExtendedExpression { - public abstract Map getReferredExpr(); + public abstract List getReferredExpr(); public abstract NamedStruct getBaseSchema(); public abstract List getExpectedTypeUrls(); + // creating simple extensions, such as extensionURIs and extensions, is performed on the fly + public abstract Optional getAdvancedExtension(); + + @Value.Immutable + public abstract static class ExpressionReference { + public abstract Expression getReferredExpr(); + + public abstract List getOutputNames(); + } } diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java index e4015f3cf..cffdfefd0 100644 --- a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java @@ -1,22 +1,19 @@ package io.substrait.extended.expression; -import com.google.protobuf.util.JsonFormat; import io.substrait.expression.*; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.extension.ExtensionCollector; -import io.substrait.extension.SimpleExtension; import io.substrait.proto.ExpressionReference; import io.substrait.proto.ExtendedExpression; -import io.substrait.type.NamedStruct; -import io.substrait.type.Type; -import io.substrait.type.TypeCreator; import io.substrait.type.proto.TypeProtoConverter; -import java.io.IOException; -import java.util.*; +/** + * Converts from {@link io.substrait.extended.expression.ExtendedExpression} to {@link + * io.substrait.proto.ExtendedExpression} + */ public class ExtendedExpressionProtoConverter { public ExtendedExpression toProto( - io.substrait.extended.expression.ExtendedExpression extendedExpressionPojo) { + io.substrait.extended.expression.ExtendedExpression extendedExpression) { ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); @@ -24,187 +21,30 @@ public ExtendedExpression toProto( final ExpressionProtoConverter expressionProtoConverter = new ExpressionProtoConverter(functionCollector, null); - // convert expression pojo into expression protobuf - io.substrait.proto.Expression expressionProto = - expressionProtoConverter.visit( - (Expression.ScalarFunctionInvocation) extendedExpressionPojo.getReferredExpr().get(0)); + for (io.substrait.extended.expression.ExtendedExpression.ExpressionReference + expressionReference : extendedExpression.getReferredExpr()) { - ExpressionReference.Builder expressionReferenceBuilder = - ExpressionReference.newBuilder().setExpression(expressionProto).addOutputNames("column-01"); + io.substrait.proto.Expression expressionProto = + expressionProtoConverter.visit( + (Expression.ScalarFunctionInvocation) expressionReference.getReferredExpr()); - extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); - extendedExpressionBuilder.setBaseSchema( - extendedExpressionPojo.getBaseSchema().toProto(new TypeProtoConverter(functionCollector))); + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expressionProto) + .addAllOutputNames(expressionReference.getOutputNames()); + extendedExpressionBuilder.addReferredExpr(expressionReferenceBuilder); + } + extendedExpressionBuilder.setBaseSchema( + extendedExpression.getBaseSchema().toProto(new TypeProtoConverter(functionCollector))); - functionCollector.addExtensionsToPlan(extendedExpressionBuilder); - if (extendedExpressionPojo.getAdvancedExtension().isPresent()) { + // the process of adding simple extensions, such as extensionURIs and extensions, is handled on + // the fly + functionCollector.addExtensionsToExtendedExpression(extendedExpressionBuilder); + if (extendedExpression.getAdvancedExtension().isPresent()) { extendedExpressionBuilder.setAdvancedExtensions( - extendedExpressionPojo.getAdvancedExtension().get()); + extendedExpression.getAdvancedExtension().get()); } return extendedExpressionBuilder.build(); } - - public static void main(String[] args) throws IOException { - SimpleExtension.ExtensionCollection defaultExtensionCollection = SimpleExtension.loadDefaults(); - System.out.println( - "defaultExtensionCollection.scalarFunctions(): " - + defaultExtensionCollection.scalarFunctions()); - System.out.println( - "defaultExtensionCollection.windowFunctions(): " - + defaultExtensionCollection.windowFunctions()); - System.out.println( - "defaultExtensionCollection.aggregateFunctions(): " - + defaultExtensionCollection.aggregateFunctions()); - - Optional equal = - defaultExtensionCollection.scalarFunctions().stream() - .filter( - s -> { - return s.name().equalsIgnoreCase("add"); - }) - .findFirst() - .map( - declaration -> - ExpressionCreator.scalarFunction( - declaration, - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.I32) - .build(), - ExpressionCreator.i32(false, 183))); - - Map indexToExpressionMap = new HashMap<>(); - indexToExpressionMap.put(0, equal.get()); - List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); - List dataTypes = - Arrays.asList( - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING, - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING); - NamedStruct namedStruct = - NamedStruct.of( - columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); - ImmutableExtendedExpression.Builder builder = - ImmutableExtendedExpression.builder() - .putAllReferredExpr(indexToExpressionMap) - .baseSchema(namedStruct); - - ExtendedExpression proto = new ExtendedExpressionProtoConverter().toProto(builder.build()); - - System.out.println( - "JsonFormat.printer().print(getFilterExtendedExpression): " - + JsonFormat.printer().print(proto)); - } - - public static ExtendedExpression createExtendedExpression( - io.substrait.expression.Expression.ScalarFunctionInvocation expr) { - ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); - - io.substrait.proto.Expression expression = new ExpressionProtoConverter(null, null).visit(expr); - ExpressionReference.Builder expressionReferenceBuilder = - ExpressionReference.newBuilder() - .setExpression(expression.toBuilder()) - .addOutputNames("col-01"); - - extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); - - return extendedExpressionBuilder.build(); - } - - public static void createExtendedExpressionManually() { - - Map nameToExpressionMap = new HashMap<>(); - ImmutableExpression.I32Literal build = Expression.I32Literal.builder().value(10).build(); - nameToExpressionMap.put("out_01", build); - - List expressionList = new ArrayList<>(); - expressionList.add(0, null); - - // nation table - List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); - List dataTypes = - Arrays.asList( - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING, - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING); - NamedStruct namedStruct = - NamedStruct.of( - columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); - - ExtensionCollector functionCollector = new ExtensionCollector(); - - new ExpressionProtoConverter(new ExtensionCollector(), null); - - FunctionArg functionArg = - new FunctionArg() { - @Override - public R accept( - SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor) - throws E { - return null; - } - }; - - // var argVisitor = FunctionArg.toProto(new TypeProtoConverter(new ExtensionCollector()), this); - - /* - - // FIXME! setFunctionReference, addArguments(index: 0, 1) - io.substrait.proto.Expression.Builder expressionBuilder = - io.substrait.proto.Expression.newBuilder() - .setScalarFunction( - io.substrait.proto.Expression.ScalarFunction.newBuilder() - .setFunctionReference(1) - .setOutputType(output) - .addArguments( - 0, - FunctionArgument.newBuilder().setValue(result.referenceBuilder())) - .addArguments( - 1, - FunctionArgument.newBuilder() - .setValue(result.expressionBuilderLiteral()))); - io.substrait.proto.ExpressionReference.Builder expressionReferenceBuilder = - ExpressionReference.newBuilder() - .setExpression(expressionBuilder) - .addOutputNames(result.ref().getName()); - - */ - - /* - - io.substrait.extended.expression.ExtendedExpression extendedExpression = new io.substrait.extended.expression.ExtendedExpression() { - @Override - public List getReferredExpr() { - io.substrait.extended.expression.ExpressionReference - - @Override - public NamedStruct getBaseSchema() { - return null; - } - - @Override - public List getExpectedTypeUrls() { - return null; - } - - @Override - public Optional getAdvancedExtension() { - return Optional.empty(); - } - }; - - System.out.println("inicio"); - System.out.println(extendedExpression.getReferredExpr().get(0)); - System.out.println(extendedExpression.getReferredExpr().get(0).getType()); - System.out.println("fin"); - - ExpressionReferenceOrBuilder - - */ - - } } diff --git a/core/src/main/java/io/substrait/extended/expression/MyCliente.java b/core/src/main/java/io/substrait/extended/expression/MyCliente.java deleted file mode 100644 index 0ec8a90ed..000000000 --- a/core/src/main/java/io/substrait/extended/expression/MyCliente.java +++ /dev/null @@ -1,75 +0,0 @@ -package io.substrait.extended.expression; - -import com.google.protobuf.util.JsonFormat; -import io.substrait.expression.Expression; -import io.substrait.expression.ExpressionCreator; -import io.substrait.expression.FieldReference; -import io.substrait.expression.ImmutableFieldReference; -import io.substrait.extension.SimpleExtension; -import io.substrait.type.ImmutableNamedStruct; -import io.substrait.type.NamedStruct; -import io.substrait.type.Type; -import io.substrait.type.TypeCreator; -import java.io.IOException; -import java.util.*; - -public class MyCliente { - public static void main(String[] args) throws IOException { - SimpleExtension.ExtensionCollection defaultExtensionCollection = SimpleExtension.loadDefaults(); - Optional equal = - defaultExtensionCollection.scalarFunctions().stream() - .filter( - s -> { - System.out.println(":>>>>"); - System.out.println(s); - System.out.println(s.uri()); - System.out.println(s.returnType()); - System.out.println(s.description()); - System.out.println("s.name(): " + s.name()); - System.out.println(s.key()); - return s.name().equalsIgnoreCase("add"); - }) - .findFirst() - .map( - declaration -> { - System.out.println("declaration: " + declaration); - System.out.println("declaration.name(): " + declaration.name()); - return ExpressionCreator.scalarFunction( - declaration, - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.I32) - .build(), - ExpressionCreator.i32(false, 183)); - }); - - Map indexToExpressionMap = new HashMap<>(); - indexToExpressionMap.put(0, equal.get()); - List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); - List dataTypes = - Arrays.asList( - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING, - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING); - NamedStruct namedStruct = - NamedStruct.of( - columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); - - ImmutableNamedStruct.builder() - .addNames("id") - .struct(Type.Struct.builder().nullable(false).addFields(TypeCreator.REQUIRED.I32).build()) - .build(); - - ImmutableExtendedExpression.Builder builder = - ImmutableExtendedExpression.builder() - .putAllReferredExpr(indexToExpressionMap) - .baseSchema(namedStruct); - - System.out.println( - "JsonFormat.printer().print(getFilterExtendedExpression): " - + JsonFormat.printer() - .print(new ExtendedExpressionProtoConverter().toProto(builder.build()))); - } -} diff --git a/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java index ed1fc377b..41fc1bef7 100644 --- a/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java @@ -5,38 +5,18 @@ import io.substrait.extension.*; import io.substrait.proto.ExpressionReference; import io.substrait.proto.NamedStruct; -import io.substrait.relation.ProtoRelConverter; import io.substrait.type.ImmutableNamedStruct; import io.substrait.type.Type; import io.substrait.type.proto.ProtoTypeConverter; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; +import java.util.*; +/** + * Converts from {@link io.substrait.proto.ExtendedExpression} to {@link + * io.substrait.extended.expression.ExtendedExpression} + */ public class ProtoExtendedExpressionConverter { - private ExtensionCollector lookup = new ExtensionCollector(); - private ProtoTypeConverter protoTypeConverter = - new ProtoTypeConverter( - lookup, ImmutableSimpleExtension.ExtensionCollection.builder().build()); - - private ProtoExpressionConverter getPprotoExpressionConverter(ExtensionLookup functionLookup) { - return new ProtoExpressionConverter( - functionLookup, - this.extensionCollection, - null, - null); - } - - private ProtoExpressionConverter getPprotoExpressionConverter(ExtensionLookup functionLookup, io.substrait.type.NamedStruct namedStruct) { - return new ProtoExpressionConverter( - functionLookup, - this.extensionCollection, - namedStruct.struct(), - null); - } - - protected final SimpleExtension.ExtensionCollection extensionCollection; + private final SimpleExtension.ExtensionCollection extensionCollection; public ProtoExtendedExpressionConverter() throws IOException { this(SimpleExtension.loadDefaults()); @@ -46,120 +26,49 @@ public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection exte this.extensionCollection = extensionCollection; } - protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) { - return new ProtoRelConverter(functionLookup, this.extensionCollection); - } + private final ProtoTypeConverter protoTypeConverter = + new ProtoTypeConverter( + new ExtensionCollector(), ImmutableSimpleExtension.ExtensionCollection.builder().build()); - public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExpressionProto) { + public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExpression) { + // fill in simple extension information through a discovery in the current proto-extended + // expression ExtensionLookup functionLookup = - ImmutableExtensionLookup.builder().from(extendedExpressionProto).build(); - + ImmutableExtensionLookup.builder() + .from(extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList()) + .build(); - // para struct - NamedStruct baseSchema = extendedExpressionProto.getBaseSchema(); - io.substrait.type.NamedStruct namedStruct = newNamedStruct(baseSchema); - - System.out.println("namedStruct"); - System.out.println(namedStruct); + NamedStruct baseSchemaProto = extendedExpression.getBaseSchema(); + io.substrait.type.NamedStruct namedStruct = convertNamedStrutProtoToPojo(baseSchemaProto); ProtoExpressionConverter protoExpressionConverter = - getPprotoExpressionConverter(functionLookup, namedStruct); - - Map indexToExpressionMap = new HashMap<>(); - for (ExpressionReference expressionReference : extendedExpressionProto.getReferredExprList()) { - System.out.println( - "expressionReference.getExpression(): " + expressionReference.getExpression()); - indexToExpressionMap.put( - 0, protoExpressionConverter.from(expressionReference.getExpression())); + new ProtoExpressionConverter( + functionLookup, this.extensionCollection, namedStruct.struct(), null); + + List expressionReferences = new ArrayList<>(); + for (ExpressionReference expressionReference : extendedExpression.getReferredExprList()) { + Expression expressionPojo = + protoExpressionConverter.from(expressionReference.getExpression()); + expressionReferences.add( + ImmutableExpressionReference.builder() + .referredExpr(expressionPojo) + .addAllOutputNames(expressionReference.getOutputNamesList()) + .build()); } - // para struct - /* - NamedStruct baseSchema = extendedExpressionProto.getBaseSchema(); - io.substrait.type.NamedStruct namedStruct = newNamedStruct(baseSchema); - - System.out.println("namedStruct"); - System.out.println(namedStruct); - - */ - ImmutableExtendedExpression.Builder builder = ImmutableExtendedExpression.builder() - .putAllReferredExpr(indexToExpressionMap) + .referredExpr(expressionReferences) .advancedExtension( Optional.ofNullable( - extendedExpressionProto.hasAdvancedExtensions() - ? extendedExpressionProto.getAdvancedExtensions() + extendedExpression.hasAdvancedExtensions() + ? extendedExpression.getAdvancedExtensions() : null)) .baseSchema(namedStruct); - /* - ProtocolStringList namesList = baseSchema.getNamesList(); - - Type.Struct struct = baseSchema.getStruct(); - Type types = struct.getTypes(0); - System.out.println("types.getDescriptorForType().getName(): " + types.getDescriptorForType().); - - - */ - - /* - System.out.println("namesList: " + namesList); - System.out.println("baseSchema.getStruct(): " + baseSchema.getStruct()); - System.out.println("}}{{{{{{{{{{''------>"); - System.out.println("baseSchema.getStruct(): " + baseSchema.getStruct().getTypes(0)); - - - */ - - /* - ImmutableNamedStruct.builder(). - - // para expression - - Optional equal = - defaultExtensionCollection.scalarFunctions().stream() - .filter( - s -> { - System.out.println(":>>>>"); - System.out.println(s); - System.out.println(s.uri()); - System.out.println(s.returnType()); - System.out.println(s.description()); - System.out.println("s.name(): " + s.name()); - System.out.println(s.key()); - return s.name().equalsIgnoreCase("add"); - }) - .findFirst() - .map( - declaration -> { - System.out.println("declaration: " + declaration); - System.out.println("declaration.name(): " + declaration.name()); - return ExpressionCreator.scalarFunction( - declaration, - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.I32) - .build(), - ExpressionCreator.i32(false, 183) - ); - } - ); - - Map indexToExpressionMap = new HashMap<>(); - indexToExpressionMap.put(0, equal.get()); - - ImmutableExtendedExpression.Builder builder = - ImmutableExtendedExpression.builder() - .putAllReferredExpr(indexToExpressionMap) - .baseSchema(namedStruct); - - */ - return builder.build(); } - private io.substrait.type.NamedStruct newNamedStruct(NamedStruct namedStruct) { + private io.substrait.type.NamedStruct convertNamedStrutProtoToPojo(NamedStruct namedStruct) { var struct = namedStruct.getStruct(); return ImmutableNamedStruct.builder() .names(namedStruct.getNamesList()) diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index f2a4a6f18..402a8c94c 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -1,5 +1,6 @@ package io.substrait.extension; +import com.github.bsideup.jabel.Desugar; import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; @@ -52,6 +53,20 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) { } public void addExtensionsToPlan(Plan.Builder builder) { + SimpleExtensions simpleExtensions = getExtensions(); + + builder.addAllExtensionUris(simpleExtensions.uris().values()); + builder.addAllExtensions(simpleExtensions.extensionList()); + } + + public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) { + SimpleExtensions simpleExtensions = getExtensions(); + + builder.addAllExtensionUris(simpleExtensions.uris().values()); + builder.addAllExtensions(simpleExtensions.extensionList()); + } + + private SimpleExtensions getExtensions() { var uriPos = new AtomicInteger(1); var uris = new HashMap(); @@ -94,58 +109,13 @@ public void addExtensionsToPlan(Plan.Builder builder) { .build(); extensionList.add(decl); } - - builder.addAllExtensionUris(uris.values()); - builder.addAllExtensions(extensionList); + return new SimpleExtensions(uris, extensionList); } - public void addExtensionsToPlan(ExtendedExpression.Builder builder) { - var uriPos = new AtomicInteger(1); - var uris = new HashMap(); - - var extensionList = new ArrayList(); - for (var e : funcMap.forwardMap.entrySet()) { - SimpleExtensionURI uri = - uris.computeIfAbsent( - e.getValue().namespace(), - k -> - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(uriPos.getAndIncrement()) - .setUri(k) - .build()); - var decl = - SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(e.getKey()) - .setName(e.getValue().key()) - .setExtensionUriReference(uri.getExtensionUriAnchor())) - .build(); - extensionList.add(decl); - } - for (var e : typeMap.forwardMap.entrySet()) { - SimpleExtensionURI uri = - uris.computeIfAbsent( - e.getValue().namespace(), - k -> - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(uriPos.getAndIncrement()) - .setUri(k) - .build()); - var decl = - SimpleExtensionDeclaration.newBuilder() - .setExtensionType( - SimpleExtensionDeclaration.ExtensionType.newBuilder() - .setTypeAnchor(e.getKey()) - .setName(e.getValue().key()) - .setExtensionUriReference(uri.getExtensionUriAnchor())) - .build(); - extensionList.add(decl); - } - - builder.addAllExtensionUris(uris.values()); - builder.addAllExtensions(extensionList); - } + @Desugar + private record SimpleExtensions( + HashMap uris, + ArrayList extensionList) {} /** We don't depend on guava... */ private static class BidiMap { diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index 3d600002b..c88bafc1c 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -1,10 +1,10 @@ package io.substrait.extension; -import io.substrait.proto.ExtendedExpression; -import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; +import io.substrait.proto.SimpleExtensionURI; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; /** @@ -31,14 +31,16 @@ public static class Builder { private final Map functionMap = new HashMap<>(); private final Map typeMap = new HashMap<>(); - public Builder from(Plan p) { + public Builder from( + List simpleExtensionURIs, + List simpleExtensionDeclarations) { Map namespaceMap = new HashMap<>(); - for (var extension : p.getExtensionUrisList()) { + for (var extension : simpleExtensionURIs) { namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri()); } // Add all functions used in plan to the functionMap - for (var extension : p.getExtensionsList()) { + for (var extension : simpleExtensionDeclarations) { if (!extension.hasExtensionFunction()) { continue; } @@ -55,50 +57,7 @@ public Builder from(Plan p) { } // Add all types used in plan to the typeMap - for (var extension : p.getExtensionsList()) { - if (!extension.hasExtensionType()) { - continue; - } - SimpleExtensionDeclaration.ExtensionType type = extension.getExtensionType(); - int reference = type.getTypeAnchor(); - String namespace = namespaceMap.get(type.getExtensionUriReference()); - if (namespace == null) { - throw new IllegalStateException( - "Could not find extension URI of " + type.getExtensionUriReference()); - } - String name = type.getName(); - SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(namespace, name); - typeMap.put(reference, anchor); - } - - return this; - } - - public Builder from(ExtendedExpression p) { - Map namespaceMap = new HashMap<>(); - for (var extension : p.getExtensionUrisList()) { - namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri()); - } - - // Add all functions used in plan to the functionMap - for (var extension : p.getExtensionsList()) { - if (!extension.hasExtensionFunction()) { - continue; - } - SimpleExtensionDeclaration.ExtensionFunction func = extension.getExtensionFunction(); - int reference = func.getFunctionAnchor(); - String namespace = namespaceMap.get(func.getExtensionUriReference()); - if (namespace == null) { - throw new IllegalStateException( - "Could not find extension URI of " + func.getExtensionUriReference()); - } - String name = func.getName(); - SimpleExtension.FunctionAnchor anchor = SimpleExtension.FunctionAnchor.of(namespace, name); - functionMap.put(reference, anchor); - } - - // Add all types used in plan to the typeMap - for (var extension : p.getExtensionsList()) { + for (var extension : simpleExtensionDeclarations) { if (!extension.hasExtensionType()) { continue; } diff --git a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java index af0f6d69a..0bdf7d68c 100644 --- a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java +++ b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java @@ -34,10 +34,6 @@ public Plan toProto(io.substrait.plan.Plan plan) { if (plan.getAdvancedExtension().isPresent()) { builder.setAdvancedExtensions(plan.getAdvancedExtension().get()); } - /* - extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); - extendedExpressionBuilder.addAllExtensions(extensions); - */ return builder.build(); } } diff --git a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java index be4f4ad9f..7222eb7ed 100644 --- a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java +++ b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java @@ -32,7 +32,10 @@ protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) } public Plan from(io.substrait.proto.Plan plan) { - ExtensionLookup functionLookup = ImmutableExtensionLookup.builder().from(plan).build(); + ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder() + .from(plan.getExtensionUrisList(), plan.getExtensionsList()) + .build(); ProtoRelConverter relConverter = getProtoRelConverter(functionLookup); List roots = new ArrayList<>(); for (PlanRel planRel : plan.getRelationsList()) { diff --git a/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java b/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java index 4004e9de7..20079e24f 100644 --- a/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java +++ b/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java @@ -1,76 +1,73 @@ package io.substrait.extended.expression; -import com.google.protobuf.util.JsonFormat; +import static org.junit.jupiter.api.Assertions.assertEquals; + import io.substrait.TestBase; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FieldReference; import io.substrait.expression.ImmutableFieldReference; import io.substrait.type.ImmutableNamedStruct; -import io.substrait.type.NamedStruct; import io.substrait.type.Type; import io.substrait.type.TypeCreator; -import java.io.IOException; -import java.util.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; import org.junit.jupiter.api.Test; public class ExtendedExpressionProtoConverterTest extends TestBase { @Test - public void toProtoTest() throws IOException { - Optional equal = + public void toProtoTest() { + // create predefined POJO extended expression + Optional scalarFunctionExpression = defaultExtensionCollection.scalarFunctions().stream() - .filter( - s -> { - System.out.println(":>>>>"); - System.out.println(s); - System.out.println(s.uri()); - System.out.println(s.returnType()); - System.out.println(s.description()); - System.out.println("s.name(): " + s.name()); - System.out.println(s.key()); - return s.name().equalsIgnoreCase("add"); - }) + .filter(s -> s.name().equalsIgnoreCase("add")) .findFirst() .map( - declaration -> { - System.out.println("declaration: " + declaration); - System.out.println("declaration.name(): " + declaration.name()); - return ExpressionCreator.scalarFunction( - declaration, - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.I32) - .build(), - ExpressionCreator.i32(false, 183)); - }); + declaration -> + ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183))); + + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .referredExpr(scalarFunctionExpression.get()) + .addOutputNames("new-column") + .build(); - Map indexToExpressionMap = new HashMap<>(); - indexToExpressionMap.put(0, equal.get()); - List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); - List dataTypes = - Arrays.asList( - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING, - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING); - NamedStruct namedStruct = - NamedStruct.of( - columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); + List expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); - ImmutableNamedStruct.builder() - .addNames("id") - .struct(Type.Struct.builder().nullable(false).addFields(TypeCreator.REQUIRED.I32).build()) - .build(); + ImmutableNamedStruct namedStruct = + ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.NULLABLE.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); - ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.Builder extendedExpression = ImmutableExtendedExpression.builder() - .putAllReferredExpr(indexToExpressionMap) + .referredExpr(expressionReferences) .baseSchema(namedStruct); - System.out.println( - "JsonFormat.printer().print(getFilterExtendedExpression): " - + JsonFormat.printer() - .print(new ExtendedExpressionProtoConverter().toProto(builder.build()))); + // convert POJO extended expression into PROTOBUF extended expression + io.substrait.proto.ExtendedExpression proto = + new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); + + assertEquals( + "/functions_arithmetic_decimal.yaml", proto.getExtensionUrisList().get(0).getUri()); + assertEquals("add:dec_dec", proto.getExtensionsList().get(0).getExtensionFunction().getName()); } } diff --git a/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java b/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java index 1c1fa40b5..9ab84f274 100644 --- a/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java +++ b/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java @@ -1,6 +1,5 @@ package io.substrait.extended.expression; -import com.google.protobuf.util.JsonFormat; import io.substrait.TestBase; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; @@ -8,87 +7,74 @@ import io.substrait.expression.ImmutableFieldReference; import io.substrait.proto.ExtendedExpression; import io.substrait.type.ImmutableNamedStruct; -import io.substrait.type.NamedStruct; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.io.IOException; -import java.util.*; - +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; public class ProtoExtendedExpressionConverterTest extends TestBase { @Test public void fromTest() throws IOException { - Optional equal = + // create predefined POJO extended expression + Optional scalarFunctionExpression = defaultExtensionCollection.scalarFunctions().stream() - .filter( - s -> { - System.out.println(":>>>>"); - System.out.println(s); - System.out.println(s.uri()); - System.out.println(s.returnType()); - System.out.println(s.description()); - System.out.println("s.name(): " + s.name()); - System.out.println(s.key()); - return s.name().equalsIgnoreCase("add"); - }) + .filter(s -> s.name().equalsIgnoreCase("add")) .findFirst() .map( - declaration -> { - System.out.println("declaration: " + declaration); - System.out.println("declaration.name(): " + declaration.name()); - return ExpressionCreator.scalarFunction( - declaration, - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.I32) - .build(), - ExpressionCreator.i32(false, 183)); - }); + declaration -> + ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183))); - Map indexToExpressionMap = new HashMap<>(); - indexToExpressionMap.put(0, equal.get()); - List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); - /* - List dataTypes = - Arrays.asList( - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING, - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING); - NamedStruct namedStruct = - NamedStruct.of( - columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); - */ - ImmutableNamedStruct id = ImmutableNamedStruct.builder() - .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") - .struct(Type.Struct.builder().nullable(false).addFields(TypeCreator.NULLABLE.I32, - TypeCreator.REQUIRED.STRING, - TypeCreator.REQUIRED.I32, - TypeCreator.REQUIRED.STRING).build()) - .build(); + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .referredExpr(scalarFunctionExpression.get()) + .addOutputNames("new-column") + .build(); - ImmutableExtendedExpression.Builder builder = - ImmutableExtendedExpression.builder() - .putAllReferredExpr(indexToExpressionMap) - .baseSchema(id); + List + expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); - ExtendedExpression proto = new ExtendedExpressionProtoConverter().toProto(builder.build()); + ImmutableNamedStruct namedStruct = + ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); - System.out.println("=======POJO 01======="); - System.out.println("xxxx: " + builder); - System.out.println("=======PROTO 02======="); - System.out.println("yyyy: " + JsonFormat.printer().print(proto)); + // pojo initial extended expression + ImmutableExtendedExpression extendedExpressionPojoInitial = + ImmutableExtendedExpression.builder() + .referredExpr(expressionReferences) + .baseSchema(namedStruct) + .build(); - System.out.println("=======POJO 03======="); - io.substrait.extended.expression.ExtendedExpression from = - new ProtoExtendedExpressionConverter().from(proto); - System.out.println("zzzz: " + from); - System.out.println("11111111"); + // proto extended expression + ExtendedExpression extendedExpressionProto = + new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); + // pojo final extended expression + io.substrait.extended.expression.ExtendedExpression extendedExpressionPojoFinal = + new ProtoExtendedExpressionConverter().from(extendedExpressionProto); - Assertions.assertEquals(from, builder.build()); + // validate extended expression pojo initial equals to final roundtrip + Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index 716fdb66e..67ae9a8ea 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -89,7 +89,24 @@ protected SqlConverterBase(FeatureBoard features) { EXTENSION_COLLECTION = defaults; } - Result registerCreateTables(List tables) throws SqlParseException { + Pair registerCreateTables(List tables) + throws SqlParseException { + CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); + CalciteCatalogReader catalogReader = + new CalciteCatalogReader(rootSchema, List.of(), factory, config); + SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT); + if (tables != null) { + for (String tableDef : tables) { + List tList = parseCreateTable(factory, validator, tableDef); + for (DefinedTable t : tList) { + rootSchema.add(t.getName(), t); + } + } + } + return Pair.of(validator, catalogReader); + } + + Result registerCreateTablesForExtendedExpression(List tables) throws SqlParseException { Map nameToTypeMap = new LinkedHashMap<>(); Map nameToNodeMap = new HashMap<>(); CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 2b35dffef..3052898e1 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -1,25 +1,17 @@ package io.substrait.isthmus; -import com.github.bsideup.jabel.Desugar; import com.google.common.annotations.VisibleForTesting; +import io.substrait.extended.expression.ExtendedExpressionProtoConverter; +import io.substrait.extended.expression.ImmutableExpressionReference; +import io.substrait.extended.expression.ImmutableExtendedExpression; import io.substrait.extension.ExtensionCollector; -import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.RexExpressionConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; -import io.substrait.proto.Expression; -import io.substrait.proto.Expression.ScalarFunction; -import io.substrait.proto.ExpressionReference; import io.substrait.proto.ExtendedExpression; -import io.substrait.proto.FunctionArgument; import io.substrait.proto.Plan; import io.substrait.proto.PlanRel; -import io.substrait.proto.SimpleExtensionDeclaration; -import io.substrait.proto.SimpleExtensionURI; import io.substrait.relation.RelProtoConverter; import io.substrait.type.NamedStruct; -import io.substrait.type.TypeCreator; -import io.substrait.type.proto.TypeProtoConverter; -import java.io.IOException; import java.util.*; import java.util.function.Function; import org.apache.calcite.plan.hep.HepPlanner; @@ -28,9 +20,6 @@ import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexInputRef; -import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.Schema; import org.apache.calcite.sql.SqlNode; @@ -64,8 +53,8 @@ public Plan execute(String sql, Function, NamedStruct> tableLookup) } public Plan execute(String sql, List tables) throws SqlParseException { - var result = registerCreateTables(tables); - return executeInner(sql, factory, result.validator(), result.catalogReader()); + var pair = registerCreateTables(tables); + return executeInner(sql, factory, pair.left, pair.right); } public Plan execute(String sql, String name, Schema schema) throws SqlParseException { @@ -84,7 +73,7 @@ public Plan execute(String sql, String name, Schema schema) throws SqlParseExcep */ public ExtendedExpression executeSQLExpression(String sqlExpression, List tables) throws SqlParseException { - var result = registerCreateTables(tables); + var result = registerCreateTablesForExtendedExpression(tables); return executeInnerSQLExpression( sqlExpression, result.validator(), @@ -95,8 +84,8 @@ public ExtendedExpression executeSQLExpression(String sqlExpression, List sqlToRelNode(String sql, List tables) throws SqlParseException { - var result = registerCreateTables(tables); - return sqlToRelNode(sql, result.validator(), result.catalogReader()); + var pair = registerCreateTables(tables); + return sqlToRelNode(sql, pair.left, pair.right); } // Package protected for testing @@ -143,249 +132,27 @@ private ExtendedExpression executeInnerSQLExpression( Map nameToTypeMap, Map nameToNodeMap) throws SqlParseException { - ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); - ExtensionCollector functionCollector = new ExtensionCollector(); - RexNode rexNode = - sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); - ResulTraverseRowExpression result = new TraverseRexNode().getRowExpression(rexNode); - io.substrait.proto.Type output = - TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); - List functionArgumentList = new ArrayList<>(); - result - .expressionBuilderMap() - .forEach( - (k, v) -> { - System.out.println("k->" + k); - System.out.println("v->" + v); - functionArgumentList.add(FunctionArgument.newBuilder().setValue(v).build()); - }); - - ScalarFunction.Builder scalarFunctionBuilder = - ScalarFunction.newBuilder() - .setFunctionReference(1) // rel_01 - .setOutputType(output) - .addAllArguments(functionArgumentList); - - Expression.Builder expressionBuilder = - Expression.newBuilder().setScalarFunction(scalarFunctionBuilder); - - ExpressionReference.Builder expressionReferenceBuilder = - ExpressionReference.newBuilder() - .setExpression(expressionBuilder) - .addOutputNames(result.ref().getName()); - - extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); - - io.substrait.expression.Expression.ScalarFunctionInvocation func = - (io.substrait.expression.Expression.ScalarFunctionInvocation) - rexNode.accept(rexExpressionConverter); - String declaration = func.declaration().key(); // values example: gt:any_any, add:i64_i64 - - // this is not mandatory to be defined; it is working without this definition. It is - // only created here to create a proto message that has the correct semantics - HashMap extensionUris = new HashMap<>(); - SimpleExtensionURI simpleExtensionURI; - try { - simpleExtensionURI = - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(1) // rel_02 - .setUri( - SimpleExtension.loadDefaults().scalarFunctions().stream() - .filter(s -> s.toString().equalsIgnoreCase(declaration)) - .findFirst() - .orElseThrow( - () -> - new IllegalArgumentException( - String.format("Failed to get URI resource for %s.", declaration))) - .uri()) - .build(); - } catch (IOException e) { - throw new RuntimeException(e); - } - extensionUris.put("uri", simpleExtensionURI); - - ArrayList extensions = new ArrayList<>(); - SimpleExtensionDeclaration extensionFunctionLowerThan = - SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(scalarFunctionBuilder.getFunctionReference()) // rel_01 - .setName(declaration) - .setExtensionUriReference(simpleExtensionURI.getExtensionUriAnchor())) // rel_02 - .build(); - extensions.add(extensionFunctionLowerThan); - - System.out.println( - "extendedExpressionBuilder.getExtensionUrisList(): " - + extendedExpressionBuilder.getExtensionUrisList()); - // adding it for semantic purposes, it is not mandatory or needed - extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); - extendedExpressionBuilder.addAllExtensions(extensions); - - NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); - extendedExpressionBuilder.setBaseSchema( - namedStruct.toProto(new TypeProtoConverter(functionCollector))); - - /* - builder.addAllExtensionUris(uris.values()); - builder.addAllExtensions(extensionList); - */ - - return extendedExpressionBuilder.build(); - } - - private ExtendedExpression executeInnerSQLExpressionPojo( - String sqlExpression, - SqlValidator validator, - CalciteCatalogReader catalogReader, - Map nameToTypeMap, - Map nameToNodeMap) - throws SqlParseException { - ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); - ExtensionCollector functionCollector = new ExtensionCollector(); RexNode rexNode = sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); - ResulTraverseRowExpression result = new TraverseRexNode().getRowExpression(rexNode); - io.substrait.proto.Type output = - TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); - List functionArgumentList = new ArrayList<>(); - result - .expressionBuilderMap() - .forEach( - (k, v) -> { - System.out.println("k->" + k); - System.out.println("v->" + v); - functionArgumentList.add(FunctionArgument.newBuilder().setValue(v).build()); - }); - - ScalarFunction.Builder scalarFunctionBuilder = - ScalarFunction.newBuilder() - .setFunctionReference(1) // rel_01 - .setOutputType(output) - .addAllArguments(functionArgumentList); - - Expression.Builder expressionBuilder = - Expression.newBuilder().setScalarFunction(scalarFunctionBuilder); - - ExpressionReference.Builder expressionReferenceBuilder = - ExpressionReference.newBuilder() - .setExpression(expressionBuilder) - .addOutputNames(result.ref().getName()); - - extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); - io.substrait.expression.Expression.ScalarFunctionInvocation func = (io.substrait.expression.Expression.ScalarFunctionInvocation) rexNode.accept(rexExpressionConverter); - String declaration = func.declaration().key(); // values example: gt:any_any, add:i64_i64 - - // this is not mandatory to be defined; it is working without this definition. It is - // only created here to create a proto message that has the correct semantics - HashMap extensionUris = new HashMap<>(); - SimpleExtensionURI simpleExtensionURI; - try { - simpleExtensionURI = - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(1) // rel_02 - .setUri( - SimpleExtension.loadDefaults().scalarFunctions().stream() - .filter(s -> s.toString().equalsIgnoreCase(declaration)) - .findFirst() - .orElseThrow( - () -> - new IllegalArgumentException( - String.format("Failed to get URI resource for %s.", declaration))) - .uri()) - .build(); - } catch (IOException e) { - throw new RuntimeException(e); - } - extensionUris.put("uri", simpleExtensionURI); - - ArrayList extensions = new ArrayList<>(); - SimpleExtensionDeclaration extensionFunctionLowerThan = - SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(scalarFunctionBuilder.getFunctionReference()) // rel_01 - .setName(declaration) - .setExtensionUriReference(simpleExtensionURI.getExtensionUriAnchor())) // rel_02 - .build(); - extensions.add(extensionFunctionLowerThan); - - System.out.println( - "extendedExpressionBuilder.getExtensionUrisList(): " - + extendedExpressionBuilder.getExtensionUrisList()); - // adding it for semantic purposes, it is not mandatory or needed - extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); - extendedExpressionBuilder.addAllExtensions(extensions); - NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); - extendedExpressionBuilder.setBaseSchema( - namedStruct.toProto(new TypeProtoConverter(functionCollector))); - - /* - builder.addAllExtensionUris(uris.values()); - builder.addAllExtensions(extensionList); - */ + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder().referredExpr(func).addOutputNames("output").build(); - return extendedExpressionBuilder.build(); - } + List + expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); - class TraverseRexNode { - RexInputRef ref = null; - int control = 0; - Expression.Builder referenceBuilder = null; - Expression.Builder literalBuilder = null; - Map expressionBuilderMap = new LinkedHashMap<>(); + ImmutableExtendedExpression.Builder extendedExpression = + ImmutableExtendedExpression.builder() + .referredExpr(expressionReferences) + .baseSchema(namedStruct); - ResulTraverseRowExpression getRowExpression(RexNode rexNode) { - switch (rexNode.getClass().getSimpleName().toUpperCase()) { - case "REXCALL": - for (RexNode rexInternal : ((RexCall) rexNode).operands) { - getRowExpression(rexInternal); - } - ; - break; - case "REXINPUTREF": - ref = (RexInputRef) rexNode; - referenceBuilder = - Expression.newBuilder() - .setSelection( - Expression.FieldReference.newBuilder() - .setDirectReference( - Expression.ReferenceSegment.newBuilder() - .setStructField( - Expression.ReferenceSegment.StructField.newBuilder() - .setField(ref.getIndex())))); - expressionBuilderMap.put(control, referenceBuilder); - control++; - break; - case "REXLITERAL": - RexLiteral literal = (RexLiteral) rexNode; - literalBuilder = - Expression.newBuilder() - .setLiteral( - Expression.Literal.newBuilder().setI32(literal.getValueAs(Integer.class))); - expressionBuilderMap.put(control, literalBuilder); - control++; - break; - default: - throw new AssertionError( - "Unsupported type for: " + rexNode.getClass().getSimpleName().toUpperCase()); - } - return new ResulTraverseRowExpression( - ref, referenceBuilder, literalBuilder, expressionBuilderMap); - } + return new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); } - @Desugar - private record ResulTraverseRowExpression( - RexInputRef ref, - Expression.Builder referenceBuilder, - Expression.Builder literalBuilder, - Map expressionBuilderMap) {} - private List sqlToRelNode( String sql, SqlValidator validator, CalciteCatalogReader catalogReader) throws SqlParseException { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java index d43fda1c1..f402aacd3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java @@ -22,9 +22,8 @@ public SubstraitToSql() { public RelNode substraitRelToCalciteRel(Rel relRoot, List tables) throws SqlParseException { - var result = registerCreateTables(tables); - return SubstraitRelNodeConverter.convert( - relRoot, relOptCluster, result.catalogReader(), parserConfig); + var pair = registerCreateTables(tables); + return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, pair.right, parserConfig); } public RelNode substraitRelToCalciteRel( diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index 1c3742b52..d29ad3c07 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -2,8 +2,6 @@ import com.google.common.base.Charsets; import com.google.common.io.Resources; -import com.google.protobuf.InvalidProtocolBufferException; -import com.google.protobuf.util.JsonFormat; import io.substrait.extended.expression.ExtendedExpressionProtoConverter; import io.substrait.extended.expression.ProtoExtendedExpressionConverter; import io.substrait.proto.ExtendedExpression; @@ -37,24 +35,20 @@ protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query, protected ExtendedExpression assertProtoExtendedExpressionRoundrip( String query, SqlToSubstrait s, List creates) throws SqlParseException, IOException { - io.substrait.proto.ExtendedExpression protoExtendedExpression = - s.executeSQLExpression(query, creates); + // proto initial extended expression + ExtendedExpression extendedExpressionProtoInitial = s.executeSQLExpression(query, creates); - try { - String ee = JsonFormat.printer().print(protoExtendedExpression); - System.out.println("Proto Extended Expression: \n" + ee); + // pojo final extended expression + io.substrait.extended.expression.ExtendedExpression extendedExpressionPojoFinal = + new ProtoExtendedExpressionConverter().from(extendedExpressionProtoInitial); - io.substrait.extended.expression.ExtendedExpression from = - new ProtoExtendedExpressionConverter().from(protoExtendedExpression); + // proto final extended expression + ExtendedExpression extendedExpressionProtoFinal = + new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoFinal); - ExtendedExpression proto = new ExtendedExpressionProtoConverter().toProto(from); + // round-trip to validate extended expression proto initial equals to final + Assertions.assertEquals(extendedExpressionProtoFinal, extendedExpressionProtoInitial); - Assertions.assertEquals(proto, protoExtendedExpression); - // FIXME! Implement test validation as the same as proto Plan implementation - } catch (InvalidProtocolBufferException e) { - throw new RuntimeException(e); - } - - return protoExtendedExpression; + return extendedExpressionProtoInitial; } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index 5a407a936..d08f588d6 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -10,7 +10,7 @@ public class SimpleExtendedExpressionsTest extends ExtendedExpressionTestBase { @Test public void filter() throws IOException, SqlParseException { ExtendedExpression extendedExpression = - assertProtoExtendedExpressionRoundrip("N_NATIONKEY > 10"); + assertProtoExtendedExpressionRoundrip("L_ORDERKEY > 10"); } @Test diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 0121fae9c..4badb4a5e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -2,7 +2,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; -import com.google.protobuf.util.JsonFormat; import com.ibm.icu.impl.ClassLoaderUtil; import io.substrait.isthmus.ExtendedExpressionTestBase; import io.substrait.isthmus.SqlToSubstrait; @@ -21,7 +20,7 @@ import org.apache.arrow.dataset.source.DatasetFactory; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.calcite.sql.parser.SqlParseException; import org.junit.jupiter.api.Test; @@ -31,7 +30,9 @@ public class ExtendedExpressionIntegrationTest { @Test public void filterDataset() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); - String sqlExpression = "N_REGIONKEY > 20"; + // Make sure you pass appropriate data, for example, if you pass N_NATIONKEY > 20 the engine + // creates an i64 but casts it to i32 = 20, causing casting problems. + String sqlExpression = "N_NATIONKEY > 9223372036854771827 - 9223372036854771807"; ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) @@ -50,11 +51,9 @@ public void filterDataset() throws SqlParseException, IOException { int count = 0; while (reader.loadNextBatch()) { count += reader.getVectorSchemaRoot().getRowCount(); - System.out.println(reader.getVectorSchemaRoot().contentToTSVString()); } assertEquals(4, count); } catch (Exception e) { - e.printStackTrace(); throw new RuntimeException(e); } } @@ -62,7 +61,9 @@ public void filterDataset() throws SqlParseException, IOException { @Test public void projectDataset() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); - String sqlExpression = "20 + N_NATIONKEY"; + // Make sure you pass appropriate data, for example, if you pass N_NATIONKEY + 20 the engine + // creates an i64 but casts it to i32 = 20, causing casting problems. + String sqlExpression = "N_NATIONKEY + 9888486986"; ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) @@ -79,62 +80,27 @@ public void projectDataset() throws SqlParseException, IOException { Scanner scanner = dataset.newScan(options); ArrowReader reader = scanner.scanBatches()) { int count = 0; - int sum = 0; + Long sum = 0L; while (reader.loadNextBatch()) { count += reader.getVectorSchemaRoot().getRowCount(); - IntVector intVector = (IntVector) reader.getVectorSchemaRoot().getVector(0); - for (int i = 0; i < intVector.getValueCount(); i++) { - sum += intVector.get(i); + BigIntVector bigIntVector = (BigIntVector) reader.getVectorSchemaRoot().getVector(0); + for (int i = 0; i < bigIntVector.getValueCount(); i++) { + sum += bigIntVector.get(i); } - System.out.println(reader.getVectorSchemaRoot().contentToTSVString()); } assertEquals(25, count); - assertEquals(24 * 25 / 2 + 20 * count, sum); + assertEquals(24 * 25 / 2 + 9888486986L * count, sum); } catch (Exception e) { throw new RuntimeException(e); } } - @Test - public void filterDatasetUsingExtendedExpression() throws SqlParseException, IOException { - URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); - String sqlExpression = "N_NATIONKEY > 20"; - ScanOptions options = - new ScanOptions.Builder(/*batchSize*/ 32768) - .columns(Optional.empty()) - .substraitFilter(getFilterExtendedExpression(sqlExpression)) - .build(); - try (BufferAllocator allocator = new RootAllocator(); - DatasetFactory datasetFactory = - new FileSystemDatasetFactory( - allocator, - NativeMemoryPool.getDefault(), - FileFormat.PARQUET, - resource.toURI().toString()); - Dataset dataset = datasetFactory.finish(); - Scanner scanner = dataset.newScan(options); - ArrowReader reader = scanner.scanBatches()) { - int count = 0; - while (reader.loadNextBatch()) { - count += reader.getVectorSchemaRoot().getRowCount(); - System.out.println(reader.getVectorSchemaRoot().contentToTSVString()); - } - assertEquals(4, count); - } catch (Exception e) { - e.printStackTrace(); - throw new RuntimeException(e); - } - } - private static ByteBuffer getFilterExtendedExpression(String sqlExpression) throws IOException, SqlParseException { ExtendedExpression extendedExpression = new SqlToSubstrait() .executeSQLExpression( sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); - System.out.println( - "JsonFormat.printer().print(getFilterExtendedExpression): " - + JsonFormat.printer().print(extendedExpression)); byte[] extendedExpressions = Base64.getDecoder() .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); @@ -149,9 +115,6 @@ private static ByteBuffer getProjectExtendedExpression(String sqlExpression) new SqlToSubstrait() .executeSQLExpression( sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); - System.out.println( - "JsonFormat.printer().print(getProjectExtendedExpression): " - + JsonFormat.printer().print(extendedExpression)); byte[] extendedExpressions = Base64.getDecoder() .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); diff --git a/isthmus/src/test/resources/tpch/schema.sql b/isthmus/src/test/resources/tpch/schema.sql index b8fb4cfd0..81f6f927b 100644 --- a/isthmus/src/test/resources/tpch/schema.sql +++ b/isthmus/src/test/resources/tpch/schema.sql @@ -1,6 +1,77 @@ +CREATE TABLE PART ( + P_PARTKEY BIGINT NOT NULL, + P_NAME VARCHAR(55), + P_MFGR CHAR(25), + P_BRAND CHAR(10), + P_TYPE VARCHAR(25), + P_SIZE INTEGER, + P_CONTAINER CHAR(10), + P_RETAILPRICE DECIMAL, + P_COMMENT VARCHAR(23) +); +CREATE TABLE SUPPLIER ( + S_SUPPKEY BIGINT NOT NULL, + S_NAME CHAR(25), + S_ADDRESS VARCHAR(40), + S_NATIONKEY BIGINT NOT NULL, + S_PHONE CHAR(15), + S_ACCTBAL DECIMAL, + S_COMMENT VARCHAR(101) +); +CREATE TABLE PARTSUPP ( + PS_PARTKEY BIGINT NOT NULL, + PS_SUPPKEY BIGINT NOT NULL, + PS_AVAILQTY INTEGER, + PS_SUPPLYCOST DECIMAL, + PS_COMMENT VARCHAR(199) +); +CREATE TABLE CUSTOMER ( + C_CUSTKEY BIGINT NOT NULL, + C_NAME VARCHAR(25), + C_ADDRESS VARCHAR(40), + C_NATIONKEY BIGINT NOT NULL, + C_PHONE CHAR(15), + C_ACCTBAL DECIMAL, + C_MKTSEGMENT CHAR(10), + C_COMMENT VARCHAR(117) +); +CREATE TABLE ORDERS ( + O_ORDERKEY BIGINT NOT NULL, + O_CUSTKEY BIGINT NOT NULL, + O_ORDERSTATUS CHAR(1), + O_TOTALPRICE DECIMAL, + O_ORDERDATE DATE, + O_ORDERPRIORITY CHAR(15), + O_CLERK CHAR(15), + O_SHIPPRIORITY INTEGER, + O_COMMENT VARCHAR(79) +); +CREATE TABLE LINEITEM ( + L_ORDERKEY BIGINT NOT NULL, + L_PARTKEY BIGINT NOT NULL, + L_SUPPKEY BIGINT NOT NULL, + L_LINENUMBER INTEGER, + L_QUANTITY DECIMAL, + L_EXTENDEDPRICE DECIMAL, + L_DISCOUNT DECIMAL, + L_TAX DECIMAL, + L_RETURNFLAG CHAR(1), + L_LINESTATUS CHAR(1), + L_SHIPDATE DATE, + L_COMMITDATE DATE, + L_RECEIPTDATE DATE, + L_SHIPINSTRUCT CHAR(25), + L_SHIPMODE CHAR(10), + L_COMMENT VARCHAR(44) +); CREATE TABLE NATION ( N_NATIONKEY BIGINT NOT NULL, N_NAME CHAR(25), N_REGIONKEY BIGINT NOT NULL, N_COMMENT VARCHAR(152) ); +CREATE TABLE REGION ( + R_REGIONKEY BIGINT NOT NULL, + R_NAME CHAR(25), + R_COMMENT VARCHAR(152) +); From 940f70399fa4c4c22c581f855004a1de44463c85 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Fri, 24 Nov 2023 19:16:58 -0500 Subject: [PATCH 13/35] fix: clean code redundant method --- .../ExtendedExpressionIntegrationTest.java | 26 +++++-------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 4badb4a5e..297517ec8 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -36,7 +36,7 @@ public void filterDataset() throws SqlParseException, IOException { ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) - .substraitFilter(getFilterExtendedExpression(sqlExpression)) + .substraitFilter(getExtendedExpression(sqlExpression)) .build(); try (BufferAllocator allocator = new RootAllocator(); DatasetFactory datasetFactory = @@ -67,7 +67,7 @@ public void projectDataset() throws SqlParseException, IOException { ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) - .substraitProjection(getProjectExtendedExpression(sqlExpression)) + .substraitProjection(getExtendedExpression(sqlExpression)) .build(); try (BufferAllocator allocator = new RootAllocator(); DatasetFactory datasetFactory = @@ -95,7 +95,7 @@ public void projectDataset() throws SqlParseException, IOException { } } - private static ByteBuffer getFilterExtendedExpression(String sqlExpression) + private static ByteBuffer getExtendedExpression(String sqlExpression) throws IOException, SqlParseException { ExtendedExpression extendedExpression = new SqlToSubstrait() @@ -104,22 +104,8 @@ private static ByteBuffer getFilterExtendedExpression(String sqlExpression) byte[] extendedExpressions = Base64.getDecoder() .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); - ByteBuffer substraitExpressionFilter = ByteBuffer.allocateDirect(extendedExpressions.length); - substraitExpressionFilter.put(extendedExpressions); - return substraitExpressionFilter; - } - - private static ByteBuffer getProjectExtendedExpression(String sqlExpression) - throws IOException, SqlParseException { - ExtendedExpression extendedExpression = - new SqlToSubstrait() - .executeSQLExpression( - sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); - byte[] extendedExpressions = - Base64.getDecoder() - .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); - ByteBuffer substraitExpressionProject = ByteBuffer.allocateDirect(extendedExpressions.length); - substraitExpressionProject.put(extendedExpressions); - return substraitExpressionProject; + ByteBuffer substraitExpression = ByteBuffer.allocateDirect(extendedExpressions.length); + substraitExpression.put(extendedExpressions); + return substraitExpression; } } From f817eb0cc0012bb4357cb619fd496c64f2e88ed7 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Wed, 29 Nov 2023 10:29:04 -0500 Subject: [PATCH 14/35] fix: apply suggestions from code review Co-authored-by: Dane Pitkin <48041712+danepitkin@users.noreply.github.com> --- .../io/substrait/extended/expression/ExtendedExpression.java | 4 ++-- .../extended/expression/ExtendedExpressionProtoConverter.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java index 4f705e82c..2aee599c2 100644 --- a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java @@ -9,7 +9,7 @@ @Value.Immutable public abstract class ExtendedExpression { - public abstract List getReferredExpr(); + public abstract List getReferredExpressions(); public abstract NamedStruct getBaseSchema(); @@ -21,7 +21,7 @@ public abstract class ExtendedExpression { @Value.Immutable public abstract static class ExpressionReference { - public abstract Expression getReferredExpr(); + public abstract Expression getExpression(); public abstract List getOutputNames(); } diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java index 6ace03df4..f3d8441ae 100644 --- a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java @@ -15,7 +15,7 @@ public class ExtendedExpressionProtoConverter { public ExtendedExpression toProto( io.substrait.extended.expression.ExtendedExpression extendedExpression) { - ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); + ExtendedExpression.Builder builder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); final ExpressionProtoConverter expressionProtoConverter = From b1c96bd458e8096c7d8663ab2a0710199f080d7a Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Wed, 29 Nov 2023 11:29:03 -0500 Subject: [PATCH 15/35] fix: code review core module --- .../ExtendedExpression.java | 2 +- .../ExtendedExpressionProtoConverter.java | 26 +++++++-------- .../ProtoExtendedExpressionConverter.java | 32 +++++-------------- .../extension/ImmutableExtensionLookup.java | 13 +++++++- .../io/substrait/plan/ProtoPlanConverter.java | 5 +-- .../java/io/substrait/type/NamedStruct.java | 17 ++++++++++ .../ExtendedExpressionProtoConverterTest.java | 6 ++-- .../ProtoExtendedExpressionConverterTest.java | 10 +++--- 8 files changed, 58 insertions(+), 53 deletions(-) rename core/src/main/java/io/substrait/{extended/expression => extendedexpression}/ExtendedExpression.java (94%) rename core/src/main/java/io/substrait/{extended/expression => extendedexpression}/ExtendedExpressionProtoConverter.java (66%) rename core/src/main/java/io/substrait/{extended/expression => extendedexpression}/ProtoExtendedExpressionConverter.java (69%) rename core/src/test/java/io/substrait/{extended/expression => extendedexpression}/ExtendedExpressionProtoConverterTest.java (94%) rename core/src/test/java/io/substrait/{extended/expression => extendedexpression}/ProtoExtendedExpressionConverterTest.java (90%) diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java similarity index 94% rename from core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java rename to core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java index 2aee599c2..de405f9a3 100644 --- a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java @@ -1,4 +1,4 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import io.substrait.expression.Expression; import io.substrait.proto.AdvancedExtension; diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java similarity index 66% rename from core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java rename to core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index f3d8441ae..a123e4b9f 100644 --- a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -1,4 +1,4 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import io.substrait.expression.Expression; import io.substrait.expression.proto.ExpressionProtoConverter; @@ -7,13 +7,10 @@ import io.substrait.proto.ExtendedExpression; import io.substrait.type.proto.TypeProtoConverter; -/** - * Converts from {@link io.substrait.extended.expression.ExtendedExpression} to {@link - * ExtendedExpression} - */ +/** Converts from {@link ExtendedExpression} to {@link ExtendedExpression} */ public class ExtendedExpressionProtoConverter { public ExtendedExpression toProto( - io.substrait.extended.expression.ExtendedExpression extendedExpression) { + io.substrait.extendedexpression.ExtendedExpression extendedExpression) { ExtendedExpression.Builder builder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); @@ -21,30 +18,29 @@ public ExtendedExpression toProto( final ExpressionProtoConverter expressionProtoConverter = new ExpressionProtoConverter(functionCollector, null); - for (io.substrait.extended.expression.ExtendedExpression.ExpressionReference - expressionReference : extendedExpression.getReferredExpr()) { + for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReference + expressionReference : extendedExpression.getReferredExpressions()) { io.substrait.proto.Expression expressionProto = expressionProtoConverter.visit( - (Expression.ScalarFunctionInvocation) expressionReference.getReferredExpr()); + (Expression.ScalarFunctionInvocation) expressionReference.getExpression()); ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setExpression(expressionProto) .addAllOutputNames(expressionReference.getOutputNames()); - extendedExpressionBuilder.addReferredExpr(expressionReferenceBuilder); + builder.addReferredExpr(expressionReferenceBuilder); } - extendedExpressionBuilder.setBaseSchema( + builder.setBaseSchema( extendedExpression.getBaseSchema().toProto(new TypeProtoConverter(functionCollector))); // the process of adding simple extensions, such as extensionURIs and extensions, is handled on // the fly - functionCollector.addExtensionsToExtendedExpression(extendedExpressionBuilder); + functionCollector.addExtensionsToExtendedExpression(builder); if (extendedExpression.getAdvancedExtension().isPresent()) { - extendedExpressionBuilder.setAdvancedExtensions( - extendedExpression.getAdvancedExtension().get()); + builder.setAdvancedExtensions(extendedExpression.getAdvancedExtension().get()); } - return extendedExpressionBuilder.build(); + return builder.build(); } } diff --git a/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java similarity index 69% rename from core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java rename to core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index 14c82b5ac..3af6ee20d 100644 --- a/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -1,12 +1,10 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import io.substrait.expression.Expression; import io.substrait.expression.proto.ProtoExpressionConverter; import io.substrait.extension.*; import io.substrait.proto.ExpressionReference; import io.substrait.proto.NamedStruct; -import io.substrait.type.ImmutableNamedStruct; -import io.substrait.type.Type; import io.substrait.type.proto.ProtoTypeConverter; import java.io.IOException; import java.util.ArrayList; @@ -33,12 +31,13 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp // fill in simple extension information through a discovery in the current proto-extended // expression ExtensionLookup functionLookup = - ImmutableExtensionLookup.builder() - .from(extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList()) - .build(); + ImmutableExtensionLookup.builder().from(extendedExpression).build(); NamedStruct baseSchemaProto = extendedExpression.getBaseSchema(); - io.substrait.type.NamedStruct namedStruct = convertNamedStrutProtoToPojo(baseSchemaProto); + + io.substrait.type.NamedStruct namedStruct = + io.substrait.type.NamedStruct.convertNamedStructProtoToPojo( + baseSchemaProto, protoTypeConverter); ProtoExpressionConverter protoExpressionConverter = new ProtoExpressionConverter( @@ -50,14 +49,14 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp protoExpressionConverter.from(expressionReference.getExpression()); expressionReferences.add( ImmutableExpressionReference.builder() - .referredExpr(expressionPojo) + .expression(expressionPojo) .addAllOutputNames(expressionReference.getOutputNamesList()) .build()); } ImmutableExtendedExpression.Builder builder = ImmutableExtendedExpression.builder() - .referredExpr(expressionReferences) + .referredExpressions(expressionReferences) .advancedExtension( Optional.ofNullable( extendedExpression.hasAdvancedExtensions() @@ -66,19 +65,4 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp .baseSchema(namedStruct); return builder.build(); } - - private io.substrait.type.NamedStruct convertNamedStrutProtoToPojo(NamedStruct namedStruct) { - var struct = namedStruct.getStruct(); - return ImmutableNamedStruct.builder() - .names(namedStruct.getNamesList()) - .struct( - Type.Struct.builder() - .fields( - struct.getTypesList().stream() - .map(protoTypeConverter::from) - .collect(java.util.stream.Collectors.toList())) - .nullable(ProtoTypeConverter.isNullable(struct.getNullability())) - .build()) - .build(); - } } diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index c88bafc1c..70034d9b1 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -1,5 +1,7 @@ package io.substrait.extension; +import io.substrait.proto.ExtendedExpression; +import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import io.substrait.proto.SimpleExtensionURI; import java.util.Collections; @@ -31,7 +33,16 @@ public static class Builder { private final Map functionMap = new HashMap<>(); private final Map typeMap = new HashMap<>(); - public Builder from( + public Builder from(Plan plan) { + return from(plan.getExtensionUrisList(), plan.getExtensionsList()); + } + + public Builder from(ExtendedExpression extendedExpression) { + return from( + extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList()); + } + + private Builder from( List simpleExtensionURIs, List simpleExtensionDeclarations) { Map namespaceMap = new HashMap<>(); diff --git a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java index 7222eb7ed..be4f4ad9f 100644 --- a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java +++ b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java @@ -32,10 +32,7 @@ protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) } public Plan from(io.substrait.proto.Plan plan) { - ExtensionLookup functionLookup = - ImmutableExtensionLookup.builder() - .from(plan.getExtensionUrisList(), plan.getExtensionsList()) - .build(); + ExtensionLookup functionLookup = ImmutableExtensionLookup.builder().from(plan).build(); ProtoRelConverter relConverter = getProtoRelConverter(functionLookup); List roots = new ArrayList<>(); for (PlanRel planRel : plan.getRelationsList()) { diff --git a/core/src/main/java/io/substrait/type/NamedStruct.java b/core/src/main/java/io/substrait/type/NamedStruct.java index 8bf345aa9..11fdd38ad 100644 --- a/core/src/main/java/io/substrait/type/NamedStruct.java +++ b/core/src/main/java/io/substrait/type/NamedStruct.java @@ -1,5 +1,6 @@ package io.substrait.type; +import io.substrait.type.proto.ProtoTypeConverter; import io.substrait.type.proto.TypeProtoConverter; import java.util.List; import org.immutables.value.Value; @@ -21,4 +22,20 @@ default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConve .addAllNames(names()) .build(); } + + static io.substrait.type.NamedStruct convertNamedStructProtoToPojo( + io.substrait.proto.NamedStruct namedStruct, ProtoTypeConverter protoTypeConverter) { + var struct = namedStruct.getStruct(); + return ImmutableNamedStruct.builder() + .names(namedStruct.getNamesList()) + .struct( + Type.Struct.builder() + .fields( + struct.getTypesList().stream() + .map(protoTypeConverter::from) + .collect(java.util.stream.Collectors.toList())) + .nullable(ProtoTypeConverter.isNullable(struct.getNullability())) + .build()) + .build(); + } } diff --git a/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java similarity index 94% rename from core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java rename to core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java index 20079e24f..fbe3526eb 100644 --- a/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java @@ -1,4 +1,4 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -36,7 +36,7 @@ public void toProtoTest() { ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() - .referredExpr(scalarFunctionExpression.get()) + .expression(scalarFunctionExpression.get()) .addOutputNames("new-column") .build(); @@ -59,7 +59,7 @@ public void toProtoTest() { ImmutableExtendedExpression.Builder extendedExpression = ImmutableExtendedExpression.builder() - .referredExpr(expressionReferences) + .referredExpressions(expressionReferences) .baseSchema(namedStruct); // convert POJO extended expression into PROTOBUF extended expression diff --git a/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java b/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java similarity index 90% rename from core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java rename to core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java index 9ab84f274..69a03a90a 100644 --- a/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java @@ -1,4 +1,4 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import io.substrait.TestBase; import io.substrait.expression.Expression; @@ -37,11 +37,11 @@ public void fromTest() throws IOException { ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() - .referredExpr(scalarFunctionExpression.get()) + .expression(scalarFunctionExpression.get()) .addOutputNames("new-column") .build(); - List + List expressionReferences = new ArrayList<>(); expressionReferences.add(expressionReference); @@ -62,7 +62,7 @@ public void fromTest() throws IOException { // pojo initial extended expression ImmutableExtendedExpression extendedExpressionPojoInitial = ImmutableExtendedExpression.builder() - .referredExpr(expressionReferences) + .referredExpressions(expressionReferences) .baseSchema(namedStruct) .build(); @@ -71,7 +71,7 @@ public void fromTest() throws IOException { new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); // pojo final extended expression - io.substrait.extended.expression.ExtendedExpression extendedExpressionPojoFinal = + io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = new ProtoExtendedExpressionConverter().from(extendedExpressionProto); // validate extended expression pojo initial equals to final roundtrip From 3d9b92729445e80d28dfff5b1aab6b5abad07447 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Wed, 29 Nov 2023 18:45:12 -0500 Subject: [PATCH 16/35] fix: code review core module testing side --- .../ExtendedExpression.java | 15 +++++++- .../ExtendedExpressionProtoConverter.java | 34 +++++++++++------ .../ProtoExtendedExpressionConverter.java | 23 +++++++---- .../ExtendedExpressionProtoConverterTest.java | 38 ++++++++----------- .../ProtoExtendedExpressionConverterTest.java | 30 +++++++-------- 5 files changed, 83 insertions(+), 57 deletions(-) diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java index de405f9a3..26c1e3803 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java @@ -2,6 +2,7 @@ import io.substrait.expression.Expression; import io.substrait.proto.AdvancedExtension; +import io.substrait.proto.AggregateFunction; import io.substrait.type.NamedStruct; import java.util.List; import java.util.Optional; @@ -21,8 +22,20 @@ public abstract class ExtendedExpression { @Value.Immutable public abstract static class ExpressionReference { - public abstract Expression getExpression(); + public abstract ExpressionTypeReference getExpressionType(); public abstract List getOutputNames(); } + + public abstract static class ExpressionTypeReference {} + + @Value.Immutable + public abstract static class ExpressionType extends ExpressionTypeReference { + public abstract Expression getExpression(); + } + + @Value.Immutable + public abstract static class AggregateFunctionType extends ExpressionTypeReference { + public abstract AggregateFunction getMeasure(); + } } diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index a123e4b9f..ac8d08180 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -20,17 +20,29 @@ public ExtendedExpression toProto( for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReference expressionReference : extendedExpression.getReferredExpressions()) { - - io.substrait.proto.Expression expressionProto = - expressionProtoConverter.visit( - (Expression.ScalarFunctionInvocation) expressionReference.getExpression()); - - ExpressionReference.Builder expressionReferenceBuilder = - ExpressionReference.newBuilder() - .setExpression(expressionProto) - .addAllOutputNames(expressionReference.getOutputNames()); - - builder.addReferredExpr(expressionReferenceBuilder); + io.substrait.extendedexpression.ExtendedExpression.ExpressionTypeReference expressionType = + expressionReference.getExpressionType(); + if (expressionType + instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionType) { + io.substrait.proto.Expression expressionProto = + expressionProtoConverter.visit( + (Expression.ScalarFunctionInvocation) + ((io.substrait.extendedexpression.ExtendedExpression.ExpressionType) + expressionType) + .getExpression()); + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expressionProto) + .addAllOutputNames(expressionReference.getOutputNames()); + builder.addReferredExpr(expressionReferenceBuilder); + } else if (expressionType + instanceof io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType) { + throw new UnsupportedOperationException( + "Aggregate function types are not supported in conversion to proto Extended Expressions for now"); + } else { + throw new UnsupportedOperationException( + "Only Expression or Aggregate Function type are supported in conversion to proto Extended Expressions for now"); + } } builder.setBaseSchema( extendedExpression.getBaseSchema().toProto(new TypeProtoConverter(functionCollector))); diff --git a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index 3af6ee20d..14bbf209e 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -45,13 +45,22 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp List expressionReferences = new ArrayList<>(); for (ExpressionReference expressionReference : extendedExpression.getReferredExprList()) { - Expression expressionPojo = - protoExpressionConverter.from(expressionReference.getExpression()); - expressionReferences.add( - ImmutableExpressionReference.builder() - .expression(expressionPojo) - .addAllOutputNames(expressionReference.getOutputNamesList()) - .build()); + if (expressionReference.getExprTypeCase().getNumber() == 1) { // Expression + Expression expressionPojo = + protoExpressionConverter.from(expressionReference.getExpression()); + expressionReferences.add( + ImmutableExpressionReference.builder() + .expressionType( + ImmutableExpressionType.builder().expression(expressionPojo).build()) + .addAllOutputNames(expressionReference.getOutputNamesList()) + .build()); + } else if (expressionReference.getExprTypeCase().getNumber() == 2) { // AggregateFunction + throw new UnsupportedOperationException( + "Aggregate function types are not supported in conversion from proto Extended Expressions for now"); + } else { + throw new UnsupportedOperationException( + "Only Expression or Aggregate Function type are supported in conversion from proto Extended Expressions for now"); + } } ImmutableExtendedExpression.Builder builder = diff --git a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java index fbe3526eb..fa4cd2ac3 100644 --- a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java @@ -3,40 +3,35 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import io.substrait.TestBase; -import io.substrait.expression.Expression; -import io.substrait.expression.ExpressionCreator; -import io.substrait.expression.FieldReference; -import io.substrait.expression.ImmutableFieldReference; +import io.substrait.expression.*; import io.substrait.type.ImmutableNamedStruct; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.util.ArrayList; import java.util.List; -import java.util.Optional; import org.junit.jupiter.api.Test; public class ExtendedExpressionProtoConverterTest extends TestBase { + static final String NAMESPACE = "/functions_arithmetic_decimal.yaml"; + @Test public void toProtoTest() { // create predefined POJO extended expression - Optional scalarFunctionExpression = - defaultExtensionCollection.scalarFunctions().stream() - .filter(s -> s.name().equalsIgnoreCase("add")) - .findFirst() - .map( - declaration -> - ExpressionCreator.scalarFunction( - declaration, - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.decimal(10, 2)) - .build(), - ExpressionCreator.i32(false, 183))); + Expression.ScalarFunctionInvocation scalarFunctionInvocation = + b.scalarFn( + NAMESPACE, + "add:dec_dec", + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183)); ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() - .expression(scalarFunctionExpression.get()) + .expressionType( + ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build()) .addOutputNames("new-column") .build(); @@ -66,8 +61,7 @@ public void toProtoTest() { io.substrait.proto.ExtendedExpression proto = new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); - assertEquals( - "/functions_arithmetic_decimal.yaml", proto.getExtensionUrisList().get(0).getUri()); + assertEquals(NAMESPACE, proto.getExtensionUrisList().get(0).getUri()); assertEquals("add:dec_dec", proto.getExtensionsList().get(0).getExtensionFunction().getName()); } } diff --git a/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java b/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java index 69a03a90a..b0f34b783 100644 --- a/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java @@ -12,32 +12,30 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; -import java.util.Optional; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; public class ProtoExtendedExpressionConverterTest extends TestBase { + static final String NAMESPACE = "/functions_arithmetic_decimal.yaml"; + @Test public void fromTest() throws IOException { // create predefined POJO extended expression - Optional scalarFunctionExpression = - defaultExtensionCollection.scalarFunctions().stream() - .filter(s -> s.name().equalsIgnoreCase("add")) - .findFirst() - .map( - declaration -> - ExpressionCreator.scalarFunction( - declaration, - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.decimal(10, 2)) - .build(), - ExpressionCreator.i32(false, 183))); + Expression.ScalarFunctionInvocation scalarFunctionInvocation = + b.scalarFn( + NAMESPACE, + "add:dec_dec", + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183)); ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() - .expression(scalarFunctionExpression.get()) + .expressionType( + ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build()) .addOutputNames("new-column") .build(); From e7904926ad19d31dd452c8fa7df3da5aa3c07088 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Tue, 5 Dec 2023 21:15:30 -0500 Subject: [PATCH 17/35] feat: support aggregation function in extended expression from/to pojo/proto --- .../ExtendedExpressionProtoConverter.java | 14 +- .../ProtoExtendedExpressionConverter.java | 18 +- .../AggregateFunctionProtoController.java | 43 ++++ .../java/io/substrait/type/NamedStruct.java | 2 +- .../ExtendedExpressionProtoConverterTest.java | 67 ----- .../ExtendedExpressionRoundTripTest.java | 229 ++++++++++++++++++ .../ProtoExtendedExpressionConverterTest.java | 78 ------ 7 files changed, 297 insertions(+), 154 deletions(-) create mode 100644 core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java delete mode 100644 core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java create mode 100644 core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java delete mode 100644 core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index ac8d08180..6d1f2efcb 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -3,6 +3,7 @@ import io.substrait.expression.Expression; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.extension.ExtensionCollector; +import io.substrait.proto.AggregateFunction; import io.substrait.proto.ExpressionReference; import io.substrait.proto.ExtendedExpression; import io.substrait.type.proto.TypeProtoConverter; @@ -37,11 +38,18 @@ public ExtendedExpression toProto( builder.addReferredExpr(expressionReferenceBuilder); } else if (expressionType instanceof io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType) { - throw new UnsupportedOperationException( - "Aggregate function types are not supported in conversion to proto Extended Expressions for now"); + AggregateFunction measure = + ((io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType) + expressionType) + .getMeasure(); + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setMeasure(measure.toBuilder()) + .addAllOutputNames(expressionReference.getOutputNames()); + builder.addReferredExpr(expressionReferenceBuilder); } else { throw new UnsupportedOperationException( - "Only Expression or Aggregate Function type are supported in conversion to proto Extended Expressions for now"); + "Only Expression or Aggregate Function type are supported in conversion to proto Extended Expressions"); } } builder.setBaseSchema( diff --git a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index 14bbf209e..8daf41cf0 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -2,7 +2,12 @@ import io.substrait.expression.Expression; import io.substrait.expression.proto.ProtoExpressionConverter; -import io.substrait.extension.*; +import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.ExtensionLookup; +import io.substrait.extension.ImmutableExtensionLookup; +import io.substrait.extension.ImmutableSimpleExtension; +import io.substrait.extension.SimpleExtension; +import io.substrait.proto.AggregateFunction; import io.substrait.proto.ExpressionReference; import io.substrait.proto.NamedStruct; import io.substrait.type.proto.ProtoTypeConverter; @@ -36,8 +41,7 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp NamedStruct baseSchemaProto = extendedExpression.getBaseSchema(); io.substrait.type.NamedStruct namedStruct = - io.substrait.type.NamedStruct.convertNamedStructProtoToPojo( - baseSchemaProto, protoTypeConverter); + io.substrait.type.NamedStruct.fromProto(baseSchemaProto, protoTypeConverter); ProtoExpressionConverter protoExpressionConverter = new ProtoExpressionConverter( @@ -55,8 +59,12 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp .addAllOutputNames(expressionReference.getOutputNamesList()) .build()); } else if (expressionReference.getExprTypeCase().getNumber() == 2) { // AggregateFunction - throw new UnsupportedOperationException( - "Aggregate function types are not supported in conversion from proto Extended Expressions for now"); + AggregateFunction measure = expressionReference.getMeasure(); + ImmutableExpressionReference.Builder builder = + ImmutableExpressionReference.builder() + .expressionType(ImmutableAggregateFunctionType.builder().measure(measure).build()) + .addAllOutputNames(expressionReference.getOutputNamesList()); + expressionReferences.add(builder.build()); } else { throw new UnsupportedOperationException( "Only Expression or Aggregate Function type are supported in conversion from proto Extended Expressions for now"); diff --git a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java new file mode 100644 index 000000000..7904bfa72 --- /dev/null +++ b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java @@ -0,0 +1,43 @@ +package io.substrait.relation; + +import io.substrait.expression.FunctionArg; +import io.substrait.expression.proto.ExpressionProtoConverter; +import io.substrait.extension.ExtensionCollector; +import io.substrait.proto.AggregateFunction; +import io.substrait.type.proto.TypeProtoConverter; +import java.util.stream.IntStream; + +/** + * Converts from {@link io.substrait.relation.Aggregate.Measure} to {@link + * io.substrait.proto.AggregateFunction} + */ +public class AggregateFunctionProtoController { + + private final ExpressionProtoConverter exprProtoConverter; + private final TypeProtoConverter typeProtoConverter; + private final ExtensionCollector functionCollector; + + public AggregateFunctionProtoController(ExtensionCollector functionCollector) { + this.functionCollector = functionCollector; + this.exprProtoConverter = new ExpressionProtoConverter(functionCollector, null); + this.typeProtoConverter = new TypeProtoConverter(functionCollector); + } + + public AggregateFunction toProto(Aggregate.Measure measure) { + var argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); + var args = measure.getFunction().arguments(); + var aggFuncDef = measure.getFunction().declaration(); + + return AggregateFunction.newBuilder() + .setPhase(measure.getFunction().aggregationPhase().toProto()) + .setInvocation(measure.getFunction().invocation().toProto()) + .setOutputType(measure.getFunction().getType().accept(typeProtoConverter)) + .addAllArguments( + IntStream.range(0, args.size()) + .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor)) + .collect(java.util.stream.Collectors.toList())) + .setFunctionReference( + functionCollector.getFunctionReference(measure.getFunction().declaration())) + .build(); + } +} diff --git a/core/src/main/java/io/substrait/type/NamedStruct.java b/core/src/main/java/io/substrait/type/NamedStruct.java index 11fdd38ad..e38a95fb5 100644 --- a/core/src/main/java/io/substrait/type/NamedStruct.java +++ b/core/src/main/java/io/substrait/type/NamedStruct.java @@ -23,7 +23,7 @@ default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConve .build(); } - static io.substrait.type.NamedStruct convertNamedStructProtoToPojo( + static io.substrait.type.NamedStruct fromProto( io.substrait.proto.NamedStruct namedStruct, ProtoTypeConverter protoTypeConverter) { var struct = namedStruct.getStruct(); return ImmutableNamedStruct.builder() diff --git a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java deleted file mode 100644 index fa4cd2ac3..000000000 --- a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java +++ /dev/null @@ -1,67 +0,0 @@ -package io.substrait.extendedexpression; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import io.substrait.TestBase; -import io.substrait.expression.*; -import io.substrait.type.ImmutableNamedStruct; -import io.substrait.type.Type; -import io.substrait.type.TypeCreator; -import java.util.ArrayList; -import java.util.List; -import org.junit.jupiter.api.Test; - -public class ExtendedExpressionProtoConverterTest extends TestBase { - static final String NAMESPACE = "/functions_arithmetic_decimal.yaml"; - - @Test - public void toProtoTest() { - // create predefined POJO extended expression - Expression.ScalarFunctionInvocation scalarFunctionInvocation = - b.scalarFn( - NAMESPACE, - "add:dec_dec", - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.decimal(10, 2)) - .build(), - ExpressionCreator.i32(false, 183)); - - ImmutableExpressionReference expressionReference = - ImmutableExpressionReference.builder() - .expressionType( - ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build()) - .addOutputNames("new-column") - .build(); - - List expressionReferences = new ArrayList<>(); - expressionReferences.add(expressionReference); - - ImmutableNamedStruct namedStruct = - ImmutableNamedStruct.builder() - .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") - .struct( - Type.Struct.builder() - .nullable(false) - .addFields( - TypeCreator.NULLABLE.decimal(10, 2), - TypeCreator.REQUIRED.STRING, - TypeCreator.REQUIRED.decimal(10, 2), - TypeCreator.REQUIRED.STRING) - .build()) - .build(); - - ImmutableExtendedExpression.Builder extendedExpression = - ImmutableExtendedExpression.builder() - .referredExpressions(expressionReferences) - .baseSchema(namedStruct); - - // convert POJO extended expression into PROTOBUF extended expression - io.substrait.proto.ExtendedExpression proto = - new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); - - assertEquals(NAMESPACE, proto.getExtensionUrisList().get(0).getUri()); - assertEquals("add:dec_dec", proto.getExtensionsList().get(0).getExtensionFunction().getName()); - } -} diff --git a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java new file mode 100644 index 000000000..1da2d6195 --- /dev/null +++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java @@ -0,0 +1,229 @@ +package io.substrait.extendedexpression; + +import io.substrait.TestBase; +import io.substrait.expression.*; +import io.substrait.relation.Aggregate; +import io.substrait.relation.AggregateFunctionProtoController; +import io.substrait.relation.ImmutableMeasure; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ExtendedExpressionRoundTripTest extends TestBase { + static final String NAMESPACE = "/functions_arithmetic_decimal.yaml"; + + @Test + public void expressionRoundTrip() throws IOException { + // create predefined POJO extended expression + Expression.ScalarFunctionInvocation scalarFunctionInvocation = + b.scalarFn( + NAMESPACE, + "add:dec_dec", + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183)); + + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .expressionType( + ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build()) + .addOutputNames("new-column") + .build(); + + List + expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); + + ImmutableNamedStruct namedStruct = + ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); + + // pojo initial extended expression + ImmutableExtendedExpression extendedExpressionPojoInitial = + ImmutableExtendedExpression.builder() + .referredExpressions(expressionReferences) + .baseSchema(namedStruct) + .build(); + + // proto extended expression + io.substrait.proto.ExtendedExpression extendedExpressionProto = + new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); + + // pojo final extended expression + io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = + new ProtoExtendedExpressionConverter().from(extendedExpressionProto); + + // validate extended expression pojo initial equals to final roundtrip + Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); + } + + @Test + public void aggregationRoundTrip() throws IOException { + // create predefined POJO aggregation function + ImmutableMeasure measure = + Aggregate.Measure.builder() + .function( + AggregateFunctionInvocation.builder() + .arguments(Collections.emptyList()) + .declaration(defaultExtensionCollection.aggregateFunctions().get(0)) + .outputType(TypeCreator.of(false).I64) + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) + .invocation(Expression.AggregationInvocation.ALL) + .build()) + .build(); + + ImmutableAggregateFunctionType aggregateFunctionType = + ImmutableAggregateFunctionType.builder() + .measure(new AggregateFunctionProtoController(functionCollector).toProto(measure)) + .build(); + + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .expressionType(aggregateFunctionType) + .addOutputNames("new-column") + .build(); + + List + expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); + + ImmutableNamedStruct namedStruct = + ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); + + // pojo initial aggregation function + ImmutableExtendedExpression extendedExpressionPojoInitial = + ImmutableExtendedExpression.builder() + .referredExpressions(expressionReferences) + .baseSchema(namedStruct) + .build(); + + // proto aggregation function + io.substrait.proto.ExtendedExpression extendedExpressionProto = + new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); + + // pojo final aggregation function + io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = + new ProtoExtendedExpressionConverter().from(extendedExpressionProto); + + // validate aggregation function pojo initial equals to final roundtrip + Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); + } + + @Test + public void expressionAndAggregationRoundTrip() throws IOException { + // POJO 01 + // create predefined POJO extended expression + Expression.ScalarFunctionInvocation scalarFunctionInvocation = + b.scalarFn( + NAMESPACE, + "add:dec_dec", + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183)); + + ImmutableExpressionReference expressionReferenceExpression = + ImmutableExpressionReference.builder() + .expressionType( + ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build()) + .addOutputNames("new-column") + .build(); + + List + expressionReferences = new ArrayList<>(); + + // POJO 02 + // create predefined POJO aggregation function + ImmutableMeasure measure = + Aggregate.Measure.builder() + .function( + AggregateFunctionInvocation.builder() + .arguments(Collections.emptyList()) + .declaration(defaultExtensionCollection.aggregateFunctions().get(0)) + .outputType(TypeCreator.of(false).I64) + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) + .invocation(Expression.AggregationInvocation.ALL) + .build()) + .build(); + + ImmutableAggregateFunctionType aggregateFunctionType = + ImmutableAggregateFunctionType.builder() + .measure(new AggregateFunctionProtoController(functionCollector).toProto(measure)) + .build(); + + ImmutableExpressionReference expressionReferenceAggregation = + ImmutableExpressionReference.builder() + .expressionType(aggregateFunctionType) + .addOutputNames("new-column") + .build(); + + // adding expression + expressionReferences.add(expressionReferenceExpression); + // adding aggregation function + expressionReferences.add(expressionReferenceAggregation); + + ImmutableNamedStruct namedStruct = + ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); + + // pojo initial extended expression + aggregation + ImmutableExtendedExpression extendedExpressionPojoInitial = + ImmutableExtendedExpression.builder() + .referredExpressions(expressionReferences) + .baseSchema(namedStruct) + .build(); + + // proto extended expression + aggregation + io.substrait.proto.ExtendedExpression extendedExpressionProto = + new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); + + // pojo final extended expression + aggregation + io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = + new ProtoExtendedExpressionConverter().from(extendedExpressionProto); + + // validate extended expression + aggregation pojo initial equals to final roundtrip + Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); + } +} diff --git a/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java b/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java deleted file mode 100644 index b0f34b783..000000000 --- a/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java +++ /dev/null @@ -1,78 +0,0 @@ -package io.substrait.extendedexpression; - -import io.substrait.TestBase; -import io.substrait.expression.Expression; -import io.substrait.expression.ExpressionCreator; -import io.substrait.expression.FieldReference; -import io.substrait.expression.ImmutableFieldReference; -import io.substrait.proto.ExtendedExpression; -import io.substrait.type.ImmutableNamedStruct; -import io.substrait.type.Type; -import io.substrait.type.TypeCreator; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class ProtoExtendedExpressionConverterTest extends TestBase { - static final String NAMESPACE = "/functions_arithmetic_decimal.yaml"; - - @Test - public void fromTest() throws IOException { - // create predefined POJO extended expression - Expression.ScalarFunctionInvocation scalarFunctionInvocation = - b.scalarFn( - NAMESPACE, - "add:dec_dec", - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.decimal(10, 2)) - .build(), - ExpressionCreator.i32(false, 183)); - - ImmutableExpressionReference expressionReference = - ImmutableExpressionReference.builder() - .expressionType( - ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build()) - .addOutputNames("new-column") - .build(); - - List - expressionReferences = new ArrayList<>(); - expressionReferences.add(expressionReference); - - ImmutableNamedStruct namedStruct = - ImmutableNamedStruct.builder() - .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") - .struct( - Type.Struct.builder() - .nullable(false) - .addFields( - TypeCreator.REQUIRED.decimal(10, 2), - TypeCreator.REQUIRED.STRING, - TypeCreator.REQUIRED.decimal(10, 2), - TypeCreator.REQUIRED.STRING) - .build()) - .build(); - - // pojo initial extended expression - ImmutableExtendedExpression extendedExpressionPojoInitial = - ImmutableExtendedExpression.builder() - .referredExpressions(expressionReferences) - .baseSchema(namedStruct) - .build(); - - // proto extended expression - ExtendedExpression extendedExpressionProto = - new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); - - // pojo final extended expression - io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = - new ProtoExtendedExpressionConverter().from(extendedExpressionProto); - - // validate extended expression pojo initial equals to final roundtrip - Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); - } -} From c26fecd7cf1f6b61e7597e02a440d0af5ae6334b Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Tue, 5 Dec 2023 22:08:35 -0500 Subject: [PATCH 18/35] fix: merge from/to proto/pojo + solve comments on the PR --- .../substrait/isthmus/SqlConverterBase.java | 6 +- .../isthmus/SqlExpressionToSubstrait.java | 120 ++++++++++++++++++ .../io/substrait/isthmus/SqlToSubstrait.java | 83 +----------- .../io/substrait/isthmus/TypeConverter.java | 5 +- .../isthmus/ExtendedExpressionTestBase.java | 9 +- .../ExtendedExpressionIntegrationTest.java | 4 +- 6 files changed, 135 insertions(+), 92 deletions(-) create mode 100644 isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index 67ae9a8ea..6501316b1 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -5,7 +5,11 @@ import io.substrait.isthmus.calcite.SubstraitOperatorTable; import io.substrait.type.NamedStruct; import java.io.IOException; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.function.Function; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionProperty; diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java new file mode 100644 index 000000000..8e31fe30a --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -0,0 +1,120 @@ +package io.substrait.isthmus; + +import com.google.common.annotations.VisibleForTesting; +import io.substrait.extendedexpression.ExtendedExpressionProtoConverter; +import io.substrait.extendedexpression.ImmutableExpressionReference; +import io.substrait.extendedexpression.ImmutableExpressionType; +import io.substrait.extendedexpression.ImmutableExtendedExpression; +import io.substrait.isthmus.expression.RexExpressionConverter; +import io.substrait.isthmus.expression.ScalarFunctionConverter; +import io.substrait.proto.ExtendedExpression; +import io.substrait.type.NamedStruct; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.apache.calcite.prepare.CalciteCatalogReader; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql2rel.SqlToRelConverter; +import org.apache.calcite.sql2rel.StandardConvertletTable; + +public class SqlExpressionToSubstrait extends SqlConverterBase { + + public SqlExpressionToSubstrait() { + this(null); + } + + protected SqlExpressionToSubstrait(FeatureBoard features) { + super(features); + } + + private final ScalarFunctionConverter functionConverter = + new ScalarFunctionConverter(EXTENSION_COLLECTION.scalarFunctions(), factory); + + private final RexExpressionConverter rexExpressionConverter = + new RexExpressionConverter(functionConverter); + + /** + * Process to execute an SQL Expression to convert into an Extended expression protobuf message + * + * @param sqlExpression expression defined by the user + * @param tables of names of table needed to consider to load into memory for catalog, schema, + * validate and parse sql + * @return extended expression protobuf message + * @throws SqlParseException + */ + public ExtendedExpression executeSQLExpression(String sqlExpression, List tables) + throws SqlParseException { + var result = registerCreateTablesForExtendedExpression(tables); + return executeInnerSQLExpression( + sqlExpression, + result.validator(), + result.catalogReader(), + result.nameToTypeMap(), + result.nameToNodeMap()); + } + + private ExtendedExpression executeInnerSQLExpression( + String sqlExpression, + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) + throws SqlParseException { + RexNode rexNode = + sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); + io.substrait.expression.Expression.ScalarFunctionInvocation func = + (io.substrait.expression.Expression.ScalarFunctionInvocation) + rexNode.accept(rexExpressionConverter); + NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); + + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .expressionType(ImmutableExpressionType.builder().expression(func).build()) + .addOutputNames("new-column") + .build(); + + List + expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); + + ImmutableExtendedExpression.Builder extendedExpression = + ImmutableExtendedExpression.builder() + .referredExpressions(expressionReferences) + .baseSchema(namedStruct); + + return new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); + } + + private RexNode sqlToRexNode( + String sql, + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) + throws SqlParseException { + SqlParser parser = SqlParser.create(sql, parserConfig); + SqlNode sqlNode = parser.parseExpression(); + SqlNode validSQLNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap); + SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); + return converter.convertExpression(validSQLNode, nameToNodeMap); + } + + @VisibleForTesting + SqlToRelConverter createSqlToRelConverter( + SqlValidator validator, CalciteCatalogReader catalogReader) { + SqlToRelConverter converter = + new SqlToRelConverter( + null, + validator, + catalogReader, + relOptCluster, + StandardConvertletTable.INSTANCE, + converterConfig); + return converter; + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 6ee3f11af..7a850499a 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -1,27 +1,18 @@ package io.substrait.isthmus; import com.google.common.annotations.VisibleForTesting; -import io.substrait.extendedexpression.ExtendedExpressionProtoConverter; -import io.substrait.extendedexpression.ImmutableExpressionReference; -import io.substrait.extendedexpression.ImmutableExpressionType; -import io.substrait.extendedexpression.ImmutableExtendedExpression; import io.substrait.extension.ExtensionCollector; -import io.substrait.isthmus.expression.RexExpressionConverter; -import io.substrait.isthmus.expression.ScalarFunctionConverter; -import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.PlanRel; import io.substrait.relation.RelProtoConverter; import io.substrait.type.NamedStruct; -import java.util.*; +import java.util.List; import java.util.function.Function; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.rel.RelRoot; -import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.Schema; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.parser.SqlParseException; @@ -33,12 +24,6 @@ /** Take a SQL statement and a set of table definitions and return a substrait plan. */ public class SqlToSubstrait extends SqlConverterBase { - private final ScalarFunctionConverter functionConverter = - new ScalarFunctionConverter(EXTENSION_COLLECTION.scalarFunctions(), factory); - - private final RexExpressionConverter rexExpressionConverter = - new RexExpressionConverter(functionConverter); - public SqlToSubstrait() { this(null); } @@ -63,26 +48,6 @@ public Plan execute(String sql, String name, Schema schema) throws SqlParseExcep return executeInner(sql, factory, pair.left, pair.right); } - /** - * Process to execute an SQL Expression to convert into an Extended expression protobuf message - * - * @param sqlExpression expression defined by the user - * @param tables of names of table needed to consider to load into memory for catalog, schema, - * validate and parse sql - * @return extended expression protobuf message - * @throws SqlParseException - */ - public ExtendedExpression executeSQLExpression(String sqlExpression, List tables) - throws SqlParseException { - var result = registerCreateTablesForExtendedExpression(tables); - return executeInnerSQLExpression( - sqlExpression, - result.validator(), - result.catalogReader(), - result.nameToTypeMap(), - result.nameToNodeMap()); - } - // Package protected for testing List sqlToRelNode(String sql, List tables) throws SqlParseException { var pair = registerCreateTables(tables); @@ -126,38 +91,6 @@ private Plan executeInner( return plan.build(); } - private ExtendedExpression executeInnerSQLExpression( - String sqlExpression, - SqlValidator validator, - CalciteCatalogReader catalogReader, - Map nameToTypeMap, - Map nameToNodeMap) - throws SqlParseException { - RexNode rexNode = - sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); - io.substrait.expression.Expression.ScalarFunctionInvocation func = - (io.substrait.expression.Expression.ScalarFunctionInvocation) - rexNode.accept(rexExpressionConverter); - NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); - - ImmutableExpressionReference expressionReference = - ImmutableExpressionReference.builder() - .expressionType(ImmutableExpressionType.builder().expression(func).build()) - .addOutputNames("new-column") - .build(); - - List - expressionReferences = new ArrayList<>(); - expressionReferences.add(expressionReference); - - ImmutableExtendedExpression.Builder extendedExpression = - ImmutableExtendedExpression.builder() - .referredExpressions(expressionReferences) - .baseSchema(namedStruct); - - return new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); - } - private List sqlToRelNode( String sql, SqlValidator validator, CalciteCatalogReader catalogReader) throws SqlParseException { @@ -174,20 +107,6 @@ private List sqlToRelNode( return roots; } - private RexNode sqlToRexNode( - String sql, - SqlValidator validator, - CalciteCatalogReader catalogReader, - Map nameToTypeMap, - Map nameToNodeMap) - throws SqlParseException { - SqlParser parser = SqlParser.create(sql, parserConfig); - SqlNode sqlNode = parser.parseExpression(); - SqlNode validSQLNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap); - SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); - return converter.convertExpression(validSQLNode, nameToNodeMap); - } - @VisibleForTesting SqlToRelConverter createSqlToRelConverter( SqlValidator validator, CalciteCatalogReader catalogReader) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java index 73c846d45..6b69f3fff 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java @@ -3,7 +3,6 @@ import static io.substrait.isthmus.SubstraitTypeSystem.DAY_SECOND_INTERVAL; import static io.substrait.isthmus.SubstraitTypeSystem.YEAR_MONTH_INTERVAL; -import com.google.common.collect.Lists; import io.substrait.function.NullableType; import io.substrait.function.TypeExpression; import io.substrait.type.NamedStruct; @@ -59,8 +58,8 @@ public NamedStruct toNamedStruct(RelDataType type) { } public NamedStruct toNamedStruct(Map nameToTypeMap) { - var names = Lists.newArrayList(); - var types = Lists.newArrayList(); + var names = new ArrayList(); + var types = new ArrayList(); nameToTypeMap.forEach( (k, v) -> { names.add(k); diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index 052e91c9c..4a9dde8b7 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -25,16 +25,17 @@ public static List tpchSchemaCreateStatements() throws IOException { protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query) throws IOException, SqlParseException { - return assertProtoExtendedExpressionRoundrip(query, new SqlToSubstrait()); + return assertProtoExtendedExpressionRoundrip(query, new SqlExpressionToSubstrait()); } - protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query, SqlToSubstrait s) - throws IOException, SqlParseException { + protected ExtendedExpression assertProtoExtendedExpressionRoundrip( + String query, SqlExpressionToSubstrait s) throws IOException, SqlParseException { return assertProtoExtendedExpressionRoundrip(query, s, tpchSchemaCreateStatements()); } protected ExtendedExpression assertProtoExtendedExpressionRoundrip( - String query, SqlToSubstrait s, List creates) throws SqlParseException, IOException { + String query, SqlExpressionToSubstrait s, List creates) + throws SqlParseException, IOException { // proto initial extended expression ExtendedExpression extendedExpressionProtoInitial = s.executeSQLExpression(query, creates); diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 297517ec8..0c25c603b 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -4,7 +4,7 @@ import com.ibm.icu.impl.ClassLoaderUtil; import io.substrait.isthmus.ExtendedExpressionTestBase; -import io.substrait.isthmus.SqlToSubstrait; +import io.substrait.isthmus.SqlExpressionToSubstrait; import io.substrait.proto.ExtendedExpression; import java.io.IOException; import java.net.URL; @@ -98,7 +98,7 @@ public void projectDataset() throws SqlParseException, IOException { private static ByteBuffer getExtendedExpression(String sqlExpression) throws IOException, SqlParseException { ExtendedExpression extendedExpression = - new SqlToSubstrait() + new SqlExpressionToSubstrait() .executeSQLExpression( sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); byte[] extendedExpressions = From c3cd3e6d0ceb87e0b4ef18527716820b433c80e6 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Wed, 6 Dec 2023 08:05:03 -0500 Subject: [PATCH 19/35] feat: expose sql expression to extended expression thru Pico CLI --- .../isthmus/ExpressionEntryPoint.java | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 isthmus/src/main/java/io/substrait/isthmus/ExpressionEntryPoint.java diff --git a/isthmus/src/main/java/io/substrait/isthmus/ExpressionEntryPoint.java b/isthmus/src/main/java/io/substrait/isthmus/ExpressionEntryPoint.java new file mode 100644 index 000000000..aa7f0ccbd --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/ExpressionEntryPoint.java @@ -0,0 +1,61 @@ +package io.substrait.isthmus; + +import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.util.JsonFormat; +import io.substrait.proto.ExtendedExpression; +import java.util.List; +import java.util.concurrent.Callable; +import org.apache.calcite.sql.validate.SqlConformanceEnum; +import picocli.CommandLine; + +@CommandLine.Command( + name = "isthmus", + version = "isthmus 0.1", + description = "Converts a SQL Expression to a Substrait Extended Expression") +public class ExpressionEntryPoint implements Callable { + + @CommandLine.Parameters(index = "0", description = "The sql expression we should parse.") + private String sqlExpression; + + @CommandLine.Option( + names = {"-c", "--create"}, + description = + "One or multiple create table statements e.g. CREATE TABLE T1(foo int, bar bigint)") + private List createStatements; + + @CommandLine.Option( + names = {"-m", "--multistatement"}, + description = "Allow multiple statements terminated with a semicolon") + private boolean allowMultiStatement; + + @CommandLine.Option( + names = {"--sqlconformancemode"}, + description = "One of built-in Calcite SQL compatibility modes: ${COMPLETION-CANDIDATES}") + private SqlConformanceEnum sqlConformanceMode = SqlConformanceEnum.DEFAULT; + + // this example implements Callable, so parsing, error handling and handling user + // requests for usage help or version help can be done with one line of code. + public static void main(String... args) { + int exitCode = new CommandLine(new PlanEntryPoint()).execute(args); + System.exit(exitCode); + } + + @Override + public Integer call() throws Exception { + FeatureBoard featureBoard = buildFeatureBoard(); + SqlExpressionToSubstrait converter = new SqlExpressionToSubstrait(featureBoard); + ExtendedExpression extendedExpression = + converter.executeSQLExpression(sqlExpression, createStatements); + System.out.println( + JsonFormat.printer().includingDefaultValueFields().print(extendedExpression)); + return 0; + } + + @VisibleForTesting + FeatureBoard buildFeatureBoard() { + return ImmutableFeatureBoard.builder() + .allowsSqlBatch(allowMultiStatement) + .sqlConformanceMode(sqlConformanceMode) + .build(); + } +} From 8d0f81d6767b679688a6bea19e4cba55fdf495ef Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Tue, 12 Dec 2023 16:19:00 -0500 Subject: [PATCH 20/35] feat: expose Isthmus native image commands for SQL Query and SQL Expressions --- isthmus/build.gradle.kts | 4 +- .../isthmus/ExpressionEntryPoint.java | 61 ------------------- ...EntryPoint.java => IsthmusEntryPoint.java} | 38 +++++++++--- .../substrait/isthmus/RegisterAtRuntime.java | 2 +- ...ntTest.java => IsthmusEntryPointTest.java} | 8 +-- isthmus/src/test/script/smoke.sh | 12 +++- 6 files changed, 45 insertions(+), 80 deletions(-) delete mode 100644 isthmus/src/main/java/io/substrait/isthmus/ExpressionEntryPoint.java rename isthmus/src/main/java/io/substrait/isthmus/{PlanEntryPoint.java => IsthmusEntryPoint.java} (58%) rename isthmus/src/test/java/io/substrait/isthmus/{PlanEntryPointTest.java => IsthmusEntryPointTest.java} (90%) diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index abf5e412c..6916cb2fd 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -102,9 +102,9 @@ dependencies { } graal { - mainClass("io.substrait.isthmus.PlanEntryPoint") + mainClass("io.substrait.isthmus.IsthmusEntryPoint") outputName("isthmus") - graalVersion("22.0.0.2") + graalVersion("22.1.0") javaVersion("17") option("--no-fallback") option( diff --git a/isthmus/src/main/java/io/substrait/isthmus/ExpressionEntryPoint.java b/isthmus/src/main/java/io/substrait/isthmus/ExpressionEntryPoint.java deleted file mode 100644 index aa7f0ccbd..000000000 --- a/isthmus/src/main/java/io/substrait/isthmus/ExpressionEntryPoint.java +++ /dev/null @@ -1,61 +0,0 @@ -package io.substrait.isthmus; - -import com.google.common.annotations.VisibleForTesting; -import com.google.protobuf.util.JsonFormat; -import io.substrait.proto.ExtendedExpression; -import java.util.List; -import java.util.concurrent.Callable; -import org.apache.calcite.sql.validate.SqlConformanceEnum; -import picocli.CommandLine; - -@CommandLine.Command( - name = "isthmus", - version = "isthmus 0.1", - description = "Converts a SQL Expression to a Substrait Extended Expression") -public class ExpressionEntryPoint implements Callable { - - @CommandLine.Parameters(index = "0", description = "The sql expression we should parse.") - private String sqlExpression; - - @CommandLine.Option( - names = {"-c", "--create"}, - description = - "One or multiple create table statements e.g. CREATE TABLE T1(foo int, bar bigint)") - private List createStatements; - - @CommandLine.Option( - names = {"-m", "--multistatement"}, - description = "Allow multiple statements terminated with a semicolon") - private boolean allowMultiStatement; - - @CommandLine.Option( - names = {"--sqlconformancemode"}, - description = "One of built-in Calcite SQL compatibility modes: ${COMPLETION-CANDIDATES}") - private SqlConformanceEnum sqlConformanceMode = SqlConformanceEnum.DEFAULT; - - // this example implements Callable, so parsing, error handling and handling user - // requests for usage help or version help can be done with one line of code. - public static void main(String... args) { - int exitCode = new CommandLine(new PlanEntryPoint()).execute(args); - System.exit(exitCode); - } - - @Override - public Integer call() throws Exception { - FeatureBoard featureBoard = buildFeatureBoard(); - SqlExpressionToSubstrait converter = new SqlExpressionToSubstrait(featureBoard); - ExtendedExpression extendedExpression = - converter.executeSQLExpression(sqlExpression, createStatements); - System.out.println( - JsonFormat.printer().includingDefaultValueFields().print(extendedExpression)); - return 0; - } - - @VisibleForTesting - FeatureBoard buildFeatureBoard() { - return ImmutableFeatureBoard.builder() - .allowsSqlBatch(allowMultiStatement) - .sqlConformanceMode(sqlConformanceMode) - .build(); - } -} diff --git a/isthmus/src/main/java/io/substrait/isthmus/PlanEntryPoint.java b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java similarity index 58% rename from isthmus/src/main/java/io/substrait/isthmus/PlanEntryPoint.java rename to isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java index b237f931e..6f7df516f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/PlanEntryPoint.java +++ b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java @@ -7,7 +7,9 @@ import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.util.JsonFormat; import io.substrait.isthmus.SubstraitRelVisitor.CrossJoinPolicy; +import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; +import java.util.Arrays; import java.util.List; import java.util.concurrent.Callable; import org.apache.calcite.sql.validate.SqlConformanceEnum; @@ -16,14 +18,18 @@ @Command( name = "isthmus", version = "isthmus 0.1", - description = "Converts a SQL query to a Substrait Plan") -public class PlanEntryPoint implements Callable { + description = "Substrait Java Native Image for parsing SQL Query and SQL Expressions") +public class IsthmusEntryPoint implements Callable { + static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(IsthmusEntryPoint.class); - static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(PlanEntryPoint.class); - - @Parameters(index = "0", description = "The sql we should parse.") + @Parameters(index = "0", arity = "0..1", description = "The sql we should parse.") private String sql; + @Option( + names = {"-e", "--sqlExpression"}, + description = "The sql expression we should parse.") + private String sqlExpression; + @Option( names = {"-c", "--create"}, description = @@ -48,16 +54,30 @@ public class PlanEntryPoint implements Callable { // this example implements Callable, so parsing, error handling and handling user // requests for usage help or version help can be done with one line of code. public static void main(String... args) { - int exitCode = new CommandLine(new PlanEntryPoint()).execute(args); + logger.debug(Arrays.toString(args)); + int exitCode = new CommandLine(new IsthmusEntryPoint()).execute(args); System.exit(exitCode); } @Override public Integer call() throws Exception { FeatureBoard featureBoard = buildFeatureBoard(); - SqlToSubstrait converter = new SqlToSubstrait(featureBoard); - Plan plan = converter.execute(sql, createStatements); - System.out.println(JsonFormat.printer().includingDefaultValueFields().print(plan)); + // Isthmus image is paring SQL Expression if that argument is defined + if (sqlExpression != null) { + logger.debug(sqlExpression); + logger.debug(String.valueOf(createStatements)); + SqlExpressionToSubstrait converter = new SqlExpressionToSubstrait(featureBoard); + ExtendedExpression extendedExpression = + converter.executeSQLExpression(sqlExpression, createStatements); + System.out.println( + JsonFormat.printer().includingDefaultValueFields().print(extendedExpression)); + } else { // by default Isthmus image are parsing SQL Query + logger.debug(sql); + logger.debug(String.valueOf(createStatements)); + SqlToSubstrait converter = new SqlToSubstrait(featureBoard); + Plan plan = converter.execute(sql, createStatements); + System.out.println(JsonFormat.printer().includingDefaultValueFields().print(plan)); + } return 0; } diff --git a/isthmus/src/main/java/io/substrait/isthmus/RegisterAtRuntime.java b/isthmus/src/main/java/io/substrait/isthmus/RegisterAtRuntime.java index 4f634434f..676a12a88 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/RegisterAtRuntime.java +++ b/isthmus/src/main/java/io/substrait/isthmus/RegisterAtRuntime.java @@ -26,7 +26,7 @@ public void beforeAnalysis(BeforeAnalysisAccess access) { try { Reflections substrait = new Reflections("io.substrait"); // cli picocli - register(PlanEntryPoint.class); + register(IsthmusEntryPoint.class); // Empty class register(Empty.class); diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanEntryPointTest.java b/isthmus/src/test/java/io/substrait/isthmus/IsthmusEntryPointTest.java similarity index 90% rename from isthmus/src/test/java/io/substrait/isthmus/PlanEntryPointTest.java rename to isthmus/src/test/java/io/substrait/isthmus/IsthmusEntryPointTest.java index 3b23d82d2..6c2998757 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanEntryPointTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/IsthmusEntryPointTest.java @@ -12,12 +12,12 @@ import picocli.CommandLine; import picocli.CommandLine.ParameterException; -class PlanEntryPointTest { +class IsthmusEntryPointTest { /** Test that the default values are set correctly into the {@link FeatureBoard}. */ @Test void defaultFeatureBoard() { - PlanEntryPoint planEntryPoint = new PlanEntryPoint(); + IsthmusEntryPoint planEntryPoint = new IsthmusEntryPoint(); new CommandLine(planEntryPoint); FeatureBoard features = planEntryPoint.buildFeatureBoard(); assertFalse(features.allowsSqlBatch()); @@ -28,7 +28,7 @@ void defaultFeatureBoard() { /** Test that the command line options are correctly parsed into the {@link FeatureBoard}. */ @Test void customFeatureBoard() { - PlanEntryPoint planEntryPoint = new PlanEntryPoint(); + IsthmusEntryPoint planEntryPoint = new IsthmusEntryPoint(); new CommandLine(planEntryPoint) .parseArgs( "--multistatement", @@ -47,7 +47,7 @@ void customFeatureBoard() { */ @Test void invalidCmdOptions() { - PlanEntryPoint planEntryPoint = new PlanEntryPoint(); + IsthmusEntryPoint planEntryPoint = new IsthmusEntryPoint(); assertThrows( ParameterException.class, () -> diff --git a/isthmus/src/test/script/smoke.sh b/isthmus/src/test/script/smoke.sh index 16a18c7c8..b5e558f60 100755 --- a/isthmus/src/test/script/smoke.sh +++ b/isthmus/src/test/script/smoke.sh @@ -6,11 +6,17 @@ LINEITEM="CREATE TABLE LINEITEM (L_ORDERKEY BIGINT NOT NULL, L_PARTKEY BIGINT NO echo $LINEITEM #set -x -# Simple +# SQL Query - Simple $CMD 'select * from lineitem' --create "${LINEITEM}" -# With condition +# SQL Query - With condition $CMD 'select * from lineitem where l_orderkey > 10' --create "${LINEITEM}" -# Aggregate +# SQL Query - Aggregate $CMD 'select l_orderkey, count(l_partkey) from lineitem group by l_orderkey' --create "${LINEITEM}" + +# SQL Expression - Filter +$CMD --sqlExpression 'l_orderkey > 10' --create "${LINEITEM}" + +# SQL Expression - Projection +$CMD --sqlExpression 'l_orderkey + 9888486986' --create "${LINEITEM}" From 6fe70bbc79a3036527dd27d05f3f53bc3c8dbe45 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Fri, 19 Jan 2024 21:48:05 -0500 Subject: [PATCH 21/35] fix: delete files not needed --- .../src/test/resources/tpch/data/nation.parquet | Bin 2319 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 isthmus/src/test/resources/tpch/data/nation.parquet diff --git a/isthmus/src/test/resources/tpch/data/nation.parquet b/isthmus/src/test/resources/tpch/data/nation.parquet deleted file mode 100644 index 0189118ce7344297b1ad6f1095a11162ddd960ea..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2319 zcmdT`U5p#m6~5OQZ#?W6cEDrC$}Zfk*6!}sJH*az2w_7yyLOWy>!0=7K~71-_4wMJ z-SLcfev)+p51>k*Nc(^qP^Cy!C{zmSLsb>=5Gqv(EfJNdih}yI;-L>nZFxu~S^-Kc z&g@Tq#9QAg&1aoC_xztT=en#`G7^$L!SJM|ERX}z07A=SA%svC!wB5^Wmr;cIgO|l zbxsTKk&kQnYEdm<{kckQ3ETz+sui_rK1Yse#Ur^=A)0tFwp3NC`K7IHoV}|a_clZ1 z5QoLU$F3-cQ<%dt=0rik2+smJEsmvwUIhr+2j>3ri1d5$E_=Ux;O4tq1>Oe&q(j?n zlPN@}4r}?Q*(WW-q9$pwp6wc*3xvmXkGN(Z&S;w&3!nx9Eqjf*r03XO!(}_ix^6p7 z!)g=HCSHdThqZl^I)uW3Z+Wgl8n)R4_NvcZFiuU|S^;}t8K|~vG#Rm5o2or#ZCct1 z1VWVF?6^Iq8{oZ1^%dHN03d8a8@BHe{Nw?{u`NS~&>T2Y2>+fqon~O3Q5O=m+R@43 zG;s}+S)NoDr;!u=UABX!TfSr1K4eJj)65C43ZCJ0@bYlW^jj^%?1w)a(t{xJ>)dG; zPCn6)eam-!YT|gzZrY~dg=dkGJTv6CArSuK*s5>nE9U5N!LeNzD#^LuF7O>%Ax&( z;`=Baeh!&ae_c>P&UFb}(|s3eq0?pCtnZQjkSQW%W@9gO$FMvo5!BvMU!VBR0k+n2 z7-&$p-|cyJ^1ITPh!rkrSNW`RJ%kJ`9P8Nn3hUz=xju7T=qqDA(ts2@%pBP(WWi$> zMW9>z&xiQ+sY`?fjbfH=jJ_kx!Ku0|t&cTq+Xcs--An!|bB&+P_~@`C^B)OgAdz20 z3O|9;(pN&;Xn8Y7W%@X<4L#9S|0VuW?G*l7@9?LF9O&}N(Hg!Ifr*;AGQY8b<=j4*&n&s zy~g>j@@Xi1sP-~FTQY6<}b`dhn!_FBpn`kEwf9xj`?$|R}igVMvOg^ z_)7UvXgy;*Ec_5el0QAU=ZWJ`!4HNl+p`_Za0@@yj%VKFHR=d3%rrWN6nr^myLOOa z(~l%?A^O@squsrv>zkg@GYu%SBCTaVM-hv*@qzW9AcKt|`o_QcLGcx#2zq!Oxo^vh zR|@-}t=Gkk)f>@`L;1`FL4iZ+juV(?Lx&=ned(#he@0y{O^nA-Nqt8;d-W``)Mc*2 z>WSogoW=t5XsC4Efv8`WMj`paFbFt6dl-4Tn*^L-wD+oU7C}hIe`*g%Z!sAUWFf^_6Dh zl&g1Gmr4;Ng_0q8&^l#N&w$UK^^@O0k^sDcfY)V7sQI=C?2UUxc z;lx4;4TC2MCGN2)h2q)x*BG8uWLD;aOGXa%JX|~DcfA24&=)rKxBYLbZgBCRbE%g? zv-0lky)EzFzJa76?B#7ZoAa~yG3FJpk)008`~HV} g0m}3L5GE8n4il7$pz(nmOlWQn{TlAWGW-?#1~YkXEC2ui From c112a2729647c8317d88e7cad003b6f720679d49 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Fri, 19 Jan 2024 21:54:36 -0500 Subject: [PATCH 22/35] fix: clean code --- .editorconfig | 2 +- .../AggregateFunctionProtoController.java | 43 ------------------- isthmus/build.gradle.kts | 3 -- .../io/substrait/isthmus/TypeConverter.java | 12 ------ 4 files changed, 1 insertion(+), 59 deletions(-) delete mode 100644 core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java diff --git a/.editorconfig b/.editorconfig index 984db0c67..3d674d593 100644 --- a/.editorconfig +++ b/.editorconfig @@ -10,7 +10,7 @@ trim_trailing_whitespace = true [*.{yaml,yml}] indent_size = 2 -[{**/*.sql,**/OuterReferenceResolver.md,gradlew.bat,**/*.parquet}] +[{**/*.sql,**/OuterReferenceResolver.md,gradlew.bat}] charset = unset end_of_line = unset insert_final_newline = unset diff --git a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java deleted file mode 100644 index 7904bfa72..000000000 --- a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java +++ /dev/null @@ -1,43 +0,0 @@ -package io.substrait.relation; - -import io.substrait.expression.FunctionArg; -import io.substrait.expression.proto.ExpressionProtoConverter; -import io.substrait.extension.ExtensionCollector; -import io.substrait.proto.AggregateFunction; -import io.substrait.type.proto.TypeProtoConverter; -import java.util.stream.IntStream; - -/** - * Converts from {@link io.substrait.relation.Aggregate.Measure} to {@link - * io.substrait.proto.AggregateFunction} - */ -public class AggregateFunctionProtoController { - - private final ExpressionProtoConverter exprProtoConverter; - private final TypeProtoConverter typeProtoConverter; - private final ExtensionCollector functionCollector; - - public AggregateFunctionProtoController(ExtensionCollector functionCollector) { - this.functionCollector = functionCollector; - this.exprProtoConverter = new ExpressionProtoConverter(functionCollector, null); - this.typeProtoConverter = new TypeProtoConverter(functionCollector); - } - - public AggregateFunction toProto(Aggregate.Measure measure) { - var argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); - var args = measure.getFunction().arguments(); - var aggFuncDef = measure.getFunction().declaration(); - - return AggregateFunction.newBuilder() - .setPhase(measure.getFunction().aggregationPhase().toProto()) - .setInvocation(measure.getFunction().invocation().toProto()) - .setOutputType(measure.getFunction().getType().accept(typeProtoConverter)) - .addAllArguments( - IntStream.range(0, args.size()) - .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor)) - .collect(java.util.stream.Collectors.toList())) - .setFunctionReference( - functionCollector.getFunctionReference(measure.getFunction().declaration())) - .build(); - } -} diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index 9ba1d6221..707a4b8e5 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -72,7 +72,6 @@ java { } var CALCITE_VERSION = "1.34.0" -var ARROW_VERSION = "14.0.0" dependencies { implementation(project(":core")) @@ -95,8 +94,6 @@ dependencies { implementation("org.immutables:value-annotations:2.8.8") annotationProcessor("org.immutables:value:2.8.8") testImplementation("org.apache.calcite:calcite-plus:${CALCITE_VERSION}") - testImplementation("org.apache.arrow:arrow-dataset:${ARROW_VERSION}") - testImplementation("org.apache.arrow:arrow-memory-netty:${ARROW_VERSION}") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } diff --git a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java index 68cae2220..92ccecbe4 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java @@ -12,7 +12,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Map; import javax.annotation.Nullable; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; @@ -58,17 +57,6 @@ public NamedStruct toNamedStruct(RelDataType type) { return NamedStruct.of(names, struct); } - public NamedStruct toNamedStruct(Map nameToTypeMap) { - var names = new ArrayList(); - var types = new ArrayList(); - nameToTypeMap.forEach( - (k, v) -> { - names.add(k); - types.add(toSubstrait(v, names)); - }); - return NamedStruct.of(names, Type.Struct.builder().fields(types).nullable(false).build()); - } - private Type toSubstrait(RelDataType type, List names) { // Check for user mapped types first as they may re-use SqlTypeNames var userType = userTypeMapper.toSubstrait(type); From 1667cd7817f9d17cadb9b35aa0a75edc1c630664 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Mon, 22 Jan 2024 11:20:20 -0500 Subject: [PATCH 23/35] fix: clean validation for empty values that could be supported for Literal expressions --- .../java/io/substrait/isthmus/SqlExpressionToSubstrait.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index 82ec0998f..4306d9603 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -151,9 +151,6 @@ private Result registerCreateTablesForExtendedExpression(List tables) } } } - } else { - throw new IllegalArgumentException( - "Information regarding the data and types must be passed."); } return new Result(validator, catalogReader, nameToTypeMap, nameToNodeMap); } From 4677b9e52b54fbc896cdacfc2a460e982cf1eea9 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 25 Jan 2024 00:38:25 -0500 Subject: [PATCH 24/35] fix: clean code --- .../io/substrait/isthmus/IsthmusEntryPoint.java | 2 +- isthmus/src/test/script/smoke.sh | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java index 3b16f227f..dfaa5b7e5 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java +++ b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java @@ -28,7 +28,7 @@ public class IsthmusEntryPoint implements Callable { private String sql; @Option( - names = {"-e", "--sqlExpression"}, + names = {"-e", "--expression"}, description = "The sql expression we should parse.") private String sqlExpression; diff --git a/isthmus/src/test/script/smoke.sh b/isthmus/src/test/script/smoke.sh index b5e558f60..fbea5289a 100755 --- a/isthmus/src/test/script/smoke.sh +++ b/isthmus/src/test/script/smoke.sh @@ -15,8 +15,14 @@ $CMD 'select * from lineitem where l_orderkey > 10' --create "${LINEITEM}" # SQL Query - Aggregate $CMD 'select l_orderkey, count(l_partkey) from lineitem group by l_orderkey' --create "${LINEITEM}" -# SQL Expression - Filter -$CMD --sqlExpression 'l_orderkey > 10' --create "${LINEITEM}" +# SQL Expression - Literal expression +$CMD --expression '10' -# SQL Expression - Projection -$CMD --sqlExpression 'l_orderkey + 9888486986' --create "${LINEITEM}" +# SQL Expression - Reference expression +$CMD --expression 'l_suppkey' --create "${LINEITEM}" + +# SQL Expression - Filter expression +$CMD --expression 'l_orderkey > 10' --create "${LINEITEM}" + +# SQL Expression - Projection expression +$CMD --expression 'l_orderkey + 9888486986' --create "${LINEITEM}" From 0553366bb389cbfdef3b24cf594ac3be8a287d6d Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 25 Jan 2024 03:07:34 -0500 Subject: [PATCH 25/35] fix: test if error appear on sequential process --- gradle.properties | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradle.properties b/gradle.properties index 6f6c93563..e1246e997 100644 --- a/gradle.properties +++ b/gradle.properties @@ -3,7 +3,7 @@ org.gradle.jvmargs=-XX:+UseG1GC -Xmx512m -XX:MaxMetaspaceSize=512m --add-exports --add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \ --add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \ --add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED -org.gradle.parallel=true +#org.gradle.parallel=true # Build cache can be disabled with --no-build-cache option org.gradle.caching=true From 370fa30dc063b0d306c9187c2fce366eb54f1e26 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 1 Feb 2024 22:04:27 -0500 Subject: [PATCH 26/35] fix: upgrade gradle build action version --- .github/workflows/pr.yml | 4 ++-- .github/workflows/release.yml | 4 ++-- gradle.properties | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index a5f529057..e8df6ad69 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -42,7 +42,7 @@ jobs: java-version: '17' distribution: 'adopt' - name: Setup Gradle - uses: gradle/gradle-build-action@v2 + uses: gradle/gradle-build-action@v3 - name: Build with Gradle run: gradle build isthmus-native-image-mac-linux: @@ -63,7 +63,7 @@ jobs: # helps avoid rate-limiting issues github-token: ${{ secrets.GITHUB_TOKEN }} - name: Setup Gradle - uses: gradle/gradle-build-action@v2 + uses: gradle/gradle-build-action@v3 - name: Report Java Version run: java -version - name: Install GraalVM native image diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 752617293..026fb999f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -29,7 +29,7 @@ jobs: # helps avoid rate-limiting issues github-token: ${{ secrets.GITHUB_TOKEN }} - name: Setup Gradle - uses: gradle/gradle-build-action@v2 + uses: gradle/gradle-build-action@v3 - name: Report Java Version run: java -version - name: Install GraalVM native image @@ -66,7 +66,7 @@ jobs: with: node-version: '20' - name: Setup Gradle - uses: gradle/gradle-build-action@v2 + uses: gradle/gradle-build-action@v3 - name: Download isthmus-ubuntu-latest binary uses: actions/download-artifact@v4 with: diff --git a/gradle.properties b/gradle.properties index e1246e997..6f6c93563 100644 --- a/gradle.properties +++ b/gradle.properties @@ -3,7 +3,7 @@ org.gradle.jvmargs=-XX:+UseG1GC -Xmx512m -XX:MaxMetaspaceSize=512m --add-exports --add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \ --add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \ --add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED -#org.gradle.parallel=true +org.gradle.parallel=true # Build cache can be disabled with --no-build-cache option org.gradle.caching=true From 6f9f6fb807af0999642163a27185d80fc879b322 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 1 Feb 2024 22:45:43 -0500 Subject: [PATCH 27/35] fix: delegate to gradle/actions/setup-gradle@v3 --- .github/workflows/pr.yml | 4 ++-- .github/workflows/release.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index e8df6ad69..3bf33dc1c 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -42,7 +42,7 @@ jobs: java-version: '17' distribution: 'adopt' - name: Setup Gradle - uses: gradle/gradle-build-action@v3 + uses: gradle/actions/setup-gradle@v3 - name: Build with Gradle run: gradle build isthmus-native-image-mac-linux: @@ -63,7 +63,7 @@ jobs: # helps avoid rate-limiting issues github-token: ${{ secrets.GITHUB_TOKEN }} - name: Setup Gradle - uses: gradle/gradle-build-action@v3 + uses: gradle/actions/setup-gradle@v3 - name: Report Java Version run: java -version - name: Install GraalVM native image diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 026fb999f..019c5e5ee 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -29,7 +29,7 @@ jobs: # helps avoid rate-limiting issues github-token: ${{ secrets.GITHUB_TOKEN }} - name: Setup Gradle - uses: gradle/gradle-build-action@v3 + uses: gradle/actions/setup-gradle@v3 - name: Report Java Version run: java -version - name: Install GraalVM native image @@ -66,7 +66,7 @@ jobs: with: node-version: '20' - name: Setup Gradle - uses: gradle/gradle-build-action@v3 + uses: gradle/actions/setup-gradle@v3 - name: Download isthmus-ubuntu-latest binary uses: actions/download-artifact@v4 with: From f5e743c07a1c0b068ecea4e79bc8f432e24115db Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Sat, 3 Feb 2024 00:42:45 -0500 Subject: [PATCH 28/35] fix: code review suggestion Co-authored-by: Dane Pitkin <48041712+danepitkin@users.noreply.github.com> --- .../src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java index dfaa5b7e5..7b8923753 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java +++ b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java @@ -76,7 +76,7 @@ public static void main(String... args) { @Override public Integer call() throws Exception { FeatureBoard featureBoard = buildFeatureBoard(); - // Isthmus image is paring SQL Expression if that argument is defined + // Isthmus image is parsing SQL Expression if that argument is defined if (sqlExpression != null) { logger.debug(sqlExpression); logger.debug(String.valueOf(createStatements)); From 67a254338f6d2bacf11d25f4742e3334f7a8d11e Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Sat, 3 Feb 2024 03:13:31 -0500 Subject: [PATCH 29/35] fix: rename planEntryPoint to isthmusEntryPoint --- .../substrait/isthmus/IsthmusEntryPointTest.java | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/isthmus/src/test/java/io/substrait/isthmus/IsthmusEntryPointTest.java b/isthmus/src/test/java/io/substrait/isthmus/IsthmusEntryPointTest.java index 6c2998757..262b31fe3 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/IsthmusEntryPointTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/IsthmusEntryPointTest.java @@ -17,9 +17,9 @@ class IsthmusEntryPointTest { /** Test that the default values are set correctly into the {@link FeatureBoard}. */ @Test void defaultFeatureBoard() { - IsthmusEntryPoint planEntryPoint = new IsthmusEntryPoint(); - new CommandLine(planEntryPoint); - FeatureBoard features = planEntryPoint.buildFeatureBoard(); + IsthmusEntryPoint isthmusEntryPoint = new IsthmusEntryPoint(); + new CommandLine(isthmusEntryPoint); + FeatureBoard features = isthmusEntryPoint.buildFeatureBoard(); assertFalse(features.allowsSqlBatch()); assertEquals(SqlConformanceEnum.DEFAULT, features.sqlConformanceMode()); assertEquals(CrossJoinPolicy.KEEP_AS_CROSS_JOIN, features.crossJoinPolicy()); @@ -28,14 +28,14 @@ void defaultFeatureBoard() { /** Test that the command line options are correctly parsed into the {@link FeatureBoard}. */ @Test void customFeatureBoard() { - IsthmusEntryPoint planEntryPoint = new IsthmusEntryPoint(); - new CommandLine(planEntryPoint) + IsthmusEntryPoint isthmusEntryPoint = new IsthmusEntryPoint(); + new CommandLine(isthmusEntryPoint) .parseArgs( "--multistatement", "--sqlconformancemode=SQL_SERVER_2008", "--crossjoinpolicy=CONVERT_TO_INNER_JOIN", "SELECT * FROM foo"); - FeatureBoard features = planEntryPoint.buildFeatureBoard(); + FeatureBoard features = isthmusEntryPoint.buildFeatureBoard(); assertTrue(features.allowsSqlBatch()); assertEquals( (SqlConformance) SqlConformanceEnum.SQL_SERVER_2008, features.sqlConformanceMode()); @@ -47,11 +47,11 @@ void customFeatureBoard() { */ @Test void invalidCmdOptions() { - IsthmusEntryPoint planEntryPoint = new IsthmusEntryPoint(); + IsthmusEntryPoint isthmusEntryPoint = new IsthmusEntryPoint(); assertThrows( ParameterException.class, () -> - new CommandLine(planEntryPoint) + new CommandLine(isthmusEntryPoint) .parseArgs( "--sqlconformancemode=SQL_SERVER_2008", "--crossjoinpolicy=REWRITE_TO_INNER_JOIN")); From 0e1d142681135ede02c7507f119580c54a1ed28e Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Sat, 3 Feb 2024 23:12:25 -0500 Subject: [PATCH 30/35] feat: add support for help version image command options --- isthmus/README.md | 243 +++++++++++++++++- isthmus/build.gradle.kts | 3 +- .../substrait/isthmus/IsthmusEntryPoint.java | 32 ++- 3 files changed, 258 insertions(+), 20 deletions(-) diff --git a/isthmus/README.md b/isthmus/README.md index dd6742f72..37e0b2818 100644 --- a/isthmus/README.md +++ b/isthmus/README.md @@ -2,7 +2,7 @@ ## Overview -Substrait Isthmus is a Java library which enables serializing SQL to [Substrait Protobuf](https://substrait.io/serialization/binary_serialization/) via +Substrait Isthmus is a Java library which enables serializing SQL to [Substrait Protobuf](https://substrait.io/serialization/binary_serialization/) and SQL Expression to [Extended Expression](https://substrait.io/expressions/extended_expression/) via the Calcite SQL compiler. Optionally, you can leverage the Calcite RelNode to Substrait Plan translator as an IR translation. ## Build @@ -15,20 +15,51 @@ Isthmus can be built as a native executable via Graal ## Usage +### Version + +``` +$ ./isthmus/build/graal/isthmus --version + +isthmus 0.1 +``` + +### Help + ``` -$ ./isthmus/build/graal/isthmus +$ ./isthmus/build/graal/isthmus --help -Usage: isthmus [-m] [-c=]... -Converts a SQL query to a Substrait Plan - The sql we should parse. +Usage: isthmus [-hmV] [--crossjoinpolicy=] + [-e=] [--outputformat=] + [--sqlconformancemode=] + [-c=]... [] +Substrait Java Native Image for parsing SQL Query and SQL Expressions + [] The sql we should parse. -c, --create= One or multiple create table statements e.g. CREATE TABLE T1(foo int, bar bigint) + --crossjoinpolicy= + One of built-in Calcite SQL compatibility modes: + KEEP_AS_CROSS_JOIN, CONVERT_TO_INNER_JOIN + -e, --expression= + The sql expression we should parse. + -h, --help Show this help message and exit. -m, --multistatement Allow multiple statements terminated with a semicolon + --outputformat= + Set the output format for the generated plan: + PROTOJSON, PROTOTEXT, BINARY + --sqlconformancemode= + One of built-in Calcite SQL compatibility modes: + DEFAULT, LENIENT, BABEL, STRICT_92, STRICT_99, + PRAGMATIC_99, BIG_QUERY, MYSQL_5, ORACLE_10, + ORACLE_12, STRICT_2003, PRAGMATIC_2003, PRESTO, + SQL_SERVER_2008 + -V, --version Print version information and exit. ``` ## Example +### SQL to Substrait Plan + ``` > $ ./isthmus/build/graal/isthmus \ -c "CREATE TABLE Persons ( firstName VARCHAR, lastName VARCHAR, zip INT )" \ @@ -155,3 +186,205 @@ Converts a SQL query to a Substrait Plan "expectedTypeUrls": [] } ``` + +### SQL Expression to Substrait Extended Expression + +#### Projection + +``` +$ ./isthmus/build/graal/isthmus -c "CREATE TABLE NATION (N_NATIONKEY BIGINT NOT NULL, N_NAME CHAR(25), N_REGIONKEY BIGINT NOT NULL, N_COMMENT VARCHAR(152))" \ + -e "N_REGIONKEY + 10" + +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "add:i64_i64" + } + }], + "referredExpr": [{ + "expression": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 10, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }], + "options": [] + } + }, + "outputNames": ["new-column"] + }], + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "expectedTypeUrls": [] +} +``` + +#### Filter + +``` +$ ./isthmus/build/graal/isthmus -c "CREATE TABLE NATION (N_NATIONKEY BIGINT NOT NULL, N_NAME CHAR(25), N_REGIONKEY BIGINT NOT NULL, N_COMMENT VARCHAR(152))" \ + -e "N_REGIONKEY > 10" + +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "gt:any_any" + } + }], + "referredExpr": [{ + "expression": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 10, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }], + "options": [] + } + }, + "outputNames": ["new-column"] + }], + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "expectedTypeUrls": [] +} +``` diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index 707a4b8e5..cf3f8b332 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -81,7 +81,8 @@ dependencies { implementation("org.reflections:reflections:0.9.12") implementation("com.google.guava:guava:29.0-jre") implementation("org.graalvm.sdk:graal-sdk:22.1.0") - implementation("info.picocli:picocli:4.6.1") + implementation("info.picocli:picocli:4.7.5") + annotationProcessor("info.picocli:picocli-codegen:4.7.5") implementation("com.fasterxml.jackson.core:jackson-databind:2.13.4") implementation("com.fasterxml.jackson.core:jackson-annotations:2.13.4") implementation("com.fasterxml.jackson.datatype:jackson-datatype-jdk8:2.13.4") diff --git a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java index 7b8923753..1eefa8032 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java +++ b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java @@ -11,7 +11,6 @@ import io.substrait.isthmus.SubstraitRelVisitor.CrossJoinPolicy; import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; -import java.util.Arrays; import java.util.List; import java.util.concurrent.Callable; import org.apache.calcite.sql.validate.SqlConformanceEnum; @@ -20,10 +19,9 @@ @Command( name = "isthmus", version = "isthmus 0.1", - description = "Substrait Java Native Image for parsing SQL Query and SQL Expressions") + description = "Substrait Java Native Image for parsing SQL Query and SQL Expressions", + mixinStandardHelpOptions = true) public class IsthmusEntryPoint implements Callable { - static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(IsthmusEntryPoint.class); - @Parameters(index = "0", arity = "0..1", description = "The sql we should parse.") private String sql; @@ -65,11 +63,23 @@ enum OutputFormat { description = "One of built-in Calcite SQL compatibility modes: ${COMPLETION-CANDIDATES}") private CrossJoinPolicy crossJoinPolicy = CrossJoinPolicy.KEEP_AS_CROSS_JOIN; - // this example implements Callable, so parsing, error handling and handling user - // requests for usage help or version help can be done with one line of code. public static void main(String... args) { - logger.debug(Arrays.toString(args)); - int exitCode = new CommandLine(new IsthmusEntryPoint()).execute(args); + CommandLine commandLine = new CommandLine(new IsthmusEntryPoint()); + commandLine.setCaseInsensitiveEnumValuesAllowed(true); + CommandLine.ParseResult parseResult = commandLine.parseArgs(args); + if (parseResult.originalArgs().isEmpty()) { // If no arguments print usage help + commandLine.usage(System.out); + System.exit(0); + } + if (commandLine.isUsageHelpRequested()) { + commandLine.usage(System.out); + System.exit(0); + } + if (commandLine.isVersionHelpRequested()) { + commandLine.printVersionHelp(System.out); + System.exit(0); + } + int exitCode = commandLine.execute(args); System.exit(exitCode); } @@ -78,13 +88,9 @@ public Integer call() throws Exception { FeatureBoard featureBoard = buildFeatureBoard(); // Isthmus image is parsing SQL Expression if that argument is defined if (sqlExpression != null) { - logger.debug(sqlExpression); - logger.debug(String.valueOf(createStatements)); SqlExpressionToSubstrait converter = new SqlExpressionToSubstrait(featureBoard, SimpleExtension.loadDefaults()); ExtendedExpression extendedExpression = converter.convert(sqlExpression, createStatements); - System.out.println( - JsonFormat.printer().includingDefaultValueFields().print(extendedExpression)); switch (outputFormat) { case PROTOJSON -> System.out.println( JsonFormat.printer().includingDefaultValueFields().print(extendedExpression)); @@ -92,8 +98,6 @@ public Integer call() throws Exception { case BINARY -> extendedExpression.writeTo(System.out); } } else { // by default Isthmus image are parsing SQL Query - logger.debug(sql); - logger.debug(String.valueOf(createStatements)); SqlToSubstrait converter = new SqlToSubstrait(featureBoard); Plan plan = converter.execute(sql, createStatements); switch (outputFormat) { From 123d85917db515f20861652f718566f63e42a99f Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Sun, 4 Feb 2024 19:38:29 -0500 Subject: [PATCH 31/35] feat: enable support for N expression + refacto Isthmus CLI --- .../substrait/isthmus/IsthmusEntryPoint.java | 56 +++++++++++------- .../isthmus/SqlExpressionToSubstrait.java | 57 +++++++++++++++---- .../SimpleExtendedExpressionsTest.java | 5 +- isthmus/src/test/script/smoke.sh | 5 +- 4 files changed, 89 insertions(+), 34 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java index 1eefa8032..80bcd627b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java +++ b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java @@ -5,14 +5,18 @@ import static picocli.CommandLine.Parameters; import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.Message; import com.google.protobuf.TextFormat; import com.google.protobuf.util.JsonFormat; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.SubstraitRelVisitor.CrossJoinPolicy; import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; +import java.io.IOException; +import java.util.Arrays; import java.util.List; import java.util.concurrent.Callable; +import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.validate.SqlConformanceEnum; import picocli.CommandLine; @@ -83,33 +87,45 @@ public static void main(String... args) { System.exit(exitCode); } + private FeatureBoard featureBoard; + @Override public Integer call() throws Exception { - FeatureBoard featureBoard = buildFeatureBoard(); - // Isthmus image is parsing SQL Expression if that argument is defined + this.featureBoard = buildFeatureBoard(); if (sqlExpression != null) { - SqlExpressionToSubstrait converter = - new SqlExpressionToSubstrait(featureBoard, SimpleExtension.loadDefaults()); - ExtendedExpression extendedExpression = converter.convert(sqlExpression, createStatements); - switch (outputFormat) { - case PROTOJSON -> System.out.println( - JsonFormat.printer().includingDefaultValueFields().print(extendedExpression)); - case PROTOTEXT -> TextFormat.printer().print(extendedExpression, System.out); - case BINARY -> extendedExpression.writeTo(System.out); - } - } else { // by default Isthmus image are parsing SQL Query - SqlToSubstrait converter = new SqlToSubstrait(featureBoard); - Plan plan = converter.execute(sql, createStatements); - switch (outputFormat) { - case PROTOJSON -> System.out.println( - JsonFormat.printer().includingDefaultValueFields().print(plan)); - case PROTOTEXT -> TextFormat.printer().print(plan, System.out); - case BINARY -> plan.writeTo(System.out); - } + handleSQLExpression(); + } else { + handleSQLPlan(); } return 0; } + private void handleSQLExpression() throws SqlParseException, IOException { + ExtendedExpression extendedExpression = createExpression(); + printExpression(extendedExpression); + } + + private void handleSQLPlan() throws SqlParseException, IOException { + SqlToSubstrait converter = new SqlToSubstrait(featureBoard); + Plan plan = converter.execute(sql, createStatements); + printExpression(plan); + } + + private ExtendedExpression createExpression() throws IOException, SqlParseException { + SqlExpressionToSubstrait converter = + new SqlExpressionToSubstrait(featureBoard, SimpleExtension.loadDefaults()); + return converter.convert(Arrays.asList(sqlExpression.split(",")), createStatements); + } + + private void printExpression(Message message) throws IOException { + switch (outputFormat) { + case PROTOJSON -> System.out.println( + JsonFormat.printer().includingDefaultValueFields().print(message)); + case PROTOTEXT -> TextFormat.printer().print(message, System.out); + case BINARY -> message.writeTo(System.out); + } + } + @VisibleForTesting FeatureBoard buildFeatureBoard() { return ImmutableFeatureBoard.builder() diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index 4306d9603..ce96a0ac5 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -11,6 +11,7 @@ import io.substrait.type.NamedStruct; import io.substrait.type.Type; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; @@ -70,6 +71,25 @@ public ExtendedExpression convert(String sqlExpression, List createState result.nameToNodeMap()); } + /** + * Converts the given SQL expressions string to an {@link io.substrait.proto.ExtendedExpression } + * + * @param sqlExpressions a List of SQL expression + * @param createStatements table creation statements defining fields referenced by the expression + * @return a {@link io.substrait.proto.ExtendedExpression } + * @throws SqlParseException + */ + public ExtendedExpression convert(List sqlExpressions, List createStatements) + throws SqlParseException { + var result = registerCreateTablesForExtendedExpression(createStatements); + return executeInnerSQLExpressions( + sqlExpressions, + result.validator(), + result.catalogReader(), + result.nameToTypeMap(), + result.nameToNodeMap()); + } + private ExtendedExpression executeInnerSQLExpression( String sqlExpression, SqlValidator validator, @@ -77,20 +97,35 @@ private ExtendedExpression executeInnerSQLExpression( Map nameToTypeMap, Map nameToNodeMap) throws SqlParseException { - RexNode rexNode = - sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); - NamedStruct namedStruct = toNamedStruct(nameToTypeMap); - - ImmutableExpressionReference expressionReference = - ImmutableExpressionReference.builder() - .expression(rexNode.accept(this.rexConverter)) - .addOutputNames("new-column") - .build(); + return executeInnerSQLExpressions( + Collections.singletonList(sqlExpression), + validator, + catalogReader, + nameToTypeMap, + nameToNodeMap); + } + private ExtendedExpression executeInnerSQLExpressions( + List sqlExpressions, + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) + throws SqlParseException { + int columnIndex = 1; List expressionReferences = new ArrayList<>(); - expressionReferences.add(expressionReference); - + RexNode rexNode; + for (String sqlExpression : sqlExpressions) { + rexNode = sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .expression(rexNode.accept(this.rexConverter)) + .addOutputNames("column-" + columnIndex++) + .build(); + expressionReferences.add(expressionReference); + } + NamedStruct namedStruct = toNamedStruct(nameToTypeMap); ImmutableExtendedExpression.Builder extendedExpression = ImmutableExtendedExpression.builder() .referredExpressions(expressionReferences) diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index 9b6afb457..5a622f769 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -20,8 +20,9 @@ private static Stream expressionTypeProvider() { Arguments.of("L_ORDERKEY + 10"), // ScalarFunctionExpressionProjection Arguments.of("L_ORDERKEY IN (10, 20)"), // ScalarFunctionExpressionIn Arguments.of("L_ORDERKEY is not null"), // ScalarFunctionExpressionIsNotNull - Arguments.of("L_ORDERKEY is null") // ScalarFunctionExpressionIsNull - ); + Arguments.of("L_ORDERKEY is null"), // ScalarFunctionExpressionIsNull + Arguments.of("L_ORDERKEY + 10", "L_ORDERKEY * 2"), + Arguments.of("L_ORDERKEY + 10", "L_ORDERKEY * 2", "L_ORDERKEY > 10")); } @ParameterizedTest diff --git a/isthmus/src/test/script/smoke.sh b/isthmus/src/test/script/smoke.sh index fbea5289a..76859204e 100755 --- a/isthmus/src/test/script/smoke.sh +++ b/isthmus/src/test/script/smoke.sh @@ -24,5 +24,8 @@ $CMD --expression 'l_suppkey' --create "${LINEITEM}" # SQL Expression - Filter expression $CMD --expression 'l_orderkey > 10' --create "${LINEITEM}" -# SQL Expression - Projection expression +# SQL Expression - Projection expression (column-1) $CMD --expression 'l_orderkey + 9888486986' --create "${LINEITEM}" + +# SQL Expression - 03 Projection expression (column-1, column-2, column-3) +$CMD --expression 'l_orderkey + 9888486986, l_orderkey * 2, l_orderkey > 10' --create "${LINEITEM}" From 10e52c7ca08d53b1ee8129577b9918414fda8ada Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Mon, 5 Feb 2024 12:08:06 -0500 Subject: [PATCH 32/35] feat: allow to execute more than one epression + testing custom separator --- isthmus/README.md | 5 +- .../substrait/isthmus/IsthmusEntryPoint.java | 44 +++++++---------- .../isthmus/SqlExpressionToSubstrait.java | 44 ++++++++--------- .../isthmus/ExtendedExpressionTestBase.java | 49 ++++++++++--------- .../SimpleExtendedExpressionsTest.java | 44 ++++++++++++++--- isthmus/src/test/script/smoke.sh | 3 ++ 6 files changed, 108 insertions(+), 81 deletions(-) diff --git a/isthmus/README.md b/isthmus/README.md index 37e0b2818..a0e4c8aad 100644 --- a/isthmus/README.md +++ b/isthmus/README.md @@ -29,7 +29,8 @@ isthmus 0.1 $ ./isthmus/build/graal/isthmus --help Usage: isthmus [-hmV] [--crossjoinpolicy=] - [-e=] [--outputformat=] + [-e=] [-es=] + [--outputformat=] [--sqlconformancemode=] [-c=]... [] Substrait Java Native Image for parsing SQL Query and SQL Expressions @@ -42,6 +43,8 @@ Substrait Java Native Image for parsing SQL Query and SQL Expressions KEEP_AS_CROSS_JOIN, CONVERT_TO_INNER_JOIN -e, --expression= The sql expression we should parse. + -es, --separator= + The separator for the sql expressions. -h, --help Show this help message and exit. -m, --multistatement Allow multiple statements terminated with a semicolon --outputformat= diff --git a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java index 80bcd627b..50173de47 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java +++ b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java @@ -13,10 +13,8 @@ import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import java.io.IOException; -import java.util.Arrays; import java.util.List; import java.util.concurrent.Callable; -import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.validate.SqlConformanceEnum; import picocli.CommandLine; @@ -34,6 +32,12 @@ public class IsthmusEntryPoint implements Callable { description = "The sql expression we should parse.") private String sqlExpression; + @Option( + names = {"-es", "--separator"}, + defaultValue = ",", + description = "The separator for the sql expressions.") + private String sqlExpressionSeparator; + @Option( names = {"-c", "--create"}, description = @@ -87,37 +91,25 @@ public static void main(String... args) { System.exit(exitCode); } - private FeatureBoard featureBoard; - @Override public Integer call() throws Exception { - this.featureBoard = buildFeatureBoard(); + FeatureBoard featureBoard = buildFeatureBoard(); + // Isthmus image is parsing SQL Expression if that argument is defined if (sqlExpression != null) { - handleSQLExpression(); - } else { - handleSQLPlan(); + SqlExpressionToSubstrait converter = + new SqlExpressionToSubstrait(featureBoard, SimpleExtension.loadDefaults()); + ExtendedExpression extendedExpression = + converter.convert(sqlExpression, sqlExpressionSeparator, createStatements); + printMessage(extendedExpression); + } else { // by default Isthmus image are parsing SQL Query + SqlToSubstrait converter = new SqlToSubstrait(featureBoard); + Plan plan = converter.execute(sql, createStatements); + printMessage(plan); } return 0; } - private void handleSQLExpression() throws SqlParseException, IOException { - ExtendedExpression extendedExpression = createExpression(); - printExpression(extendedExpression); - } - - private void handleSQLPlan() throws SqlParseException, IOException { - SqlToSubstrait converter = new SqlToSubstrait(featureBoard); - Plan plan = converter.execute(sql, createStatements); - printExpression(plan); - } - - private ExtendedExpression createExpression() throws IOException, SqlParseException { - SqlExpressionToSubstrait converter = - new SqlExpressionToSubstrait(featureBoard, SimpleExtension.loadDefaults()); - return converter.convert(Arrays.asList(sqlExpression.split(",")), createStatements); - } - - private void printExpression(Message message) throws IOException { + private void printMessage(Message message) throws IOException { switch (outputFormat) { case PROTOJSON -> System.out.println( JsonFormat.printer().includingDefaultValueFields().print(message)); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index ce96a0ac5..9a75a3382 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -11,7 +11,7 @@ import io.substrait.type.NamedStruct; import io.substrait.type.Type; import java.util.ArrayList; -import java.util.Collections; +import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; @@ -62,13 +62,22 @@ private record Result( */ public ExtendedExpression convert(String sqlExpression, List createStatements) throws SqlParseException { - var result = registerCreateTablesForExtendedExpression(createStatements); - return executeInnerSQLExpression( - sqlExpression, - result.validator(), - result.catalogReader(), - result.nameToTypeMap(), - result.nameToNodeMap()); + return convert(sqlExpression, ",", createStatements); + } + + /** + * Converts the given SQL expression string to an {@link io.substrait.proto.ExtendedExpression } + * + * @param sqlExpression a SQL expression + * @param separator the separator for the sql expressions + * @param createStatements table creation statements defining fields referenced by the expression + * @return a {@link io.substrait.proto.ExtendedExpression } + * @throws SqlParseException + */ + public ExtendedExpression convert( + String sqlExpression, String separator, List createStatements) + throws SqlParseException { + return convert(Arrays.asList(sqlExpression.split(separator)), createStatements); } /** @@ -90,21 +99,6 @@ public ExtendedExpression convert(List sqlExpressions, List crea result.nameToNodeMap()); } - private ExtendedExpression executeInnerSQLExpression( - String sqlExpression, - SqlValidator validator, - CalciteCatalogReader catalogReader, - Map nameToTypeMap, - Map nameToNodeMap) - throws SqlParseException { - return executeInnerSQLExpressions( - Collections.singletonList(sqlExpression), - validator, - catalogReader, - nameToTypeMap, - nameToNodeMap); - } - private ExtendedExpression executeInnerSQLExpressions( List sqlExpressions, SqlValidator validator, @@ -117,7 +111,9 @@ private ExtendedExpression executeInnerSQLExpressions( expressionReferences = new ArrayList<>(); RexNode rexNode; for (String sqlExpression : sqlExpressions) { - rexNode = sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); + rexNode = + sqlToRexNode( + sqlExpression.trim(), validator, catalogReader, nameToTypeMap, nameToNodeMap); ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() .expression(rexNode.accept(this.rexConverter)) diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index d47abcc77..fe29d6215 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -27,35 +27,42 @@ public static List tpchSchemaCreateStatements() throws IOException { return tpchSchemaCreateStatements("tpch/schema.sql"); } - protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(String query) - throws IOException, SqlParseException { - return assertProtoExtendedExpressionRoundtrip(query, new SqlExpressionToSubstrait()); - } - - protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( - String query, String schemaToLoad) throws IOException, SqlParseException { - return assertProtoExtendedExpressionRoundtrip( - query, new SqlExpressionToSubstrait(), schemaToLoad); + protected void assertProtoEEForExpressionsDefaultCommaSeparatorRoundtrip(String expressions) + throws SqlParseException, IOException { + // proto initial extended expression + ExtendedExpression extendedExpressionProtoInitial = + new SqlExpressionToSubstrait().convert(expressions, tpchSchemaCreateStatements()); + asserProtoExtendedExpression(extendedExpressionProtoInitial); } - protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( - String query, SqlExpressionToSubstrait s) throws IOException, SqlParseException { - return assertProtoExtendedExpressionRoundtrip(query, s, tpchSchemaCreateStatements()); + protected void assertProtoEEForExpressionsDefaultCommaSeparatorErrorRoundtrip( + String expressions, String schemaToLoad) throws SqlParseException, IOException { + // proto initial extended expression + ExtendedExpression extendedExpressionProtoInitial = + new SqlExpressionToSubstrait() + .convert(expressions, tpchSchemaCreateStatements(schemaToLoad)); + asserProtoExtendedExpression(extendedExpressionProtoInitial); } - protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( - String query, SqlExpressionToSubstrait s, String schemaToLoad) - throws IOException, SqlParseException { - return assertProtoExtendedExpressionRoundtrip( - query, s, tpchSchemaCreateStatements(schemaToLoad)); + protected void assertProtoEEForExpressionsCustomSeparatorRoundtrip( + String expressions, String separator) throws SqlParseException, IOException { + // proto initial extended expression + ExtendedExpression extendedExpressionProtoInitial = + new SqlExpressionToSubstrait() + .convert(expressions, separator, tpchSchemaCreateStatements()); + asserProtoExtendedExpression(extendedExpressionProtoInitial); } - protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( - String query, SqlExpressionToSubstrait s, List creates) + protected void assertProtoEEForListExpressionRoundtrip(List expression) throws SqlParseException, IOException { // proto initial extended expression - ExtendedExpression extendedExpressionProtoInitial = s.convert(query, creates); + ExtendedExpression extendedExpressionProtoInitial = + new SqlExpressionToSubstrait().convert(expression, tpchSchemaCreateStatements()); + asserProtoExtendedExpression(extendedExpressionProtoInitial); + } + private static void asserProtoExtendedExpression( + ExtendedExpression extendedExpressionProtoInitial) throws IOException { // pojo final extended expression io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = new ProtoExtendedExpressionConverter().from(extendedExpressionProtoInitial); @@ -66,7 +73,5 @@ protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( // round-trip to validate extended expression proto initial equals to final Assertions.assertEquals(extendedExpressionProtoFinal, extendedExpressionProtoInitial); - - return extendedExpressionProtoInitial; } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index 5a622f769..4c38b267b 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -4,8 +4,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.io.IOException; +import java.util.Arrays; +import java.util.List; import java.util.stream.Stream; import org.apache.calcite.sql.parser.SqlParseException; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -18,30 +21,55 @@ private static Stream expressionTypeProvider() { Arguments.of("L_ORDERKEY"), // FieldReferenceExpression Arguments.of("L_ORDERKEY > 10"), // ScalarFunctionExpressionFilter Arguments.of("L_ORDERKEY + 10"), // ScalarFunctionExpressionProjection - Arguments.of("L_ORDERKEY IN (10, 20)"), // ScalarFunctionExpressionIn + Arguments.of("L_ORDERKEY IN (10)"), // ScalarFunctionExpressionIn Arguments.of("L_ORDERKEY is not null"), // ScalarFunctionExpressionIsNotNull - Arguments.of("L_ORDERKEY is null"), // ScalarFunctionExpressionIsNull - Arguments.of("L_ORDERKEY + 10", "L_ORDERKEY * 2"), - Arguments.of("L_ORDERKEY + 10", "L_ORDERKEY * 2", "L_ORDERKEY > 10")); + Arguments.of("L_ORDERKEY is null")); // ScalarFunctionExpressionIsNull } @ParameterizedTest @MethodSource("expressionTypeProvider") - public void testExtendedExpressionsRoundTrip(String sqlExpression) + public void testExtendedExpressionsCommaSeparatorRoundTrip(String sqlExpression) throws SqlParseException, IOException { - assertProtoExtendedExpressionRoundtrip(sqlExpression); + assertProtoEEForExpressionsDefaultCommaSeparatorRoundtrip( + sqlExpression); // comma-separator by default } @ParameterizedTest @MethodSource("expressionTypeProvider") - public void testExtendedExpressionsRoundTripDuplicateColumnIdentifier(String sqlExpression) { + public void testExtendedExpressionsDuplicateColumnIdentifierRoundTrip(String sqlExpression) { IllegalArgumentException illegalArgumentException = assertThrows( IllegalArgumentException.class, - () -> assertProtoExtendedExpressionRoundtrip(sqlExpression, "tpch/schema_error.sql")); + () -> + assertProtoEEForExpressionsDefaultCommaSeparatorErrorRoundtrip( + sqlExpression, "tpch/schema_error.sql")); assertTrue( illegalArgumentException .getMessage() .startsWith("There is no support for duplicate column names")); } + + @Test + public void testExtendedExpressionsCustomSeparatorRoundTrip() + throws SqlParseException, IOException { + String expressions = + "2#L_ORDERKEY#L_ORDERKEY > 10#L_ORDERKEY + 10#L_ORDERKEY IN (10, 20)#L_ORDERKEY is not null#L_ORDERKEY is null"; + String separator = "#"; + assertProtoEEForExpressionsCustomSeparatorRoundtrip(expressions, separator); + } + + @Test + public void testExtendedExpressionsListExpressionRoundTrip() + throws SqlParseException, IOException { + List expressions = + Arrays.asList( + "2", + "L_ORDERKEY", + "L_ORDERKEY > 10", + "L_ORDERKEY + 10", + "L_ORDERKEY IN (10, 20)", // the comma won't cause any problems + "L_ORDERKEY is not null", + "L_ORDERKEY is null"); + assertProtoEEForListExpressionRoundtrip(expressions); + } } diff --git a/isthmus/src/test/script/smoke.sh b/isthmus/src/test/script/smoke.sh index 76859204e..f64f7417d 100755 --- a/isthmus/src/test/script/smoke.sh +++ b/isthmus/src/test/script/smoke.sh @@ -29,3 +29,6 @@ $CMD --expression 'l_orderkey + 9888486986' --create "${LINEITEM}" # SQL Expression - 03 Projection expression (column-1, column-2, column-3) $CMD --expression 'l_orderkey + 9888486986, l_orderkey * 2, l_orderkey > 10' --create "${LINEITEM}" + +# SQL Expression - 03 Projection expression (column-1, column-2, column-3) with custom seprator +$CMD --expression 'l_orderkey + 9888486986 # l_orderkey * 2 # l_orderkey > 10' --create "${LINEITEM}" --separator "#" From 5cd32d66d1a531391f6da8766373143f93c6172b Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Mon, 5 Feb 2024 18:32:51 -0500 Subject: [PATCH 33/35] fix: delete separator of expressions. Migrate to use expression as String[] --- isthmus/README.md | 10 +++--- .../substrait/isthmus/IsthmusEntryPoint.java | 16 +++------- .../isthmus/SqlExpressionToSubstrait.java | 24 +++----------- .../isthmus/ExtendedExpressionTestBase.java | 11 +------ .../SimpleExtendedExpressionsTest.java | 32 +++++++------------ isthmus/src/test/script/smoke.sh | 5 +-- 6 files changed, 26 insertions(+), 72 deletions(-) diff --git a/isthmus/README.md b/isthmus/README.md index a0e4c8aad..212dd44af 100644 --- a/isthmus/README.md +++ b/isthmus/README.md @@ -29,10 +29,9 @@ isthmus 0.1 $ ./isthmus/build/graal/isthmus --help Usage: isthmus [-hmV] [--crossjoinpolicy=] - [-e=] [-es=] [--outputformat=] [--sqlconformancemode=] - [-c=]... [] + [-c=]... [-e=...]... [] Substrait Java Native Image for parsing SQL Query and SQL Expressions [] The sql we should parse. -c, --create= @@ -41,10 +40,8 @@ Substrait Java Native Image for parsing SQL Query and SQL Expressions --crossjoinpolicy= One of built-in Calcite SQL compatibility modes: KEEP_AS_CROSS_JOIN, CONVERT_TO_INNER_JOIN - -e, --expression= - The sql expression we should parse. - -es, --separator= - The separator for the sql expressions. + -e, --expression=... + One or more SQL expressions e.g. col + 1 -h, --help Show this help message and exit. -m, --multistatement Allow multiple statements terminated with a semicolon --outputformat= @@ -57,6 +54,7 @@ Substrait Java Native Image for parsing SQL Query and SQL Expressions ORACLE_12, STRICT_2003, PRAGMATIC_2003, PRESTO, SQL_SERVER_2008 -V, --version Print version information and exit. + ``` ## Example diff --git a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java index 50173de47..739ba8c77 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java +++ b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java @@ -29,14 +29,9 @@ public class IsthmusEntryPoint implements Callable { @Option( names = {"-e", "--expression"}, - description = "The sql expression we should parse.") - private String sqlExpression; - - @Option( - names = {"-es", "--separator"}, - defaultValue = ",", - description = "The separator for the sql expressions.") - private String sqlExpressionSeparator; + arity = "1..*", + description = "One or more SQL expressions e.g. col + 1") + private String[] sqlExpressions; @Option( names = {"-c", "--create"}, @@ -95,11 +90,10 @@ public static void main(String... args) { public Integer call() throws Exception { FeatureBoard featureBoard = buildFeatureBoard(); // Isthmus image is parsing SQL Expression if that argument is defined - if (sqlExpression != null) { + if (sqlExpressions != null) { SqlExpressionToSubstrait converter = new SqlExpressionToSubstrait(featureBoard, SimpleExtension.loadDefaults()); - ExtendedExpression extendedExpression = - converter.convert(sqlExpression, sqlExpressionSeparator, createStatements); + ExtendedExpression extendedExpression = converter.convert(sqlExpressions, createStatements); printMessage(extendedExpression); } else { // by default Isthmus image are parsing SQL Query SqlToSubstrait converter = new SqlToSubstrait(featureBoard); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index 9a75a3382..649a5039f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -11,7 +11,6 @@ import io.substrait.type.NamedStruct; import io.substrait.type.Type; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; @@ -62,33 +61,18 @@ private record Result( */ public ExtendedExpression convert(String sqlExpression, List createStatements) throws SqlParseException { - return convert(sqlExpression, ",", createStatements); - } - - /** - * Converts the given SQL expression string to an {@link io.substrait.proto.ExtendedExpression } - * - * @param sqlExpression a SQL expression - * @param separator the separator for the sql expressions - * @param createStatements table creation statements defining fields referenced by the expression - * @return a {@link io.substrait.proto.ExtendedExpression } - * @throws SqlParseException - */ - public ExtendedExpression convert( - String sqlExpression, String separator, List createStatements) - throws SqlParseException { - return convert(Arrays.asList(sqlExpression.split(separator)), createStatements); + return convert(new String[] {sqlExpression}, createStatements); } /** * Converts the given SQL expressions string to an {@link io.substrait.proto.ExtendedExpression } * - * @param sqlExpressions a List of SQL expression + * @param sqlExpressions an array of SQL expression * @param createStatements table creation statements defining fields referenced by the expression * @return a {@link io.substrait.proto.ExtendedExpression } * @throws SqlParseException */ - public ExtendedExpression convert(List sqlExpressions, List createStatements) + public ExtendedExpression convert(String[] sqlExpressions, List createStatements) throws SqlParseException { var result = registerCreateTablesForExtendedExpression(createStatements); return executeInnerSQLExpressions( @@ -100,7 +84,7 @@ public ExtendedExpression convert(List sqlExpressions, List crea } private ExtendedExpression executeInnerSQLExpressions( - List sqlExpressions, + String[] sqlExpressions, SqlValidator validator, CalciteCatalogReader catalogReader, Map nameToTypeMap, diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index fe29d6215..e5d9abf94 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -44,16 +44,7 @@ protected void assertProtoEEForExpressionsDefaultCommaSeparatorErrorRoundtrip( asserProtoExtendedExpression(extendedExpressionProtoInitial); } - protected void assertProtoEEForExpressionsCustomSeparatorRoundtrip( - String expressions, String separator) throws SqlParseException, IOException { - // proto initial extended expression - ExtendedExpression extendedExpressionProtoInitial = - new SqlExpressionToSubstrait() - .convert(expressions, separator, tpchSchemaCreateStatements()); - asserProtoExtendedExpression(extendedExpressionProtoInitial); - } - - protected void assertProtoEEForListExpressionRoundtrip(List expression) + protected void assertProtoEEForListExpressionRoundtrip(String[] expression) throws SqlParseException, IOException { // proto initial extended expression ExtendedExpression extendedExpressionProtoInitial = diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index 4c38b267b..3c5e6f4fb 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -4,8 +4,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.io.IOException; -import java.util.Arrays; -import java.util.List; import java.util.stream.Stream; import org.apache.calcite.sql.parser.SqlParseException; import org.junit.jupiter.api.Test; @@ -21,7 +19,7 @@ private static Stream expressionTypeProvider() { Arguments.of("L_ORDERKEY"), // FieldReferenceExpression Arguments.of("L_ORDERKEY > 10"), // ScalarFunctionExpressionFilter Arguments.of("L_ORDERKEY + 10"), // ScalarFunctionExpressionProjection - Arguments.of("L_ORDERKEY IN (10)"), // ScalarFunctionExpressionIn + Arguments.of("L_ORDERKEY IN (10, 20)"), // ScalarFunctionExpressionIn Arguments.of("L_ORDERKEY is not null"), // ScalarFunctionExpressionIsNotNull Arguments.of("L_ORDERKEY is null")); // ScalarFunctionExpressionIsNull } @@ -49,27 +47,19 @@ public void testExtendedExpressionsDuplicateColumnIdentifierRoundTrip(String sql .startsWith("There is no support for duplicate column names")); } - @Test - public void testExtendedExpressionsCustomSeparatorRoundTrip() - throws SqlParseException, IOException { - String expressions = - "2#L_ORDERKEY#L_ORDERKEY > 10#L_ORDERKEY + 10#L_ORDERKEY IN (10, 20)#L_ORDERKEY is not null#L_ORDERKEY is null"; - String separator = "#"; - assertProtoEEForExpressionsCustomSeparatorRoundtrip(expressions, separator); - } - @Test public void testExtendedExpressionsListExpressionRoundTrip() throws SqlParseException, IOException { - List expressions = - Arrays.asList( - "2", - "L_ORDERKEY", - "L_ORDERKEY > 10", - "L_ORDERKEY + 10", - "L_ORDERKEY IN (10, 20)", // the comma won't cause any problems - "L_ORDERKEY is not null", - "L_ORDERKEY is null"); + String[] expressions = { + "2", + "L_ORDERKEY", + "L_ORDERKEY > 10", + "L_ORDERKEY + 10", + "L_ORDERKEY IN (10, 20)", + "L_ORDERKEY is not null", + "L_ORDERKEY is null" + }; + assertProtoEEForListExpressionRoundtrip(expressions); } } diff --git a/isthmus/src/test/script/smoke.sh b/isthmus/src/test/script/smoke.sh index f64f7417d..7ecbc2ec6 100755 --- a/isthmus/src/test/script/smoke.sh +++ b/isthmus/src/test/script/smoke.sh @@ -28,7 +28,4 @@ $CMD --expression 'l_orderkey > 10' --create "${LINEITEM}" $CMD --expression 'l_orderkey + 9888486986' --create "${LINEITEM}" # SQL Expression - 03 Projection expression (column-1, column-2, column-3) -$CMD --expression 'l_orderkey + 9888486986, l_orderkey * 2, l_orderkey > 10' --create "${LINEITEM}" - -# SQL Expression - 03 Projection expression (column-1, column-2, column-3) with custom seprator -$CMD --expression 'l_orderkey + 9888486986 # l_orderkey * 2 # l_orderkey > 10' --create "${LINEITEM}" --separator "#" +$CMD --expression 'l_orderkey + 9888486986' 'l_orderkey * 2' 'l_orderkey > 10' 'l_orderkey in (10, 20)' --create "${LINEITEM}" From 908d5e322ff7e718bae12c4d2dca7de4d652e425 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Mon, 5 Feb 2024 16:21:18 -0800 Subject: [PATCH 34/35] refactor: small PR changes --- isthmus/README.md | 7 +++---- .../substrait/isthmus/IsthmusEntryPoint.java | 4 ++-- .../isthmus/SqlExpressionToSubstrait.java | 6 +++--- .../isthmus/ExtendedExpressionTestBase.java | 19 +++++++++---------- .../SimpleExtendedExpressionsTest.java | 11 ++++------- 5 files changed, 21 insertions(+), 26 deletions(-) diff --git a/isthmus/README.md b/isthmus/README.md index 212dd44af..21d9da919 100644 --- a/isthmus/README.md +++ b/isthmus/README.md @@ -2,7 +2,7 @@ ## Overview -Substrait Isthmus is a Java library which enables serializing SQL to [Substrait Protobuf](https://substrait.io/serialization/binary_serialization/) and SQL Expression to [Extended Expression](https://substrait.io/expressions/extended_expression/) via +Substrait Isthmus is a Java library which enables serializing SQL queries to [Substrait Protobuf](https://substrait.io/serialization/binary_serialization/) and SQL expressions to [Extended Expressions](https://substrait.io/expressions/extended_expression/) via the Calcite SQL compiler. Optionally, you can leverage the Calcite RelNode to Substrait Plan translator as an IR translation. ## Build @@ -32,8 +32,8 @@ Usage: isthmus [-hmV] [--crossjoinpolicy=] [--outputformat=] [--sqlconformancemode=] [-c=]... [-e=...]... [] -Substrait Java Native Image for parsing SQL Query and SQL Expressions - [] The sql we should parse. +Convert SQL Queries and SQL Expressions to Substrait + [] A SQL query -c, --create= One or multiple create table statements e.g. CREATE TABLE T1(foo int, bar bigint) @@ -54,7 +54,6 @@ Substrait Java Native Image for parsing SQL Query and SQL Expressions ORACLE_12, STRICT_2003, PRAGMATIC_2003, PRESTO, SQL_SERVER_2008 -V, --version Print version information and exit. - ``` ## Example diff --git a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java index 739ba8c77..eac6fbed3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java +++ b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java @@ -21,10 +21,10 @@ @Command( name = "isthmus", version = "isthmus 0.1", - description = "Substrait Java Native Image for parsing SQL Query and SQL Expressions", + description = "Convert SQL Queries and SQL Expressions to Substrait", mixinStandardHelpOptions = true) public class IsthmusEntryPoint implements Callable { - @Parameters(index = "0", arity = "0..1", description = "The sql we should parse.") + @Parameters(index = "0", arity = "0..1", description = "A SQL query") private String sql; @Option( diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index 649a5039f..5932a0d35 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -52,7 +52,7 @@ private record Result( Map nameToNodeMap) {} /** - * Converts the given SQL expression string to an {@link io.substrait.proto.ExtendedExpression } + * Converts the given SQL expression to an {@link io.substrait.proto.ExtendedExpression } * * @param sqlExpression a SQL expression * @param createStatements table creation statements defining fields referenced by the expression @@ -65,9 +65,9 @@ public ExtendedExpression convert(String sqlExpression, List createState } /** - * Converts the given SQL expressions string to an {@link io.substrait.proto.ExtendedExpression } + * Converts the given SQL expressions to an {@link io.substrait.proto.ExtendedExpression } * - * @param sqlExpressions an array of SQL expression + * @param sqlExpressions an array of SQL expressions * @param createStatements table creation statements defining fields referenced by the expression * @return a {@link io.substrait.proto.ExtendedExpression } * @throws SqlParseException diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index e5d9abf94..9be98e731 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -4,7 +4,6 @@ import com.google.common.io.Resources; import io.substrait.extendedexpression.ExtendedExpressionProtoConverter; import io.substrait.extendedexpression.ProtoExtendedExpressionConverter; -import io.substrait.proto.ExtendedExpression; import java.io.IOException; import java.util.Arrays; import java.util.List; @@ -27,39 +26,39 @@ public static List tpchSchemaCreateStatements() throws IOException { return tpchSchemaCreateStatements("tpch/schema.sql"); } - protected void assertProtoEEForExpressionsDefaultCommaSeparatorRoundtrip(String expressions) + protected void assertProtoExtendedExpressionRoundtrip(String expressions) throws SqlParseException, IOException { // proto initial extended expression - ExtendedExpression extendedExpressionProtoInitial = + io.substrait.proto.ExtendedExpression extendedExpressionProtoInitial = new SqlExpressionToSubstrait().convert(expressions, tpchSchemaCreateStatements()); asserProtoExtendedExpression(extendedExpressionProtoInitial); } - protected void assertProtoEEForExpressionsDefaultCommaSeparatorErrorRoundtrip( - String expressions, String schemaToLoad) throws SqlParseException, IOException { + protected void assertProtoExtendedExpressionRoundtrip(String expressions, String schemaToLoad) + throws SqlParseException, IOException { // proto initial extended expression - ExtendedExpression extendedExpressionProtoInitial = + io.substrait.proto.ExtendedExpression extendedExpressionProtoInitial = new SqlExpressionToSubstrait() .convert(expressions, tpchSchemaCreateStatements(schemaToLoad)); asserProtoExtendedExpression(extendedExpressionProtoInitial); } - protected void assertProtoEEForListExpressionRoundtrip(String[] expression) + protected void assertProtoExtendedExpressionRoundtrip(String[] expression) throws SqlParseException, IOException { // proto initial extended expression - ExtendedExpression extendedExpressionProtoInitial = + io.substrait.proto.ExtendedExpression extendedExpressionProtoInitial = new SqlExpressionToSubstrait().convert(expression, tpchSchemaCreateStatements()); asserProtoExtendedExpression(extendedExpressionProtoInitial); } private static void asserProtoExtendedExpression( - ExtendedExpression extendedExpressionProtoInitial) throws IOException { + io.substrait.proto.ExtendedExpression extendedExpressionProtoInitial) throws IOException { // pojo final extended expression io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = new ProtoExtendedExpressionConverter().from(extendedExpressionProtoInitial); // proto final extended expression - ExtendedExpression extendedExpressionProtoFinal = + io.substrait.proto.ExtendedExpression extendedExpressionProtoFinal = new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoFinal); // round-trip to validate extended expression proto initial equals to final diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index 3c5e6f4fb..1ff843d3e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -26,10 +26,9 @@ private static Stream expressionTypeProvider() { @ParameterizedTest @MethodSource("expressionTypeProvider") - public void testExtendedExpressionsCommaSeparatorRoundTrip(String sqlExpression) + public void testExtendedExpressionsRoundTrip(String sqlExpression) throws SqlParseException, IOException { - assertProtoEEForExpressionsDefaultCommaSeparatorRoundtrip( - sqlExpression); // comma-separator by default + assertProtoExtendedExpressionRoundtrip(sqlExpression); // comma-separator by default } @ParameterizedTest @@ -38,9 +37,7 @@ public void testExtendedExpressionsDuplicateColumnIdentifierRoundTrip(String sql IllegalArgumentException illegalArgumentException = assertThrows( IllegalArgumentException.class, - () -> - assertProtoEEForExpressionsDefaultCommaSeparatorErrorRoundtrip( - sqlExpression, "tpch/schema_error.sql")); + () -> assertProtoExtendedExpressionRoundtrip(sqlExpression, "tpch/schema_error.sql")); assertTrue( illegalArgumentException .getMessage() @@ -60,6 +57,6 @@ public void testExtendedExpressionsListExpressionRoundTrip() "L_ORDERKEY is null" }; - assertProtoEEForListExpressionRoundtrip(expressions); + assertProtoExtendedExpressionRoundtrip(expressions); } } From 9041b7c95642ac9a565b4cf1b1aa17fdd374eba6 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Tue, 6 Feb 2024 11:20:33 -0800 Subject: [PATCH 35/35] test: remove reference to comma separator --- .../io/substrait/isthmus/SimpleExtendedExpressionsTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index 1ff843d3e..b998bbac3 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -28,7 +28,7 @@ private static Stream expressionTypeProvider() { @MethodSource("expressionTypeProvider") public void testExtendedExpressionsRoundTrip(String sqlExpression) throws SqlParseException, IOException { - assertProtoExtendedExpressionRoundtrip(sqlExpression); // comma-separator by default + assertProtoExtendedExpressionRoundtrip(sqlExpression); } @ParameterizedTest