3131import org .apache .flink .table .api .TableException ;
3232import org .apache .flink .table .catalog .DataTypeFactory ;
3333import org .apache .flink .table .data .RowData ;
34- import org .apache .flink .table .data .conversion .DataStructureConverter ;
35- import org .apache .flink .table .data .conversion .DataStructureConverters ;
3634import org .apache .flink .table .functions .AsyncPredictFunction ;
3735import org .apache .flink .table .functions .PredictFunction ;
3836import org .apache .flink .table .functions .UserDefinedFunction ;
4139import org .apache .flink .table .ml .PredictRuntimeProvider ;
4240import org .apache .flink .table .planner .calcite .FlinkContext ;
4341import org .apache .flink .table .planner .codegen .CodeGeneratorContext ;
44- import org .apache .flink .table .planner .codegen .FilterCodeGenerator ;
45- import org .apache .flink .table .planner .codegen .LookupJoinCodeGenerator ;
42+ import org .apache .flink .table .planner .codegen .FunctionCallCodeGenerator ;
43+ import org .apache .flink .table .planner .codegen .MLPredictCodeGenerator ;
4644import org .apache .flink .table .planner .delegation .PlannerBase ;
4745import org .apache .flink .table .planner .plan .nodes .exec .ExecNode ;
4846import org .apache .flink .table .planner .plan .nodes .exec .ExecNodeBase ;
5553import org .apache .flink .table .planner .plan .nodes .exec .spec .ModelSpec ;
5654import org .apache .flink .table .planner .plan .nodes .exec .utils .ExecNodeUtil ;
5755import org .apache .flink .table .planner .plan .utils .FunctionCallUtil ;
58- import org .apache .flink .table .planner .utils .JavaScalaConversionUtil ;
5956import org .apache .flink .table .runtime .collector .ListenableCollector ;
60- import org .apache .flink .table .runtime .collector .TableFunctionResultFuture ;
6157import org .apache .flink .table .runtime .functions .ml .ModelPredictRuntimeProviderContext ;
6258import org .apache .flink .table .runtime .generated .GeneratedCollector ;
6359import org .apache .flink .table .runtime .generated .GeneratedFunction ;
64- import org .apache .flink .table .runtime .generated .GeneratedResultFuture ;
65- import org .apache .flink .table .runtime .operators .join .lookup .AsyncLookupJoinRunner ;
66- import org .apache .flink .table .runtime .operators .join .lookup .LookupJoinRunner ;
67- import org .apache .flink .table .runtime .typeutils .InternalSerializers ;
60+ import org .apache .flink .table .runtime .operators .ml .AsyncMLPredictRunner ;
61+ import org .apache .flink .table .runtime .operators .ml .MLPredictRunner ;
6862import org .apache .flink .table .runtime .typeutils .InternalTypeInfo ;
6963import org .apache .flink .table .types .logical .RowType ;
64+ import org .apache .flink .util .Preconditions ;
7065
7166import org .apache .flink .shaded .jackson2 .com .fasterxml .jackson .annotation .JsonCreator ;
7267import org .apache .flink .shaded .jackson2 .com .fasterxml .jackson .annotation .JsonProperty ;
7570
7671import java .util .Collections ;
7772import java .util .List ;
78- import java .util .Optional ;
7973
8074/** Stream {@link ExecNode} for {@code ML_PREDICT}. */
8175@ ExecNodeMetadata (
@@ -197,7 +191,7 @@ private Transformation<RowData> createModelPredict(
197191 RowType resultRowType ,
198192 PredictFunction predictFunction ) {
199193 GeneratedFunction <FlatMapFunction <RowData , RowData >> generatedFetcher =
200- LookupJoinCodeGenerator . generateSyncLookupFunction (
194+ MLPredictCodeGenerator . generateSyncPredictFunction (
201195 config ,
202196 classLoader ,
203197 dataTypeFactory ,
@@ -206,25 +200,15 @@ private Transformation<RowData> createModelPredict(
206200 resultRowType ,
207201 mlPredictSpec .getFeatures (),
208202 predictFunction ,
209- "MLPredict" ,
203+ modelSpec . getContextResolvedModel (). getIdentifier (). asSummaryString () ,
210204 config .get (PipelineOptions .OBJECT_REUSE ));
211205 GeneratedCollector <ListenableCollector <RowData >> generatedCollector =
212- LookupJoinCodeGenerator .generateCollector (
206+ MLPredictCodeGenerator .generateCollector (
213207 new CodeGeneratorContext (config , classLoader ),
214208 inputRowType ,
215209 modelOutputType ,
216- (RowType ) getOutputType (),
217- JavaScalaConversionUtil .toScala (Optional .empty ()),
218- JavaScalaConversionUtil .toScala (Optional .empty ()),
219- true );
220- LookupJoinRunner mlPredictRunner =
221- new LookupJoinRunner (
222- generatedFetcher ,
223- generatedCollector ,
224- FilterCodeGenerator .generateFilterCondition (
225- config , classLoader , null , inputRowType ),
226- false ,
227- modelOutputType .getFieldCount ());
210+ (RowType ) getOutputType ());
211+ MLPredictRunner mlPredictRunner = new MLPredictRunner (generatedFetcher , generatedCollector );
228212 SimpleOperatorFactory <RowData > operatorFactory =
229213 SimpleOperatorFactory .of (new ProcessOperator <>(mlPredictRunner ));
230214 return ExecNodeUtil .createOneInputTransformation (
@@ -246,9 +230,9 @@ private Transformation<RowData> createAsyncModelPredict(
246230 RowType modelOutputType ,
247231 RowType resultRowType ,
248232 AsyncPredictFunction asyncPredictFunction ) {
249- LookupJoinCodeGenerator .GeneratedTableFunctionWithDataType <AsyncFunction <RowData , Object >>
233+ FunctionCallCodeGenerator .GeneratedTableFunctionWithDataType <AsyncFunction <RowData , Object >>
250234 generatedFuncWithType =
251- LookupJoinCodeGenerator . generateAsyncLookupFunction (
235+ MLPredictCodeGenerator . generateAsyncPredictFunction (
252236 config ,
253237 classLoader ,
254238 dataTypeFactory ,
@@ -257,29 +241,14 @@ private Transformation<RowData> createAsyncModelPredict(
257241 resultRowType ,
258242 mlPredictSpec .getFeatures (),
259243 asyncPredictFunction ,
260- "AsyncMLPredict" );
261-
262- GeneratedResultFuture <TableFunctionResultFuture <RowData >> generatedResultFuture =
263- LookupJoinCodeGenerator .generateTableAsyncCollector (
264- config ,
265- classLoader ,
266- "TableFunctionResultFuture" ,
267- inputRowType ,
268- modelOutputType ,
269- JavaScalaConversionUtil .toScala (Optional .empty ()));
270-
271- DataStructureConverter <?, ?> fetcherConverter =
272- DataStructureConverters .getConverter (generatedFuncWithType .dataType ());
244+ modelSpec
245+ .getContextResolvedModel ()
246+ .getIdentifier ()
247+ .asSummaryString ());
273248 AsyncFunction <RowData , RowData > asyncFunc =
274- new AsyncLookupJoinRunner (
275- generatedFuncWithType .tableFunc (),
276- (DataStructureConverter <RowData , Object >) fetcherConverter ,
277- generatedResultFuture ,
278- FilterCodeGenerator .generateFilterCondition (
279- config , classLoader , null , inputRowType ),
280- InternalSerializers .create (modelOutputType ),
281- false ,
282- asyncOptions .asyncBufferCapacity );
249+ new AsyncMLPredictRunner (
250+ (GeneratedFunction ) generatedFuncWithType .tableFunc (),
251+ Preconditions .checkNotNull (asyncOptions ).asyncBufferCapacity );
283252 return ExecNodeUtil .createOneInputTransformation (
284253 inputTransformation ,
285254 createTransformationMeta (ML_PREDICT_TRANSFORMATION , config ),
0 commit comments