Skip to content

Commit ecdbd65

Browse files
committed
address comments
1 parent 40f3fb6 commit ecdbd65

File tree

4 files changed

+17
-149
lines changed

4 files changed

+17
-149
lines changed

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/FunctionCallContext.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.apache.flink.table.catalog.DataTypeFactory;
2222
import org.apache.flink.table.connector.source.LookupTableSource;
2323
import org.apache.flink.table.functions.UserDefinedFunction;
24+
import org.apache.flink.table.ml.PredictRuntimeProvider;
2425
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.Constant;
2526
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam;
2627
import org.apache.flink.table.types.DataType;
@@ -38,11 +39,14 @@
3839
import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getFieldTypes;
3940
import static org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType;
4041

41-
/** The {@link CallContext} of a {@link LookupTableSource} runtime function. */
42+
/**
43+
* The {@link CallContext} of {@link LookupTableSource}, {@link PredictRuntimeProvider} runtime
44+
* function.
45+
*/
4246
@Internal
4347
public class FunctionCallContext extends AbstractSqlCallContext {
4448

45-
private final List<FunctionParam> lookupKeys;
49+
private final List<FunctionParam> params;
4650

4751
private final List<DataType> argumentDataTypes;
4852

@@ -52,10 +56,10 @@ public FunctionCallContext(
5256
DataTypeFactory dataTypeFactory,
5357
UserDefinedFunction function,
5458
LogicalType inputType,
55-
List<FunctionParam> lookupKeys,
56-
LogicalType lookupType) {
59+
List<FunctionParam> params,
60+
LogicalType outputDataType) {
5761
super(dataTypeFactory, function, generateInlineFunctionName(function), false);
58-
this.lookupKeys = lookupKeys;
62+
this.params = params;
5963
this.argumentDataTypes =
6064
new AbstractList<>() {
6165
@Override
@@ -74,10 +78,10 @@ public DataType get(int index) {
7478

7579
@Override
7680
public int size() {
77-
return lookupKeys.size();
81+
return params.size();
7882
}
7983
};
80-
this.outputDataType = fromLogicalToDataType(lookupType);
84+
this.outputDataType = fromLogicalToDataType(outputDataType);
8185
}
8286

8387
@Override
@@ -118,6 +122,6 @@ public Optional<DataType> getOutputDataType() {
118122
// --------------------------------------------------------------------------------------------
119123

120124
private FunctionParam getKey(int pos) {
121-
return lookupKeys.get(pos);
125+
return params.get(pos);
122126
}
123127
}

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ protected ProcessFunction<RowData, RowData> createSyncLookupJoinFunction(
624624
RowType rightRowType =
625625
getRightOutputRowType(projectionOutputRelDataType, tableSourceRowType);
626626
GeneratedCollector<ListenableCollector<RowData>> generatedCollector =
627-
LookupJoinCodeGenerator.generateCollector(
627+
FunctionCallCodeGenerator.generateCollector(
628628
new CodeGeneratorContext(config, classLoader),
629629
inputRowType,
630630
rightRowType,

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCallCodeGenerator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ object FunctionCallCodeGenerator {
206206
nullableInput = false,
207207
fieldCopy)
208208
case _ =>
209-
throw new CodeGenException("Invalid lookup key.")
209+
throw new CodeGenException("Invalid parameters.")
210210
}
211211
}
212212

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala

Lines changed: 3 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,23 @@
1717
*/
1818
package org.apache.flink.table.planner.codegen
1919

20-
import org.apache.flink.api.common.functions.{FlatMapFunction, Function, OpenContext}
20+
import org.apache.flink.api.common.functions.{FlatMapFunction, OpenContext}
2121
import org.apache.flink.configuration.ReadableConfig
2222
import org.apache.flink.streaming.api.functions.async.AsyncFunction
2323
import org.apache.flink.table.api.ValidationException
2424
import org.apache.flink.table.catalog.DataTypeFactory
2525
import org.apache.flink.table.connector.source.{LookupTableSource, ScanTableSource}
2626
import org.apache.flink.table.data.{GenericRowData, RowData}
2727
import org.apache.flink.table.data.utils.JoinedRowData
28-
import org.apache.flink.table.functions.{AsyncLookupFunction, AsyncPredictFunction, AsyncTableFunction, LookupFunction, PredictFunction, TableFunction, UserDefinedFunction, UserDefinedFunctionHelper}
28+
import org.apache.flink.table.functions.{AsyncLookupFunction, AsyncTableFunction, LookupFunction, TableFunction, UserDefinedFunction}
2929
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
3030
import org.apache.flink.table.planner.codegen.CodeGenUtils._
3131
import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType
32-
import org.apache.flink.table.planner.codegen.GenerateUtils._
3332
import org.apache.flink.table.planner.codegen.Indenter.toISC
34-
import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator.generateCallWithDataType
3533
import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil
3634
import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil.verifyFunctionAwareImplementation
37-
import org.apache.flink.table.planner.delegation.PlannerBase
3835
import org.apache.flink.table.planner.functions.inference.FunctionCallContext
39-
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.{Constant, FieldRef, FunctionParam}
40-
import org.apache.flink.table.planner.plan.utils.RexLiteralUtil
36+
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam
4137
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
4238
import org.apache.flink.table.runtime.collector.{ListenableCollector, TableFunctionResultFuture}
4339
import org.apache.flink.table.runtime.collector.ListenableCollector.CollectListener
@@ -238,138 +234,6 @@ object LookupJoinCodeGenerator {
238234
}
239235
}
240236

241-
/**
242-
* Generates collector for temporal join ([[Collector]])
243-
*
244-
* Differs from CommonCorrelate.generateCollector which has no real condition because of
245-
* FLINK-7865, here we should deal with outer join type when real conditions filtered result.
246-
*/
247-
def generateCollector(
248-
ctx: CodeGeneratorContext,
249-
inputRowType: RowType,
250-
rightRowType: RowType,
251-
resultRowType: RowType,
252-
condition: Option[RexNode],
253-
pojoFieldMapping: Option[Array[Int]],
254-
retainHeader: Boolean = true): GeneratedCollector[ListenableCollector[RowData]] = {
255-
256-
val inputTerm = DEFAULT_INPUT1_TERM
257-
val rightInputTerm = DEFAULT_INPUT2_TERM
258-
259-
val exprGenerator = new ExprCodeGenerator(ctx, nullableInput = false)
260-
.bindInput(rightRowType, inputTerm = rightInputTerm, inputFieldMapping = pojoFieldMapping)
261-
262-
val rightResultExpr =
263-
exprGenerator.generateConverterResultExpression(rightRowType, classOf[GenericRowData])
264-
265-
val joinedRowTerm = CodeGenUtils.newName(ctx, "joinedRow")
266-
ctx.addReusableOutputRecord(resultRowType, classOf[JoinedRowData], joinedRowTerm)
267-
268-
val header = if (retainHeader) {
269-
s"$joinedRowTerm.setRowKind($inputTerm.getRowKind());"
270-
} else {
271-
""
272-
}
273-
274-
val body =
275-
s"""
276-
|${rightResultExpr.code}
277-
|$joinedRowTerm.replace($inputTerm, ${rightResultExpr.resultTerm});
278-
|$header
279-
|outputResult($joinedRowTerm);
280-
""".stripMargin
281-
282-
val collectorCode = if (condition.isEmpty) {
283-
body
284-
} else {
285-
286-
val filterGenerator = new ExprCodeGenerator(ctx, nullableInput = false)
287-
.bindInput(inputRowType, inputTerm)
288-
.bindSecondInput(rightRowType, rightInputTerm, pojoFieldMapping)
289-
val filterCondition = filterGenerator.generateExpression(condition.get)
290-
291-
s"""
292-
|${filterCondition.code}
293-
|if (${filterCondition.resultTerm}) {
294-
| $body
295-
|}
296-
|""".stripMargin
297-
}
298-
299-
generateTableFunctionCollectorForJoinTable(
300-
ctx,
301-
"JoinTableFuncCollector",
302-
collectorCode,
303-
inputRowType,
304-
rightRowType,
305-
inputTerm = inputTerm,
306-
collectedTerm = rightInputTerm)
307-
}
308-
309-
/**
310-
* The only differences against CollectorCodeGenerator.generateTableFunctionCollector is
311-
* "super.collect" call is binding with collect join row in "body" code
312-
*/
313-
private def generateTableFunctionCollectorForJoinTable(
314-
ctx: CodeGeneratorContext,
315-
name: String,
316-
bodyCode: String,
317-
inputType: RowType,
318-
collectedType: RowType,
319-
inputTerm: String = DEFAULT_INPUT1_TERM,
320-
collectedTerm: String = DEFAULT_INPUT2_TERM)
321-
: GeneratedCollector[ListenableCollector[RowData]] = {
322-
323-
val funcName = newName(ctx, name)
324-
val input1TypeClass = boxedTypeTermForType(inputType)
325-
val input2TypeClass = boxedTypeTermForType(collectedType)
326-
327-
val funcCode =
328-
s"""
329-
public class $funcName extends ${classOf[ListenableCollector[_]].getCanonicalName} {
330-
331-
${ctx.reuseMemberCode()}
332-
333-
public $funcName(Object[] references) throws Exception {
334-
${ctx.reuseInitCode()}
335-
}
336-
337-
@Override
338-
public void open(${className[OpenContext]} openContext) throws Exception {
339-
${ctx.reuseOpenCode()}
340-
}
341-
342-
@Override
343-
public void collect(Object record) throws Exception {
344-
$input1TypeClass $inputTerm = ($input1TypeClass) getInput();
345-
$input2TypeClass $collectedTerm = ($input2TypeClass) record;
346-
347-
// callback only when collectListener exists, equivalent to:
348-
// getCollectListener().ifPresent(
349-
// listener -> ((CollectListener) listener).onCollect(record));
350-
// TODO we should update code splitter's grammar file to accept lambda expressions.
351-
352-
if (getCollectListener().isPresent()) {
353-
((${classOf[CollectListener[_]].getCanonicalName}) getCollectListener().get())
354-
.onCollect(record);
355-
}
356-
357-
${ctx.reuseLocalVariableCode()}
358-
${ctx.reuseInputUnboxingCode()}
359-
${ctx.reusePerRecordCode()}
360-
$bodyCode
361-
}
362-
363-
@Override
364-
public void close() throws Exception {
365-
${ctx.reuseCloseCode()}
366-
}
367-
}
368-
""".stripMargin
369-
370-
new GeneratedCollector(funcName, funcCode, ctx.references.toArray, ctx.tableConfig)
371-
}
372-
373237
/**
374238
* Generates a [[TableFunctionResultFuture]] that can be passed to Java compiler.
375239
*

0 commit comments

Comments
 (0)