Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.connector.source.LookupTableSource;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.ml.PredictRuntimeProvider;
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.Constant;
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam;
import org.apache.flink.table.types.DataType;
Expand All @@ -38,24 +39,27 @@
import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getFieldTypes;
import static org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType;

/** The {@link CallContext} of a {@link LookupTableSource} runtime function. */
/**
* The {@link CallContext} of {@link LookupTableSource}, {@link PredictRuntimeProvider} runtime
* function.
*/
@Internal
public class LookupCallContext extends AbstractSqlCallContext {
public class FunctionCallContext extends AbstractSqlCallContext {

private final List<FunctionParam> lookupKeys;
private final List<FunctionParam> params;

private final List<DataType> argumentDataTypes;

private final DataType outputDataType;

public LookupCallContext(
public FunctionCallContext(
DataTypeFactory dataTypeFactory,
UserDefinedFunction function,
LogicalType inputType,
List<FunctionParam> lookupKeys,
LogicalType lookupType) {
List<FunctionParam> params,
LogicalType outputDataType) {
super(dataTypeFactory, function, generateInlineFunctionName(function), false);
this.lookupKeys = lookupKeys;
this.params = params;
this.argumentDataTypes =
new AbstractList<>() {
@Override
Expand All @@ -74,10 +78,10 @@ public DataType get(int index) {

@Override
public int size() {
return lookupKeys.size();
return params.size();
}
};
this.outputDataType = fromLogicalToDataType(lookupType);
this.outputDataType = fromLogicalToDataType(outputDataType);
}

@Override
Expand Down Expand Up @@ -118,6 +122,6 @@ public Optional<DataType> getOutputDataType() {
// --------------------------------------------------------------------------------------------

private FunctionParam getKey(int pos) {
return lookupKeys.get(pos);
return params.get(pos);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
import org.apache.flink.table.planner.codegen.FilterCodeGenerator;
import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator;
import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
Expand Down Expand Up @@ -459,7 +460,7 @@ protected StreamOperatorFactory<RowData> createAsyncLookupJoin(
.mapToObj(allLookupKeys::get)
.collect(Collectors.toList());

LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>>
FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>>
generatedFuncWithType =
LookupJoinCodeGenerator.generateAsyncLookupFunction(
config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.flink.table.functions.AsyncTableFunction;
import org.apache.flink.table.functions.UserDefinedFunctionHelper;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator;
import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
Expand Down Expand Up @@ -353,7 +354,7 @@ private AsyncDeltaJoinRunner createAsyncDeltaJoinRunner(
.mapToObj(lookupKeys::get)
.collect(Collectors.toList());

LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>>
FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>>
lookupSideGeneratedFuncWithType =
LookupJoinCodeGenerator.generateAsyncLookupFunction(
config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.conversion.DataStructureConverter;
import org.apache.flink.table.data.conversion.DataStructureConverters;
import org.apache.flink.table.functions.AsyncPredictFunction;
import org.apache.flink.table.functions.PredictFunction;
import org.apache.flink.table.functions.UserDefinedFunction;
Expand All @@ -41,8 +39,8 @@
import org.apache.flink.table.ml.PredictRuntimeProvider;
import org.apache.flink.table.planner.calcite.FlinkContext;
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
import org.apache.flink.table.planner.codegen.FilterCodeGenerator;
import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator;
import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator;
import org.apache.flink.table.planner.codegen.MLPredictCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
Expand All @@ -55,18 +53,15 @@
import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec;
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.runtime.collector.ListenableCollector;
import org.apache.flink.table.runtime.collector.TableFunctionResultFuture;
import org.apache.flink.table.runtime.functions.ml.ModelPredictRuntimeProviderContext;
import org.apache.flink.table.runtime.generated.GeneratedCollector;
import org.apache.flink.table.runtime.generated.GeneratedFunction;
import org.apache.flink.table.runtime.generated.GeneratedResultFuture;
import org.apache.flink.table.runtime.operators.join.lookup.AsyncLookupJoinRunner;
import org.apache.flink.table.runtime.operators.join.lookup.LookupJoinRunner;
import org.apache.flink.table.runtime.typeutils.InternalSerializers;
import org.apache.flink.table.runtime.operators.ml.AsyncMLPredictRunner;
import org.apache.flink.table.runtime.operators.ml.MLPredictRunner;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.Preconditions;

import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
Expand All @@ -75,7 +70,6 @@

import java.util.Collections;
import java.util.List;
import java.util.Optional;

/** Stream {@link ExecNode} for {@code ML_PREDICT}. */
@ExecNodeMetadata(
Expand Down Expand Up @@ -197,7 +191,7 @@ private Transformation<RowData> createModelPredict(
RowType resultRowType,
PredictFunction predictFunction) {
GeneratedFunction<FlatMapFunction<RowData, RowData>> generatedFetcher =
LookupJoinCodeGenerator.generateSyncLookupFunction(
MLPredictCodeGenerator.generateSyncPredictFunction(
config,
classLoader,
dataTypeFactory,
Expand All @@ -206,25 +200,15 @@ private Transformation<RowData> createModelPredict(
resultRowType,
mlPredictSpec.getFeatures(),
predictFunction,
"MLPredict",
modelSpec.getContextResolvedModel().getIdentifier().asSummaryString(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we change function name to model name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The field here is used by UserDefinedFunctionHelper#createSpecializedFunction. When the model doesn't implement required interface, an exception should be thrown to indicate which function is illegal. Therefore, the origin implementation is not correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry didn't get it. Which required interface? Is it to improve error message?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config.get(PipelineOptions.OBJECT_REUSE));
GeneratedCollector<ListenableCollector<RowData>> generatedCollector =
LookupJoinCodeGenerator.generateCollector(
MLPredictCodeGenerator.generateCollector(
new CodeGeneratorContext(config, classLoader),
inputRowType,
modelOutputType,
(RowType) getOutputType(),
JavaScalaConversionUtil.toScala(Optional.empty()),
JavaScalaConversionUtil.toScala(Optional.empty()),
true);
LookupJoinRunner mlPredictRunner =
new LookupJoinRunner(
generatedFetcher,
generatedCollector,
FilterCodeGenerator.generateFilterCondition(
config, classLoader, null, inputRowType),
false,
modelOutputType.getFieldCount());
(RowType) getOutputType());
MLPredictRunner mlPredictRunner = new MLPredictRunner(generatedFetcher, generatedCollector);
SimpleOperatorFactory<RowData> operatorFactory =
SimpleOperatorFactory.of(new ProcessOperator<>(mlPredictRunner));
return ExecNodeUtil.createOneInputTransformation(
Expand All @@ -246,9 +230,9 @@ private Transformation<RowData> createAsyncModelPredict(
RowType modelOutputType,
RowType resultRowType,
AsyncPredictFunction asyncPredictFunction) {
LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>>
FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>>
generatedFuncWithType =
LookupJoinCodeGenerator.generateAsyncLookupFunction(
MLPredictCodeGenerator.generateAsyncPredictFunction(
config,
classLoader,
dataTypeFactory,
Expand All @@ -257,29 +241,14 @@ private Transformation<RowData> createAsyncModelPredict(
resultRowType,
mlPredictSpec.getFeatures(),
asyncPredictFunction,
"AsyncMLPredict");

GeneratedResultFuture<TableFunctionResultFuture<RowData>> generatedResultFuture =
LookupJoinCodeGenerator.generateTableAsyncCollector(
config,
classLoader,
"TableFunctionResultFuture",
inputRowType,
modelOutputType,
JavaScalaConversionUtil.toScala(Optional.empty()));

DataStructureConverter<?, ?> fetcherConverter =
DataStructureConverters.getConverter(generatedFuncWithType.dataType());
modelSpec
.getContextResolvedModel()
.getIdentifier()
.asSummaryString());
AsyncFunction<RowData, RowData> asyncFunc =
new AsyncLookupJoinRunner(
generatedFuncWithType.tableFunc(),
(DataStructureConverter<RowData, Object>) fetcherConverter,
generatedResultFuture,
FilterCodeGenerator.generateFilterCondition(
config, classLoader, null, inputRowType),
InternalSerializers.create(modelOutputType),
false,
asyncOptions.asyncBufferCapacity);
new AsyncMLPredictRunner(
(GeneratedFunction) generatedFuncWithType.tableFunc(),
Preconditions.checkNotNull(asyncOptions).asyncBufferCapacity);
return ExecNodeUtil.createOneInputTransformation(
inputTransformation,
createTransformationMeta(ML_PREDICT_TRANSFORMATION, config),
Expand Down
Loading