Skip to content

Commit 8b6e69e

Browse files
authored
[FLINK-38435][table] Refactor codegen and runner for MLPredict and LookupJoin (#27041)
1 parent 83db8f7 commit 8b6e69e

File tree

14 files changed

+987
-388
lines changed

14 files changed

+987
-388
lines changed
Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.apache.flink.table.catalog.DataTypeFactory;
2222
import org.apache.flink.table.connector.source.LookupTableSource;
2323
import org.apache.flink.table.functions.UserDefinedFunction;
24+
import org.apache.flink.table.ml.PredictRuntimeProvider;
2425
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.Constant;
2526
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam;
2627
import org.apache.flink.table.types.DataType;
@@ -38,24 +39,27 @@
3839
import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getFieldTypes;
3940
import static org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType;
4041

41-
/** The {@link CallContext} of a {@link LookupTableSource} runtime function. */
42+
/**
43+
* The {@link CallContext} of {@link LookupTableSource}, {@link PredictRuntimeProvider} runtime
44+
* function.
45+
*/
4246
@Internal
43-
public class LookupCallContext extends AbstractSqlCallContext {
47+
public class FunctionCallContext extends AbstractSqlCallContext {
4448

45-
private final List<FunctionParam> lookupKeys;
49+
private final List<FunctionParam> params;
4650

4751
private final List<DataType> argumentDataTypes;
4852

4953
private final DataType outputDataType;
5054

51-
public LookupCallContext(
55+
public FunctionCallContext(
5256
DataTypeFactory dataTypeFactory,
5357
UserDefinedFunction function,
5458
LogicalType inputType,
55-
List<FunctionParam> lookupKeys,
56-
LogicalType lookupType) {
59+
List<FunctionParam> params,
60+
LogicalType outputDataType) {
5761
super(dataTypeFactory, function, generateInlineFunctionName(function), false);
58-
this.lookupKeys = lookupKeys;
62+
this.params = params;
5963
this.argumentDataTypes =
6064
new AbstractList<>() {
6165
@Override
@@ -74,10 +78,10 @@ public DataType get(int index) {
7478

7579
@Override
7680
public int size() {
77-
return lookupKeys.size();
81+
return params.size();
7882
}
7983
};
80-
this.outputDataType = fromLogicalToDataType(lookupType);
84+
this.outputDataType = fromLogicalToDataType(outputDataType);
8185
}
8286

8387
@Override
@@ -118,6 +122,6 @@ public Optional<DataType> getOutputDataType() {
118122
// --------------------------------------------------------------------------------------------
119123

120124
private FunctionParam getKey(int pos) {
121-
return lookupKeys.get(pos);
125+
return params.get(pos);
122126
}
123127
}

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
4747
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
4848
import org.apache.flink.table.planner.codegen.FilterCodeGenerator;
49+
import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator;
4950
import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator;
5051
import org.apache.flink.table.planner.delegation.PlannerBase;
5152
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
@@ -459,7 +460,7 @@ protected StreamOperatorFactory<RowData> createAsyncLookupJoin(
459460
.mapToObj(allLookupKeys::get)
460461
.collect(Collectors.toList());
461462

462-
LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>>
463+
FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>>
463464
generatedFuncWithType =
464465
LookupJoinCodeGenerator.generateAsyncLookupFunction(
465466
config,

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.apache.flink.table.functions.AsyncTableFunction;
3232
import org.apache.flink.table.functions.UserDefinedFunctionHelper;
3333
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
34+
import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator;
3435
import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator;
3536
import org.apache.flink.table.planner.delegation.PlannerBase;
3637
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
@@ -353,7 +354,7 @@ private AsyncDeltaJoinRunner createAsyncDeltaJoinRunner(
353354
.mapToObj(lookupKeys::get)
354355
.collect(Collectors.toList());
355356

356-
LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>>
357+
FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>>
357358
lookupSideGeneratedFuncWithType =
358359
LookupJoinCodeGenerator.generateAsyncLookupFunction(
359360
config,

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java

Lines changed: 19 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
import org.apache.flink.table.api.TableException;
3232
import org.apache.flink.table.catalog.DataTypeFactory;
3333
import org.apache.flink.table.data.RowData;
34-
import org.apache.flink.table.data.conversion.DataStructureConverter;
35-
import org.apache.flink.table.data.conversion.DataStructureConverters;
3634
import org.apache.flink.table.functions.AsyncPredictFunction;
3735
import org.apache.flink.table.functions.PredictFunction;
3836
import org.apache.flink.table.functions.UserDefinedFunction;
@@ -41,8 +39,8 @@
4139
import org.apache.flink.table.ml.PredictRuntimeProvider;
4240
import org.apache.flink.table.planner.calcite.FlinkContext;
4341
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
44-
import org.apache.flink.table.planner.codegen.FilterCodeGenerator;
45-
import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator;
42+
import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator;
43+
import org.apache.flink.table.planner.codegen.MLPredictCodeGenerator;
4644
import org.apache.flink.table.planner.delegation.PlannerBase;
4745
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
4846
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
@@ -55,18 +53,15 @@
5553
import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec;
5654
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
5755
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil;
58-
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
5956
import org.apache.flink.table.runtime.collector.ListenableCollector;
60-
import org.apache.flink.table.runtime.collector.TableFunctionResultFuture;
6157
import org.apache.flink.table.runtime.functions.ml.ModelPredictRuntimeProviderContext;
6258
import org.apache.flink.table.runtime.generated.GeneratedCollector;
6359
import org.apache.flink.table.runtime.generated.GeneratedFunction;
64-
import org.apache.flink.table.runtime.generated.GeneratedResultFuture;
65-
import org.apache.flink.table.runtime.operators.join.lookup.AsyncLookupJoinRunner;
66-
import org.apache.flink.table.runtime.operators.join.lookup.LookupJoinRunner;
67-
import org.apache.flink.table.runtime.typeutils.InternalSerializers;
60+
import org.apache.flink.table.runtime.operators.ml.AsyncMLPredictRunner;
61+
import org.apache.flink.table.runtime.operators.ml.MLPredictRunner;
6862
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
6963
import org.apache.flink.table.types.logical.RowType;
64+
import org.apache.flink.util.Preconditions;
7065

7166
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
7267
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
@@ -75,7 +70,6 @@
7570

7671
import java.util.Collections;
7772
import java.util.List;
78-
import java.util.Optional;
7973

8074
/** Stream {@link ExecNode} for {@code ML_PREDICT}. */
8175
@ExecNodeMetadata(
@@ -197,7 +191,7 @@ private Transformation<RowData> createModelPredict(
197191
RowType resultRowType,
198192
PredictFunction predictFunction) {
199193
GeneratedFunction<FlatMapFunction<RowData, RowData>> generatedFetcher =
200-
LookupJoinCodeGenerator.generateSyncLookupFunction(
194+
MLPredictCodeGenerator.generateSyncPredictFunction(
201195
config,
202196
classLoader,
203197
dataTypeFactory,
@@ -206,25 +200,15 @@ private Transformation<RowData> createModelPredict(
206200
resultRowType,
207201
mlPredictSpec.getFeatures(),
208202
predictFunction,
209-
"MLPredict",
203+
modelSpec.getContextResolvedModel().getIdentifier().asSummaryString(),
210204
config.get(PipelineOptions.OBJECT_REUSE));
211205
GeneratedCollector<ListenableCollector<RowData>> generatedCollector =
212-
LookupJoinCodeGenerator.generateCollector(
206+
MLPredictCodeGenerator.generateCollector(
213207
new CodeGeneratorContext(config, classLoader),
214208
inputRowType,
215209
modelOutputType,
216-
(RowType) getOutputType(),
217-
JavaScalaConversionUtil.toScala(Optional.empty()),
218-
JavaScalaConversionUtil.toScala(Optional.empty()),
219-
true);
220-
LookupJoinRunner mlPredictRunner =
221-
new LookupJoinRunner(
222-
generatedFetcher,
223-
generatedCollector,
224-
FilterCodeGenerator.generateFilterCondition(
225-
config, classLoader, null, inputRowType),
226-
false,
227-
modelOutputType.getFieldCount());
210+
(RowType) getOutputType());
211+
MLPredictRunner mlPredictRunner = new MLPredictRunner(generatedFetcher, generatedCollector);
228212
SimpleOperatorFactory<RowData> operatorFactory =
229213
SimpleOperatorFactory.of(new ProcessOperator<>(mlPredictRunner));
230214
return ExecNodeUtil.createOneInputTransformation(
@@ -246,9 +230,9 @@ private Transformation<RowData> createAsyncModelPredict(
246230
RowType modelOutputType,
247231
RowType resultRowType,
248232
AsyncPredictFunction asyncPredictFunction) {
249-
LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>>
233+
FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>>
250234
generatedFuncWithType =
251-
LookupJoinCodeGenerator.generateAsyncLookupFunction(
235+
MLPredictCodeGenerator.generateAsyncPredictFunction(
252236
config,
253237
classLoader,
254238
dataTypeFactory,
@@ -257,29 +241,14 @@ private Transformation<RowData> createAsyncModelPredict(
257241
resultRowType,
258242
mlPredictSpec.getFeatures(),
259243
asyncPredictFunction,
260-
"AsyncMLPredict");
261-
262-
GeneratedResultFuture<TableFunctionResultFuture<RowData>> generatedResultFuture =
263-
LookupJoinCodeGenerator.generateTableAsyncCollector(
264-
config,
265-
classLoader,
266-
"TableFunctionResultFuture",
267-
inputRowType,
268-
modelOutputType,
269-
JavaScalaConversionUtil.toScala(Optional.empty()));
270-
271-
DataStructureConverter<?, ?> fetcherConverter =
272-
DataStructureConverters.getConverter(generatedFuncWithType.dataType());
244+
modelSpec
245+
.getContextResolvedModel()
246+
.getIdentifier()
247+
.asSummaryString());
273248
AsyncFunction<RowData, RowData> asyncFunc =
274-
new AsyncLookupJoinRunner(
275-
generatedFuncWithType.tableFunc(),
276-
(DataStructureConverter<RowData, Object>) fetcherConverter,
277-
generatedResultFuture,
278-
FilterCodeGenerator.generateFilterCondition(
279-
config, classLoader, null, inputRowType),
280-
InternalSerializers.create(modelOutputType),
281-
false,
282-
asyncOptions.asyncBufferCapacity);
249+
new AsyncMLPredictRunner(
250+
(GeneratedFunction) generatedFuncWithType.tableFunc(),
251+
Preconditions.checkNotNull(asyncOptions).asyncBufferCapacity);
283252
return ExecNodeUtil.createOneInputTransformation(
284253
inputTransformation,
285254
createTransformationMeta(ML_PREDICT_TRANSFORMATION, config),

0 commit comments

Comments
 (0)