diff --git a/docs/source/java/substrait.rst b/docs/source/java/substrait.rst index d8d49a96e88f8..c5857dcc23f75 100644 --- a/docs/source/java/substrait.rst +++ b/docs/source/java/substrait.rst @@ -113,31 +113,19 @@ This requires the substrait-java library. This Java program: - Loads a Parquet file containing the "nation" table from the TPC-H benchmark. +- Applies a filter: + - ``N_NATIONKEY > 18`` - Projects two new columns: - - ``N_NAME || ' - ' || N_COMMENT`` - ``N_REGIONKEY + 10`` -- Applies a filter: ``N_NATIONKEY > 18`` + - ``N_NAME || ' - ' || N_COMMENT`` + + .. code-block:: Java - import io.substrait.extension.ExtensionCollector; - import io.substrait.proto.Expression; - import io.substrait.proto.ExpressionReference; + import com.google.common.collect.ImmutableList; + import io.substrait.isthmus.SqlExpressionToSubstrait; import io.substrait.proto.ExtendedExpression; - import io.substrait.proto.FunctionArgument; - import io.substrait.proto.SimpleExtensionDeclaration; - import io.substrait.proto.SimpleExtensionURI; - import io.substrait.type.NamedStruct; - import io.substrait.type.Type; - import io.substrait.type.TypeCreator; - import io.substrait.type.proto.TypeProtoConverter; - import java.nio.ByteBuffer; - import java.util.ArrayList; - import java.util.Arrays; - import java.util.Base64; - import java.util.HashMap; - import java.util.List; - import java.util.Optional; import org.apache.arrow.dataset.file.FileFormat; import org.apache.arrow.dataset.file.FileSystemDatasetFactory; import org.apache.arrow.dataset.jni.NativeMemoryPool; @@ -148,297 +136,60 @@ This Java program: import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.ipc.ArrowReader; + import org.apache.calcite.sql.parser.SqlParseException; + + import java.nio.ByteBuffer; + import java.util.Base64; + import java.util.Optional; public class ClientSubstraitExtendedExpressionsCookbook { - public static void main(String[] args) throws Exception { - // project and filter dataset using extended expression definition - 03 Expressions: - // Expression 01 - CONCAT: N_NAME || ' - ' || N_COMMENT = col 1 || ' - ' || col 3 - // Expression 02 - ADD: N_REGIONKEY + 10 = col 1 + 10 - // Expression 03 - FILTER: N_NATIONKEY > 18 = col 3 > 18 + public static void main(String[] args) throws SqlParseException { projectAndFilterDataset(); } - public static void projectAndFilterDataset() { + private static void projectAndFilterDataset() throws SqlParseException { String uri = "file:///Users/data/tpch_parquet/nation.parquet"; - ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) - .columns(Optional.empty()) - .substraitFilter(getSubstraitExpressionFilter()) - .substraitProjection(getSubstraitExpressionProjection()) - .build(); - try ( - BufferAllocator allocator = new RootAllocator(); - DatasetFactory datasetFactory = new FileSystemDatasetFactory( - allocator, NativeMemoryPool.getDefault(), - FileFormat.PARQUET, uri); - Dataset dataset = datasetFactory.finish(); - Scanner scanner = dataset.newScan(options); - ArrowReader reader = scanner.scanBatches() - ) { + ScanOptions options = + new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitFilter(getByteBuffer(new String[]{"N_NATIONKEY > 18"})) + .substraitProjection(getByteBuffer(new String[]{"N_REGIONKEY + 10", + "N_NAME || CAST(' - ' as VARCHAR) || N_COMMENT"})) + .build(); + try (BufferAllocator allocator = new RootAllocator(); + DatasetFactory datasetFactory = + new FileSystemDatasetFactory( + allocator, NativeMemoryPool.getDefault(), FileFormat.PARQUET, uri); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches()) { while (reader.loadNextBatch()) { - System.out.println( - reader.getVectorSchemaRoot().contentToTSVString()); + System.out.println(reader.getVectorSchemaRoot().contentToTSVString()); } } catch (Exception e) { throw new RuntimeException(e); } } - private static ByteBuffer getSubstraitExpressionProjection() { - // Expression: N_REGIONKEY + 10 = col 3 + 10 - Expression.Builder selectionBuilderProjectOne = Expression.newBuilder(). - setSelection( - Expression.FieldReference.newBuilder(). - setDirectReference( - Expression.ReferenceSegment.newBuilder(). - setStructField( - Expression.ReferenceSegment.StructField.newBuilder().setField( - 2) - ) - ) - ); - Expression.Builder literalBuilderProjectOne = Expression.newBuilder() - .setLiteral( - Expression.Literal.newBuilder().setI32(10) - ); - io.substrait.proto.Type outputProjectOne = TypeCreator.NULLABLE.I32.accept( - new TypeProtoConverter(new ExtensionCollector())); - Expression.Builder expressionBuilderProjectOne = Expression. - newBuilder(). - setScalarFunction( - Expression. - ScalarFunction. - newBuilder(). - setFunctionReference(0). - setOutputType(outputProjectOne). - addArguments( - 0, - FunctionArgument.newBuilder().setValue( - selectionBuilderProjectOne) - ). - addArguments( - 1, - FunctionArgument.newBuilder().setValue( - literalBuilderProjectOne) - ) - ); - ExpressionReference.Builder expressionReferenceBuilderProjectOne = ExpressionReference.newBuilder(). - setExpression(expressionBuilderProjectOne) - .addOutputNames("ADD_TEN_TO_COLUMN_N_REGIONKEY"); - - // Expression: name || name = N_NAME || "-" || N_COMMENT = col 1 || col 3 - Expression.Builder selectionBuilderProjectTwo = Expression.newBuilder(). - setSelection( - Expression.FieldReference.newBuilder(). - setDirectReference( - Expression.ReferenceSegment.newBuilder(). - setStructField( - Expression.ReferenceSegment.StructField.newBuilder().setField( - 1) - ) - ) - ); - Expression.Builder selectionBuilderProjectTwoConcatLiteral = Expression.newBuilder() - .setLiteral( - Expression.Literal.newBuilder().setString(" - ") - ); - Expression.Builder selectionBuilderProjectOneToConcat = Expression.newBuilder(). - setSelection( - Expression.FieldReference.newBuilder(). - setDirectReference( - Expression.ReferenceSegment.newBuilder(). - setStructField( - Expression.ReferenceSegment.StructField.newBuilder().setField( - 3) - ) - ) - ); - io.substrait.proto.Type outputProjectTwo = TypeCreator.NULLABLE.STRING.accept( - new TypeProtoConverter(new ExtensionCollector())); - Expression.Builder expressionBuilderProjectTwo = Expression. - newBuilder(). - setScalarFunction( - Expression. - ScalarFunction. - newBuilder(). - setFunctionReference(1). - setOutputType(outputProjectTwo). - addArguments( - 0, - FunctionArgument.newBuilder().setValue( - selectionBuilderProjectTwo) - ). - addArguments( - 1, - FunctionArgument.newBuilder().setValue( - selectionBuilderProjectTwoConcatLiteral) - ). - addArguments( - 2, - FunctionArgument.newBuilder().setValue( - selectionBuilderProjectOneToConcat) - ) - ); - ExpressionReference.Builder expressionReferenceBuilderProjectTwo = ExpressionReference.newBuilder(). - setExpression(expressionBuilderProjectTwo) - .addOutputNames("CONCAT_COLUMNS_N_NAME_AND_N_COMMENT"); - - List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", - "N_REGIONKEY", "N_COMMENT"); - List dataTypes = Arrays.asList( - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING, - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING - ); - NamedStruct of = NamedStruct.of( - columnNames, - Type.Struct.builder().fields(dataTypes).nullable(false).build() - ); - // Extensions URI - HashMap extensionUris = new HashMap<>(); - extensionUris.put( - "key-001", - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(1) - .setUri("/functions_arithmetic.yaml") - .build() - ); - // Extensions - ArrayList extensions = new ArrayList<>(); - SimpleExtensionDeclaration extensionFunctionAdd = SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(0) - .setName("add:i32_i32") - .setExtensionUriReference(1)) - .build(); - SimpleExtensionDeclaration extensionFunctionGreaterThan = SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(1) - .setName("concat:vchar") - .setExtensionUriReference(2)) - .build(); - extensions.add(extensionFunctionAdd); - extensions.add(extensionFunctionGreaterThan); - // Extended Expression - ExtendedExpression.Builder extendedExpressionBuilder = - ExtendedExpression.newBuilder(). - addReferredExpr(0, - expressionReferenceBuilderProjectOne). - addReferredExpr(1, - expressionReferenceBuilderProjectTwo). - setBaseSchema(of.toProto(new TypeProtoConverter( - new ExtensionCollector()))); - extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); - extendedExpressionBuilder.addAllExtensions(extensions); - ExtendedExpression extendedExpression = extendedExpressionBuilder.build(); - byte[] extendedExpressions = Base64.getDecoder().decode( - Base64.getEncoder().encodeToString( - extendedExpression.toByteArray())); - ByteBuffer substraitExpressionProjection = ByteBuffer.allocateDirect( - extendedExpressions.length); - substraitExpressionProjection.put(extendedExpressions); - return substraitExpressionProjection; - } - - private static ByteBuffer getSubstraitExpressionFilter() { - // Expression: Filter: N_NATIONKEY > 18 = col 1 > 18 - Expression.Builder selectionBuilderFilterOne = Expression.newBuilder(). - setSelection( - Expression.FieldReference.newBuilder(). - setDirectReference( - Expression.ReferenceSegment.newBuilder(). - setStructField( - Expression.ReferenceSegment.StructField.newBuilder().setField( - 0) - ) - ) - ); - Expression.Builder literalBuilderFilterOne = Expression.newBuilder() - .setLiteral( - Expression.Literal.newBuilder().setI32(18) - ); - io.substrait.proto.Type outputFilterOne = TypeCreator.NULLABLE.BOOLEAN.accept( - new TypeProtoConverter(new ExtensionCollector())); - Expression.Builder expressionBuilderFilterOne = Expression. - newBuilder(). - setScalarFunction( - Expression. - ScalarFunction. - newBuilder(). - setFunctionReference(1). - setOutputType(outputFilterOne). - addArguments( - 0, - FunctionArgument.newBuilder().setValue( - selectionBuilderFilterOne) - ). - addArguments( - 1, - FunctionArgument.newBuilder().setValue( - literalBuilderFilterOne) - ) - ); - ExpressionReference.Builder expressionReferenceBuilderFilterOne = ExpressionReference.newBuilder(). - setExpression(expressionBuilderFilterOne) - .addOutputNames("COLUMN_N_NATIONKEY_GREATER_THAN_18"); - - List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", - "N_REGIONKEY", "N_COMMENT"); - List dataTypes = Arrays.asList( - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING, - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING - ); - NamedStruct of = NamedStruct.of( - columnNames, - Type.Struct.builder().fields(dataTypes).nullable(false).build() - ); - // Extensions URI - HashMap extensionUris = new HashMap<>(); - extensionUris.put( - "key-001", - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(1) - .setUri("/functions_comparison.yaml") - .build() - ); - // Extensions - ArrayList extensions = new ArrayList<>(); - SimpleExtensionDeclaration extensionFunctionLowerThan = SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(1) - .setName("gt:any_any") - .setExtensionUriReference(1)) - .build(); - extensions.add(extensionFunctionLowerThan); - // Extended Expression - ExtendedExpression.Builder extendedExpressionBuilder = - ExtendedExpression.newBuilder(). - addReferredExpr(0, - expressionReferenceBuilderFilterOne). - setBaseSchema(of.toProto(new TypeProtoConverter( - new ExtensionCollector()))); - extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); - extendedExpressionBuilder.addAllExtensions(extensions); - ExtendedExpression extendedExpression = extendedExpressionBuilder.build(); - byte[] extendedExpressions = Base64.getDecoder().decode( - Base64.getEncoder().encodeToString( - extendedExpression.toByteArray())); - ByteBuffer substraitExpressionFilter = ByteBuffer.allocateDirect( - extendedExpressions.length); - substraitExpressionFilter.put(extendedExpressions); - return substraitExpressionFilter; + private static ByteBuffer getByteBuffer(String[] sqlExpression) throws SqlParseException { + String schema = + "CREATE TABLE NATION (N_NATIONKEY INT NOT NULL, N_NAME VARCHAR, " + + "N_REGIONKEY INT NOT NULL, N_COMMENT VARCHAR)"; + SqlExpressionToSubstrait expressionToSubstrait = new SqlExpressionToSubstrait(); + ExtendedExpression expression = + expressionToSubstrait.convert(sqlExpression, ImmutableList.of(schema)); + byte[] expressionToByte = + Base64.getDecoder().decode(Base64.getEncoder().encodeToString(expression.toByteArray())); + ByteBuffer byteBuffer = ByteBuffer.allocateDirect(expressionToByte.length); + byteBuffer.put(expressionToByte); + return byteBuffer; } } .. code-block:: text - ADD_TEN_TO_COLUMN_N_REGIONKEY CONCAT_COLUMNS_N_NAME_AND_N_COMMENT + column-1 column-2 13 ROMANIA - ular asymptotes are about the furious multipliers. express dependencies nag above the ironically ironic account 14 SAUDI ARABIA - ts. silent requests haggle. closely express packages sleep across the blithely 12 VIETNAM - hely enticingly express accounts. even, final