diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/LookupCallContext.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/FunctionCallContext.java similarity index 87% rename from flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/LookupCallContext.java rename to flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/FunctionCallContext.java index ce655466c0f32..562012f225b48 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/LookupCallContext.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/FunctionCallContext.java @@ -21,6 +21,7 @@ import org.apache.flink.table.catalog.DataTypeFactory; import org.apache.flink.table.connector.source.LookupTableSource; import org.apache.flink.table.functions.UserDefinedFunction; +import org.apache.flink.table.ml.PredictRuntimeProvider; import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.Constant; import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam; import org.apache.flink.table.types.DataType; @@ -38,24 +39,27 @@ import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getFieldTypes; import static org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType; -/** The {@link CallContext} of a {@link LookupTableSource} runtime function. */ +/** + * The {@link CallContext} of {@link LookupTableSource}, {@link PredictRuntimeProvider} runtime + * function. + */ @Internal -public class LookupCallContext extends AbstractSqlCallContext { +public class FunctionCallContext extends AbstractSqlCallContext { - private final List lookupKeys; + private final List params; private final List argumentDataTypes; private final DataType outputDataType; - public LookupCallContext( + public FunctionCallContext( DataTypeFactory dataTypeFactory, UserDefinedFunction function, LogicalType inputType, - List lookupKeys, - LogicalType lookupType) { + List params, + LogicalType outputDataType) { super(dataTypeFactory, function, generateInlineFunctionName(function), false); - this.lookupKeys = lookupKeys; + this.params = params; this.argumentDataTypes = new AbstractList<>() { @Override @@ -74,10 +78,10 @@ public DataType get(int index) { @Override public int size() { - return lookupKeys.size(); + return params.size(); } }; - this.outputDataType = fromLogicalToDataType(lookupType); + this.outputDataType = fromLogicalToDataType(outputDataType); } @Override @@ -118,6 +122,6 @@ public Optional getOutputDataType() { // -------------------------------------------------------------------------------------------- private FunctionParam getKey(int pos) { - return lookupKeys.get(pos); + return params.get(pos); } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java index bbf4486c5ed44..05b07458a10ed 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java @@ -46,6 +46,7 @@ import org.apache.flink.table.planner.calcite.FlinkTypeFactory; import org.apache.flink.table.planner.codegen.CodeGeneratorContext; import org.apache.flink.table.planner.codegen.FilterCodeGenerator; +import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator; import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator; import org.apache.flink.table.planner.delegation.PlannerBase; import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge; @@ -459,7 +460,7 @@ protected StreamOperatorFactory createAsyncLookupJoin( .mapToObj(allLookupKeys::get) .collect(Collectors.toList()); - LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType> + FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType> generatedFuncWithType = LookupJoinCodeGenerator.generateAsyncLookupFunction( config, diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java index 39a48443e0d45..6d368618c9130 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java @@ -31,6 +31,7 @@ import org.apache.flink.table.functions.AsyncTableFunction; import org.apache.flink.table.functions.UserDefinedFunctionHelper; import org.apache.flink.table.planner.calcite.FlinkTypeFactory; +import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator; import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator; import org.apache.flink.table.planner.delegation.PlannerBase; import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge; @@ -353,7 +354,7 @@ private AsyncDeltaJoinRunner createAsyncDeltaJoinRunner( .mapToObj(lookupKeys::get) .collect(Collectors.toList()); - LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType> + FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType> lookupSideGeneratedFuncWithType = LookupJoinCodeGenerator.generateAsyncLookupFunction( config, diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java index 0ff85e23031af..6d1fe8a8cee39 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java @@ -31,8 +31,6 @@ import org.apache.flink.table.api.TableException; import org.apache.flink.table.catalog.DataTypeFactory; import org.apache.flink.table.data.RowData; -import org.apache.flink.table.data.conversion.DataStructureConverter; -import org.apache.flink.table.data.conversion.DataStructureConverters; import org.apache.flink.table.functions.AsyncPredictFunction; import org.apache.flink.table.functions.PredictFunction; import org.apache.flink.table.functions.UserDefinedFunction; @@ -41,8 +39,8 @@ import org.apache.flink.table.ml.PredictRuntimeProvider; import org.apache.flink.table.planner.calcite.FlinkContext; import org.apache.flink.table.planner.codegen.CodeGeneratorContext; -import org.apache.flink.table.planner.codegen.FilterCodeGenerator; -import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator; +import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator; +import org.apache.flink.table.planner.codegen.MLPredictCodeGenerator; import org.apache.flink.table.planner.delegation.PlannerBase; import org.apache.flink.table.planner.plan.nodes.exec.ExecNode; import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase; @@ -55,18 +53,15 @@ import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec; import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil; import org.apache.flink.table.planner.plan.utils.FunctionCallUtil; -import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; import org.apache.flink.table.runtime.collector.ListenableCollector; -import org.apache.flink.table.runtime.collector.TableFunctionResultFuture; import org.apache.flink.table.runtime.functions.ml.ModelPredictRuntimeProviderContext; import org.apache.flink.table.runtime.generated.GeneratedCollector; import org.apache.flink.table.runtime.generated.GeneratedFunction; -import org.apache.flink.table.runtime.generated.GeneratedResultFuture; -import org.apache.flink.table.runtime.operators.join.lookup.AsyncLookupJoinRunner; -import org.apache.flink.table.runtime.operators.join.lookup.LookupJoinRunner; -import org.apache.flink.table.runtime.typeutils.InternalSerializers; +import org.apache.flink.table.runtime.operators.ml.AsyncMLPredictRunner; +import org.apache.flink.table.runtime.operators.ml.MLPredictRunner; import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.util.Preconditions; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; @@ -75,7 +70,6 @@ import java.util.Collections; import java.util.List; -import java.util.Optional; /** Stream {@link ExecNode} for {@code ML_PREDICT}. */ @ExecNodeMetadata( @@ -197,7 +191,7 @@ private Transformation createModelPredict( RowType resultRowType, PredictFunction predictFunction) { GeneratedFunction> generatedFetcher = - LookupJoinCodeGenerator.generateSyncLookupFunction( + MLPredictCodeGenerator.generateSyncPredictFunction( config, classLoader, dataTypeFactory, @@ -206,25 +200,15 @@ private Transformation createModelPredict( resultRowType, mlPredictSpec.getFeatures(), predictFunction, - "MLPredict", + modelSpec.getContextResolvedModel().getIdentifier().asSummaryString(), config.get(PipelineOptions.OBJECT_REUSE)); GeneratedCollector> generatedCollector = - LookupJoinCodeGenerator.generateCollector( + MLPredictCodeGenerator.generateCollector( new CodeGeneratorContext(config, classLoader), inputRowType, modelOutputType, - (RowType) getOutputType(), - JavaScalaConversionUtil.toScala(Optional.empty()), - JavaScalaConversionUtil.toScala(Optional.empty()), - true); - LookupJoinRunner mlPredictRunner = - new LookupJoinRunner( - generatedFetcher, - generatedCollector, - FilterCodeGenerator.generateFilterCondition( - config, classLoader, null, inputRowType), - false, - modelOutputType.getFieldCount()); + (RowType) getOutputType()); + MLPredictRunner mlPredictRunner = new MLPredictRunner(generatedFetcher, generatedCollector); SimpleOperatorFactory operatorFactory = SimpleOperatorFactory.of(new ProcessOperator<>(mlPredictRunner)); return ExecNodeUtil.createOneInputTransformation( @@ -246,9 +230,9 @@ private Transformation createAsyncModelPredict( RowType modelOutputType, RowType resultRowType, AsyncPredictFunction asyncPredictFunction) { - LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType> + FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType> generatedFuncWithType = - LookupJoinCodeGenerator.generateAsyncLookupFunction( + MLPredictCodeGenerator.generateAsyncPredictFunction( config, classLoader, dataTypeFactory, @@ -257,29 +241,14 @@ private Transformation createAsyncModelPredict( resultRowType, mlPredictSpec.getFeatures(), asyncPredictFunction, - "AsyncMLPredict"); - - GeneratedResultFuture> generatedResultFuture = - LookupJoinCodeGenerator.generateTableAsyncCollector( - config, - classLoader, - "TableFunctionResultFuture", - inputRowType, - modelOutputType, - JavaScalaConversionUtil.toScala(Optional.empty())); - - DataStructureConverter fetcherConverter = - DataStructureConverters.getConverter(generatedFuncWithType.dataType()); + modelSpec + .getContextResolvedModel() + .getIdentifier() + .asSummaryString()); AsyncFunction asyncFunc = - new AsyncLookupJoinRunner( - generatedFuncWithType.tableFunc(), - (DataStructureConverter) fetcherConverter, - generatedResultFuture, - FilterCodeGenerator.generateFilterCondition( - config, classLoader, null, inputRowType), - InternalSerializers.create(modelOutputType), - false, - asyncOptions.asyncBufferCapacity); + new AsyncMLPredictRunner( + (GeneratedFunction) generatedFuncWithType.tableFunc(), + Preconditions.checkNotNull(asyncOptions).asyncBufferCapacity); return ExecNodeUtil.createOneInputTransformation( inputTransformation, createTransformationMeta(ML_PREDICT_TRANSFORMATION, config), diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCallCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCallCodeGenerator.scala new file mode 100644 index 0000000000000..438f14c21e38b --- /dev/null +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCallCodeGenerator.scala @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.planner.codegen + +import org.apache.flink.api.common.functions.{FlatMapFunction, Function, OpenContext} +import org.apache.flink.configuration.ReadableConfig +import org.apache.flink.streaming.api.functions.async.AsyncFunction +import org.apache.flink.table.catalog.DataTypeFactory +import org.apache.flink.table.data.{GenericRowData, RowData} +import org.apache.flink.table.data.utils.JoinedRowData +import org.apache.flink.table.functions.{AsyncTableFunction, TableFunction, UserDefinedFunction, UserDefinedFunctionHelper} +import org.apache.flink.table.planner.codegen.CodeGenUtils.{boxedTypeTermForType, className, newName, DEFAULT_COLLECTOR_TERM, DEFAULT_INPUT1_TERM, DEFAULT_INPUT2_TERM} +import org.apache.flink.table.planner.codegen.GenerateUtils.{generateInputAccess, generateLiteral} +import org.apache.flink.table.planner.delegation.PlannerBase +import org.apache.flink.table.planner.functions.inference.FunctionCallContext +import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.{Constant, FieldRef, FunctionParam} +import org.apache.flink.table.planner.plan.utils.RexLiteralUtil +import org.apache.flink.table.runtime.collector.ListenableCollector +import org.apache.flink.table.runtime.collector.ListenableCollector.CollectListener +import org.apache.flink.table.runtime.generated.{GeneratedCollector, GeneratedFunction} +import org.apache.flink.table.types.DataType +import org.apache.flink.table.types.logical.{LogicalType, RowType} +import org.apache.flink.util.Collector + +import org.apache.calcite.rex.RexNode + +import java.util + +import scala.collection.JavaConverters._ + +object FunctionCallCodeGenerator { + + case class GeneratedTableFunctionWithDataType[F <: Function]( + tableFunc: GeneratedFunction[F], + dataType: DataType) + + /** Generates a sync function ([[TableFunction]]) call. */ + def generateSyncFunctionCall( + tableConfig: ReadableConfig, + classLoader: ClassLoader, + dataTypeFactory: DataTypeFactory, + inputType: LogicalType, + functionOutputType: LogicalType, + collectorOutputType: LogicalType, + parameters: util.List[FunctionParam], + syncFunctionDefinition: TableFunction[_], + inferCall: ( + CodeGeneratorContext, + FunctionCallContext, + UserDefinedFunction, + Seq[GeneratedExpression]) => (GeneratedExpression, DataType), + functionName: String, + generateClassName: String, + fieldCopy: Boolean): GeneratedTableFunctionWithDataType[FlatMapFunction[RowData, RowData]] = { + + val bodyCode: GeneratedExpression => String = call => { + val resultCollectorTerm = call.resultTerm + s""" + |$resultCollectorTerm.setCollector($DEFAULT_COLLECTOR_TERM); + |${call.code} + |""".stripMargin + } + + generateFunctionCall( + classOf[FlatMapFunction[RowData, RowData]], + tableConfig, + classLoader, + dataTypeFactory, + inputType, + functionOutputType, + collectorOutputType, + parameters, + syncFunctionDefinition, + inferCall, + functionName, + generateClassName, + fieldCopy, + bodyCode + ) + } + + /** Generates an async function ([[AsyncTableFunction]]) call. */ + def generateAsyncFunctionCall( + tableConfig: ReadableConfig, + classLoader: ClassLoader, + dataTypeFactory: DataTypeFactory, + inputType: LogicalType, + functionOutputType: LogicalType, + collectorOutputType: LogicalType, + parameters: util.List[FunctionParam], + asyncFunctionDefinition: AsyncTableFunction[_], + generateCallWithDataType: ( + CodeGeneratorContext, + FunctionCallContext, + UserDefinedFunction, + Seq[GeneratedExpression]) => (GeneratedExpression, DataType), + functionName: String, + generateClassName: String + ): GeneratedTableFunctionWithDataType[AsyncFunction[RowData, AnyRef]] = { + generateFunctionCall( + classOf[AsyncFunction[RowData, AnyRef]], + tableConfig, + classLoader, + dataTypeFactory, + inputType, + functionOutputType, + collectorOutputType, + parameters, + asyncFunctionDefinition, + generateCallWithDataType, + functionName, + generateClassName, + fieldCopy = true, + _.code + ) + } + + private def generateFunctionCall[F <: Function]( + generatedClass: Class[F], + tableConfig: ReadableConfig, + classLoader: ClassLoader, + dataTypeFactory: DataTypeFactory, + inputType: LogicalType, + functionOutputType: LogicalType, + collectorOutputType: LogicalType, + parameters: util.List[FunctionParam], + functionDefinition: UserDefinedFunction, + generateCallWithDataType: ( + CodeGeneratorContext, + FunctionCallContext, + UserDefinedFunction, + Seq[GeneratedExpression]) => (GeneratedExpression, DataType), + functionName: String, + generateClassName: String, + fieldCopy: Boolean, + bodyCode: GeneratedExpression => String): GeneratedTableFunctionWithDataType[F] = { + + val callContext = + new FunctionCallContext( + dataTypeFactory, + functionDefinition, + inputType, + parameters, + functionOutputType) + + // create the final UDF for runtime + val udf = UserDefinedFunctionHelper.createSpecializedFunction( + functionName, + functionDefinition, + callContext, + classOf[PlannerBase].getClassLoader, + tableConfig, + // no need to support expression evaluation at this point + null + ) + + val ctx = new CodeGeneratorContext(tableConfig, classLoader) + val operands = prepareOperands(ctx, inputType, parameters, fieldCopy) + + val callWithDataType: (GeneratedExpression, DataType) = + generateCallWithDataType(ctx, callContext, udf, operands) + + val function = FunctionCodeGenerator.generateFunction( + ctx, + generateClassName, + generatedClass, + bodyCode(callWithDataType._1), + collectorOutputType, + inputType) + + GeneratedTableFunctionWithDataType(function, callWithDataType._2) + } + + private def prepareOperands( + ctx: CodeGeneratorContext, + inputType: LogicalType, + parameters: util.List[FunctionParam], + fieldCopy: Boolean): Seq[GeneratedExpression] = { + + parameters.asScala + .map { + case constantKey: Constant => + val res = RexLiteralUtil.toFlinkInternalValue(constantKey.literal) + generateLiteral(ctx, res.f0, res.f1) + case fieldKey: FieldRef => + generateInputAccess( + ctx, + inputType, + DEFAULT_INPUT1_TERM, + fieldKey.index, + nullableInput = false, + fieldCopy) + case _ => + throw new CodeGenException("Invalid parameters.") + } + } + + /** + * Generates collector for join ([[Collector]]) + * + * Differs from CommonCorrelate.generateCollector which has no real condition because of + * FLINK-7865, here we should deal with outer join type when real conditions filtered result. + */ + def generateCollector( + ctx: CodeGeneratorContext, + inputRowType: RowType, + rightRowType: RowType, + resultRowType: RowType, + condition: Option[RexNode], + pojoFieldMapping: Option[Array[Int]], + retainHeader: Boolean = true): GeneratedCollector[ListenableCollector[RowData]] = { + + val inputTerm = DEFAULT_INPUT1_TERM + val rightInputTerm = DEFAULT_INPUT2_TERM + + val exprGenerator = new ExprCodeGenerator(ctx, nullableInput = false) + .bindInput(rightRowType, inputTerm = rightInputTerm, inputFieldMapping = pojoFieldMapping) + + val rightResultExpr = + exprGenerator.generateConverterResultExpression(rightRowType, classOf[GenericRowData]) + + val joinedRowTerm = CodeGenUtils.newName(ctx, "joinedRow") + ctx.addReusableOutputRecord(resultRowType, classOf[JoinedRowData], joinedRowTerm) + + val header = if (retainHeader) { + s"$joinedRowTerm.setRowKind($inputTerm.getRowKind());" + } else { + "" + } + + val body = + s""" + |${rightResultExpr.code} + |$joinedRowTerm.replace($inputTerm, ${rightResultExpr.resultTerm}); + |$header + |outputResult($joinedRowTerm); + """.stripMargin + + val collectorCode = if (condition.isEmpty) { + body + } else { + + val filterGenerator = new ExprCodeGenerator(ctx, nullableInput = false) + .bindInput(inputRowType, inputTerm) + .bindSecondInput(rightRowType, rightInputTerm, pojoFieldMapping) + val filterCondition = filterGenerator.generateExpression(condition.get) + + s""" + |${filterCondition.code} + |if (${filterCondition.resultTerm}) { + | $body + |} + |""".stripMargin + } + + generateTableFunctionCollectorForJoinTable( + ctx, + "JoinTableFuncCollector", + collectorCode, + inputRowType, + rightRowType, + inputTerm = inputTerm, + collectedTerm = rightInputTerm) + } + + /** + * The only differences against CollectorCodeGenerator.generateTableFunctionCollector is + * "super.collect" call is binding with collect join row in "body" code + */ + private def generateTableFunctionCollectorForJoinTable( + ctx: CodeGeneratorContext, + name: String, + bodyCode: String, + inputType: RowType, + collectedType: RowType, + inputTerm: String = DEFAULT_INPUT1_TERM, + collectedTerm: String = DEFAULT_INPUT2_TERM) + : GeneratedCollector[ListenableCollector[RowData]] = { + + val funcName = newName(ctx, name) + val input1TypeClass = boxedTypeTermForType(inputType) + val input2TypeClass = boxedTypeTermForType(collectedType) + + val funcCode = + s""" + public class $funcName extends ${classOf[ListenableCollector[_]].getCanonicalName} { + + ${ctx.reuseMemberCode()} + + public $funcName(Object[] references) throws Exception { + ${ctx.reuseInitCode()} + } + + @Override + public void open(${className[OpenContext]} openContext) throws Exception { + ${ctx.reuseOpenCode()} + } + + @Override + public void collect(Object record) throws Exception { + $input1TypeClass $inputTerm = ($input1TypeClass) getInput(); + $input2TypeClass $collectedTerm = ($input2TypeClass) record; + + // callback only when collectListener exists, equivalent to: + // getCollectListener().ifPresent( + // listener -> ((CollectListener) listener).onCollect(record)); + // TODO we should update code splitter's grammar file to accept lambda expressions. + + if (getCollectListener().isPresent()) { + ((${classOf[CollectListener[_]].getCanonicalName}) getCollectListener().get()) + .onCollect(record); + } + + ${ctx.reuseLocalVariableCode()} + ${ctx.reuseInputUnboxingCode()} + ${ctx.reusePerRecordCode()} + $bodyCode + } + + @Override + public void close() throws Exception { + ${ctx.reuseCloseCode()} + } + } + """.stripMargin + + new GeneratedCollector(funcName, funcCode, ctx.references.toArray, ctx.tableConfig) + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala index e61783b68447d..2f3ab1708b869 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala @@ -17,7 +17,7 @@ */ package org.apache.flink.table.planner.codegen -import org.apache.flink.api.common.functions.{FlatMapFunction, Function, OpenContext} +import org.apache.flink.api.common.functions.{FlatMapFunction, OpenContext} import org.apache.flink.configuration.ReadableConfig import org.apache.flink.streaming.api.functions.async.AsyncFunction import org.apache.flink.table.api.ValidationException @@ -25,17 +25,15 @@ import org.apache.flink.table.catalog.DataTypeFactory import org.apache.flink.table.connector.source.{LookupTableSource, ScanTableSource} import org.apache.flink.table.data.{GenericRowData, RowData} import org.apache.flink.table.data.utils.JoinedRowData -import org.apache.flink.table.functions.{AsyncLookupFunction, AsyncPredictFunction, AsyncTableFunction, LookupFunction, PredictFunction, TableFunction, UserDefinedFunction, UserDefinedFunctionHelper} +import org.apache.flink.table.functions.{AsyncLookupFunction, AsyncTableFunction, LookupFunction, TableFunction, UserDefinedFunction} import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.codegen.CodeGenUtils._ -import org.apache.flink.table.planner.codegen.GenerateUtils._ +import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType import org.apache.flink.table.planner.codegen.Indenter.toISC import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil.verifyFunctionAwareImplementation -import org.apache.flink.table.planner.delegation.PlannerBase -import org.apache.flink.table.planner.functions.inference.LookupCallContext -import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.{Constant, FieldRef, FunctionParam} -import org.apache.flink.table.planner.plan.utils.RexLiteralUtil +import org.apache.flink.table.planner.functions.inference.FunctionCallContext +import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala import org.apache.flink.table.runtime.collector.{ListenableCollector, TableFunctionResultFuture} import org.apache.flink.table.runtime.collector.ListenableCollector.CollectListener @@ -57,10 +55,6 @@ import scala.collection.JavaConverters._ object LookupJoinCodeGenerator { - case class GeneratedTableFunctionWithDataType[F <: Function]( - tableFunc: GeneratedFunction[F], - dataType: DataType) - private val ARRAY_LIST = className[util.ArrayList[_]] /** Generates a lookup function ([[TableFunction]]) */ @@ -76,29 +70,26 @@ object LookupJoinCodeGenerator { functionName: String, fieldCopy: Boolean): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { - val bodyCode: GeneratedExpression => String = call => { - val resultCollectorTerm = call.resultTerm - s""" - |$resultCollectorTerm.setCollector($DEFAULT_COLLECTOR_TERM); - |${call.code} - |""".stripMargin - } - - generateLookupFunction( - classOf[FlatMapFunction[RowData, RowData]], - tableConfig, - classLoader, - dataTypeFactory, - inputType, - tableSourceType, - returnType, - lookupKeys, - classOf[TableFunction[_]], - syncLookupFunction, - functionName, - fieldCopy, - bodyCode - ).tableFunc + FunctionCallCodeGenerator + .generateSyncFunctionCall( + tableConfig, + classLoader, + dataTypeFactory, + inputType, + tableSourceType, + returnType, + lookupKeys, + syncLookupFunction, + generateCallWithDataType( + dataTypeFactory, + functionName, + tableSourceType, + classOf[TableFunction[_]]), + functionName, + "LookupFunction", + fieldCopy + ) + .tableFunc } /** Generates a async lookup function ([[AsyncTableFunction]]) */ @@ -112,9 +103,7 @@ object LookupJoinCodeGenerator { lookupKeys: util.List[FunctionParam], asyncLookupFunction: AsyncTableFunction[_], functionName: String): GeneratedTableFunctionWithDataType[AsyncFunction[RowData, AnyRef]] = { - - generateLookupFunction( - classOf[AsyncFunction[RowData, AnyRef]], + FunctionCallCodeGenerator.generateAsyncFunctionCall( tableConfig, classLoader, dataTypeFactory, @@ -122,98 +111,66 @@ object LookupJoinCodeGenerator { tableSourceType, returnType, lookupKeys, - classOf[AsyncTableFunction[_]], asyncLookupFunction, + generateCallWithDataType( + dataTypeFactory, + functionName, + tableSourceType, + classOf[AsyncTableFunction[_]]), functionName, - fieldCopy = true, // always copy input field because of async buffer - _.code + "AsyncLookupFunction" ) } - private def generateLookupFunction[F <: Function]( - generatedClass: Class[F], - tableConfig: ReadableConfig, - classLoader: ClassLoader, + private def generateCallWithDataType( dataTypeFactory: DataTypeFactory, - inputType: LogicalType, - tableSourceType: LogicalType, - returnType: LogicalType, - lookupKeys: util.List[FunctionParam], - lookupFunctionBase: Class[_], - lookupFunction: UserDefinedFunction, functionName: String, - fieldCopy: Boolean, - bodyCode: GeneratedExpression => String): GeneratedTableFunctionWithDataType[F] = { - - val callContext = - new LookupCallContext(dataTypeFactory, lookupFunction, inputType, lookupKeys, tableSourceType) - - // create the final UDF for runtime - val udf = UserDefinedFunctionHelper.createSpecializedFunction( - functionName, - lookupFunction, - callContext, - classOf[PlannerBase].getClassLoader, - tableConfig, - // no need to support expression evaluation at this point - null) - - val inference = - createLookupTypeInference(dataTypeFactory, callContext, lookupFunctionBase, udf, functionName) - - val ctx = new CodeGeneratorContext(tableConfig, classLoader) - val operands = prepareOperands(ctx, inputType, lookupKeys, fieldCopy) - - // TODO: filter all records when there are any nulls on the join key, because - // "IS NOT DISTINCT FROM" is not supported yet. - // Note: AsyncPredictFunction or PredictFunction does not use Lookup Syntax. - val skipIfArgsNull = !lookupFunction.isInstanceOf[PredictFunction] && !lookupFunction - .isInstanceOf[AsyncPredictFunction] - - val callWithDataType = BridgingFunctionGenUtil.generateFunctionAwareCallWithDataType( - ctx, - operands, - tableSourceType, - inference, - callContext, - udf, - functionName, - skipIfArgsNull = skipIfArgsNull - ) - - val function = FunctionCodeGenerator.generateFunction( - ctx, - "LookupFunction", - generatedClass, - bodyCode(callWithDataType._1), - returnType, - inputType) - - GeneratedTableFunctionWithDataType(function, callWithDataType._2) - } - - private def prepareOperands( + tableSourceType: LogicalType, + baseClass: Class[_] + ) = ( ctx: CodeGeneratorContext, - inputType: LogicalType, - lookupKeys: util.List[FunctionParam], - fieldCopy: Boolean): Seq[GeneratedExpression] = { + callContext: FunctionCallContext, + udf: UserDefinedFunction, + operands: Seq[GeneratedExpression]) => { + def inferCallWithDataType( + ctx: CodeGeneratorContext, + callContext: FunctionCallContext, + udf: UserDefinedFunction, + operands: Seq[GeneratedExpression], + legacy: Boolean, + e: Exception = null): (GeneratedExpression, DataType) = { + val inference = createLookupTypeInference( + dataTypeFactory, + callContext, + baseClass, + udf, + functionName, + legacy, + e) + + // TODO: filter all records when there is any nulls on the join key, because + // "IS NOT DISTINCT FROM" is not supported yet. + val callWithDataType = BridgingFunctionGenUtil.generateFunctionAwareCallWithDataType( + ctx, + operands, + tableSourceType, + inference, + callContext, + udf, + functionName, + skipIfArgsNull = true + ) + callWithDataType + } - lookupKeys.asScala - .map { - case constantKey: Constant => - val res = RexLiteralUtil.toFlinkInternalValue(constantKey.literal) - generateLiteral(ctx, res.f0, res.f1) - case fieldKey: FieldRef => - generateInputAccess( - ctx, - inputType, - DEFAULT_INPUT1_TERM, - fieldKey.index, - nullableInput = false, - fieldCopy) - case _ => - throw new CodeGenException("Invalid lookup key.") - } + try { + // user provided type inference has precedence + // this ensures that all functions work in the same way + inferCallWithDataType(ctx, callContext, udf, operands, legacy = false) + } catch { + case e: Exception => + inferCallWithDataType(ctx, callContext, udf, operands, legacy = true, e) + } } /** @@ -225,66 +182,58 @@ object LookupJoinCodeGenerator { */ private def createLookupTypeInference( dataTypeFactory: DataTypeFactory, - callContext: LookupCallContext, + callContext: FunctionCallContext, baseClass: Class[_], udf: UserDefinedFunction, - functionName: String): TypeInference = { + functionName: String, + legacy: Boolean, + e: Exception): TypeInference = { - try { + if (!legacy) { // user provided type inference has precedence // this ensures that all functions work in the same way udf.getTypeInference(dataTypeFactory) - } catch { - case e: Exception => - // for convenience, we assume internal or default external data structures - // of expected logical types - val defaultArgDataTypes = callContext.getArgumentDataTypes.asScala - val defaultOutputDataType = callContext.getOutputDataType.get() - - val outputClass = - if ( - udf.isInstanceOf[LookupFunction] || udf.isInstanceOf[AsyncLookupFunction] || udf - .isInstanceOf[PredictFunction] || udf.isInstanceOf[AsyncPredictFunction] - ) { - Some(classOf[RowData]) - } else { - toScala(extractSimpleGeneric(baseClass, udf.getClass, 0)) - } - val (argDataTypes, outputDataType) = outputClass match { - case Some(c) if c == classOf[Row] => - (defaultArgDataTypes, defaultOutputDataType) - case Some(c) if c == classOf[RowData] => - val internalArgDataTypes = defaultArgDataTypes - .map(dt => transform(dt, TypeTransformations.TO_INTERNAL_CLASS)) - val internalOutputDataType = - transform(defaultOutputDataType, TypeTransformations.TO_INTERNAL_CLASS) - (internalArgDataTypes, internalOutputDataType) - case _ => - throw new ValidationException( - s"Could not determine a type inference for lookup function '$functionName'. " + - s"Lookup functions support regular type inference. However, for convenience, the " + - s"output class can simply be a ${classOf[Row].getSimpleName} or " + - s"${classOf[RowData].getSimpleName} class in which case the input and output " + - s"types are derived from the table's schema with default conversion.", - e) + } else { + // for convenience, we assume internal or default external data structures + // of expected logical types + val defaultArgDataTypes = callContext.getArgumentDataTypes.asScala + val defaultOutputDataType = callContext.getOutputDataType.get() + + val outputClass = + if (udf.isInstanceOf[LookupFunction] || udf.isInstanceOf[AsyncLookupFunction]) { + Some(classOf[RowData]) + } else { + toScala(extractSimpleGeneric(baseClass, udf.getClass, 0)) } + val (argDataTypes, outputDataType) = outputClass match { + case Some(c) if c == classOf[Row] => + (defaultArgDataTypes, defaultOutputDataType) + case Some(c) if c == classOf[RowData] => + val internalArgDataTypes = defaultArgDataTypes + .map(dt => transform(dt, TypeTransformations.TO_INTERNAL_CLASS)) + val internalOutputDataType = + transform(defaultOutputDataType, TypeTransformations.TO_INTERNAL_CLASS) + (internalArgDataTypes, internalOutputDataType) + case _ => + throw new ValidationException( + s"Could not determine a type inference for lookup function '$functionName'. " + + s"Lookup functions support regular type inference. However, for convenience, the " + + s"output class can simply be a ${classOf[Row].getSimpleName} or " + + s"${classOf[RowData].getSimpleName} class in which case the input and output " + + s"types are derived from the table's schema with default conversion.", + e) + } - verifyFunctionAwareImplementation(argDataTypes, outputDataType, udf, functionName) + verifyFunctionAwareImplementation(argDataTypes, outputDataType, udf, functionName) - TypeInference - .newBuilder() - .typedArguments(argDataTypes.asJava) - .outputTypeStrategy(TypeStrategies.explicit(outputDataType)) - .build() + TypeInference + .newBuilder() + .typedArguments(argDataTypes.asJava) + .outputTypeStrategy(TypeStrategies.explicit(outputDataType)) + .build() } } - /** - * Generates collector for temporal join ([[Collector]]) - * - * Differs from CommonCorrelate.generateCollector which has no real condition because of - * FLINK-7865, here we should deal with outer join type when real conditions filtered result. - */ def generateCollector( ctx: CodeGeneratorContext, inputRowType: RowType, @@ -293,122 +242,14 @@ object LookupJoinCodeGenerator { condition: Option[RexNode], pojoFieldMapping: Option[Array[Int]], retainHeader: Boolean = true): GeneratedCollector[ListenableCollector[RowData]] = { - - val inputTerm = DEFAULT_INPUT1_TERM - val rightInputTerm = DEFAULT_INPUT2_TERM - - val exprGenerator = new ExprCodeGenerator(ctx, nullableInput = false) - .bindInput(rightRowType, inputTerm = rightInputTerm, inputFieldMapping = pojoFieldMapping) - - val rightResultExpr = - exprGenerator.generateConverterResultExpression(rightRowType, classOf[GenericRowData]) - - val joinedRowTerm = CodeGenUtils.newName(ctx, "joinedRow") - ctx.addReusableOutputRecord(resultRowType, classOf[JoinedRowData], joinedRowTerm) - - val header = if (retainHeader) { - s"$joinedRowTerm.setRowKind($inputTerm.getRowKind());" - } else { - "" - } - - val body = - s""" - |${rightResultExpr.code} - |$joinedRowTerm.replace($inputTerm, ${rightResultExpr.resultTerm}); - |$header - |outputResult($joinedRowTerm); - """.stripMargin - - val collectorCode = if (condition.isEmpty) { - body - } else { - - val filterGenerator = new ExprCodeGenerator(ctx, nullableInput = false) - .bindInput(inputRowType, inputTerm) - .bindSecondInput(rightRowType, rightInputTerm, pojoFieldMapping) - val filterCondition = filterGenerator.generateExpression(condition.get) - - s""" - |${filterCondition.code} - |if (${filterCondition.resultTerm}) { - | $body - |} - |""".stripMargin - } - - generateTableFunctionCollectorForJoinTable( + FunctionCallCodeGenerator.generateCollector( ctx, - "JoinTableFuncCollector", - collectorCode, inputRowType, rightRowType, - inputTerm = inputTerm, - collectedTerm = rightInputTerm) - } - - /** - * The only differences against CollectorCodeGenerator.generateTableFunctionCollector is - * "super.collect" call is binding with collect join row in "body" code - */ - private def generateTableFunctionCollectorForJoinTable( - ctx: CodeGeneratorContext, - name: String, - bodyCode: String, - inputType: RowType, - collectedType: RowType, - inputTerm: String = DEFAULT_INPUT1_TERM, - collectedTerm: String = DEFAULT_INPUT2_TERM) - : GeneratedCollector[ListenableCollector[RowData]] = { - - val funcName = newName(ctx, name) - val input1TypeClass = boxedTypeTermForType(inputType) - val input2TypeClass = boxedTypeTermForType(collectedType) - - val funcCode = - s""" - public class $funcName extends ${classOf[ListenableCollector[_]].getCanonicalName} { - - ${ctx.reuseMemberCode()} - - public $funcName(Object[] references) throws Exception { - ${ctx.reuseInitCode()} - } - - @Override - public void open(${className[OpenContext]} openContext) throws Exception { - ${ctx.reuseOpenCode()} - } - - @Override - public void collect(Object record) throws Exception { - $input1TypeClass $inputTerm = ($input1TypeClass) getInput(); - $input2TypeClass $collectedTerm = ($input2TypeClass) record; - - // callback only when collectListener exists, equivalent to: - // getCollectListener().ifPresent( - // listener -> ((CollectListener) listener).onCollect(record)); - // TODO we should update code splitter's grammar file to accept lambda expressions. - - if (getCollectListener().isPresent()) { - ((${classOf[CollectListener[_]].getCanonicalName}) getCollectListener().get()) - .onCollect(record); - } - - ${ctx.reuseLocalVariableCode()} - ${ctx.reuseInputUnboxingCode()} - ${ctx.reusePerRecordCode()} - $bodyCode - } - - @Override - public void close() throws Exception { - ${ctx.reuseCloseCode()} - } - } - """.stripMargin - - new GeneratedCollector(funcName, funcCode, ctx.references.toArray, ctx.tableConfig) + resultRowType, + condition, + pojoFieldMapping, + retainHeader) } /** diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/MLPredictCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/MLPredictCodeGenerator.scala new file mode 100644 index 0000000000000..68d322f8275dd --- /dev/null +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/MLPredictCodeGenerator.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.planner.codegen + +import org.apache.flink.api.common.functions.FlatMapFunction +import org.apache.flink.configuration.ReadableConfig +import org.apache.flink.streaming.api.functions.async.AsyncFunction +import org.apache.flink.table.catalog.DataTypeFactory +import org.apache.flink.table.data.RowData +import org.apache.flink.table.functions.{AsyncTableFunction, TableFunction, UserDefinedFunction} +import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType +import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil +import org.apache.flink.table.planner.functions.inference.FunctionCallContext +import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam +import org.apache.flink.table.runtime.collector.ListenableCollector +import org.apache.flink.table.runtime.generated.{GeneratedCollector, GeneratedFunction} +import org.apache.flink.table.types.inference.{TypeInference, TypeStrategies, TypeTransformations} +import org.apache.flink.table.types.logical.{LogicalType, RowType} +import org.apache.flink.table.types.utils.DataTypeUtils.transform + +import java.util + +import scala.collection.JavaConverters._ + +object MLPredictCodeGenerator { + + /** Generates a predict function ([[TableFunction]]) */ + def generateSyncPredictFunction( + tableConfig: ReadableConfig, + classLoader: ClassLoader, + dataTypeFactory: DataTypeFactory, + inputType: LogicalType, + predictFunctionOutputType: LogicalType, + collectorOutputType: LogicalType, + features: util.List[FunctionParam], + syncPredictFunction: TableFunction[_], + functionName: String, + fieldCopy: Boolean + ): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { + FunctionCallCodeGenerator + .generateSyncFunctionCall( + tableConfig, + classLoader, + dataTypeFactory, + inputType, + predictFunctionOutputType, + collectorOutputType, + features, + syncPredictFunction, + generateCallWithDataType(functionName, predictFunctionOutputType), + functionName, + "PredictFunction", + fieldCopy + ) + .tableFunc + } + + /** Generates a async predict function ([[AsyncTableFunction]]) */ + def generateAsyncPredictFunction( + tableConfig: ReadableConfig, + classLoader: ClassLoader, + dataTypeFactory: DataTypeFactory, + inputType: LogicalType, + predictFunctionOutputType: LogicalType, + collectorOutputType: LogicalType, + features: util.List[FunctionParam], + asyncPredictFunction: AsyncTableFunction[_], + functionName: String): GeneratedTableFunctionWithDataType[AsyncFunction[RowData, AnyRef]] = { + FunctionCallCodeGenerator.generateAsyncFunctionCall( + tableConfig, + classLoader, + dataTypeFactory, + inputType, + predictFunctionOutputType, + collectorOutputType, + features, + asyncPredictFunction, + generateCallWithDataType(functionName, predictFunctionOutputType), + functionName, + "AsyncPredictFunction" + ) + } + + /** Generate a collector to collect to join the input row and predicted results. */ + def generateCollector( + ctx: CodeGeneratorContext, + inputRowType: RowType, + predictFunctionOutputType: RowType, + collectorOutputType: RowType + ): GeneratedCollector[ListenableCollector[RowData]] = { + FunctionCallCodeGenerator.generateCollector( + ctx, + inputRowType, + predictFunctionOutputType, + collectorOutputType, + Option.empty, + Option.empty + ) + } + + private def generateCallWithDataType( + functionName: String, + modelOutputType: LogicalType + ) = ( + ctx: CodeGeneratorContext, + callContext: FunctionCallContext, + udf: UserDefinedFunction, + operands: Seq[GeneratedExpression]) => { + val inference = TypeInference + .newBuilder() + .typedArguments( + callContext.getArgumentDataTypes.asScala + .map(dt => transform(dt, TypeTransformations.TO_INTERNAL_CLASS)) + .asJava) + .outputTypeStrategy(TypeStrategies.explicit( + transform(callContext.getOutputDataType.get(), TypeTransformations.TO_INTERNAL_CLASS))) + .build() + BridgingFunctionGenUtil.generateFunctionAwareCallWithDataType( + ctx, + operands, + modelOutputType, + inference, + callContext, + udf, + functionName, + skipIfArgsNull = false + ) + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractAsyncFunctionRunner.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractAsyncFunctionRunner.java new file mode 100644 index 0000000000000..10e4defcfb441 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractAsyncFunctionRunner.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators; + +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.util.FunctionUtils; +import org.apache.flink.streaming.api.functions.async.AsyncFunction; +import org.apache.flink.streaming.api.functions.async.RichAsyncFunction; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.functions.AsyncLookupFunction; +import org.apache.flink.table.functions.AsyncPredictFunction; +import org.apache.flink.table.runtime.generated.GeneratedFunction; + +/** + * Base function runner for specialized table function, e.g. {@link AsyncLookupFunction} or {@link + * AsyncPredictFunction}. + */ +public abstract class AbstractAsyncFunctionRunner extends RichAsyncFunction { + + protected final GeneratedFunction> generatedFetcher; + + protected transient AsyncFunction fetcher; + + public AbstractAsyncFunctionRunner( + GeneratedFunction> generatedFetcher) { + this.generatedFetcher = generatedFetcher; + } + + @Override + public void open(OpenContext openContext) throws Exception { + super.open(openContext); + fetcher = generatedFetcher.newInstance(getRuntimeContext().getUserCodeClassLoader()); + FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext()); + FunctionUtils.openFunction(fetcher, openContext); + } + + @Override + public void close() throws Exception { + super.close(); + if (fetcher != null) { + FunctionUtils.closeFunction(fetcher); + } + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractFunctionRunner.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractFunctionRunner.java new file mode 100644 index 0000000000000..126aeba220453 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractFunctionRunner.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.util.FunctionUtils; +import org.apache.flink.streaming.api.functions.ProcessFunction; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.functions.LookupFunction; +import org.apache.flink.table.runtime.generated.GeneratedFunction; + +/** + * Base function runner for specialized table function, e.g. {@link LookupFunction} or {@link + * ProcessFunction}. + */ +public abstract class AbstractFunctionRunner extends ProcessFunction { + + private final GeneratedFunction> generatedFetcher; + + protected transient FlatMapFunction fetcher; + + public AbstractFunctionRunner( + GeneratedFunction> generatedFetcher) { + this.generatedFetcher = generatedFetcher; + } + + @Override + public void open(OpenContext openContext) throws Exception { + super.open(openContext); + this.fetcher = generatedFetcher.newInstance(getRuntimeContext().getUserCodeClassLoader()); + + FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext()); + FunctionUtils.openFunction(fetcher, openContext); + } + + @Override + public void close() throws Exception { + if (fetcher != null) { + FunctionUtils.closeFunction(fetcher); + } + super.close(); + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java index 456f07bcd962c..7cc3a7273f7c6 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java @@ -18,37 +18,23 @@ package org.apache.flink.table.runtime.operators.calc.async; -import org.apache.flink.api.common.functions.OpenContext; -import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.streaming.api.functions.async.AsyncFunction; import org.apache.flink.streaming.api.functions.async.ResultFuture; -import org.apache.flink.streaming.api.functions.async.RichAsyncFunction; import org.apache.flink.table.data.RowData; import org.apache.flink.table.runtime.generated.GeneratedFunction; +import org.apache.flink.table.runtime.operators.AbstractAsyncFunctionRunner; /** * Async function runner for {@link org.apache.flink.table.functions.AsyncScalarFunction}, which * takes the generated function, instantiates it, and then calls its lifecycle methods. */ -public class AsyncFunctionRunner extends RichAsyncFunction { +public class AsyncFunctionRunner extends AbstractAsyncFunctionRunner { private static final long serialVersionUID = -7198305381139008806L; - private final GeneratedFunction> generatedFetcher; - - private transient AsyncFunction fetcher; - public AsyncFunctionRunner( GeneratedFunction> generatedFetcher) { - this.generatedFetcher = generatedFetcher; - } - - @Override - public void open(OpenContext openContext) throws Exception { - super.open(openContext); - fetcher = generatedFetcher.newInstance(getRuntimeContext().getUserCodeClassLoader()); - FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext()); - FunctionUtils.openFunction(fetcher, openContext); + super(generatedFetcher); } @Override @@ -59,10 +45,4 @@ public void asyncInvoke(RowData input, ResultFuture resultFuture) { resultFuture.completeExceptionally(t); } } - - @Override - public void close() throws Exception { - super.close(); - FunctionUtils.closeFunction(fetcher); - } } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinRunner.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinRunner.java index 310f38d5489fa..6b34c99f4a766 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinRunner.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinRunner.java @@ -26,7 +26,6 @@ import org.apache.flink.streaming.api.functions.async.AsyncFunction; import org.apache.flink.streaming.api.functions.async.CollectionSupplier; import org.apache.flink.streaming.api.functions.async.ResultFuture; -import org.apache.flink.streaming.api.functions.async.RichAsyncFunction; import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.conversion.DataStructureConverter; @@ -35,6 +34,7 @@ import org.apache.flink.table.runtime.generated.FilterCondition; import org.apache.flink.table.runtime.generated.GeneratedFunction; import org.apache.flink.table.runtime.generated.GeneratedResultFuture; +import org.apache.flink.table.runtime.operators.AbstractAsyncFunctionRunner; import org.apache.flink.table.runtime.typeutils.RowDataSerializer; import java.util.ArrayList; @@ -45,10 +45,9 @@ import java.util.concurrent.BlockingQueue; /** The async join runner to lookup the dimension table. */ -public class AsyncLookupJoinRunner extends RichAsyncFunction { +public class AsyncLookupJoinRunner extends AbstractAsyncFunctionRunner { private static final long serialVersionUID = -6664660022391632480L; - private final GeneratedFunction> generatedFetcher; private final DataStructureConverter fetcherConverter; private final GeneratedResultFuture> generatedResultFuture; private final GeneratedFunction generatedPreFilterCondition; @@ -56,8 +55,6 @@ public class AsyncLookupJoinRunner extends RichAsyncFunction { private final boolean isLeftOuterJoin; private final int asyncBufferCapacity; - private transient AsyncFunction fetcher; - protected final RowDataSerializer rightRowSerializer; /** @@ -83,7 +80,7 @@ public AsyncLookupJoinRunner( RowDataSerializer rightRowSerializer, boolean isLeftOuterJoin, int asyncBufferCapacity) { - this.generatedFetcher = generatedFetcher; + super(generatedFetcher); this.fetcherConverter = fetcherConverter; this.generatedResultFuture = generatedResultFuture; this.generatedPreFilterCondition = generatedPreFilterCondition; @@ -96,11 +93,9 @@ public AsyncLookupJoinRunner( public void open(OpenContext openContext) throws Exception { super.open(openContext); ClassLoader cl = getRuntimeContext().getUserCodeClassLoader(); - this.fetcher = generatedFetcher.newInstance(cl); this.preFilterCondition = generatedPreFilterCondition.newInstance(cl); FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext()); FunctionUtils.setFunctionRuntimeContext(preFilterCondition, getRuntimeContext()); - FunctionUtils.openFunction(fetcher, openContext); FunctionUtils.openFunction(preFilterCondition, openContext); // try to compile the generated ResultFuture, fail fast if the code is corrupt. @@ -152,9 +147,6 @@ public TableFunctionResultFuture createFetcherResultFuture(Configuratio @Override public void close() throws Exception { super.close(); - if (fetcher != null) { - FunctionUtils.closeFunction(fetcher); - } if (preFilterCondition != null) { FunctionUtils.closeFunction(preFilterCondition); } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/LookupJoinRunner.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/LookupJoinRunner.java index c91a9feb7bfb8..4ecb0b53f5980 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/LookupJoinRunner.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/LookupJoinRunner.java @@ -21,7 +21,6 @@ import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.OpenContext; import org.apache.flink.api.common.functions.util.FunctionUtils; -import org.apache.flink.streaming.api.functions.ProcessFunction; import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.utils.JoinedRowData; @@ -29,20 +28,19 @@ import org.apache.flink.table.runtime.generated.FilterCondition; import org.apache.flink.table.runtime.generated.GeneratedCollector; import org.apache.flink.table.runtime.generated.GeneratedFunction; +import org.apache.flink.table.runtime.operators.AbstractFunctionRunner; import org.apache.flink.util.Collector; /** The join runner to lookup the dimension table. */ -public class LookupJoinRunner extends ProcessFunction { +public class LookupJoinRunner extends AbstractFunctionRunner { private static final long serialVersionUID = -4521543015709964733L; - private final GeneratedFunction> generatedFetcher; private final GeneratedCollector> generatedCollector; private final GeneratedFunction generatedPreFilterCondition; protected final boolean isLeftOuterJoin; protected final int tableFieldsCount; - private transient FlatMapFunction fetcher; protected transient ListenableCollector collector; protected transient JoinedRowData outRow; protected transient FilterCondition preFilterCondition; @@ -54,7 +52,7 @@ public LookupJoinRunner( GeneratedFunction generatedPreFilterCondition, boolean isLeftOuterJoin, int tableFieldsCount) { - this.generatedFetcher = generatedFetcher; + super(generatedFetcher); this.generatedCollector = generatedCollector; this.generatedPreFilterCondition = generatedPreFilterCondition; this.isLeftOuterJoin = isLeftOuterJoin; @@ -64,17 +62,14 @@ public LookupJoinRunner( @Override public void open(OpenContext openContext) throws Exception { super.open(openContext); - this.fetcher = generatedFetcher.newInstance(getRuntimeContext().getUserCodeClassLoader()); this.collector = generatedCollector.newInstance(getRuntimeContext().getUserCodeClassLoader()); this.preFilterCondition = generatedPreFilterCondition.newInstance( getRuntimeContext().getUserCodeClassLoader()); - FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext()); FunctionUtils.setFunctionRuntimeContext(collector, getRuntimeContext()); FunctionUtils.setFunctionRuntimeContext(preFilterCondition, getRuntimeContext()); - FunctionUtils.openFunction(fetcher, openContext); FunctionUtils.openFunction(collector, openContext); FunctionUtils.openFunction(preFilterCondition, openContext); @@ -124,9 +119,6 @@ public Collector getFetcherCollector() { @Override public void close() throws Exception { - if (fetcher != null) { - FunctionUtils.closeFunction(fetcher); - } if (collector != null) { FunctionUtils.closeFunction(collector); } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/AsyncMLPredictRunner.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/AsyncMLPredictRunner.java new file mode 100644 index 0000000000000..aa77eedaeaffb --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/AsyncMLPredictRunner.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.ml; + +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.streaming.api.functions.async.AsyncFunction; +import org.apache.flink.streaming.api.functions.async.CollectionSupplier; +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.utils.JoinedRowData; +import org.apache.flink.table.functions.AsyncPredictFunction; +import org.apache.flink.table.runtime.generated.GeneratedFunction; +import org.apache.flink.table.runtime.operators.AbstractAsyncFunctionRunner; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; + +/** + * Async function runner for {@link AsyncPredictFunction}, which takes the generated function, + * instantiates it, and then calls its lifecycle methods. + */ +public class AsyncMLPredictRunner extends AbstractAsyncFunctionRunner { + + private final int asyncBufferCapacity; + + /** + * Buffers {@link ResultFuture} to avoid newInstance cost when processing elements every time. + * We use {@link BlockingQueue} to make sure the head {@link ResultFuture}s are available. + */ + private transient BlockingQueue resultFutureBuffer; + + public AsyncMLPredictRunner( + GeneratedFunction> generatedFetcher, + int asyncBufferCapacity) { + super(generatedFetcher); + this.asyncBufferCapacity = asyncBufferCapacity; + } + + @Override + public void open(OpenContext openContext) throws Exception { + super.open(openContext); + this.resultFutureBuffer = new ArrayBlockingQueue<>(asyncBufferCapacity + 1); + for (int i = 0; i < asyncBufferCapacity + 1; i++) { + JoinedRowResultFuture rf = new JoinedRowResultFuture(resultFutureBuffer); + // add will throw exception immediately if the queue is full which should never happen + resultFutureBuffer.add(rf); + } + registerMetric(getRuntimeContext().getMetricGroup()); + } + + @Override + public void asyncInvoke(RowData input, ResultFuture resultFuture) throws Exception { + try { + JoinedRowResultFuture buffer = resultFutureBuffer.take(); + buffer.reset(input, resultFuture); + fetcher.asyncInvoke(input, buffer); + } catch (Throwable t) { + resultFuture.completeExceptionally(t); + } + } + + private void registerMetric(MetricGroup metricGroup) { + metricGroup.gauge( + "ai_queue_length", () -> asyncBufferCapacity + 1 - resultFutureBuffer.size()); + metricGroup.gauge("ai_queue_capacity", () -> asyncBufferCapacity); + metricGroup.gauge( + "ai_queue_usage_ratio", + () -> + 1.0 + * (asyncBufferCapacity + 1 - resultFutureBuffer.size()) + / asyncBufferCapacity); + } + + private static final class JoinedRowResultFuture implements ResultFuture { + + private final BlockingQueue resultFutureBuffer; + + private ResultFuture realOutput; + private RowData leftRow; + + public JoinedRowResultFuture(BlockingQueue resultFutureBuffer) { + this.resultFutureBuffer = resultFutureBuffer; + } + + public void reset(RowData row, ResultFuture realOutput) { + this.realOutput = realOutput; + this.leftRow = row; + } + + @Override + public void complete(Collection result) { + List outRows = new ArrayList<>(); + for (RowData rightRow : result) { + RowData outRow = new JoinedRowData(leftRow.getRowKind(), leftRow, rightRow); + outRows.add(outRow); + } + realOutput.complete(outRows); + + try { + // put this collector to the queue to avoid this collector is used + // again before outRows in the collector is not consumed. + resultFutureBuffer.put(this); + } catch (InterruptedException e) { + completeExceptionally(e); + } + } + + @Override + public void completeExceptionally(Throwable error) { + realOutput.completeExceptionally(error); + } + + @Override + public void complete(CollectionSupplier supplier) { + throw new UnsupportedOperationException(); + } + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/MLPredictRunner.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/MLPredictRunner.java new file mode 100644 index 0000000000000..a8a486911f84b --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/MLPredictRunner.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.ml; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.util.FunctionUtils; +import org.apache.flink.streaming.api.functions.ProcessFunction; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.functions.PredictFunction; +import org.apache.flink.table.runtime.collector.ListenableCollector; +import org.apache.flink.table.runtime.generated.GeneratedCollector; +import org.apache.flink.table.runtime.generated.GeneratedFunction; +import org.apache.flink.table.runtime.operators.AbstractFunctionRunner; +import org.apache.flink.util.Collector; + +/** + * Function runner for {@link PredictFunction}, which takes the generated function, instantiates it, + * and then calls its lifecycle methods. + */ +public class MLPredictRunner extends AbstractFunctionRunner { + + private final GeneratedCollector> generatedCollector; + + protected transient ListenableCollector collector; + + public MLPredictRunner( + GeneratedFunction> generatedFetcher, + GeneratedCollector> generatedCollector) { + super(generatedFetcher); + this.generatedCollector = generatedCollector; + } + + @Override + public void open(OpenContext openContext) throws Exception { + super.open(openContext); + + this.collector = + generatedCollector.newInstance(getRuntimeContext().getUserCodeClassLoader()); + FunctionUtils.setFunctionRuntimeContext(collector, getRuntimeContext()); + FunctionUtils.openFunction(collector, openContext); + } + + @Override + public void processElement( + RowData in, ProcessFunction.Context ctx, Collector out) + throws Exception { + prepareCollector(in, out); + fetcher.flatMap(in, collector); + } + + public void prepareCollector(RowData in, Collector out) { + collector.setCollector(out); + collector.setInput(in); + collector.reset(); + } +}