From 9db64637882dbf47d50f8d244b14f41643992bb0 Mon Sep 17 00:00:00 2001 From: Shengkai <1059623455@qq.com> Date: Thu, 25 Sep 2025 14:49:07 +0800 Subject: [PATCH 1/2] [FLINK-38424][planner] Support to parse VECTOR_SEARCH function --- .../sql/validate/SqlValidatorImpl.java | 18 +- .../functions/sql/FlinkSqlOperatorTable.java | 5 + .../sql/ml/SqlVectorSearchTableFunction.java | 239 ++++++++++++++++++ .../functions/utils/SqlValidatorUtils.java | 22 +- .../sql/VectorSearchTableFunctionTest.java | 212 ++++++++++++++++ .../sql/VectorSearchTableFunctionTest.xml | 141 +++++++++++ 6 files changed, 628 insertions(+), 9 deletions(-) create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlVectorSearchTableFunction.java create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java create mode 100644 flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java index 624fb7e6d7bb6..0b2872b57f351 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java @@ -17,6 +17,7 @@ package org.apache.calcite.sql.validate; import org.apache.flink.table.planner.calcite.FlinkSqlCallBinding; +import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -177,7 +178,7 @@ * *

Lines 2571 ~ 2588, CALCITE-7217, should be removed after upgrading Calcite to 1.41.0. * - *

Lines 3895 ~ 3899, 6574 ~ 6580 Flink improves Optimize the retrieval of sub-operands in + *

Lines 3840 ~ 3844, 6511 ~ 6517 Flink improves Optimize the retrieval of sub-operands in * SqlCall when using NamedParameters at {@link SqlValidatorImpl#checkRollUp}. * *

Lines 5315 ~ 5321, FLINK-24352 Add null check for temporal table check on SqlSnapshot. @@ -2614,6 +2615,21 @@ private SqlNode registerFrom( scopes.put(node, getSelectScope(call1.operand(0))); return newNode; } + + // Related to CALCITE-4077 + // ----- FLINK MODIFICATION BEGIN ----- + FlinkSqlCallBinding binding = + new FlinkSqlCallBinding(this, getEmptyScope(), call1); + if (op instanceof SqlVectorSearchTableFunction + && binding.operand(0) + .isA( + new HashSet<>( + Collections.singletonList(SqlKind.SELECT)))) { + SqlValidatorScope scope = getSelectScope((SqlSelect) binding.operand(0)); + scopes.put(node, scope); + return newNode; + } + // ----- FLINK MODIFICATION END ----- } // Put the usingScope which can be a JoinScope // or a SelectScope, in order to see the left items diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java index 5ddbebd98c907..d5cc43493b35e 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java @@ -22,6 +22,8 @@ import org.apache.flink.table.planner.calcite.FlinkTypeFactory; import org.apache.flink.table.planner.functions.sql.internal.SqlAuxiliaryGroupAggFunction; import org.apache.flink.table.planner.functions.sql.ml.SqlMLEvaluateTableFunction; +import org.apache.flink.table.planner.functions.sql.ml.SqlMLPredictTableFunction; +import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction; import org.apache.flink.table.planner.plan.type.FlinkReturnTypes; import org.apache.flink.table.planner.plan.type.NumericExceptFirstOperandChecker; @@ -1328,6 +1330,9 @@ public List getAuxiliaryFunctions() { // MODEL TABLE FUNCTIONS public static final SqlFunction ML_EVALUATE = new SqlMLEvaluateTableFunction(); + // SEARCH FUNCTIONS + public static final SqlFunction VECTOR_SEARCH = new SqlVectorSearchTableFunction(); + // Catalog Functions public static final SqlFunction CURRENT_DATABASE = BuiltInSqlFunction.newBuilder() diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlVectorSearchTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlVectorSearchTableFunction.java new file mode 100644 index 0000000000000..a655efdf9f072 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlVectorSearchTableFunction.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.sql.ml; + +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.utils.LogicalTypeCasts; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeFieldImpl; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperandCountRange; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.SqlTableFunction; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlOperandCountRanges; +import org.apache.calcite.sql.type.SqlOperandMetadata; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlNameMatcher; +import org.apache.calcite.util.Util; +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType; + +/** + * {@link SqlVectorSearchTableFunction} implements an operator for search. + * + *

It allows four parameters: + * + *

    + *
  1. a table + *
  2. a descriptor to provide a column name from the input table + *
  3. a query column from the left table + *
  4. a literal value for top k + *
+ */ +public class SqlVectorSearchTableFunction extends SqlFunction implements SqlTableFunction { + + private static final String PARAM_SEARCH_TABLE = "SEARCH_TABLE"; + private static final String PARAM_COLUMN_TO_SEARCH = "COLUMN_TO_SEARCH"; + private static final String PARAM_COLUMN_TO_QUERY = "COLUMN_TO_QUERY"; + private static final String PARAM_TOP_K = "TOP_K"; + + private static final String OUTPUT_SCORE = "score"; + + public SqlVectorSearchTableFunction() { + super( + "VECTOR_SEARCH", + SqlKind.OTHER_FUNCTION, + ReturnTypes.CURSOR, + null, + new OperandMetadataImpl(), + SqlFunctionCategory.SYSTEM); + } + + @Override + public SqlReturnTypeInference getRowTypeInference() { + return new SqlReturnTypeInference() { + @Override + public @Nullable RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + final RelDataType inputRowType = opBinding.getOperandType(0); + + return typeFactory + .builder() + .kind(inputRowType.getStructKind()) + .addAll(inputRowType.getFieldList()) + .addAll( + SqlValidatorUtils.makeOutputUnique( + inputRowType.getFieldList(), + Collections.singletonList( + new RelDataTypeFieldImpl( + OUTPUT_SCORE, + 0, + typeFactory.createSqlType( + SqlTypeName.DOUBLE))))) + .build(); + } + }; + } + + @Override + public boolean argumentMustBeScalar(int ordinal) { + return ordinal != 0; + } + + private static class OperandMetadataImpl implements SqlOperandMetadata { + + private static final List PARAMETERS = + Collections.unmodifiableList( + Arrays.asList( + PARAM_SEARCH_TABLE, + PARAM_COLUMN_TO_SEARCH, + PARAM_COLUMN_TO_QUERY, + PARAM_TOP_K)); + + @Override + public List paramTypes(RelDataTypeFactory relDataTypeFactory) { + return Collections.nCopies( + PARAMETERS.size(), relDataTypeFactory.createSqlType(SqlTypeName.ANY)); + } + + @Override + public List paramNames() { + return PARAMETERS; + } + + @Override + public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { + // check vector table contains descriptor columns + if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1)) { + return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse( + callBinding, throwOnFailure); + } + + List operands = callBinding.operands(); + // check descriptor has one column + SqlCall descriptor = (SqlCall) operands.get(1); + List descriptorCols = descriptor.getOperandList(); + if (descriptorCols.size() != 1) { + return SqlValidatorUtils.throwExceptionOrReturnFalse( + Optional.of( + new ValidationException( + String.format( + "Expect parameter COLUMN_TO_SEARCH for VECTOR_SEARCH only contains one column, but multiple columns are found in operand %s.", + descriptor))), + throwOnFailure); + } + + // check descriptor type is ARRAY or ARRAY + RelDataType searchTableType = callBinding.getOperandType(0); + SqlNameMatcher matcher = callBinding.getValidator().getCatalogReader().nameMatcher(); + SqlIdentifier columnName = (SqlIdentifier) descriptorCols.get(0); + String descriptorColName = + columnName.isSimple() ? columnName.getSimple() : Util.last(columnName.names); + int index = matcher.indexOf(searchTableType.getFieldNames(), descriptorColName); + RelDataType targetType = searchTableType.getFieldList().get(index).getType(); + LogicalType targetLogicalType = toLogicalType(targetType); + + if (!(targetLogicalType.is(LogicalTypeRoot.ARRAY) + && ((ArrayType) (targetLogicalType)) + .getElementType() + .isAnyOf(LogicalTypeRoot.FLOAT, LogicalTypeRoot.DOUBLE))) { + return SqlValidatorUtils.throwExceptionOrReturnFalse( + Optional.of( + new ValidationException( + String.format( + "Expect search column `%s` type is ARRAY or ARRAY, but its type is %s.", + columnName, targetType))), + throwOnFailure); + } + + // check query type is ARRAY or ARRAY + LogicalType sourceLogicalType = toLogicalType(callBinding.getOperandType(2)); + if (!LogicalTypeCasts.supportsImplicitCast(sourceLogicalType, targetLogicalType)) { + return SqlValidatorUtils.throwExceptionOrReturnFalse( + Optional.of( + new ValidationException( + String.format( + "Can not cast the query column type %s to target type %s. Please keep the query column type is same to the search column type.", + sourceLogicalType, targetType))), + throwOnFailure); + } + + // check topK is literal + LogicalType topKType = toLogicalType(callBinding.getOperandType(3)); + if (!operands.get(3).getKind().equals(SqlKind.LITERAL) + || !topKType.is(LogicalTypeRoot.INTEGER)) { + return SqlValidatorUtils.throwExceptionOrReturnFalse( + Optional.of( + new ValidationException( + String.format( + "Expect parameter topK is integer literal in VECTOR_SEARCH, but it is %s with type %s.", + operands.get(3), topKType))), + throwOnFailure); + } + + return true; + } + + @Override + public SqlOperandCountRange getOperandCountRange() { + return SqlOperandCountRanges.between(4, 4); + } + + @Override + public String getAllowedSignatures(SqlOperator op, String opName) { + return opName + "(TABLE table_name, DESCRIPTOR(query_column), search_column, top_k)"; + } + + @Override + public Consistency getConsistency() { + return Consistency.NONE; + } + + @Override + public boolean isOptional(int i) { + return false; + } + + @Override + public boolean isFixedParameters() { + return true; + } + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java index 42e381a606290..66b58499e0949 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java @@ -160,27 +160,33 @@ private static void adjustTypeForMultisetConstructor( /** * Make output field names unique from input field names by appending index. For example, Input * has field names {@code a, b, c} and output has field names {@code b, c, d}. After calling - * this function, new output field names will be {@code b0, c0, d}. Duplicate names are not - * checked inside input and output itself. + * this function, new output field names will be {@code b0, c0, d}. + * + *

We assume that input fields in the input parameter are uniquely named, just as the output + * fields in the output parameter are. * * @param input Input fields * @param output Output fields - * @return + * @return output fields with unique names. */ public static List makeOutputUnique( List input, List output) { - final Set inputFieldNames = new HashSet<>(); + final Set uniqueNames = new HashSet<>(); for (RelDataTypeField field : input) { - inputFieldNames.add(field.getName()); + uniqueNames.add(field.getName()); } List result = new ArrayList<>(); for (RelDataTypeField field : output) { String fieldName = field.getName(); - if (inputFieldNames.contains(fieldName)) { - fieldName += "0"; // Append index to make it unique + int count = 0; + String candidate = fieldName; + while (uniqueNames.contains(candidate)) { + candidate = fieldName + count; + count++; } - result.add(new RelDataTypeFieldImpl(fieldName, field.getIndex(), field.getType())); + uniqueNames.add(candidate); + result.add(new RelDataTypeFieldImpl(candidate, field.getIndex(), field.getType())); } return result; } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java new file mode 100644 index 0000000000000..818abc6b6d6dd --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.stream.sql; + +import org.apache.flink.core.testutils.FlinkAssertions; +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction; +import org.apache.flink.table.planner.utils.TableTestBase; +import org.apache.flink.table.planner.utils.TableTestUtil; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Test for {@link SqlVectorSearchTableFunction}. */ +public class VectorSearchTableFunctionTest extends TableTestBase { + + private TableTestUtil util; + + @BeforeEach + public void setup() { + util = streamTestUtil(TableConfig.getDefault()); + + // Create test table + util.tableEnv() + .executeSql( + "CREATE TABLE QueryTable (\n" + + " a INT,\n" + + " b BIGINT,\n" + + " c STRING,\n" + + " d ARRAY,\n" + + " rowtime TIMESTAMP(3),\n" + + " proctime as PROCTIME(),\n" + + " WATERMARK FOR rowtime AS rowtime - INTERVAL '1' SECOND\n" + + ") with (\n" + + " 'connector' = 'values'\n" + + ")"); + + util.tableEnv() + .executeSql( + "CREATE TABLE VectorTable (\n" + + " e INT,\n" + + " f BIGINT,\n" + + " g ARRAY\n" + + ") with (\n" + + " 'connector' = 'values'\n" + + ")"); + } + + @Test + void testSimple() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " TABLE VectorTable, DESCRIPTOR(`g`), QueryTable.d, 10" + + ")\n" + + ")"; + util.verifyRelPlan(sql); + } + + @Test + void testLiteralValue() { + String sql = + "SELECT * FROM LATERAL TABLE(VECTOR_SEARCH(TABLE VectorTable, DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .satisfies( + FlinkAssertions.anyCauseMatches( + TableException.class, + "FlinkLogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])\n" + + "+- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, VectorTable]], fields=[e, f, g])")); + } + + @Test + void testNamedArgument() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " SEARCH_TABLE => TABLE VectorTable,\n" + + " COLUMN_TO_QUERY => QueryTable.d,\n" + + " COLUMN_TO_SEARCH => DESCRIPTOR(`g`),\n" + + " TOP_K => 10" + + " )\n" + + ")"; + util.verifyRelPlan(sql); + } + + @Test + void testOutOfOrderNamedArgument() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " COLUMN_TO_QUERY => QueryTable.d,\n" + + " COLUMN_TO_SEARCH => DESCRIPTOR(`g`),\n" + + " TOP_K => 10,\n" + + " SEARCH_TABLE => TABLE VectorTable\n" + + " )\n" + + ")"; + util.verifyRelPlan(sql); + } + + @Test + void testNameConflicts() { + util.tableEnv() + .executeSql( + "CREATE TABLE NameConflictTable(\n" + + " a INT,\n" + + " score ARRAY,\n" + + " score0 ARRAY,\n" + + " score1 ARRAY\n" + + ") WITH (\n" + + " 'connector' = 'values'\n" + + ")"); + util.verifyRelPlan( + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " TABLE NameConflictTable, DESCRIPTOR(`score`), QueryTable.d, 10))"); + } + + @Test + void testDescriptorTypeIsNotExpected() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " TABLE VectorTable, DESCRIPTOR(`f`), QueryTable.d, 10" + + ")\n" + + ")"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .satisfies( + FlinkAssertions.anyCauseMatches( + ValidationException.class, + "Expect search column `f` type is ARRAY or ARRAY, but its type is BIGINT.")); + } + + @Test + void testDescriptorContainsMultipleColumns() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " TABLE VectorTable, DESCRIPTOR(`f`, `g`), QueryTable.d, 10" + + ")\n" + + ")"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .satisfies( + FlinkAssertions.anyCauseMatches( + ValidationException.class, + "Expect parameter COLUMN_TO_SEARCH for VECTOR_SEARCH only contains one column, but multiple columns are found in operand DESCRIPTOR(`f`, `g`).")); + } + + @Test + void testQueryColumnIsNotArray() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " TABLE VectorTable, DESCRIPTOR(`g`), QueryTable.c, 10" + + ")\n" + + ")"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .satisfies( + FlinkAssertions.anyCauseMatches( + ValidationException.class, + "Can not cast the query column type STRING to target type FLOAT ARRAY. Please keep the query column type is same to the search column type.")); + } + + @Test + void testIllegalTopKValue1() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " TABLE VectorTable, DESCRIPTOR(`g`), QueryTable.d, 10.0" + + ")\n" + + ")"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .satisfies( + FlinkAssertions.anyCauseMatches( + ValidationException.class, + "Expect parameter topK is integer literal in VECTOR_SEARCH, but it is 10.0 with type DECIMAL(3, 1) NOT NULL.")); + } + + @Test + void testIllegalTopKValue2() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " TABLE VectorTable, DESCRIPTOR(`g`), QueryTable.d, QueryTable.a" + + ")\n" + + ")"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .satisfies( + FlinkAssertions.anyCauseMatches( + ValidationException.class, + "Expect parameter topK is integer literal in VECTOR_SEARCH, but it is QueryTable.a with type INT.")); + } +} diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml new file mode 100644 index 0000000000000..8aca81dc52d4c --- /dev/null +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml @@ -0,0 +1,141 @@ + + + + + + + + + + + + + + + + + TABLE VectorTable, + COLUMN_TO_QUERY => QueryTable.d, + COLUMN_TO_SEARCH => DESCRIPTOR(`g`), + TOP_K => 10 ) +)]]> + + + + + + + + + + + + + + + + + + + + + + QueryTable.d, + COLUMN_TO_SEARCH => DESCRIPTOR(`g`), + TOP_K => 10, + SEARCH_TABLE => TABLE VectorTable + ) +)]]> + + + + + + + + + From 06c61b8c363ff4cb53e992428fa716a3c008bb6f Mon Sep 17 00:00:00 2001 From: Shengkai <1059623455@qq.com> Date: Tue, 14 Oct 2025 20:51:26 +0800 Subject: [PATCH 2/2] rebase and address comments --- .../sql/validate/SqlValidatorImpl.java | 23 +++++++++++++------ .../functions/sql/FlinkSqlOperatorTable.java | 1 - .../sql/VectorSearchTableFunctionTest.java | 12 ++++++++++ 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java index 0b2872b57f351..89317e352f6a4 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java @@ -16,6 +16,7 @@ */ package org.apache.calcite.sql.validate; +import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.planner.calcite.FlinkSqlCallBinding; import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction; @@ -178,10 +179,12 @@ * *

Lines 2571 ~ 2588, CALCITE-7217, should be removed after upgrading Calcite to 1.41.0. * - *

Lines 3840 ~ 3844, 6511 ~ 6517 Flink improves Optimize the retrieval of sub-operands in + *

Line 2618 ~2631, set the correct scope for VECTOR_SEARCH. + * + *

Lines 3920 ~ 3925, 6599 ~ 6606 Flink improves Optimize the retrieval of sub-operands in * SqlCall when using NamedParameters at {@link SqlValidatorImpl#checkRollUp}. * - *

Lines 5315 ~ 5321, FLINK-24352 Add null check for temporal table check on SqlSnapshot. + *

Lines 5340 ~ 5347, FLINK-24352 Add null check for temporal table check on SqlSnapshot. */ public class SqlValidatorImpl implements SqlValidatorWithHints { // ~ Static fields/initializers --------------------------------------------- @@ -2571,6 +2574,10 @@ private SqlNode registerFrom( case LATERAL: // ----- FLINK MODIFICATION BEGIN ----- SqlBasicCall sbc = (SqlBasicCall) node; + // Put the usingScope which is a JoinScope, + // in order to make visible the left items + // of the JOIN tree. + scopes.put(node, usingScope); registerFrom( parentScope, usingScope, @@ -2581,10 +2588,6 @@ private SqlNode registerFrom( extendList, forceNullable, true); - // Put the usingScope which is a JoinScope, - // in order to make visible the left items - // of the JOIN tree. - scopes.put(node, usingScope); return sbc; // ----- FLINK MODIFICATION END ----- @@ -2625,8 +2628,14 @@ private SqlNode registerFrom( .isA( new HashSet<>( Collections.singletonList(SqlKind.SELECT)))) { + boolean queryColumnIsNotLiteral = + binding.operand(2).getKind() != SqlKind.LITERAL; + if (!queryColumnIsNotLiteral && !lateral) { + throw new ValidationException( + "The query column is not literal, please use LATERAL TABLE to run VECTOR_SEARCH."); + } SqlValidatorScope scope = getSelectScope((SqlSelect) binding.operand(0)); - scopes.put(node, scope); + scopes.put(enclosingNode, scope); return newNode; } // ----- FLINK MODIFICATION END ----- diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java index d5cc43493b35e..4469b37642056 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java @@ -22,7 +22,6 @@ import org.apache.flink.table.planner.calcite.FlinkTypeFactory; import org.apache.flink.table.planner.functions.sql.internal.SqlAuxiliaryGroupAggFunction; import org.apache.flink.table.planner.functions.sql.ml.SqlMLEvaluateTableFunction; -import org.apache.flink.table.planner.functions.sql.ml.SqlMLPredictTableFunction; import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction; import org.apache.flink.table.planner.plan.type.FlinkReturnTypes; import org.apache.flink.table.planner.plan.type.NumericExceptFirstOperandChecker; diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java index 818abc6b6d6dd..5d85e6b88eb23 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java @@ -89,6 +89,18 @@ void testLiteralValue() { + "+- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, VectorTable]], fields=[e, f, g])")); } + @Test + void testLiteralValueWithoutLateralKeyword() { + String sql = + "SELECT * FROM TABLE(VECTOR_SEARCH(TABLE VectorTable, DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .satisfies( + FlinkAssertions.anyCauseMatches( + TableException.class, + "FlinkLogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])\n" + + "+- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, VectorTable]], fields=[e, f, g])")); + } + @Test void testNamedArgument() { String sql =